Module:FunList/Iterators: Difference between revisions

From Melvor Idle
No edit summary
No edit summary
Line 1: Line 1:
-- Helper Functions --
-- FORWARD DECLARED FUNCTIONS --
local function isType(obj, class)
 
    local mt = getmetatable(obj)
-- Checks if the provided object matches the provided type.
    while mt do
-- Returns true if object is of the provided type.
        if mt == class then
-- (object, type). Returns boolean
            return true
local isType
        end
 
        mt = getmetatable(mt)
-- Returns a TableEnumerator from the provided object or
    end
-- creates a new one if the object is not of type TableEnumerator
    return false
-- (object). Returns TableEnumerator
end
local getTableEnumerator
 
-- Attempts to add an object to the provided table as a hashset.
-- Returns True if the object was not already present.
-- (table, object). Returns True
local addToSet


-- CLASS DEFINITIONS --
-- BASE ENUMERATOR CLASS --
-- BASE ENUMERATOR CLASS --
---@class Enumerator
---@field current any
---@field index any
---@field state integer
---@field isArray boolean
local Enumerator = {}
local Enumerator = {}
local Enumerator_mt = {
local Enumerator_mt = {
__index = Enumerator,
__index = Enumerator,
__pairs = function(t) return t:overridePairs() end,
__pairs = function(t) return t:getPairs() end,
__ipairs = function(t) return t:overridePairs(0) end
__ipairs = function(t) return t:getiPairs()
    end
}
}


---@return Enumerator
function Enumerator.new()
function Enumerator.new()
local self = setmetatable({}, Enumerator_mt)
local self = setmetatable({}, Enumerator_mt)
self.current = nil
self.current = nil
self.index = nil
self.index = nil
    self.state = -1
-- Assume by default we are not dealing with a simple array
-- Assume by default we are not dealing with a simple array
self.isArray = false
self.isArray = false
Line 28: Line 42:
end
end


---@return boolean
function Enumerator:moveNext()
function Enumerator:moveNext()
error('Not implemented in base class.')
    error('Abstract function must be overridden in derived class.')
end
end


---@return Enumerator
function Enumerator:getEnumerator(isArray)
function Enumerator:getEnumerator(isArray)
error('Not implemented in base class.')
    -- The default state is -1 which signifies a Enumerator isn't used.
    local instance = (self.state == -1) and self or self:clone()
    instance.isArray = isArray
    instance.state = 0
    return instance
end
 
---@return Enumerator
function Enumerator:clone()
    error('Abstract function must be overridden in derived class.')
end
 
function Enumerator:finalize()
    -- Signals invalid state.
    self.state = -4
end
end


-- Hooks the moveNext function into the Lua 'pairs' function
-- Hooks the moveNext function into the Lua 'pairs' function
function Enumerator:overridePairs(startIndex)
local function overridePairs(enum, startIndex)
-- Get or create clean enumerator. This ensures the state is 0.
-- Get or create clean enumerator. This ensures the state is 0.
local enum = self:getEnumerator(startIndex == 0)
local new = enum:getEnumerator(startIndex == 0)
enum.current = nil
new.current = nil
enum.index = startIndex
new.index = startIndex
local function iterator(t, k)
local function iterator(t, k)
if enum:moveNext() == true then
if new:moveNext() == true then
return enum.index, enum.current
return new.index, new.current
end
end
return nil, nil
return nil, nil
end
end
 
return iterator, enum, enum.index
return iterator, new, new.index
end
 
-- Manual override for iterating over the Enumerator using pairs()
function Enumerator:getPairs()
    return overridePairs(self, nil)
end
 
-- Manual override for iterating over the Enumerator using ipairs()
function Enumerator:getiPairs()
    return overridePairs(self, 0)
end
end


Line 55: Line 95:
-- This is essentially a wrapper for the table object,  
-- This is essentially a wrapper for the table object,  
-- since it provides no state machine esque iteration out of the box
-- since it provides no state machine esque iteration out of the box
---@class TableEnumerator : Enumerator
---@field state integer
---@field tbl table
local TableEnumerator = setmetatable({}, { __index = Enumerator })
local TableEnumerator = setmetatable({}, { __index = Enumerator })
TableEnumerator.__index = TableEnumerator
TableEnumerator.__index = TableEnumerator
Line 63: Line 106:
     local self = setmetatable(Enumerator.new(), TableEnumerator)
     local self = setmetatable(Enumerator.new(), TableEnumerator)
     self.tbl = tbl or {} -- Allow creation of empty enumerable
     self.tbl = tbl or {} -- Allow creation of empty enumerable
    self.state = 0


     return self
     return self
Line 93: Line 135:
end
end


-- startIndex is used to determine if the underlying table should be treated
function TableEnumerator:clone()
-- as an array or as a mixed table. It is ignored in the other enumerators as
     return TableEnumerator.new(self.tbl)
-- they just call moveNext on the enumerator instead.
function TableEnumerator:getEnumerator(isArray)
     local instance = (self.state == 0) and self or TableEnumerator.new(self.tbl)
    instance.isArray = isArray
    return instance
end
end


-- SELECT ENUMERATOR --
-- SELECT ENUMERATOR --
local SelectEnumerator = setmetatable({}, { __index = Enumerator })
---@class MapEnumerator : Enumerator
SelectEnumerator.__index = SelectEnumerator
---@field state integer
SelectEnumerator.__pairs = Enumerator_mt.__pairs
---@field source Enumerator
SelectEnumerator.__ipairs = Enumerator_mt.__ipairs
---@field selector Enumerator
local MapEnumerator = setmetatable({}, { __index = Enumerator })
MapEnumerator.__index = MapEnumerator
MapEnumerator.__pairs = Enumerator_mt.__pairs
MapEnumerator.__ipairs = Enumerator_mt.__ipairs


function SelectEnumerator.new(source, selector)
function MapEnumerator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
assert(selector, 'Selector cannot be nil')
     local self = setmetatable(Enumerator.new(), SelectEnumerator)
     local self = setmetatable(Enumerator.new(), MapEnumerator)
    self.state = 0
     self.source = source
     self.source = source
     self.selector = selector
     self.selector = selector
     self.enumerator = nil
     self.enumerator = nil
     self.position = 0
     self.position = 0
     return self
     return self
end
end


function SelectEnumerator:moveNext()
function MapEnumerator:moveNext()
if self.state == 0 then
if self.state == 0 then
self.state = 1
self.state = 1
Line 126: Line 166:
self.enumerator = self.source:getEnumerator(self.isArray)
self.enumerator = self.source:getEnumerator(self.isArray)
end
end
 
if self.state == 1 then
if self.state == 1 then
local enumerator = self.enumerator
local enumerator = self.enumerator
Line 136: Line 176:
assert(self.current, 'Selected value must be non-nil')
assert(self.current, 'Selected value must be non-nil')
return true
return true
end
        end
 
        self:finalize()
end
end
return false
return false
end
end


function SelectEnumerator:getEnumerator(isArray)
function MapEnumerator:finalize()
     local instance = (self.state == 0) and self or SelectEnumerator.new(self.source, self.selector)
     if self.enumerator then
    instance.isArray = isArray
        self.enumerator:finalize()
    return instance
    end
    Enumerator.finalize(self)
end
 
function MapEnumerator:clone()
    return MapEnumerator.new(self.source, self.selector)
end
end


-- WHERE ENUMERATOR --
-- WHERE ENUMERATOR --
---@class WhereEnumerator : Enumerator
---@field source Enumerator
---@field predicate function
local WhereEnumerator = setmetatable({}, { __index = Enumerator })
local WhereEnumerator = setmetatable({}, { __index = Enumerator })
WhereEnumerator.__index = WhereEnumerator
WhereEnumerator.__index = WhereEnumerator
Line 157: Line 207:
assert(predicate, 'Predicate cannot be nil')
assert(predicate, 'Predicate cannot be nil')
     local self = setmetatable(Enumerator.new(), WhereEnumerator)
     local self = setmetatable(Enumerator.new(), WhereEnumerator)
    self.state = 0
     self.source = source
     self.source = source
     self.predicate = predicate
     self.predicate = predicate
     self.enumerator = nil
     self.enumerator = nil
     return self
     return self
end
end


Line 170: Line 219:
self.enumerator = self.source:getEnumerator(self.isArray)
self.enumerator = self.source:getEnumerator(self.isArray)
end
end
 
if self.state == 1 then
if self.state == 1 then
local enumerator = self.enumerator
local enumerator = self.enumerator
while enumerator:moveNext() == true do
while enumerator:moveNext() == true do
local sourceElement = enumerator.current
local sourceElement = enumerator.current
if self.predicate(sourceElement) == true then
            local sourceIndex = enumerator.index
self.index = enumerator.index
if self.predicate(sourceElement, sourceIndex) == true then
self.index = sourceIndex
self.current = sourceElement
self.current = sourceElement
return true
return true
end
end
end
end
        self:finalize()
end
end
 
return false
return false
end
end


function WhereEnumerator:getEnumerator()
function WhereEnumerator:finalize()
     local instance = (self.state == 0) and self or WhereEnumerator.new(self.source, self.predicate)
     if self.enumerator then
    instance.isArray = isArray
        self.enumerator:finalize()
    return instance
    end
    Enumerator.finalize(self)
end
 
function WhereEnumerator:clone()
    return WhereEnumerator.new(self.source, self.predicate)
end
end


-- SELECTMANY ENUMERATOR --
-- FLATMAP (SELECTMANY) ENUMERATOR --
local SelectManyEnumerator = setmetatable({}, { __index = Enumerator })
---@class FlatMapEnumerator : Enumerator
SelectManyEnumerator.__index = SelectManyEnumerator
---@field source Enumerator
SelectManyEnumerator.__pairs = Enumerator_mt.__pairs
---@field selector function
SelectManyEnumerator.__ipairs = Enumerator_mt.__ipairs
---@field position integer
local FlatMapEnumerator = setmetatable({}, { __index = Enumerator })
FlatMapEnumerator.__index = FlatMapEnumerator
FlatMapEnumerator.__pairs = Enumerator_mt.__pairs
FlatMapEnumerator.__ipairs = Enumerator_mt.__ipairs


function SelectManyEnumerator.new(source, selector)
function FlatMapEnumerator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
assert(selector, 'Selector cannot be nil')
     local self = setmetatable(Enumerator.new(), SelectManyEnumerator)
     local self = setmetatable(Enumerator.new(), FlatMapEnumerator)
    self.state = 0
     self.source = source
     self.source = source
     self.selector = selector
     self.selector = selector
Line 208: Line 268:
     self.enumerator = nil -- Enumerator of the source Enumerable
     self.enumerator = nil -- Enumerator of the source Enumerable
     self.sourceEnumerator = nil  -- Enumerator of the nested Enumerable
     self.sourceEnumerator = nil  -- Enumerator of the nested Enumerable
     return self
     return self
end
end


function SelectManyEnumerator:moveNext()
function FlatMapEnumerator:moveNext()
    if self.state == -4 then
        return false
    end
 
    -- Setup state
    if self.state == 0 then
        self.position = 0
        self.enumerator = self.source:getEnumerator(self.isArray)
        self.state = 3 -- signal to get (first) nested enumerator
    end
     while true do
     while true do
        -- Setup state
        if self.state == 0 then
            self.position = 0
            self.enumerator = self.source:getEnumerator(self.isArray)
            self.state = -3 -- signal to get (first) nested enumerator
        end
         -- Grab next value from nested enumerator
         -- Grab next value from nested enumerator
         if self.state == -4 then
         if self.state == 4 then
             if self.sourceEnumerator:moveNext() then
             if self.sourceEnumerator:moveNext() then
                 self.current = self.sourceEnumerator.current
                 self.current = self.sourceEnumerator.current
                 self.index = self.sourceEnumerator.index
                 self.index = self.sourceEnumerator.index
                 self.state = -4 -- signal to get next item
                 self.state = 4 -- signal to get next item
                 return true
                 return true
             else
             else
                 self.state = -3 -- signal to get next enumerator
                -- Cleanup nested enumerator
                self.sourceEnumerator:finalize()
                 self.state = 3 -- signal to get next enumerator
             end
             end
         end
         end
 
-- Grab nest nested enumerator
-- Grab nest nested enumerator
         if self.state == -3 then
         if self.state == 3 then
             if self.enumerator:moveNext() then
             if self.enumerator:moveNext() then
                 local current = self.enumerator.current
                 local current = self.enumerator.current
Line 239: Line 304:


                 local sourceTable = self.selector(current, self.position)
                 local sourceTable = self.selector(current, self.position)
                 if not isType(sourceTable, Enumerator) then
                 -- Nested tables are never treated as arrays.
                -- We need to turn the nested table into an enumerator
                self.sourceEnumerator = getTableEnumerator(sourceTable, false)
                self.sourceEnumerator = TableEnumerator.new(sourceTable)
                 self.state = 4 -- signal to get next item
                :getEnumerator(self.isArray)
                else
                self.sourceEnumerator = sourceTable
                end
                 self.state = -4 -- signal to get next item
             else
             else
             -- enumerator doesn't have any more nested enumerators.
             -- enumerator doesn't have any more nested enumerators.
                self:finalize()
                 return false
                 return false
             end
             end
Line 255: Line 316:
end
end


function FlatMapEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
    if self.sourceEnumerator then
        self.sourceEnumerator:finalize()
    end


function SelectManyEnumerator:getEnumerator(isArray)
    Enumerator.finalize(self)
local instance = (self.state == 0) and self or SelectManyEnumerator.new(self.source, self.selector)
end
    instance.isArray = isArray
 
    return instance
function FlatMapEnumerator:clone()
return FlatMapEnumerator.new(self.source, self.selector)
end
end


-- CONCAT ENUMERATOR --
-- CONCAT ENUMERATOR --
---@class ConcatEnumerator : Enumerator
---@field first Enumerator
---@field second Enumerator
local ConcatEnumerator = setmetatable({}, { __index = Enumerator })
local ConcatEnumerator = setmetatable({}, { __index = Enumerator })
ConcatEnumerator.__index = ConcatEnumerator
ConcatEnumerator.__index = ConcatEnumerator
Line 272: Line 344:
assert(second, 'Second cannot be nil')
assert(second, 'Second cannot be nil')
     local self = setmetatable(Enumerator.new(), ConcatEnumerator)
     local self = setmetatable(Enumerator.new(), ConcatEnumerator)
    self.state = 0
     self.first = first
     self.first = first
     self.second = second
     self.second = second
     self.enumerator = nil
     self.enumerator = nil
     return self
     return self
end
end


Line 282: Line 353:
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
function ConcatEnumerator:getEnumerable(index)
function ConcatEnumerator:getEnumerable(index)
if index == 0 then  
if index == 0 then
return self.first
return self.first
elseif index == 1 then
elseif index == 1 then
Line 292: Line 363:


function ConcatEnumerator:moveNext()
function ConcatEnumerator:moveNext()
    if self.state == -4 then
        return false
    end
if self.state == 0 then
if self.state == 0 then
self.enumerator = self:getEnumerable(self.state)
self.enumerator = self:getEnumerable(self.state)
Line 297: Line 372:
self.state = 1
self.state = 1
end
end
 
if self.state > 0 then
if self.state > 0 then
while true do
while true do
Line 309: Line 384:
local next = self:getEnumerable(self.state - 1)
local next = self:getEnumerable(self.state - 1)
if next ~= nil then
if next ~= nil then
self.enumerator = next:getEnumerator(self.isArray)
                -- Cleanup previous enumerator
                self.enumerator:finalize()
self.enumerator = getTableEnumerator(next, self.isArray)
else
else
                self:finalize()
return false
return false
end
end
Line 319: Line 397:
end
end


function ConcatEnumerator:getEnumerator(isArray)
function ConcatEnumerator:finalize()
     local instance = (self.state == 0) and self or ConcatEnumerator.new(self.first, self.second)
    if self.enumerator then
     instance.isArray = isArray
        self.enumerator:finalize()
     return instance
    end
 
    Enumerator.finalize(self)
end
 
function ConcatEnumerator:clone()
    return ConcatEnumerator.new(self.first, self.second)
end
 
-- APPEND ENUMERATOR --
---@class AppendEnumerator : Enumerator
---@field source Enumerator
---@field item any
---@field itemIndex any
---@field append boolean
---@field enumerator Enumerator
local AppendEnumerator = setmetatable({}, { __index = Enumerator })
AppendEnumerator.__index = AppendEnumerator
AppendEnumerator.__pairs = Enumerator_mt.__pairs
AppendEnumerator.__ipairs = Enumerator_mt.__ipairs
 
function AppendEnumerator.new(source, item, itemIndex, append)
assert(source, 'Source cannot be nil')
assert(item, 'Item cannot be nil')
    local self = setmetatable(Enumerator.new(), AppendEnumerator)
    self.source = source
    self.item = item
    -- Index needs *some* value, otherwise iteration stops.
    if itemIndex == nil then self.itemIndex = item else self.itemIndex = itemIndex end
    if append == nil then self.append = true else self.append = append end
    self.enumerator = nil
    return self
end
 
function AppendEnumerator:moveNext()
    if self.state == 0 then
        self.state = 1
        -- Put the item in front, as a prepend.
        if self.append == false then
            self.current = self.item
            self.index = self.itemIndex
            return true
        end
    end
 
    -- Grab the source enumerator
    if self.state == 1 then
        self.enumerator = self.source:getEnumerator(self.isArray)
        self.state = 2
    end
 
    if self.state == 2 then
        if self.enumerator:moveNext() then
            self.current = self.enumerator.current
            self.index = self.enumerator.index
            return true
        else
            self:finalize()
        end
 
        if self.append == true then
            self.current = self.item
            self.index = self.itemIndex
            return true
        end
    end
 
    return false
end
 
function AppendEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
 
    Enumerator.finalize(self)
end
 
function AppendEnumerator:clone()
    return AppendEnumerator.new(self.source, self.item, self.itemIndex, self.append)
end
 
-- UNIQUE (DISTINCT) ENUMERATOR --
---@class UniqueEnumerator : Enumerator
---@field source Enumerator
---@field selector function
---@field enumerator Enumerator
---@field set table
local UniqueEnumerator = setmetatable({}, { __index = Enumerator })
UniqueEnumerator.__index = UniqueEnumerator
UniqueEnumerator.__pairs = Enumerator_mt.__pairs
UniqueEnumerator.__ipairs = Enumerator_mt.__ipairs
 
function UniqueEnumerator.new(source, selector)
assert(source, 'Source cannot be nil')
    local self = setmetatable(Enumerator.new(), UniqueEnumerator)
    self.source = source
    self.selector = selector
    self.enumerator = nil
    self.set = nil
    return self
end
 
function UniqueEnumerator:moveNext()
    if self.state == -4 then
        return false
    end
 
    if self.state == 0 then
        self.enumerator = self.source:getEnumerator(self.isArray)
        -- If we have any items, create a hashtable. Otherwise abort.
        if self.enumerator:moveNext() == true then
            self.set = {}
            self.state = 1
        else
            self:finalize()
            return false
        end
    end
 
    while true do
        if self.state == 1 then
            self.state = 2
            local current = self.enumerator.current
            local index = self.enumerator.index
            -- Manipulate the item if we have a selector (DistinctBy)
            if self.selector then
                current = self.selector(current, index)
            end
 
            if addToSet(self.set, current) then
                self.current = self.enumerator.current
                self.index = self.enumerator.index
                return true
            end
        end
 
        -- Try to grab a new item.
        if self.state == 2 then
            if self.enumerator:moveNext() == true then
                self.state = 1
            else
                self:finalize()
                return false
            end
        end
    end
end
 
function UniqueEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
 
    Enumerator.finalize(self)
end
 
function UniqueEnumerator:clone()
    return UniqueEnumerator.new(self.source, self.selector)
end
 
-- EXCEPT ENUMERATOR --
---@class DifferenceEnumerator : Enumerator
---@field first Enumerator
---@field second any
---@field selector function
---@field enumerator Enumerator
---@field set table
local DifferenceEnumerator = setmetatable({}, { __index = Enumerator })
DifferenceEnumerator.__index = DifferenceEnumerator
DifferenceEnumerator.__pairs = Enumerator_mt.__pairs
DifferenceEnumerator.__ipairs = Enumerator_mt.__ipairs
 
function DifferenceEnumerator.new(first, second, selector)
assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Enumerator.new(), DifferenceEnumerator)
    self.first = first
    self.second = second
    self.selector = selector
    self.enumerator = nil
    self.set = nil
    return self
end
 
local function createSet(other)
     local set = {}
    if isType(other, Enumerator) then
        for _, v in other:getPairs() do
            set[v] = true
        end
    else
        assert(type(other) == 'table', 'Source must be a table.')
        for _, v in pairs(other) do
            set[v] = true
        end
    end
 
    return set
end
 
function DifferenceEnumerator:moveNext()
    if self.state == -4 then
        return false
    end
 
    if self.state == 0 then
        self.enumerator = self.first:getEnumerator(self.isArray)
        self.set = createSet(self.second)
        self.state = 1
    end
 
    while self.enumerator:moveNext() do
        local current = self.enumerator.current
        local index = self.enumerator.index
        if self.selector ~= nil then
            current = self.selector(current, index)
        end
        if addToSet(self.set, current) then
            self.current = self.enumerator.current
            self.index = self.enumerator.index
            self.state = 1
            return true
        end
    end
 
    self:finalize()
    return false
end
 
function DifferenceEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
end
 
function DifferenceEnumerator:clone()
    return DifferenceEnumerator.new(self.first, self.second, self.selector)
end
 
-- UNION ENUMERATOR --
---@class UnionEnumerator : Enumerator
---@field first Enumerator
---@field second any
---@field selector function
---@field enumerator Enumerator
---@field set table
local UnionEnumerator = setmetatable({}, { __index = Enumerator })
UnionEnumerator.__index = UnionEnumerator
UnionEnumerator.__pairs = Enumerator_mt.__pairs
UnionEnumerator.__ipairs = Enumerator_mt.__ipairs
 
function UnionEnumerator.new(first, second, selector)
assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Enumerator.new(), UnionEnumerator)
    self.first = first
    self.second = second
    self.selector = selector
    self.enumerator = nil
    self.set = nil
    return self
end
 
local function enumerateSource(self, state)
    while self.enumerator:moveNext() do
        local current = self.enumerator.current
        local index = self.enumerator.index
        if self.selector ~= nil then
            current = self.selector(current, index)
        end
        if addToSet(self.set, current) == true then
            self.current = self.enumerator.current
            self.index = self.enumerator.index
            self.state = state
            return true
        end
    end
end
 
function UnionEnumerator:moveNext()
     if self.state == -4 then
        return false
    end
 
    if self.state == 0 then
        self.enumerator = self.first:getEnumerator(self.isArray)
        self.set = {}
        self.state = 1
    end
 
    -- Process first
    if self.state == 1 then
        if enumerateSource(self, 1) == true then
            return true
        end
 
        -- End first enumeration
        self.state = 2
        self.enumerator:finalize()
        self.enumerator = getTableEnumerator(self.second, self.isArray)
    end
 
    -- Process second
    if self.state == 2 then
        if enumerateSource(self, 2) == true then
            return true
        end
 
        -- End first enumeration
        self.state = 3
    end
 
    self:finalize()
    return false
end
 
function UnionEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
 
    Enumerator.finalize(self)
end
 
function UnionEnumerator:clone()
     return UnionEnumerator.new(self.first, self.second, self.selector)
end
 
--TODO:
-- INTERSECT
-- GROUPBY
-- ZIP
 
 
-- ORDER/SORTING???
 
--Optional:
---Take
---Skip
 
-- Define forward declared functions
isType = function(obj, class)
    local mt = getmetatable(obj)
    while mt do
        if mt.__index == class then
            return true
        end
        mt = getmetatable(mt)
    end
end
 
getTableEnumerator = function(sourceTable, isArray)
    if not isType(sourceTable, Enumerator) then
        return TableEnumerator.new(sourceTable):getEnumerator(isArray)
    else
        return sourceTable
    end
end
 
addToSet = function(set, item)
    if set[item] == nil then
        set[item] = true
        return true
    else
        return false
    end
end
end


return {
return {
     Enumerator = Enumerator,
     Enumerator = Enumerator,
     TableEnumerator = TableEnumerator,
     TableEnumerator = TableEnumerator,
     SelectEnumerator = SelectEnumerator,
     MapEnumerator = MapEnumerator,
     WhereEnumerator = WhereEnumerator,
     WhereEnumerator = WhereEnumerator,
     SelectManyEnumerator = SelectManyEnumerator,
     FlatMapEnumerator = FlatMapEnumerator,
     ConcatEnumerator = ConcatEnumerator
     ConcatEnumerator = ConcatEnumerator,
    AppendEnumerator = AppendEnumerator,
    UniqueEnumerator = UniqueEnumerator,
    DifferenceEnumerator = DifferenceEnumerator,
    UnionEnumerator = UnionEnumerator
}
}

Revision as of 20:37, 23 July 2024

Documentation for this module may be created at Module:FunList/Iterators/doc

-- FORWARD DECLARED FUNCTIONS --

-- Checks if the provided object matches the provided type.
-- Returns true if object is of the provided type.
-- (object, type). Returns boolean
local isType

-- Returns a TableEnumerator from the provided object or
-- creates a new one if the object is not of type TableEnumerator
-- (object). Returns TableEnumerator
local getTableEnumerator

-- Attempts to add an object to the provided table as a hashset.
-- Returns True if the object was not already present.
-- (table, object). Returns True
local addToSet

-- CLASS DEFINITIONS --
-- BASE ENUMERATOR CLASS --
---@class Enumerator
---@field current any
---@field index any
---@field state integer
---@field isArray boolean
local Enumerator = {}
local Enumerator_mt = {
	__index = Enumerator,
	__pairs = function(t) return t:getPairs() end,
	__ipairs = function(t) return t:getiPairs()
    end
}

---@return Enumerator
function Enumerator.new()
	local self = setmetatable({}, Enumerator_mt)
	self.current = nil
	self.index = nil
    self.state = -1
	-- Assume by default we are not dealing with a simple array
	self.isArray = false
	return self
end

---@return boolean
function Enumerator:moveNext()
    error('Abstract function must be overridden in derived class.')
end

---@return Enumerator
function Enumerator:getEnumerator(isArray)
    -- The default state is -1 which signifies a Enumerator isn't used.
    local instance = (self.state == -1) and self or self:clone()
    instance.isArray = isArray
    instance.state = 0
    return instance
end

---@return Enumerator
function Enumerator:clone()
    error('Abstract function must be overridden in derived class.')
end

function Enumerator:finalize()
    -- Signals invalid state.
    self.state = -4
end

-- Hooks the moveNext function into the Lua 'pairs' function
local function overridePairs(enum, startIndex)
	-- Get or create clean enumerator. This ensures the state is 0.
	local new = enum:getEnumerator(startIndex == 0)
	new.current = nil
	new.index = startIndex
	local function iterator(t, k)
		if new:moveNext() == true then
			return new.index, new.current
		end
		return nil, nil
	end

	return iterator, new, new.index
end

-- Manual override for iterating over the Enumerator using pairs()
function Enumerator:getPairs()
    return overridePairs(self, nil)
end

-- Manual override for iterating over the Enumerator using ipairs()
function Enumerator:getiPairs()
    return overridePairs(self, 0)
end

-- TABLE ENUMERATOR CLASS --
-- This is essentially a wrapper for the table object, 
-- since it provides no state machine esque iteration out of the box
---@class TableEnumerator : Enumerator
---@field state integer
---@field tbl table
local TableEnumerator = setmetatable({}, { __index = Enumerator })
TableEnumerator.__index = TableEnumerator
TableEnumerator.__pairs =  Enumerator_mt.__pairs
TableEnumerator.__ipairs = Enumerator_mt.__ipairs

function TableEnumerator.new(tbl)
    local self = setmetatable(Enumerator.new(), TableEnumerator)
    self.tbl = tbl or {} -- Allow creation of empty enumerable

    return self
end

function TableEnumerator:moveNext()
    if self.state == 0 then
        self.state = 1
		self.index = self.isArray and 0 or nil
    end

    if self.isArray == true then
        -- Iterate using ipairs, starting from index 1
        self.index = self.index + 1
        self.current = self.tbl[self.index]
        return self.current ~= nil
    else
        -- Iterate using pairs
        local key = self.index
        key = next(self.tbl, key)
        self.index = key
        if key ~= nil then
            self.current = self.tbl[key]
            return true
        end
    end

    return false
end

function TableEnumerator:clone()
    return TableEnumerator.new(self.tbl)
end

-- SELECT ENUMERATOR --
---@class MapEnumerator : Enumerator
---@field state integer
---@field source Enumerator
---@field selector Enumerator
local MapEnumerator = setmetatable({}, { __index = Enumerator })
MapEnumerator.__index = MapEnumerator
MapEnumerator.__pairs = Enumerator_mt.__pairs
MapEnumerator.__ipairs = Enumerator_mt.__ipairs

function MapEnumerator.new(source, selector)
	assert(source, 'Source cannot be nil')
	assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Enumerator.new(), MapEnumerator)
    self.source = source
    self.selector = selector
    self.enumerator = nil
    self.position = 0
    return self
end

function MapEnumerator:moveNext()
	if self.state == 0 then
		self.state = 1
		self.position = 0
		self.enumerator = self.source:getEnumerator(self.isArray)
	end

	if self.state == 1 then
		local enumerator = self.enumerator
		if enumerator:moveNext() == true then
			local sourceElement = enumerator.current
			self.index = enumerator.index -- Preserve index
			self.position = self.position + 1
			self.current = self.selector(sourceElement, self.position)
			assert(self.current, 'Selected value must be non-nil')
			return true
        end

        self:finalize()
	end
	return false
end

function MapEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
    Enumerator.finalize(self)
end

function MapEnumerator:clone()
    return MapEnumerator.new(self.source, self.selector)
end

-- WHERE ENUMERATOR --
---@class WhereEnumerator : Enumerator
---@field source Enumerator
---@field predicate function
local WhereEnumerator = setmetatable({}, { __index = Enumerator })
WhereEnumerator.__index = WhereEnumerator
WhereEnumerator.__pairs =  Enumerator_mt.__pairs
WhereEnumerator.__ipairs = Enumerator_mt.__ipairs

function WhereEnumerator.new(source, predicate)
	assert(source, 'Source cannot be nil')
	assert(predicate, 'Predicate cannot be nil')
    local self = setmetatable(Enumerator.new(), WhereEnumerator)
    self.source = source
    self.predicate = predicate
    self.enumerator = nil
    return self
end

function WhereEnumerator:moveNext()
	if self.state == 0 then
		self.state = 1
		self.position = 0
		self.enumerator = self.source:getEnumerator(self.isArray)
	end

	if self.state == 1 then
		local enumerator = self.enumerator
		while enumerator:moveNext() == true do
			local sourceElement = enumerator.current
            local sourceIndex = enumerator.index
			if self.predicate(sourceElement, sourceIndex) == true then
				self.index = sourceIndex
				self.current = sourceElement
				return true
			end
		end

        self:finalize()
	end

	return false
end

function WhereEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
    Enumerator.finalize(self)
end

function WhereEnumerator:clone()
    return WhereEnumerator.new(self.source, self.predicate)
end

-- FLATMAP (SELECTMANY) ENUMERATOR --
---@class FlatMapEnumerator : Enumerator
---@field source Enumerator
---@field selector function
---@field position integer
local FlatMapEnumerator = setmetatable({}, { __index = Enumerator })
FlatMapEnumerator.__index = FlatMapEnumerator
FlatMapEnumerator.__pairs = Enumerator_mt.__pairs
FlatMapEnumerator.__ipairs = Enumerator_mt.__ipairs

function FlatMapEnumerator.new(source, selector)
	assert(source, 'Source cannot be nil')
	assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Enumerator.new(), FlatMapEnumerator)
    self.source = source
    self.selector = selector
    self.position = 0
    self.enumerator = nil		 -- Enumerator of the source Enumerable
    self.sourceEnumerator = nil  -- Enumerator of the nested Enumerable
    return self
end

function FlatMapEnumerator:moveNext()
    if self.state == -4 then
        return false
    end

    -- Setup state
    if self.state == 0 then
        self.position = 0
        self.enumerator = self.source:getEnumerator(self.isArray)
        self.state = 3 -- signal to get (first) nested enumerator
    end
    while true do
        -- Grab next value from nested enumerator		
        if self.state == 4 then
            if self.sourceEnumerator:moveNext() then
                self.current = self.sourceEnumerator.current
                self.index = self.sourceEnumerator.index
                self.state = 4 -- signal to get next item
                return true
            else
                -- Cleanup nested enumerator
                self.sourceEnumerator:finalize()
                self.state = 3 -- signal to get next enumerator
            end
        end

		-- Grab nest nested enumerator
        if self.state == 3 then
            if self.enumerator:moveNext() then
                local current = self.enumerator.current
                self.position = self.position + 1

                local sourceTable = self.selector(current, self.position)
                -- Nested tables are never treated as arrays.
                self.sourceEnumerator = getTableEnumerator(sourceTable, false)
                self.state = 4 -- signal to get next item
            else
            	-- enumerator doesn't have any more nested enumerators.
                self:finalize()
                return false
            end
        end
    end
end

function FlatMapEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
    if self.sourceEnumerator then
        self.sourceEnumerator:finalize()
    end

    Enumerator.finalize(self)
end

function FlatMapEnumerator:clone()
	return FlatMapEnumerator.new(self.source, self.selector)
end

-- CONCAT ENUMERATOR --
---@class ConcatEnumerator : Enumerator
---@field first Enumerator
---@field second Enumerator
local ConcatEnumerator = setmetatable({}, { __index = Enumerator })
ConcatEnumerator.__index = ConcatEnumerator
ConcatEnumerator.__pairs = Enumerator_mt.__pairs
ConcatEnumerator.__ipairs = Enumerator_mt.__ipairs

function ConcatEnumerator.new(first, second)
	assert(first, 'First cannot be nil')
	assert(second, 'Second cannot be nil')
    local self = setmetatable(Enumerator.new(), ConcatEnumerator)
    self.first = first
    self.second = second
    self.enumerator = nil
    return self
end

-- Function to grab enumerators in order. We currently only have two so this
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
function ConcatEnumerator:getEnumerable(index)
	if index == 0 then
		return self.first
	elseif index == 1 then
		return self.second
	else
		return nil
	end
end

function ConcatEnumerator:moveNext()
    if self.state == -4 then
        return false
    end

	if self.state == 0 then
		self.enumerator = self:getEnumerable(self.state)
			:getEnumerator(self.isArray)
		self.state = 1
	end

	if self.state > 0 then
		while true do
			if self.enumerator:moveNext() == true then
				self.index = self.enumerator.index
				self.current = self.enumerator.current
				return true
			end

			self.state = self.state + 1
			local next = self:getEnumerable(self.state - 1)
			if next ~= nil then
                -- Cleanup previous enumerator
                self.enumerator:finalize()
				self.enumerator = getTableEnumerator(next, self.isArray)
			else
                self:finalize()
				return false
			end
		end
	end

	return false
end

function ConcatEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end

    Enumerator.finalize(self)
end

function ConcatEnumerator:clone()
    return ConcatEnumerator.new(self.first, self.second)
end

-- APPEND ENUMERATOR --
---@class AppendEnumerator : Enumerator
---@field source Enumerator
---@field item any
---@field itemIndex any
---@field append boolean
---@field enumerator Enumerator
local AppendEnumerator = setmetatable({}, { __index = Enumerator })
AppendEnumerator.__index = AppendEnumerator
AppendEnumerator.__pairs = Enumerator_mt.__pairs
AppendEnumerator.__ipairs = Enumerator_mt.__ipairs

function AppendEnumerator.new(source, item, itemIndex, append)
	assert(source, 'Source cannot be nil')
	assert(item, 'Item cannot be nil')
    local self = setmetatable(Enumerator.new(), AppendEnumerator)
    self.source = source
    self.item = item
    -- Index needs *some* value, otherwise iteration stops.
    if itemIndex == nil then self.itemIndex = item else self.itemIndex = itemIndex end
    if append == nil then self.append = true else self.append = append end
    self.enumerator = nil
    return self
end

function AppendEnumerator:moveNext()
    if self.state == 0 then
        self.state = 1
        -- Put the item in front, as a prepend.
        if self.append == false then
            self.current = self.item
            self.index = self.itemIndex
            return true
        end
    end

    -- Grab the source enumerator
    if self.state == 1 then
        self.enumerator = self.source:getEnumerator(self.isArray)
        self.state = 2
    end

    if self.state == 2 then
        if self.enumerator:moveNext() then
            self.current = self.enumerator.current
            self.index = self.enumerator.index
            return true
        else
            self:finalize()
        end

        if self.append == true then
            self.current = self.item
            self.index = self.itemIndex
            return true
        end
    end

    return false
end

function AppendEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end

    Enumerator.finalize(self)
end

function AppendEnumerator:clone()
    return AppendEnumerator.new(self.source, self.item, self.itemIndex, self.append)
end

-- UNIQUE (DISTINCT) ENUMERATOR --
---@class UniqueEnumerator : Enumerator
---@field source Enumerator
---@field selector function
---@field enumerator Enumerator
---@field set table
local UniqueEnumerator = setmetatable({}, { __index = Enumerator })
UniqueEnumerator.__index = UniqueEnumerator
UniqueEnumerator.__pairs = Enumerator_mt.__pairs
UniqueEnumerator.__ipairs = Enumerator_mt.__ipairs

function UniqueEnumerator.new(source, selector)
	assert(source, 'Source cannot be nil')
    local self = setmetatable(Enumerator.new(), UniqueEnumerator)
    self.source = source
    self.selector = selector
    self.enumerator = nil
    self.set = nil
    return self
end

function UniqueEnumerator:moveNext()
    if self.state == -4 then
        return false
    end

    if self.state == 0 then
        self.enumerator = self.source:getEnumerator(self.isArray)
        -- If we have any items, create a hashtable. Otherwise abort.
        if self.enumerator:moveNext() == true then
            self.set = {}
            self.state = 1
        else
            self:finalize()
            return false
        end
    end

    while true do
        if self.state == 1 then
            self.state = 2
            local current = self.enumerator.current
            local index = self.enumerator.index
            -- Manipulate the item if we have a selector (DistinctBy)
            if self.selector then
                current = self.selector(current, index)
            end

            if addToSet(self.set, current) then
                self.current = self.enumerator.current
                self.index = self.enumerator.index
                return true
            end
        end

        -- Try to grab a new item.
        if self.state == 2 then
            if self.enumerator:moveNext() == true then
                self.state = 1
            else
                self:finalize()
                return false
            end
        end
    end
end

function UniqueEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end

    Enumerator.finalize(self)
end

function UniqueEnumerator:clone()
    return UniqueEnumerator.new(self.source, self.selector)
end

-- EXCEPT ENUMERATOR --
---@class DifferenceEnumerator : Enumerator
---@field first Enumerator
---@field second any
---@field selector function
---@field enumerator Enumerator
---@field set table
local DifferenceEnumerator = setmetatable({}, { __index = Enumerator })
DifferenceEnumerator.__index = DifferenceEnumerator
DifferenceEnumerator.__pairs = Enumerator_mt.__pairs
DifferenceEnumerator.__ipairs = Enumerator_mt.__ipairs

function DifferenceEnumerator.new(first, second, selector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Enumerator.new(), DifferenceEnumerator)
    self.first = first
    self.second = second
    self.selector = selector
    self.enumerator = nil
    self.set = nil
    return self
end

local function createSet(other)
    local set = {}
    if isType(other, Enumerator) then
        for _, v in other:getPairs() do
            set[v] = true
        end
    else
        assert(type(other) == 'table', 'Source must be a table.')
        for _, v in pairs(other) do
            set[v] = true
        end
    end

    return set
end

function DifferenceEnumerator:moveNext()
    if self.state == -4 then
        return false
    end

    if self.state == 0 then
        self.enumerator = self.first:getEnumerator(self.isArray)
        self.set = createSet(self.second)
        self.state = 1
    end

    while self.enumerator:moveNext() do
        local current = self.enumerator.current
        local index = self.enumerator.index
        if self.selector ~= nil then 
            current = self.selector(current, index)
        end
        if addToSet(self.set, current) then
            self.current = self.enumerator.current
            self.index = self.enumerator.index
            self.state = 1
            return true
        end
    end

    self:finalize()
    return false
end

function DifferenceEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
end

function DifferenceEnumerator:clone()
    return DifferenceEnumerator.new(self.first, self.second, self.selector)
end

-- UNION ENUMERATOR --
---@class UnionEnumerator : Enumerator
---@field first Enumerator
---@field second any
---@field selector function
---@field enumerator Enumerator
---@field set table
local UnionEnumerator = setmetatable({}, { __index = Enumerator })
UnionEnumerator.__index = UnionEnumerator
UnionEnumerator.__pairs = Enumerator_mt.__pairs
UnionEnumerator.__ipairs = Enumerator_mt.__ipairs

function UnionEnumerator.new(first, second, selector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Enumerator.new(), UnionEnumerator)
    self.first = first
    self.second = second
    self.selector = selector
    self.enumerator = nil
    self.set = nil
    return self
end

local function enumerateSource(self, state)
    while self.enumerator:moveNext() do
        local current = self.enumerator.current
        local index = self.enumerator.index
        if self.selector ~= nil then
            current = self.selector(current, index)
        end
        if addToSet(self.set, current) == true then
            self.current = self.enumerator.current
            self.index = self.enumerator.index
            self.state = state
            return true
        end
    end
end

function UnionEnumerator:moveNext()
    if self.state == -4 then
        return false
    end

    if self.state == 0 then
        self.enumerator = self.first:getEnumerator(self.isArray)
        self.set = {}
        self.state = 1
    end

    -- Process first
    if self.state == 1 then
        if enumerateSource(self, 1) == true then
            return true
        end

        -- End first enumeration
        self.state = 2
        self.enumerator:finalize()
        self.enumerator = getTableEnumerator(self.second, self.isArray)
    end

    -- Process second
    if self.state == 2 then
        if enumerateSource(self, 2) == true then
            return true
        end

        -- End first enumeration
        self.state = 3
    end

    self:finalize()
    return false
end

function UnionEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end

    Enumerator.finalize(self)
end

function UnionEnumerator:clone()
    return UnionEnumerator.new(self.first, self.second, self.selector)
end

--TODO:
-- INTERSECT
-- GROUPBY
-- ZIP


-- ORDER/SORTING???

--Optional:
---Take
---Skip

-- Define forward declared functions
isType = function(obj, class)
    local mt = getmetatable(obj)
    while mt do
        if mt.__index == class then
            return true
        end
        mt = getmetatable(mt)
    end
end

getTableEnumerator = function(sourceTable, isArray)
    if not isType(sourceTable, Enumerator) then
        return TableEnumerator.new(sourceTable):getEnumerator(isArray)
    else
        return sourceTable
    end
end

addToSet = function(set, item)
    if set[item] == nil then
        set[item] = true
        return true
    else
        return false
    end
end



return {
    Enumerator = Enumerator,
    TableEnumerator = TableEnumerator,
    MapEnumerator = MapEnumerator,
    WhereEnumerator = WhereEnumerator,
    FlatMapEnumerator = FlatMapEnumerator,
    ConcatEnumerator = ConcatEnumerator,
    AppendEnumerator = AppendEnumerator,
    UniqueEnumerator = UniqueEnumerator,
    DifferenceEnumerator = DifferenceEnumerator,
    UnionEnumerator = UnionEnumerator
}