Module:FunList/Iterators: Difference between revisions

From Melvor Idle
(Add ConcatEnumerator)
(bugfix)
 
(5 intermediate revisions by the same user not shown)
Line 1: Line 1:
-- Helper Functions --
local Enumerable = require('Module:FunList/Enumerable')
local function isType(obj, class)
local Lookup = require('Module:FunList/Lookup')
     local mt = getmetatable(obj)
local TableEnumerator = require('Module:FunList/TableEnumerator')
    while mt do
 
         if mt == class then
---Creates a TableIterator Enumerable or returns the Enumerable if it already is one.
             return true
---We set the propper starting index here since this only applies for tables!
---@param source any
---@param isArray any
---@return Enumerator
local function getTableEnumerator(source, isArray)
    if source.getEnumerator == nil then
        assert(type(source) == 'table', 'sourceTable must be either an Enumerable or table.')
        return TableEnumerator.create(source, isArray)
    else
        return source:getEnumerator(isArray)
    end
end
 
-- Helper functions to create and manage a hashset structure.
-- These are used to distinguish elements for Union/Unique/Difference etc operations.
---@param source any
---@return table
local function sourceToSet(source)
     local set = {}
    if Enumerable.isEnumerable(source) then
        local enum = source:getEnumerator()
        while enum:moveNext() do
            set[enum.current] = true
        end
    else
         assert(type(source) == 'table', 'Source must be a table.')
        for _, v in pairs(source) do
             set[v] = true
         end
         end
         mt = getmetatable(mt)
    end
    return set
end
 
---@param set table
---@param item any
---@return boolean
local function setAdd(set, item)
    if set[item] == nil then
         set[item] = true
        return true
     end
     end
     return false
     return false
end
end


-- BASE ENUMERATOR CLASS --
---@param set table
--enumerable = {}
---@param item any
local Enumerator = {}
---@return boolean
local Enumerator_mt = {
local function setRemove(set, item)
__index = Enumerator,
    if set[item] == nil then
__pairs = function(t) return t:overridePairs() end,
        return false
__ipairs = function(t) return t:overridePairs(0) end
    end
    set[item] = nil
    return true
end
 
-- Defines the base class for all Iterator state machines. This is an Enumerable and Enumerator in one.
---@class Iterator : Enumerable, Enumerator
---@field current any
---@field index any
---@field _state integer
---@field _isArray boolean
local Iterator = setmetatable({}, { __index = Enumerable })
local Iterator_mt = {
__index = Iterator,
__pairs = Enumerable.getPairs,
__ipairs = Enumerable.getiPairs
}
}


function Enumerator.new()
---@return Iterator
local self = setmetatable({}, Enumerator_mt)
function Iterator.new()
local self = setmetatable({}, Iterator_mt)
self.current = nil
self.current = nil
self.index = nil
self.index = nil
    self._state = -1
-- Assume by default we are not dealing with a simple array
-- Assume by default we are not dealing with a simple array
self.isArray = false
self._isArray = false
return self
return self
end
end


function Enumerator:moveNext()
-- This is normally implemented by the Enumerator abstract class.
error('Not implemented in base class.')
-- But to prevent double inheritance, we cheat a bit and implement it outselves.
---@return boolean
function Iterator:moveNext()
    error('Abstract function must be overridden in derived class.')
end
 
---Returns this Enumerator or creates a new one if it has been used.
---@param isArray? boolean
---@return Enumerator
function Iterator:getEnumerator(isArray)
    -- The default _state is -1 which signifies a Iterator isn't used.
    local instance = (self._state == -1) and self or self:clone()
    instance._isArray = isArray or self._isArray
    instance._state = 0
    return instance
end
end


function Enumerator:getEnumerator(isArray)
---@return Iterator
error('Not implemented in base class.')
function Iterator:clone()
    error('Abstract function must be overridden in derived class.')
end
end


-- Hooks the moveNext function into the Lua 'pairs' function
function Iterator:finalize()
function Enumerator:overridePairs(startIndex)
    -- Signals invalid _state.
-- Get or create clean enumerator. This ensures the state is 0.
    self._state = -4
local enum = self:getEnumerator(isArray)
enum.current = nil
enum.index = startIndex
local function iterator(t, k)
if enum:moveNext() == true then
return enum.index, enum.current
end
return nil, nil
end
return iterator, enum, enum.index
end
end


-- TABLE ENUMERATOR CLASS --
function Iterator.createIteratorSubclass()
-- This is essentially a wrapper for the table object,
    local derivedClass = setmetatable({}, { __index = Iterator })
-- since it provides no state machine esque iteration out of the box
    derivedClass.__index = derivedClass
local TableEnumerator = setmetatable({}, { __index = Enumerator })
    derivedClass.__pairs = Enumerable.getPairs
TableEnumerator.__index = TableEnumerator
    derivedClass.__ipairs = Enumerable.getiPairs
TableEnumerator.__pairs = Enumerator_mt.__pairs
    return derivedClass
TableEnumerator.__ipairs = Enumerator_mt.__ipairs
end


function TableEnumerator.new(tbl)
--Projects each element of a sequence into a new form.
    local self = setmetatable(Enumerator.new(), TableEnumerator)
---@class MapIterator : Iterator
    self.tbl = tbl or {} -- Allow creation of empty enumerable
---@field _source Enumerable
    self.state = 0
---@field _selector fun(value: any, index: any): any
---@field _position integer
---@field _enumerator Enumerator
local MapIterator = Iterator.createIteratorSubclass()


---@param source Enumerable
---@param selector fun(value: any, index: any): any
function MapIterator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Iterator.new(), MapIterator)
    self._source = source
    self._selector = selector
    self._enumerator = nil
    self._position = 0
     return self
     return self
end
end


function TableEnumerator:moveNext()
function MapIterator:moveNext()
     if self.state == 0 then
     local state = self._state
         self.state = 1
    if state ~= 0 then
self.index = self.isArray and 0 or nil
        if state ~= 1 then
            return false
        end
    else
         self._state = 1 state = 1
    self._position = 0
self._enumerator = self._source:getEnumerator(self._isArray)
     end
     end


     if self.isArray == true then
     local _enumerator = self._enumerator
         -- Iterate using ipairs, starting from index 1
    if _enumerator:moveNext() == true then
         self.index = self.index + 1
         -- 1 based index because Lua
         self.current = self.tbl[self.index]
         local pos = self._position + 1
         return self.current ~= nil
         local index = _enumerator.index
    else
         local current = self._selector(_enumerator.current, pos)
        -- Iterate using pairs
         assert(current, 'Selected value must be non-nil')
        local key = self.index
         self.index = index
         key = next(self.tbl, key)
         self.current = current
         self.index = key
        self._position = pos
         if key ~= nil then
        return true
            self.current = self.tbl[key]
            return true
        end
     end
     end
    self:finalize()


    return false
return false
end
 
function MapIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    Iterator.finalize(self)
end
end


-- startIndex is used to determine if the underlying table should be treated
function MapIterator:clone()
-- as an array or as a mixed table. It is ignored in the other enumerators as
     return MapIterator.new(self._source, self._selector)
-- they just call moveNext on the enumerator instead.
function TableEnumerator:getEnumerator(isArray)
     local instance = (self.state == 0) and self or TableEnumerator.new(self.tbl)
    instance.isArray = isArray
    return instance
end
end


-- SELECT ENUMERATOR --
--Filters a sequence of values based on a predicate.
local SelectEnumerator = setmetatable({}, { __index = Enumerator })
---@class WhereIterator : Iterator
SelectEnumerator.__index = SelectEnumerator
---@field _source Enumerable
SelectEnumerator.__pairs = Enumerator_mt.__pairs
---@field _predicate fun(value: any, index: any): boolean
SelectEnumerator.__ipairs = Enumerator_mt.__ipairs
---@field _enumerator Enumerator
local WhereIterator = Iterator.createIteratorSubclass()


function SelectEnumerator.new(source, selector)
---@param source Enumerable
---@param predicate fun(value: any, index: any): boolean
function WhereIterator.new(source, predicate)
assert(source, 'Source cannot be nil')
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
assert(predicate, 'Predicate cannot be nil')
     local self = setmetatable(Enumerator.new(), SelectEnumerator)
     local self = setmetatable(Iterator.new(), WhereIterator)
     self.state = 0
     self._source = source
    self.source = source
     self._predicate = predicate
     self.selector = selector
     self._enumerator = nil
     self.enumerator = nil
     return self
    self.position = 0
     return self
end
end


function SelectEnumerator:moveNext()
function WhereIterator:moveNext()
if self.state == 0 then
    local state = self._state
self.state = 1
    if state ~= 0 then
self.position = 0
        if state ~= 1 then
self.enumerator = self.source:getEnumerator(self.isArray)
            return false
end
        end
    else
if self.state == 1 then
self._state = 1 state = 1
local enumerator = self.enumerator
self._enumerator = self._source:getEnumerator(self._isArray)
if enumerator:moveNext() == true then
    end
local sourceElement = enumerator.current
self.index = enumerator.index -- Preserve index
self.position = self.position + 1
self.current = self.selector(sourceElement, self.position)
assert(self.current, 'Selected value must be non-nil')
return true
end
end
return false
end


function SelectEnumerator:getEnumerator(isArray)
    -- Cache some class lookups since these are likely
     local instance = (self.state == 0) and self or SelectEnumerator.new(self.source, self.selector)
    -- to be accessed more than once per moveNext
    instance.isArray = isArray
     local enum = self._enumerator
    return instance
    local nextFunc = enum.moveNext
end
    local predicate = self._predicate
    while nextFunc(enum) == true do
        local sourceElement = enum.current
        local sourceIndex = enum.index
        if predicate(sourceElement, sourceIndex) == true then
            self.index = sourceIndex
            self.current = sourceElement
            return true
        end
    end
    self:finalize()


-- WHERE ENUMERATOR --
return false
local WhereEnumerator = setmetatable({}, { __index = Enumerator })
WhereEnumerator.__index = WhereEnumerator
WhereEnumerator.__pairs =  Enumerator_mt.__pairs
WhereEnumerator.__ipairs = Enumerator_mt.__ipairs
 
function WhereEnumerator.new(source, predicate)
assert(source, 'Source cannot be nil')
assert(predicate, 'Predicate cannot be nil')
    local self = setmetatable(Enumerator.new(), WhereEnumerator)
    self.state = 0
    self.source = source
    self.predicate = predicate
    self.enumerator = nil
    return self
end
end


function WhereEnumerator:moveNext()
function WhereIterator:finalize()
if self.state == 0 then
    if self._enumerator then
self.state = 1
        self._enumerator:finalize()
self.position = 0
    end
self.enumerator = self.source:getEnumerator(self.isArray)
    Iterator.finalize(self)
end
if self.state == 1 then
local enumerator = self.enumerator
while enumerator:moveNext() == true do
local sourceElement = enumerator.current
if self.predicate(sourceElement) == true then
self.index = enumerator.index
self.current = sourceElement
return true
end
end
end
return false
end
end


function WhereEnumerator:getEnumerator()
function WhereIterator:clone()
     local instance = (self.state == 0) and self or WhereEnumerator.new(self.source, self.predicate)
     return WhereIterator.new(self._source, self._predicate)
    instance.isArray = isArray
    return instance
end
end


-- SELECTMANY ENUMERATOR --
--Projects each element of a sequence to an Enumerable and flattens the resulting sequences into one sequence.
local SelectManyEnumerator = setmetatable({}, { __index = Enumerator })
---@class FlatMapIterator : Iterator
SelectManyEnumerator.__index = SelectManyEnumerator
---@field _source Enumerable
SelectManyEnumerator.__pairs = Enumerator_mt.__pairs
---@field _selector fun(value: any, index: integer): any
SelectManyEnumerator.__ipairs = Enumerator_mt.__ipairs
---@field _position integer
---@field _enumerator Enumerator
---@field _sourceEnumerator Enumerator
local FlatMapIterator = Iterator.createIteratorSubclass()


function SelectManyEnumerator.new(source, selector)
---@param source Enumerable
---@param selector fun(value: any, index: integer): any
function FlatMapIterator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
assert(selector, 'Selector cannot be nil')
     local self = setmetatable(Enumerator.new(), SelectManyEnumerator)
     local self = setmetatable(Iterator.new(), FlatMapIterator)
     self.state = 0
     self._source = source
    self.source = source
     self._selector = selector
     self.selector = selector
     self._position = 0
     self.position = 0
     self._enumerator = nil -- Iterator of the _source Enumerable
     self.enumerator = nil -- Enumerator of the source Enumerable
     self._sourceEnumerator = nil  -- Iterator of the nested Enumerable
     self.sourceEnumerator = nil  -- Enumerator of the nested Enumerable
     return self
     return self
end
end


function SelectManyEnumerator:moveNext()
function FlatMapIterator:moveNext()
    local state = self._state
    if state == -4 then
        return false
    end
 
    -- Setup _state
    if state == 0 then
        self._position = 0
        self._enumerator = self._source:getEnumerator(self._isArray)
        self._state = 3 state = 3 -- signal to get (_first) nested _enumerator
    end
     while true do
     while true do
        -- Setup state
         -- Grab next value from nested _enumerator
        if self.state == 0 then
         if state == 4 then
            self.position = 0
             local sourceEnum = self._sourceEnumerator
            self.enumerator = self.source:getEnumerator(self.isArray)
            if sourceEnum:moveNext() then
            self.state = -3 -- signal to get (first) nested enumerator
                 self.current = sourceEnum.current
        end
                 self.index = sourceEnum.index
 
                 self._state, state = 4, 4 -- signal to get next item
         -- Grab next value from nested enumerator
         if self.state == -4 then
             if self.sourceEnumerator:moveNext() then
                 self.current = self.sourceEnumerator.current
                 self.index = self.sourceEnumerator.index
                 self.state = -4 -- signal to get next item
                 return true
                 return true
             else
             else
                 self.state = -3 -- signal to get next enumerator
                -- Cleanup nested _enumerator
                self._sourceEnumerator:finalize()
                 self._state = 3 state = 3 -- signal to get next _enumerator
             end
             end
         end
         end
-- Grab nest nested enumerator
        if self.state == -3 then
            if self.enumerator:moveNext() then
                local current = self.enumerator.current
                self.position = self.position + 1


                 local sourceTable = self.selector(current, self.position)
-- Grab nest nested _enumerator
                 if not isType(sourceTable, Enumerator) then
        if state == 3 then
                -- We need to turn the nested table into an enumerator
            if self._enumerator:moveNext() then
                self.sourceEnumerator = TableEnumerator.new(sourceTable)
                 local current = self._enumerator.current
                :getEnumerator(self.isArray)
                local pos = self._position + 1
                 else
 
                self.sourceEnumerator = sourceTable
                 local sourceTable = self._selector(current, pos)
                end
                -- Nested tables are never treated as arrays.
                 self.state = -4 -- signal to get next item
                self._sourceEnumerator = getTableEnumerator(sourceTable, false)
                 self._position = pos
                 self._state = 4 state = 4 -- signal to get next item
             else
             else
             -- enumerator doesn't have any more nested enumerators.
             -- _enumerator doesn't have any more nested enumerators.
                self:finalize()
                 return false
                 return false
             end
             end
Line 256: Line 311:
end
end


function FlatMapIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    if self._sourceEnumerator then
        self._sourceEnumerator:finalize()
    end


function SelectManyEnumerator:getEnumerator(isArray)
    Iterator.finalize(self)
local instance = (self.state == 0) and self or SelectManyEnumerator.new(self.source, self.selector)
end
    instance.isArray = isArray
 
    return instance
function FlatMapIterator:clone()
return FlatMapIterator.new(self._source, self._selector)
end
end


-- CONCAT ENUMERATOR --
--Concatenates two sequences.
local ConcatEnumerator = setmetatable({}, { __index = Enumerator })
---@class ConcatIterator : Iterator
ConcatEnumerator.__index = ConcatEnumerator
---@field _first Enumerable
ConcatEnumerator.__pairs = Enumerator_mt.__pairs
---@field _second any
ConcatEnumerator.__ipairs = Enumerator_mt.__ipairs
---@field _enumerator Enumerator
local ConcatIterator = Iterator.createIteratorSubclass()


function ConcatEnumerator.new(first, second)
---@param first Enumerable
---@param second any
function ConcatIterator.new(first, second)
assert(first, 'First cannot be nil')
assert(first, 'First cannot be nil')
assert(second, 'Second cannot be nil')
assert(second, 'Second cannot be nil')
     local self = setmetatable(Enumerator.new(), ConcatEnumerator)
     local self = setmetatable(Iterator.new(), ConcatIterator)
     self.state = 0
     self._first = first
    self.first = first
     self._second = second
     self.second = second
     self._enumerator = nil
     self.enumerator = nil
     return self
     return self
end
end


-- Function to grab enumerators in order. We currently only have two so this
-- Function to grab enumerators in order. We currently only have two so this
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
function ConcatEnumerator:getEnumerable(index)
function ConcatIterator:getEnumerable(index)
if index == 0 then  
if index == 0 then
return self.first
return self._first
elseif index == 1 then
elseif index == 1 then
return self.second
return self._second
else
else
return nil
return nil
Line 292: Line 357:
end
end


function ConcatEnumerator:moveNext()
function ConcatIterator:moveNext()
if self.state == 0 then
    local state = self._state
self.enumerator = self:getEnumerable(self.state)
    if state == -4 then
:getEnumerator(self.isArray)
        return false
self.state = 1
    end
 
if state == 0 then
self._enumerator = self:getEnumerable(state)
:getEnumerator(self._isArray)
self._state = 1 state = 1
end
end
 
if self.state > 0 then
if state > 0 then
while true do
while true do
if self.enumerator:moveNext() == true then
            local enum = self._enumerator
self.index = self.enumerator.index
if enum:moveNext() == true then
self.current = self.enumerator.current
self.index = enum.index
self.current = enum.current
return true
return true
end
end


self.state = self.state + 1
            state = state + 1
local next = self:getEnumerable(self.state - 1)
            self._state = state
local next = self:getEnumerable(state - 1)
if next ~= nil then
if next ~= nil then
self.enumerator = next:getEnumerator(self.isArray)
                -- Cleanup previous _enumerator
                self._enumerator:finalize()
self._enumerator = getTableEnumerator(next, self._isArray)
else
else
                self:finalize()
return false
return false
end
end
Line 320: Line 395:
end
end


function ConcatEnumerator:getEnumerator(isArray)
function ConcatIterator:finalize()
     local instance = (self.state == 0) and self or ConcatEnumerator.new(self.first, self.second)
    if self._enumerator then
     instance.isArray = isArray
        self._enumerator:finalize()
     return instance
    end
 
    Iterator.finalize(self)
end
 
function ConcatIterator:clone()
    return ConcatIterator.new(self._first, self._second)
end
 
--Appends a value to the end or start of the sequence.
---@class AppendIterator : Iterator
---@field _source Enumerable
---@field _item any
---@field _itemIndex any
---@field _append boolean
---@field _enumerator Enumerator
local AppendIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param item any
---@param itemIndex? any
---@param append? boolean
function AppendIterator.new(source, item, itemIndex, append)
assert(source, 'Source cannot be nil')
assert(item, 'Item cannot be nil')
    local self = setmetatable(Iterator.new(), AppendIterator)
    self._source = source
    self._item = item
    -- Index needs *some* value, otherwise iteration stops.
    if itemIndex == nil then self._itemIndex = item else self._itemIndex = itemIndex end
    if append == nil then self._append = true else self._append = append end
    self._enumerator = nil
    return self
end
 
function AppendIterator:moveNext()
    local state = self._state
    if state == 0 then
        self._state = 1 state = 1
        -- Put the item in front, as a prepend.
        if self._append == false then
            self.current = self._item
            self.index = self._itemIndex
            return true
        end
    end
 
    -- Grab the _source _enumerator
    if state == 1 then
        self._enumerator = self._source:getEnumerator(self._isArray)
        self._state = 2 state = 2
    end
 
    if state == 2 then
        local enum = self._enumerator
        if enum:moveNext() then
            self.current = enum.current
            self.index = enum.index
            return true
        else
            self:finalize()
        end
 
        if self._append == true then
            self.current = self._item
            self.index = self._itemIndex
            return true
        end
    end
 
    return false
end
 
function AppendIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function AppendIterator:clone()
    return AppendIterator.new(self._source, self._item, self._itemIndex, self._append)
end
 
--Returns unique (distinct) elements from a sequence according to a specified key selector function.
---@class UniqueIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local UniqueIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param keySelector? fun(value: any, index: any): any
function UniqueIterator.new(source, keySelector)
assert(source, 'Source cannot be nil')
    local self = setmetatable(Iterator.new(), UniqueIterator)
    self._source = source
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end
 
function UniqueIterator:moveNext()
    local state = self._state
    if state == -4 then
        return false
    end
 
    if state == 0 then
        self._enumerator = self._source:getEnumerator(self._isArray)
        -- If we have any items, create a hashtable. Otherwise abort.
        if self._enumerator:moveNext() == true then
            self._set = {}
            self._state = 1 state = 1
        else
            self:finalize()
            return false
        end
    end
 
    local enum = self._enumerator
    local set = self._set
    while true do
        if state == 1 then
            self._state = 2 state = 2
            local current = enum.current
            local index = enum.index
            local key = self._keySelector(current, index)
            if setAdd(key) == true then
                self.current = current
                self.index = index
                return true
            end
        end
 
        -- Try to grab a new item.
        if state == 2 then
            if enum:moveNext() == true then
                self._state = 1 state = 1
            else
                self:finalize()
                return false
            end
        end
    end
end
 
function UniqueIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function UniqueIterator:clone()
    return UniqueIterator.new(self._source, self._keySelector)
end
 
--Produces the set difference of two sequences according to a specified key selector function.
---@class DifferenceIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local DifferenceIterator = Iterator.createIteratorSubclass()
 
---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function DifferenceIterator.new(first, second, keySelector)
assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), DifferenceIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end
 
function DifferenceIterator:moveNext()
     local state = self._state
    if state ~= 0 then
        if state ~= 1 then
            return false
        end
    else
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = sourceToSet(self._second)
        self._state = 1 state = 1
    end
 
    local enum = self._enumerator
    local nextFunc = enum.moveNext
    local keySelector = self._keySelector
    local set = self._set
    while nextFunc(enum) == true do
        local current = enum.current
        local index = enum.index
        local key = keySelector(current, index)
        if setAdd(key) == true then
            self.current = current
            self.index = index
            return true
        end
    end
 
    self:finalize()
    return false
end
 
function DifferenceIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
end
 
function DifferenceIterator:clone()
    return DifferenceIterator.new(self._first, self._second, self._keySelector)
end
 
--Produces the set union of two sequences according to a specified key selector function.
---@class UnionIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local UnionIterator = Iterator.createIteratorSubclass()
 
---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function UnionIterator.new(first, second, keySelector)
assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), UnionIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end
 
local function enumerateSource(self, _state)
    while self._enumerator:moveNext() do
        local current = self._enumerator.current
        local index = self._enumerator.index
        if self._set:add(self._keySelector(current, index)) == true then
            self.current = self._enumerator.current
            self.index = self._enumerator.index
            self._state = _state
            return true
        end
    end
end
 
function UnionIterator:moveNext()
    if self._state == -4 then
        return false
    end
 
    if self._state == 0 then
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = {}
        self._state = 1
    end
 
    -- Process _first
    if self._state == 1 then
        if enumerateSource(self, 1) == true then
            return true
        end
 
        -- End _first enumeration
        self._state = 2
        self._enumerator:finalize()
        self._enumerator = getTableEnumerator(self._second, self._isArray)
    end
 
    -- Process _second
    if self._state == 2 then
        if enumerateSource(self, 2) == true then
            return true
        end
 
        -- End _first enumeration
        self._state = 3
    end
 
    self:finalize()
    return false
end
 
function UnionIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function UnionIterator:clone()
    return UnionIterator.new(self._first, self._second, self._keySelector)
end
 
--Produces the set intersection of two sequences according to a specified key selector function.
---@class IntersectIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local IntersectIterator = Iterator.createIteratorSubclass()
 
---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function IntersectIterator.new(first, second, keySelector)
assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), IntersectIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end
 
function IntersectIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = sourceToSet(self._second)
        self._state = 1
    end
    while self._enumerator:moveNext() do
        local current = self._enumerator.current
        local index = self._enumerator.index
        local key = self._keySelector(current, index)
        if setRemove(self._set, key) == true then
            self.current = self._enumerator.current
            self.index = self._enumerator.index
            self._state = 1
            return true
        end
    end
 
    self:finalize()
    return false
end
 
function IntersectIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function IntersectIterator:clone()
    return IntersectIterator.new(self._first, self._second, self._keySelector)
end
 
--Applies a specified function to the corresponding elements of two sequences, producing a sequence of the results.
---@class ZipIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _resultSelector fun(left: any, right: any): any
---@field _enumerator1 Enumerator
---@field _enumerator2 Enumerator
local ZipIterator = Iterator.createIteratorSubclass()
 
---@param first Enumerable
---@param second any
---@param resultSelector? fun(left: any, right: any): any
function ZipIterator.new(first, second, resultSelector)
assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), ZipIterator)
    self._first = first
    self._second = second
    self._resultSelector = resultSelector or (function(a, b) return {a, b} end)
    self._enumerator1 = nil
    self._enumerator2 = nil
    return self
end
 
function ZipIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._enumerator1 = self._first:getEnumerator(self._isArray)
        self._enumerator2 = getTableEnumerator(self._second, self._isArray)
        self._state = -4
    end
    if self._enumerator1:moveNext() == true and self._enumerator2:moveNext() == true then
        self.current = self._resultSelector(self._enumerator1.current, self._enumerator2.current)
        self.index = self._enumerator1.index
        self._state = 1
        return true
    end
 
    self:finalize()
    return false
end
 
function ZipIterator:finalize()
    if self._enumerator1 then
        self._enumerator1:finalize()
    end
    if self._enumerator2 then
        self._enumerator2:finalize()
    end
 
    Iterator.finalize(self)
end
 
function ZipIterator:clone()
    return ZipIterator.new(self._first, self._second, self._resultSelector)
end
 
--Groups the elements of a sequence.
---@class GroupByIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
function GroupByIterator.new(source, keySelector, elementSelector)
assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    local self = setmetatable(Iterator.new(), GroupByIterator)
    self._source = source
    self._keySelector = keySelector
    self._elementSelector = elementSelector or function(x) return x end
    self._enumerator = nil
    self._lookup = nil
    return self
end
 
function GroupByIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
        self._enumerator = self._lookup:getEnumerator()
        self._state = 1
    end
 
    local enum = self._enumerator
    if enum:moveNext() == true then
        self.current = enum.current
        self.index = enum.index
        self._state = 1
        return true
    end
 
    self:finalize()
    return false
end
 
function GroupByIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function GroupByIterator:clone()
    return GroupByIterator.new(self._source, self._keySelector, self._elementSelector)
end
 
--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class GroupByResultIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _resultSelector fun(key: any, grouping: Grouping): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByResultIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
---@param resultSelector fun(key: any, grouping: Grouping): any
function GroupByResultIterator.new(source, keySelector, elementSelector, resultSelector)
assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    assert(resultSelector, 'keySelector cannot be nil')
    local self = setmetatable(Iterator.new(), GroupByResultIterator)
    self._source = source
    self._keySelector = keySelector
    self._elementSelector = elementSelector or function(x) return x end
    self._resultSelector = resultSelector
    self._enumerator = nil
    self._lookup = nil
    return self
end
 
function GroupByResultIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
        self._enumerator = self._lookup:getEnumerator()
        self._state = 1
    end
 
    local enumerator = self._enumerator
    if enumerator:moveNext() == true then
        -- Transform Grouping element for enumeration
        local gKey = enumerator.current.key
        local gElements = enumerator.current
        self.current = self._resultSelector(gKey, gElements)
        self.index = enumerator.index
        self._state = 1
        return true
    end
 
    self:finalize()
    return false
end
 
function GroupByResultIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function GroupByResultIterator:clone()
    return GroupByResultIterator.new(self._source, self._keySelector, self._elementSelector, self._resultSelector)
end
 
 
--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class SortableIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(value: any, index: any): any|?
---@field _descendingSort boolean
---@field _parentIterator SortableIterator
---@field _enumerator Enumerator
local SortableIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@param parent? SortableIterator
---@return SortableIterator
function SortableIterator.new(source, keySelector, descendingSort, parent)
assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    assert(parent == nil or parent.getEnumerableSorters ~= nil, 'parent must be of type SortableIterator')
    if descendingSort == nil then
        descendingSort = false
    end
 
    local self = setmetatable(Iterator.new(), SortableIterator)
    self._source = source
    self._keySelector = keySelector
    self._descendingSort = descendingSort
    self._parentIterator = parent
    self._enumerator = nil
---@diagnostic disable-next-line: return-type-mismatch
    return self
end
 
local function createSortFunc(sorters)
    local sortCount = #sorters
    if sortCount > 1 then
        return function(a, b)
            for _, sorter in ipairs(sorters) do
                local selector = sorter._keySelector
                local descendingSort = sorter._descendingSort
 
                local aVal = selector(a)
                local bVal = selector(b)
                if aVal ~= bVal then
                    if descendingSort == true then
                        return aVal > bVal
                    else
                        return aVal < bVal
                    end
                end
            end
        end
    end
 
    -- Faster sort function if there's no nested sorts.
    local sorter = sorters[1]
    local selector = sorter._keySelector
    if sorter._descendingSort == true then
        return function(a, b) return selector(a) > selector(b) end
    else
        return function(a, b) return selector(a) < selector(b) end
    end
end
 
function SortableIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        -- Collect sorter meta information.
        local sorters = {}
        self:getEnumerableSorters(sorters)
        local sortFunc = createSortFunc(sorters)
 
        local enumerator = self._source:getEnumerator(self._isArray)
        local buffer = {}
        -- Process and sort all previous enumerators.
        while enumerator:moveNext() == true do
            table.insert(buffer, enumerator.current)
        end
        table.sort(buffer, sortFunc)
        self._enumerator = TableEnumerator.createForArray(buffer)
        self._state = 1
    end
 
    local enumerator = self._enumerator
    if enumerator:moveNext() == true then
        self.current = enumerator.current
        self.index = enumerator.index
        self._state = 1
        return true
    end
 
    self:finalize()
    return false
end
 
---Collects all sorter functions down the line.
---@param sorters table
---@return table
function SortableIterator:getEnumerableSorters(sorters)
    -- Get parent sorter first.
    if self._parentIterator ~= nil then
        self._parentIterator:getEnumerableSorters(sorters)
    end
    table.insert(sorters, self)
 
    return sorters
end
 
---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@return SortableIterator
function SortableIterator:createSortableIterator(keySelector, descendingSort)
    return SortableIterator.new(self._source, keySelector, descendingSort, self)
end
 
function SortableIterator:finalize()
     if self._enumerator then
        self._enumerator:finalize()
        self._enumerator = nil
    end
    Iterator.finalize(self)
end
 
function SortableIterator:clone()
     return SortableIterator.new(self._source, self._keySelector, self._descendingSort, self._parentIterator)
end
end


return {
return {
     Enumerator = Enumerator,
     MapIterator = MapIterator,
     TableEnumerator = TableEnumerator,
     WhereIterator = WhereIterator,
     SelectEnumerator = SelectEnumerator,
     FlatMapIterator = FlatMapIterator,
     WhereEnumerator = WhereEnumerator,
     ConcatIterator = ConcatIterator,
     SelectManyEnumerator = SelectManyEnumerator,
     AppendIterator = AppendIterator,
     ConcatEnumerator = ConcatEnumerator
     UniqueIterator = UniqueIterator,
    DifferenceIterator = DifferenceIterator,
    UnionIterator = UnionIterator,
    IntersectIterator = IntersectIterator,
    ZipIterator = ZipIterator,
    GroupByIterator = GroupByIterator,
    GroupByResultIterator = GroupByResultIterator,
    SortableIterator = SortableIterator,
}
}

Latest revision as of 20:46, 26 July 2024

Documentation for this module may be created at Module:FunList/Iterators/doc

local Enumerable = require('Module:FunList/Enumerable')
local Lookup = require('Module:FunList/Lookup')
local TableEnumerator = require('Module:FunList/TableEnumerator')

---Creates a TableIterator Enumerable or returns the Enumerable if it already is one.
---We set the propper starting index here since this only applies for tables!
---@param source any
---@param isArray any
---@return Enumerator
local function getTableEnumerator(source, isArray)
    if source.getEnumerator == nil then
        assert(type(source) == 'table', 'sourceTable must be either an Enumerable or table.')
        return TableEnumerator.create(source, isArray)
    else
        return source:getEnumerator(isArray)
    end
end

-- Helper functions to create and manage a hashset structure.
-- These are used to distinguish elements for Union/Unique/Difference etc operations.
---@param source any
---@return table
local function sourceToSet(source)
    local set = {}
    if Enumerable.isEnumerable(source) then
        local enum = source:getEnumerator()
        while enum:moveNext() do
            set[enum.current] = true
        end
    else
        assert(type(source) == 'table', 'Source must be a table.')
        for _, v in pairs(source) do
            set[v] = true
        end
    end
    return set
end

---@param set table
---@param item any
---@return boolean
local function setAdd(set, item)
    if set[item] == nil then
        set[item] = true
        return true
    end
    return false
end

---@param set table
---@param item any
---@return boolean
local function setRemove(set, item)
    if set[item] == nil then
        return false
    end
    set[item] = nil
    return true
end

-- Defines the base class for all Iterator state machines. This is an Enumerable and Enumerator in one.
---@class Iterator : Enumerable, Enumerator
---@field current any
---@field index any
---@field _state integer
---@field _isArray boolean
local Iterator = setmetatable({}, { __index = Enumerable })
local Iterator_mt = {
	__index = Iterator,
	__pairs = Enumerable.getPairs,
	__ipairs = Enumerable.getiPairs
}

---@return Iterator
function Iterator.new()
	local self = setmetatable({}, Iterator_mt)
	self.current = nil
	self.index = nil
    self._state = -1
	-- Assume by default we are not dealing with a simple array
	self._isArray = false
	return self
end

-- This is normally implemented by the Enumerator abstract class.
-- But to prevent double inheritance, we cheat a bit and implement it outselves.
---@return boolean
function Iterator:moveNext()
    error('Abstract function must be overridden in derived class.')
end

---Returns this Enumerator or creates a new one if it has been used.
---@param isArray? boolean
---@return Enumerator
function Iterator:getEnumerator(isArray)
    -- The default _state is -1 which signifies a Iterator isn't used.
    local instance = (self._state == -1) and self or self:clone()
    instance._isArray = isArray or self._isArray
    instance._state = 0
    return instance
end

---@return Iterator
function Iterator:clone()
    error('Abstract function must be overridden in derived class.')
end

function Iterator:finalize()
    -- Signals invalid _state.
    self._state = -4
end

function Iterator.createIteratorSubclass()
    local derivedClass = setmetatable({}, { __index = Iterator })
    derivedClass.__index = derivedClass
    derivedClass.__pairs  = Enumerable.getPairs
    derivedClass.__ipairs = Enumerable.getiPairs
    return derivedClass
end

--Projects each element of a sequence into a new form.
---@class MapIterator : Iterator
---@field _source Enumerable
---@field _selector fun(value: any, index: any): any
---@field _position integer
---@field _enumerator Enumerator
local MapIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param selector fun(value: any, index: any): any
function MapIterator.new(source, selector)
	assert(source, 'Source cannot be nil')
	assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Iterator.new(), MapIterator)
    self._source = source
    self._selector = selector
    self._enumerator = nil
    self._position = 0
    return self
end

function MapIterator:moveNext()
    local state = self._state
    if state ~= 0 then
        if state ~= 1 then
            return false
        end
    else
        self._state = 1 state = 1
     	self._position = 0
		self._enumerator = self._source:getEnumerator(self._isArray)
    end

    local _enumerator = self._enumerator
    if _enumerator:moveNext() == true then
        -- 1 based index because Lua
        local pos = self._position + 1
        local index = _enumerator.index
        local current = self._selector(_enumerator.current, pos)
        assert(current, 'Selected value must be non-nil')
        self.index = index
        self.current = current
        self._position = pos
        return true
    end
    self:finalize()

	return false
end

function MapIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    Iterator.finalize(self)
end

function MapIterator:clone()
    return MapIterator.new(self._source, self._selector)
end

--Filters a sequence of values based on a predicate.
---@class WhereIterator : Iterator
---@field _source Enumerable
---@field _predicate fun(value: any, index: any): boolean
---@field _enumerator Enumerator
local WhereIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param predicate fun(value: any, index: any): boolean
function WhereIterator.new(source, predicate)
	assert(source, 'Source cannot be nil')
	assert(predicate, 'Predicate cannot be nil')
    local self = setmetatable(Iterator.new(), WhereIterator)
    self._source = source
    self._predicate = predicate
    self._enumerator = nil
    return self
end

function WhereIterator:moveNext()
    local state = self._state
    if state ~= 0 then
        if state ~= 1 then
            return false
        end
    else
		self._state = 1 state = 1
		self._enumerator = self._source:getEnumerator(self._isArray)
    end

    -- Cache some class lookups since these are likely
    -- to be accessed more than once per moveNext
    local enum = self._enumerator
    local nextFunc = enum.moveNext
    local predicate = self._predicate
    while nextFunc(enum) == true do
        local sourceElement = enum.current
        local sourceIndex = enum.index
        if predicate(sourceElement, sourceIndex) == true then
            self.index = sourceIndex
            self.current = sourceElement
            return true
        end
    end
    self:finalize()

	return false
end

function WhereIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    Iterator.finalize(self)
end

function WhereIterator:clone()
    return WhereIterator.new(self._source, self._predicate)
end

--Projects each element of a sequence to an Enumerable and flattens the resulting sequences into one sequence.
---@class FlatMapIterator : Iterator
---@field _source Enumerable
---@field _selector fun(value: any, index: integer): any
---@field _position integer
---@field _enumerator Enumerator
---@field _sourceEnumerator Enumerator
local FlatMapIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param selector fun(value: any, index: integer): any
function FlatMapIterator.new(source, selector)
	assert(source, 'Source cannot be nil')
	assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Iterator.new(), FlatMapIterator)
    self._source = source
    self._selector = selector
    self._position = 0
    self._enumerator = nil		 -- Iterator of the _source Enumerable
    self._sourceEnumerator = nil  -- Iterator of the nested Enumerable
    return self
end

function FlatMapIterator:moveNext()
    local state = self._state
    if state == -4 then
        return false
    end

    -- Setup _state
    if state == 0 then
        self._position = 0
        self._enumerator = self._source:getEnumerator(self._isArray)
        self._state = 3 state = 3 -- signal to get (_first) nested _enumerator
    end
    while true do
        -- Grab next value from nested _enumerator		
        if state == 4 then
            local sourceEnum = self._sourceEnumerator
            if sourceEnum:moveNext() then
                self.current = sourceEnum.current
                self.index = sourceEnum.index
                self._state, state = 4, 4 -- signal to get next item
                return true
            else
                -- Cleanup nested _enumerator
                self._sourceEnumerator:finalize()
                self._state = 3 state = 3 -- signal to get next _enumerator
            end
        end

		-- Grab nest nested _enumerator
        if state == 3 then
            if self._enumerator:moveNext() then
                local current = self._enumerator.current
                local pos = self._position + 1

                local sourceTable = self._selector(current, pos)
                -- Nested tables are never treated as arrays.
                self._sourceEnumerator = getTableEnumerator(sourceTable, false)
                self._position = pos
                self._state = 4 state = 4 -- signal to get next item
            else
            	-- _enumerator doesn't have any more nested enumerators.
                self:finalize()
                return false
            end
        end
    end
end

function FlatMapIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    if self._sourceEnumerator then
        self._sourceEnumerator:finalize()
    end

    Iterator.finalize(self)
end

function FlatMapIterator:clone()
	return FlatMapIterator.new(self._source, self._selector)
end

--Concatenates two sequences.
---@class ConcatIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _enumerator Enumerator
local ConcatIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
function ConcatIterator.new(first, second)
	assert(first, 'First cannot be nil')
	assert(second, 'Second cannot be nil')
    local self = setmetatable(Iterator.new(), ConcatIterator)
    self._first = first
    self._second = second
    self._enumerator = nil
    return self
end

-- Function to grab enumerators in order. We currently only have two so this
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
function ConcatIterator:getEnumerable(index)
	if index == 0 then
		return self._first
	elseif index == 1 then
		return self._second
	else
		return nil
	end
end

function ConcatIterator:moveNext()
    local state = self._state
    if state == -4 then
        return false
    end

	if state == 0 then
		self._enumerator = self:getEnumerable(state)
			:getEnumerator(self._isArray)
		self._state = 1 state = 1
	end

	if state > 0 then
		while true do
            local enum = self._enumerator
			if enum:moveNext() == true then
				self.index = enum.index
				self.current = enum.current
				return true
			end

            state = state + 1
            self._state = state
			local next = self:getEnumerable(state - 1)
			if next ~= nil then
                -- Cleanup previous _enumerator
                self._enumerator:finalize()
				self._enumerator = getTableEnumerator(next, self._isArray)
			else
                self:finalize()
				return false
			end
		end
	end

	return false
end

function ConcatIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function ConcatIterator:clone()
    return ConcatIterator.new(self._first, self._second)
end

--Appends a value to the end or start of the sequence.
---@class AppendIterator : Iterator
---@field _source Enumerable
---@field _item any
---@field _itemIndex any
---@field _append boolean
---@field _enumerator Enumerator
local AppendIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param item any
---@param itemIndex? any
---@param append? boolean
function AppendIterator.new(source, item, itemIndex, append)
	assert(source, 'Source cannot be nil')
	assert(item, 'Item cannot be nil')
    local self = setmetatable(Iterator.new(), AppendIterator)
    self._source = source
    self._item = item
    -- Index needs *some* value, otherwise iteration stops.
    if itemIndex == nil then self._itemIndex = item else self._itemIndex = itemIndex end
    if append == nil then self._append = true else self._append = append end
    self._enumerator = nil
    return self
end

function AppendIterator:moveNext()
    local state = self._state
    if state == 0 then
        self._state = 1 state = 1
        -- Put the item in front, as a prepend.
        if self._append == false then
            self.current = self._item
            self.index = self._itemIndex
            return true
        end
    end

    -- Grab the _source _enumerator
    if state == 1 then
        self._enumerator = self._source:getEnumerator(self._isArray)
        self._state = 2 state = 2
    end

    if state == 2 then
        local enum = self._enumerator
        if enum:moveNext() then
            self.current = enum.current
            self.index = enum.index
            return true
        else
            self:finalize()
        end

        if self._append == true then
            self.current = self._item
            self.index = self._itemIndex
            return true
        end
    end

    return false
end

function AppendIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function AppendIterator:clone()
    return AppendIterator.new(self._source, self._item, self._itemIndex, self._append)
end

--Returns unique (distinct) elements from a sequence according to a specified key selector function. 
---@class UniqueIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local UniqueIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param keySelector? fun(value: any, index: any): any
function UniqueIterator.new(source, keySelector)
	assert(source, 'Source cannot be nil')
    local self = setmetatable(Iterator.new(), UniqueIterator)
    self._source = source
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end

function UniqueIterator:moveNext()
    local state = self._state
    if state == -4 then
        return false
    end

    if state == 0 then
        self._enumerator = self._source:getEnumerator(self._isArray)
        -- If we have any items, create a hashtable. Otherwise abort.
        if self._enumerator:moveNext() == true then
            self._set = {}
            self._state = 1 state = 1
        else
            self:finalize()
            return false
        end
    end

    local enum = self._enumerator
    local set = self._set
    while true do
        if state == 1 then
            self._state = 2 state = 2
            local current = enum.current
            local index = enum.index
            local key = self._keySelector(current, index)
            if setAdd(key) == true then
                self.current = current
                self.index = index
                return true
            end
        end

        -- Try to grab a new item.
        if state == 2 then
            if enum:moveNext() == true then
                self._state = 1 state = 1
            else
                self:finalize()
                return false
            end
        end
    end
end

function UniqueIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function UniqueIterator:clone()
    return UniqueIterator.new(self._source, self._keySelector)
end

--Produces the set difference of two sequences according to a specified key selector function.
---@class DifferenceIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local DifferenceIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function DifferenceIterator.new(first, second, keySelector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), DifferenceIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end

function DifferenceIterator:moveNext()
    local state = self._state
    if state ~= 0 then
        if state ~= 1 then
            return false
        end
    else
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = sourceToSet(self._second)
        self._state = 1 state = 1
    end

    local enum = self._enumerator
    local nextFunc = enum.moveNext
    local keySelector = self._keySelector
    local set = self._set
    while nextFunc(enum) == true do
        local current = enum.current
        local index = enum.index
        local key = keySelector(current, index)
        if setAdd(key) == true then
            self.current = current
            self.index = index
            return true
        end
    end

    self:finalize()
    return false
end

function DifferenceIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
end

function DifferenceIterator:clone()
    return DifferenceIterator.new(self._first, self._second, self._keySelector)
end

--Produces the set union of two sequences according to a specified key selector function.
---@class UnionIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local UnionIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function UnionIterator.new(first, second, keySelector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), UnionIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end

local function enumerateSource(self, _state)
    while self._enumerator:moveNext() do
        local current = self._enumerator.current
        local index = self._enumerator.index
        if self._set:add(self._keySelector(current, index)) == true then
            self.current = self._enumerator.current
            self.index = self._enumerator.index
            self._state = _state
            return true
        end
    end
end

function UnionIterator:moveNext()
    if self._state == -4 then
        return false
    end

    if self._state == 0 then
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = {}
        self._state = 1
    end

    -- Process _first
    if self._state == 1 then
        if enumerateSource(self, 1) == true then
            return true
        end

        -- End _first enumeration
        self._state = 2
        self._enumerator:finalize()
        self._enumerator = getTableEnumerator(self._second, self._isArray)
    end

    -- Process _second
    if self._state == 2 then
        if enumerateSource(self, 2) == true then
            return true
        end

        -- End _first enumeration
        self._state = 3
    end

    self:finalize()
    return false
end

function UnionIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function UnionIterator:clone()
    return UnionIterator.new(self._first, self._second, self._keySelector)
end

--Produces the set intersection of two sequences according to a specified key selector function.
---@class IntersectIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local IntersectIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function IntersectIterator.new(first, second, keySelector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), IntersectIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end

function IntersectIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = sourceToSet(self._second)
        self._state = 1
    end
    while self._enumerator:moveNext() do
        local current = self._enumerator.current
        local index = self._enumerator.index
        local key = self._keySelector(current, index)
        if setRemove(self._set, key) == true then
            self.current = self._enumerator.current
            self.index = self._enumerator.index
            self._state = 1
            return true
        end
    end

    self:finalize()
    return false
end

function IntersectIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function IntersectIterator:clone()
    return IntersectIterator.new(self._first, self._second, self._keySelector)
end

--Applies a specified function to the corresponding elements of two sequences, producing a sequence of the results.
---@class ZipIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _resultSelector fun(left: any, right: any): any
---@field _enumerator1 Enumerator
---@field _enumerator2 Enumerator
local ZipIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
---@param resultSelector? fun(left: any, right: any): any
function ZipIterator.new(first, second, resultSelector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), ZipIterator)
    self._first = first
    self._second = second
    self._resultSelector = resultSelector or (function(a, b) return {a, b} end)
    self._enumerator1 = nil
    self._enumerator2 = nil
    return self
end

function ZipIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._enumerator1 = self._first:getEnumerator(self._isArray)
        self._enumerator2 = getTableEnumerator(self._second, self._isArray)
        self._state = -4
    end
    if self._enumerator1:moveNext() == true and self._enumerator2:moveNext() == true then
        self.current = self._resultSelector(self._enumerator1.current, self._enumerator2.current)
        self.index = self._enumerator1.index
        self._state = 1
        return true
    end

    self:finalize()
    return false
end

function ZipIterator:finalize()
    if self._enumerator1 then
        self._enumerator1:finalize()
    end
    if self._enumerator2 then
        self._enumerator2:finalize()
    end

    Iterator.finalize(self)
end

function ZipIterator:clone()
    return ZipIterator.new(self._first, self._second, self._resultSelector)
end

--Groups the elements of a sequence.
---@class GroupByIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
function GroupByIterator.new(source, keySelector, elementSelector)
	assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    local self = setmetatable(Iterator.new(), GroupByIterator)
    self._source = source
    self._keySelector = keySelector
    self._elementSelector = elementSelector or function(x) return x end
    self._enumerator = nil
    self._lookup = nil
    return self
end

function GroupByIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
        self._enumerator = self._lookup:getEnumerator()
        self._state = 1
    end

    local enum = self._enumerator
    if enum:moveNext() == true then
        self.current = enum.current
        self.index = enum.index
        self._state = 1
        return true
    end

    self:finalize()
    return false
end

function GroupByIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function GroupByIterator:clone()
    return GroupByIterator.new(self._source, self._keySelector, self._elementSelector)
end

--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class GroupByResultIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _resultSelector fun(key: any, grouping: Grouping): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByResultIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
---@param resultSelector fun(key: any, grouping: Grouping): any
function GroupByResultIterator.new(source, keySelector, elementSelector, resultSelector)
	assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    assert(resultSelector, 'keySelector cannot be nil')
    local self = setmetatable(Iterator.new(), GroupByResultIterator)
    self._source = source
    self._keySelector = keySelector
    self._elementSelector = elementSelector or function(x) return x end
    self._resultSelector = resultSelector
    self._enumerator = nil
    self._lookup = nil
    return self
end

function GroupByResultIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
        self._enumerator = self._lookup:getEnumerator()
        self._state = 1
    end

    local enumerator = self._enumerator
    if enumerator:moveNext() == true then
        -- Transform Grouping element for enumeration
        local gKey = enumerator.current.key
        local gElements = enumerator.current
        self.current = self._resultSelector(gKey, gElements)
        self.index = enumerator.index
        self._state = 1
        return true
    end

    self:finalize()
    return false
end

function GroupByResultIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function GroupByResultIterator:clone()
    return GroupByResultIterator.new(self._source, self._keySelector, self._elementSelector, self._resultSelector)
end


--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class SortableIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(value: any, index: any): any|?
---@field _descendingSort boolean
---@field _parentIterator SortableIterator
---@field _enumerator Enumerator
local SortableIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@param parent? SortableIterator
---@return SortableIterator
function SortableIterator.new(source, keySelector, descendingSort, parent)
	assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    assert(parent == nil or parent.getEnumerableSorters ~= nil, 'parent must be of type SortableIterator')
    if descendingSort == nil then
        descendingSort = false
    end

    local self = setmetatable(Iterator.new(), SortableIterator)
    self._source = source
    self._keySelector = keySelector
    self._descendingSort = descendingSort
    self._parentIterator = parent
    self._enumerator = nil
---@diagnostic disable-next-line: return-type-mismatch
    return self
end

local function createSortFunc(sorters)
    local sortCount = #sorters
    if sortCount > 1 then
        return function(a, b)
            for _, sorter in ipairs(sorters) do
                local selector = sorter._keySelector
                local descendingSort = sorter._descendingSort

                local aVal = selector(a)
                local bVal = selector(b)
                if aVal ~= bVal then
                    if descendingSort == true then
                        return aVal > bVal
                    else
                        return aVal < bVal
                    end
                end
            end
        end
    end

    -- Faster sort function if there's no nested sorts.
    local sorter = sorters[1]
    local selector = sorter._keySelector
    if sorter._descendingSort == true then
        return function(a, b) return selector(a) > selector(b) end
    else
        return function(a, b) return selector(a) < selector(b) end
    end
end

function SortableIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        -- Collect sorter meta information.
        local sorters = {}
        self:getEnumerableSorters(sorters)
        local sortFunc = createSortFunc(sorters)

        local enumerator = self._source:getEnumerator(self._isArray)
        local buffer = {}
        -- Process and sort all previous enumerators.
        while enumerator:moveNext() == true do
            table.insert(buffer, enumerator.current)
        end
        table.sort(buffer, sortFunc)
        self._enumerator = TableEnumerator.createForArray(buffer)
        self._state = 1
    end

    local enumerator = self._enumerator
    if enumerator:moveNext() == true then
        self.current = enumerator.current
        self.index = enumerator.index
        self._state = 1
        return true
    end

    self:finalize()
    return false
end

---Collects all sorter functions down the line.
---@param sorters table
---@return table
function SortableIterator:getEnumerableSorters(sorters)
    -- Get parent sorter first.
    if self._parentIterator ~= nil then
        self._parentIterator:getEnumerableSorters(sorters)
    end
    table.insert(sorters, self)

    return sorters
end

---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@return SortableIterator
function SortableIterator:createSortableIterator(keySelector, descendingSort)
    return SortableIterator.new(self._source, keySelector, descendingSort, self)
end

function SortableIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
        self._enumerator = nil
    end
    Iterator.finalize(self)
end

function SortableIterator:clone()
    return SortableIterator.new(self._source, self._keySelector, self._descendingSort, self._parentIterator)
end

return {
    MapIterator = MapIterator,
    WhereIterator = WhereIterator,
    FlatMapIterator = FlatMapIterator,
    ConcatIterator = ConcatIterator,
    AppendIterator = AppendIterator,
    UniqueIterator = UniqueIterator,
    DifferenceIterator = DifferenceIterator,
    UnionIterator = UnionIterator,
    IntersectIterator = IntersectIterator,
    ZipIterator = ZipIterator,
    GroupByIterator = GroupByIterator,
    GroupByResultIterator = GroupByResultIterator,
    SortableIterator = SortableIterator,
}