Module:FunList/Iterators: Difference between revisions

From Melvor Idle
No edit summary
(bugfix)
 
(9 intermediate revisions by the same user not shown)
Line 1: Line 1:
-- Helper Functions --
local Enumerable = require('Module:FunList/Enumerable')
local function isType(obj, class)
local Lookup = require('Module:FunList/Lookup')
     local mt = getmetatable(obj)
local TableEnumerator = require('Module:FunList/TableEnumerator')
    while mt do
 
         if mt == class then
---Creates a TableIterator Enumerable or returns the Enumerable if it already is one.
             return true
---We set the propper starting index here since this only applies for tables!
---@param source any
---@param isArray any
---@return Enumerator
local function getTableEnumerator(source, isArray)
    if source.getEnumerator == nil then
        assert(type(source) == 'table', 'sourceTable must be either an Enumerable or table.')
        return TableEnumerator.create(source, isArray)
    else
        return source:getEnumerator(isArray)
    end
end
 
-- Helper functions to create and manage a hashset structure.
-- These are used to distinguish elements for Union/Unique/Difference etc operations.
---@param source any
---@return table
local function sourceToSet(source)
     local set = {}
    if Enumerable.isEnumerable(source) then
        local enum = source:getEnumerator()
        while enum:moveNext() do
            set[enum.current] = true
        end
    else
         assert(type(source) == 'table', 'Source must be a table.')
        for _, v in pairs(source) do
             set[v] = true
         end
         end
         mt = getmetatable(mt)
    end
    return set
end
 
---@param set table
---@param item any
---@return boolean
local function setAdd(set, item)
    if set[item] == nil then
         set[item] = true
        return true
     end
     end
     return false
     return false
end
end


-- BASE ENUMERATOR CLASS --
---@param set table
--enumerable = {}
---@param item any
local Enumerator = {}
---@return boolean
local Enumerator_mt = {
local function setRemove(set, item)
__index = Enumerator,
    if set[item] == nil then
__pairs = function(t) return t:overridePairs() end,
        return false
__ipairs = function(t) return t:overridePairs(0) end
    end
    set[item] = nil
    return true
end
 
-- Defines the base class for all Iterator state machines. This is an Enumerable and Enumerator in one.
---@class Iterator : Enumerable, Enumerator
---@field current any
---@field index any
---@field _state integer
---@field _isArray boolean
local Iterator = setmetatable({}, { __index = Enumerable })
local Iterator_mt = {
__index = Iterator,
__pairs = Enumerable.getPairs,
__ipairs = Enumerable.getiPairs
}
}


function Enumerator.new()
---@return Iterator
local self = setmetatable({}, Enumerator_mt)
function Iterator.new()
local self = setmetatable({}, Iterator_mt)
self.current = nil
self.current = nil
self.index = nil
self.index = nil
    self._state = -1
-- Assume by default we are not dealing with a simple array
-- Assume by default we are not dealing with a simple array
self.isArray = false
self._isArray = false
return self
return self
end
end


function Enumerator:moveNext()
-- This is normally implemented by the Enumerator abstract class.
error('Not implemented in base class.')
-- But to prevent double inheritance, we cheat a bit and implement it outselves.
---@return boolean
function Iterator:moveNext()
    error('Abstract function must be overridden in derived class.')
end
end


function Enumerator:getEnumerator(isArray)
---Returns this Enumerator or creates a new one if it has been used.
error('Not implemented in base class.')
---@param isArray? boolean
---@return Enumerator
function Iterator:getEnumerator(isArray)
    -- The default _state is -1 which signifies a Iterator isn't used.
    local instance = (self._state == -1) and self or self:clone()
    instance._isArray = isArray or self._isArray
    instance._state = 0
    return instance
end
end


-- Hooks the moveNext function into the Lua 'pairs' function
---@return Iterator
function Enumerator:overridePairs(startIndex)
function Iterator:clone()
-- Get or create clean enumerator. This ensures the state is 0.
    error('Abstract function must be overridden in derived class.')
local enum = self:getEnumerator(startIndex == 0)
end
enum.current = nil
 
enum.index = startIndex
function Iterator:finalize()
local function iterator(t, k)
    -- Signals invalid _state.
if enum:moveNext() == true then
    self._state = -4
return enum.index, enum.current
end
return nil, nil
end
return iterator, enum, enum.index
end
end


-- TABLE ENUMERATOR CLASS --
function Iterator.createIteratorSubclass()
-- This is essentially a wrapper for the table object,
    local derivedClass = setmetatable({}, { __index = Iterator })
-- since it provides no state machine esque iteration out of the box
    derivedClass.__index = derivedClass
local TableEnumerator = setmetatable({}, { __index = Enumerator })
    derivedClass.__pairs = Enumerable.getPairs
TableEnumerator.__index = TableEnumerator
    derivedClass.__ipairs = Enumerable.getiPairs
TableEnumerator.__pairs = Enumerator_mt.__pairs
    return derivedClass
TableEnumerator.__ipairs = Enumerator_mt.__ipairs
end


function TableEnumerator.new(tbl)
--Projects each element of a sequence into a new form.
    local self = setmetatable(Enumerator.new(), TableEnumerator)
---@class MapIterator : Iterator
    self.tbl = tbl or {} -- Allow creation of empty enumerable
---@field _source Enumerable
    self.state = 0
---@field _selector fun(value: any, index: any): any
---@field _position integer
---@field _enumerator Enumerator
local MapIterator = Iterator.createIteratorSubclass()


---@param source Enumerable
---@param selector fun(value: any, index: any): any
function MapIterator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Iterator.new(), MapIterator)
    self._source = source
    self._selector = selector
    self._enumerator = nil
    self._position = 0
     return self
     return self
end
end


function TableEnumerator:moveNext()
function MapIterator:moveNext()
     if self.state == 0 then
     local state = self._state
         self.state = 1
    if state ~= 0 then
        if state ~= 1 then
            return false
        end
    else
        self._state = 1 state = 1
    self._position = 0
self._enumerator = self._source:getEnumerator(self._isArray)
    end
 
    local _enumerator = self._enumerator
    if _enumerator:moveNext() == true then
        -- 1 based index because Lua
        local pos = self._position + 1
        local index = _enumerator.index
        local current = self._selector(_enumerator.current, pos)
        assert(current, 'Selected value must be non-nil')
        self.index = index
        self.current = current
         self._position = pos
        return true
    end
    self:finalize()
 
return false
end
 
function MapIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
     end
     end
    Iterator.finalize(self)
end


     if self.isArray == true then
function MapIterator:clone()
        -- Iterate using ipairs, starting from index 1
     return MapIterator.new(self._source, self._selector)
        self.index = self.index + 1
end
        self.current = self.tbl[self.index]
 
        return self.current ~= nil
--Filters a sequence of values based on a predicate.
---@class WhereIterator : Iterator
---@field _source Enumerable
---@field _predicate fun(value: any, index: any): boolean
---@field _enumerator Enumerator
local WhereIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param predicate fun(value: any, index: any): boolean
function WhereIterator.new(source, predicate)
assert(source, 'Source cannot be nil')
assert(predicate, 'Predicate cannot be nil')
    local self = setmetatable(Iterator.new(), WhereIterator)
    self._source = source
    self._predicate = predicate
    self._enumerator = nil
    return self
end
 
function WhereIterator:moveNext()
    local state = self._state
    if state ~= 0 then
        if state ~= 1 then
            return false
        end
     else
     else
        -- Iterate using pairs
self._state = 1 state = 1
        local key = self.index
self._enumerator = self._source:getEnumerator(self._isArray)
        key = next(self.tbl, key)
    end
         self.index = key
 
         if key ~= nil then
    -- Cache some class lookups since these are likely
             self.current = self.tbl[key]
    -- to be accessed more than once per moveNext
    local enum = self._enumerator
    local nextFunc = enum.moveNext
    local predicate = self._predicate
    while nextFunc(enum) == true do
         local sourceElement = enum.current
        local sourceIndex = enum.index
         if predicate(sourceElement, sourceIndex) == true then
             self.index = sourceIndex
            self.current = sourceElement
             return true
             return true
         end
         end
     end
     end
    self:finalize()
return false
end
function WhereIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    Iterator.finalize(self)
end


     return false
function WhereIterator:clone()
     return WhereIterator.new(self._source, self._predicate)
end
 
--Projects each element of a sequence to an Enumerable and flattens the resulting sequences into one sequence.
---@class FlatMapIterator : Iterator
---@field _source Enumerable
---@field _selector fun(value: any, index: integer): any
---@field _position integer
---@field _enumerator Enumerator
---@field _sourceEnumerator Enumerator
local FlatMapIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param selector fun(value: any, index: integer): any
function FlatMapIterator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Iterator.new(), FlatMapIterator)
    self._source = source
    self._selector = selector
    self._position = 0
    self._enumerator = nil -- Iterator of the _source Enumerable
    self._sourceEnumerator = nil  -- Iterator of the nested Enumerable
    return self
end
end


-- startIndex is used to determine if the underlying table should be treated
function FlatMapIterator:moveNext()
-- as an array or as a mixed table. It is ignored in the other enumerators as
    local state = self._state
-- they just call moveNext on the enumerator instead.
    if state == -4 then
function TableEnumerator:getEnumerator(isArray)
        return false
local instance = nil
    end
if self.state == 0 then
 
instance = self
    -- Setup _state
else
    if state == 0 then
instance = TableEnumerator.new(self.tbl)
        self._position = 0
end
        self._enumerator = self._source:getEnumerator(self._isArray)
instance.isArray = isArray
        self._state = 3 state = 3 -- signal to get (_first) nested _enumerator
return instance
    end
    while true do
        -- Grab next value from nested _enumerator
        if state == 4 then
            local sourceEnum = self._sourceEnumerator
            if sourceEnum:moveNext() then
                self.current = sourceEnum.current
                self.index = sourceEnum.index
                self._state, state = 4, 4 -- signal to get next item
                return true
            else
                -- Cleanup nested _enumerator
                self._sourceEnumerator:finalize()
                self._state = 3 state = 3 -- signal to get next _enumerator
            end
        end
 
-- Grab nest nested _enumerator
        if state == 3 then
            if self._enumerator:moveNext() then
                local current = self._enumerator.current
                local pos = self._position + 1
 
                local sourceTable = self._selector(current, pos)
                -- Nested tables are never treated as arrays.
                self._sourceEnumerator = getTableEnumerator(sourceTable, false)
                self._position = pos
                self._state = 4 state = 4 -- signal to get next item
            else
            -- _enumerator doesn't have any more nested enumerators.
                self:finalize()
                return false
            end
        end
    end
end
end


-- SELECT ENUMERATOR --
function FlatMapIterator:finalize()
local SelectEnumerator = setmetatable({}, { __index = Enumerator })
    if self._enumerator then
SelectEnumerator.__index = SelectEnumerator
        self._enumerator:finalize()
SelectEnumerator.__pairs = Enumerator_mt.__pairs
    end
SelectEnumerator.__ipairs = Enumerator_mt.__ipairs
    if self._sourceEnumerator then
        self._sourceEnumerator:finalize()
    end


function SelectEnumerator.new(source, selector)
    Iterator.finalize(self)
    local self = setmetatable(Enumerator.new(), SelectEnumerator)
    self.state = 0
    self.source = source
    self.selector = selector
    self.enumerator = nil
    self.position = 0
    return self
end
end


function SelectEnumerator:moveNext()
function FlatMapIterator:clone()
if self.state == 0 then
return FlatMapIterator.new(self._source, self._selector)
self.state = 1
end
self.position = 0
 
self.enumerator = self.source:getEnumerator(self.isArray)
--Concatenates two sequences.
end
---@class ConcatIterator : Iterator
---@field _first Enumerable
if self.state == 1 then
---@field _second any
local enumerator = self.enumerator
---@field _enumerator Enumerator
if enumerator:moveNext() == true then
local ConcatIterator = Iterator.createIteratorSubclass()
local sourceElement = enumerator.current
 
self.index = enumerator.index -- Preserve index
---@param first Enumerable
self.position = self.position + 1
---@param second any
self.current = self.selector(sourceElement, self.position)
function ConcatIterator.new(first, second)
assert(self.current, 'Selected value must be non-nil')
assert(first, 'First cannot be nil')
return true
assert(second, 'Second cannot be nil')
end
    local self = setmetatable(Iterator.new(), ConcatIterator)
end
    self._first = first
return false
    self._second = second
    self._enumerator = nil
    return self
end
end


function SelectEnumerator:getEnumerator()
-- Function to grab enumerators in order. We currently only have two so this
if self.state == 0 then
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
return self
function ConcatIterator:getEnumerable(index)
if index == 0 then
return self._first
elseif index == 1 then
return self._second
else
else
return SelectEnumerator.new(self.source, self.selector)
return nil
end
end
end
end


-- WHERE ENUMERATOR --
function ConcatIterator:moveNext()
local WhereEnumerator = setmetatable({}, { __index = Enumerator })
    local state = self._state
WhereEnumerator.__index = WhereEnumerator
    if state == -4 then
WhereEnumerator.__pairs = Enumerator_mt.__pairs
        return false
WhereEnumerator.__ipairs = Enumerator_mt.__ipairs
    end


function WhereEnumerator.new(source, predicate)
if state == 0 then
    local self = setmetatable(Enumerator.new(), WhereEnumerator)
self._enumerator = self:getEnumerable(state)
    self.state = 0
:getEnumerator(self._isArray)
    self.source = source
self._state = 1 state = 1
    self.predicate = predicate
end
    self.enumerator = nil
    return self
end


function WhereEnumerator:moveNext()
if state > 0 then
if self.state == 0 then
while true do
self.state = 1
            local enum = self._enumerator
self.position = 0
if enum:moveNext() == true then
self.enumerator = self.source:getEnumerator(self.isArray)
self.index = enum.index
end
self.current = enum.current
if self.state == 1 then
local enumerator = self.enumerator
while enumerator:moveNext() == true do
local sourceElement = enumerator.current
if self.predicate(sourceElement) == true then
self.index = enumerator.index
self.current = sourceElement
return true
return true
end
            state = state + 1
            self._state = state
local next = self:getEnumerable(state - 1)
if next ~= nil then
                -- Cleanup previous _enumerator
                self._enumerator:finalize()
self._enumerator = getTableEnumerator(next, self._isArray)
else
                self:finalize()
return false
end
end
end
end
end
end
 
return false
return false
end
end


function WhereEnumerator:getEnumerator()
function ConcatIterator:finalize()
if self.state == 0 then
    if self._enumerator then
return self
        self._enumerator:finalize()
else
    end
return WhereEnumerator.new(self.source, self.predicate)
 
end
    Iterator.finalize(self)
end
 
function ConcatIterator:clone()
    return ConcatIterator.new(self._first, self._second)
end
 
--Appends a value to the end or start of the sequence.
---@class AppendIterator : Iterator
---@field _source Enumerable
---@field _item any
---@field _itemIndex any
---@field _append boolean
---@field _enumerator Enumerator
local AppendIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param item any
---@param itemIndex? any
---@param append? boolean
function AppendIterator.new(source, item, itemIndex, append)
assert(source, 'Source cannot be nil')
assert(item, 'Item cannot be nil')
    local self = setmetatable(Iterator.new(), AppendIterator)
    self._source = source
    self._item = item
    -- Index needs *some* value, otherwise iteration stops.
    if itemIndex == nil then self._itemIndex = item else self._itemIndex = itemIndex end
    if append == nil then self._append = true else self._append = append end
    self._enumerator = nil
    return self
end
 
function AppendIterator:moveNext()
    local state = self._state
    if state == 0 then
        self._state = 1 state = 1
        -- Put the item in front, as a prepend.
        if self._append == false then
            self.current = self._item
            self.index = self._itemIndex
            return true
        end
    end
 
    -- Grab the _source _enumerator
    if state == 1 then
        self._enumerator = self._source:getEnumerator(self._isArray)
        self._state = 2 state = 2
    end
 
    if state == 2 then
        local enum = self._enumerator
        if enum:moveNext() then
            self.current = enum.current
            self.index = enum.index
            return true
        else
            self:finalize()
        end
 
        if self._append == true then
            self.current = self._item
            self.index = self._itemIndex
            return true
        end
    end
 
    return false
end
 
function AppendIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function AppendIterator:clone()
    return AppendIterator.new(self._source, self._item, self._itemIndex, self._append)
end
end


-- WHERE ENUMERATOR --
--Returns unique (distinct) elements from a sequence according to a specified key selector function.
local SelectManyEnumerator = setmetatable({}, { __index = Enumerator })
---@class UniqueIterator : Iterator
SelectManyEnumerator.__index = SelectManyEnumerator
---@field _source Enumerable
SelectManyEnumerator.__pairs = Enumerator_mt.__pairs
---@field _keySelector fun(value: any, index: any): any
SelectManyEnumerator.__ipairs = Enumerator_mt.__ipairs
---@field _enumerator Enumerator
---@field _set table
local UniqueIterator = Iterator.createIteratorSubclass()


function SelectManyEnumerator.new(source, selector)
---@param source Enumerable
     local self = setmetatable(Enumerator.new(), SelectManyEnumerator)
---@param keySelector? fun(value: any, index: any): any
     self.state = 0
function UniqueIterator.new(source, keySelector)
    self.source = source
assert(source, 'Source cannot be nil')
     self.selector = selector
     local self = setmetatable(Iterator.new(), UniqueIterator)
    self.position = 0
     self._source = source
     self.enumerator = nil -- Enumerator of the source Enumerable
     self._keySelector = keySelector or function(x) return x end
     self.sourceEnumerator = nil -- Enumerator of the nested Enumerable
     self._enumerator = nil
     return self
     self._set = nil
     return self
end
end


function SelectManyEnumerator:moveNext()
function UniqueIterator:moveNext()
    local state = self._state
    if state == -4 then
        return false
    end
 
    if state == 0 then
        self._enumerator = self._source:getEnumerator(self._isArray)
        -- If we have any items, create a hashtable. Otherwise abort.
        if self._enumerator:moveNext() == true then
            self._set = {}
            self._state = 1 state = 1
        else
            self:finalize()
            return false
        end
    end
 
    local enum = self._enumerator
    local set = self._set
     while true do
     while true do
        -- Setup state
         if state == 1 then
         if self.state == 0 then
             self._state = 2 state = 2
             self.position = 0
             local current = enum.current
             self.enumerator = self.source:getEnumerator(self.isArray)
            local index = enum.index
            self.state = -3 -- signal to get (first) nested enumerator
            local key = self._keySelector(current, index)
            if setAdd(key) == true then
                self.current = current
                self.index = index
                return true
            end
         end
         end


         -- Grab next value from nested enumerator
         -- Try to grab a new item.
         if self.state == -4 then
         if state == 2 then
             if self.sourceEnumerator:moveNext() then
             if enum:moveNext() == true then
                 self.current = self.sourceEnumerator.current
                 self._state = 1 state = 1
                self.index = self.sourceEnumerator.index
                self.state = -4 -- signal to get next item
                return true
             else
             else
                 self.state = -3 -- signal to get next enumerator
                 self:finalize()
                return false
             end
             end
         end
         end
    end
-- Grab nest nested enumerator
end
         if self.state == -3 then
 
             if self.enumerator:moveNext() then
function UniqueIterator:finalize()
                local current = self.enumerator.current
    if self._enumerator then
                self.position = self.position + 1
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function UniqueIterator:clone()
    return UniqueIterator.new(self._source, self._keySelector)
end
 
--Produces the set difference of two sequences according to a specified key selector function.
---@class DifferenceIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local DifferenceIterator = Iterator.createIteratorSubclass()
 
---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function DifferenceIterator.new(first, second, keySelector)
assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), DifferenceIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end
 
function DifferenceIterator:moveNext()
    local state = self._state
    if state ~= 0 then
         if state ~= 1 then
            return false
        end
    else
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = sourceToSet(self._second)
        self._state = 1 state = 1
    end
 
    local enum = self._enumerator
    local nextFunc = enum.moveNext
    local keySelector = self._keySelector
    local set = self._set
    while nextFunc(enum) == true do
        local current = enum.current
        local index = enum.index
        local key = keySelector(current, index)
        if setAdd(key) == true then
            self.current = current
            self.index = index
            return true
        end
    end
 
    self:finalize()
    return false
end
 
function DifferenceIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
end
 
function DifferenceIterator:clone()
    return DifferenceIterator.new(self._first, self._second, self._keySelector)
end
 
--Produces the set union of two sequences according to a specified key selector function.
---@class UnionIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local UnionIterator = Iterator.createIteratorSubclass()
 
---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function UnionIterator.new(first, second, keySelector)
assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), UnionIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end
 
local function enumerateSource(self, _state)
    while self._enumerator:moveNext() do
        local current = self._enumerator.current
        local index = self._enumerator.index
        if self._set:add(self._keySelector(current, index)) == true then
            self.current = self._enumerator.current
            self.index = self._enumerator.index
            self._state = _state
            return true
        end
    end
end
 
function UnionIterator:moveNext()
    if self._state == -4 then
        return false
    end
 
    if self._state == 0 then
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = {}
        self._state = 1
    end
 
    -- Process _first
    if self._state == 1 then
        if enumerateSource(self, 1) == true then
            return true
        end
 
        -- End _first enumeration
        self._state = 2
        self._enumerator:finalize()
        self._enumerator = getTableEnumerator(self._second, self._isArray)
    end
 
    -- Process _second
    if self._state == 2 then
        if enumerateSource(self, 2) == true then
            return true
        end
 
        -- End _first enumeration
        self._state = 3
    end
 
    self:finalize()
    return false
end
 
function UnionIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function UnionIterator:clone()
    return UnionIterator.new(self._first, self._second, self._keySelector)
end
 
--Produces the set intersection of two sequences according to a specified key selector function.
---@class IntersectIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local IntersectIterator = Iterator.createIteratorSubclass()
 
---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function IntersectIterator.new(first, second, keySelector)
assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), IntersectIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end
 
function IntersectIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
             return false
        end
    else
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = sourceToSet(self._second)
        self._state = 1
    end
    while self._enumerator:moveNext() do
        local current = self._enumerator.current
        local index = self._enumerator.index
        local key = self._keySelector(current, index)
        if setRemove(self._set, key) == true then
            self.current = self._enumerator.current
            self.index = self._enumerator.index
            self._state = 1
            return true
        end
    end
 
    self:finalize()
    return false
end
 
function IntersectIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function IntersectIterator:clone()
    return IntersectIterator.new(self._first, self._second, self._keySelector)
end
 
--Applies a specified function to the corresponding elements of two sequences, producing a sequence of the results.
---@class ZipIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _resultSelector fun(left: any, right: any): any
---@field _enumerator1 Enumerator
---@field _enumerator2 Enumerator
local ZipIterator = Iterator.createIteratorSubclass()
 
---@param first Enumerable
---@param second any
---@param resultSelector? fun(left: any, right: any): any
function ZipIterator.new(first, second, resultSelector)
assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), ZipIterator)
    self._first = first
    self._second = second
    self._resultSelector = resultSelector or (function(a, b) return {a, b} end)
    self._enumerator1 = nil
    self._enumerator2 = nil
    return self
end
 
function ZipIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._enumerator1 = self._first:getEnumerator(self._isArray)
        self._enumerator2 = getTableEnumerator(self._second, self._isArray)
        self._state = -4
    end
    if self._enumerator1:moveNext() == true and self._enumerator2:moveNext() == true then
        self.current = self._resultSelector(self._enumerator1.current, self._enumerator2.current)
        self.index = self._enumerator1.index
        self._state = 1
        return true
    end
 
    self:finalize()
    return false
end
 
function ZipIterator:finalize()
    if self._enumerator1 then
        self._enumerator1:finalize()
    end
    if self._enumerator2 then
        self._enumerator2:finalize()
    end
 
    Iterator.finalize(self)
end
 
function ZipIterator:clone()
    return ZipIterator.new(self._first, self._second, self._resultSelector)
end
 
--Groups the elements of a sequence.
---@class GroupByIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
function GroupByIterator.new(source, keySelector, elementSelector)
assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    local self = setmetatable(Iterator.new(), GroupByIterator)
    self._source = source
    self._keySelector = keySelector
    self._elementSelector = elementSelector or function(x) return x end
    self._enumerator = nil
    self._lookup = nil
    return self
end
 
function GroupByIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
        self._enumerator = self._lookup:getEnumerator()
        self._state = 1
    end
 
    local enum = self._enumerator
    if enum:moveNext() == true then
        self.current = enum.current
        self.index = enum.index
        self._state = 1
        return true
    end
 
    self:finalize()
    return false
end
 
function GroupByIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function GroupByIterator:clone()
    return GroupByIterator.new(self._source, self._keySelector, self._elementSelector)
end
 
--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class GroupByResultIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _resultSelector fun(key: any, grouping: Grouping): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByResultIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
---@param resultSelector fun(key: any, grouping: Grouping): any
function GroupByResultIterator.new(source, keySelector, elementSelector, resultSelector)
assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    assert(resultSelector, 'keySelector cannot be nil')
    local self = setmetatable(Iterator.new(), GroupByResultIterator)
    self._source = source
    self._keySelector = keySelector
    self._elementSelector = elementSelector or function(x) return x end
    self._resultSelector = resultSelector
    self._enumerator = nil
    self._lookup = nil
    return self
end
 
function GroupByResultIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
        self._enumerator = self._lookup:getEnumerator()
        self._state = 1
    end
 
    local enumerator = self._enumerator
    if enumerator:moveNext() == true then
        -- Transform Grouping element for enumeration
        local gKey = enumerator.current.key
        local gElements = enumerator.current
        self.current = self._resultSelector(gKey, gElements)
        self.index = enumerator.index
        self._state = 1
        return true
    end
 
    self:finalize()
    return false
end
 
function GroupByResultIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function GroupByResultIterator:clone()
    return GroupByResultIterator.new(self._source, self._keySelector, self._elementSelector, self._resultSelector)
end
 
 
--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class SortableIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(value: any, index: any): any|?
---@field _descendingSort boolean
---@field _parentIterator SortableIterator
---@field _enumerator Enumerator
local SortableIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@param parent? SortableIterator
---@return SortableIterator
function SortableIterator.new(source, keySelector, descendingSort, parent)
assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    assert(parent == nil or parent.getEnumerableSorters ~= nil, 'parent must be of type SortableIterator')
    if descendingSort == nil then
        descendingSort = false
    end


                local sourceTable = self.selector(current, self.position)
    local self = setmetatable(Iterator.new(), SortableIterator)
                if not isType(sourceTable, Enumerator) then
    self._source = source
                -- We need to turn the nested table into an enumerator
    self._keySelector = keySelector
                self.sourceEnumerator = TableEnumerator.new(sourceTable)
    self._descendingSort = descendingSort
                :getEnumerator(self.isArray)
    self._parentIterator = parent
                 else
    self._enumerator = nil
                self.sourceEnumerator = sourceTable
---@diagnostic disable-next-line: return-type-mismatch
    return self
end
 
local function createSortFunc(sorters)
    local sortCount = #sorters
    if sortCount > 1 then
        return function(a, b)
            for _, sorter in ipairs(sorters) do
                local selector = sorter._keySelector
                local descendingSort = sorter._descendingSort
 
                local aVal = selector(a)
                local bVal = selector(b)
                 if aVal ~= bVal then
                    if descendingSort == true then
                        return aVal > bVal
                    else
                        return aVal < bVal
                    end
                 end
                 end
                self.state = -4 -- signal to get next item
            else
            -- enumerator doesn't have any more nested enumerators.
                return false
             end
             end
         end
         end
    end
    -- Faster sort function if there's no nested sorts.
    local sorter = sorters[1]
    local selector = sorter._keySelector
    if sorter._descendingSort == true then
        return function(a, b) return selector(a) > selector(b) end
    else
        return function(a, b) return selector(a) < selector(b) end
     end
     end
end
end


function SortableIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        -- Collect sorter meta information.
        local sorters = {}
        self:getEnumerableSorters(sorters)
        local sortFunc = createSortFunc(sorters)


function SelectManyEnumerator:getEnumerator()
        local enumerator = self._source:getEnumerator(self._isArray)
if self.state == 0 then
        local buffer = {}
return self
        -- Process and sort all previous enumerators.
else
        while enumerator:moveNext() == true do
return SelectManyEnumerator.new(self.source, self.predicate)
            table.insert(buffer, enumerator.current)
end
        end
        table.sort(buffer, sortFunc)
        self._enumerator = TableEnumerator.createForArray(buffer)
        self._state = 1
    end
 
    local enumerator = self._enumerator
    if enumerator:moveNext() == true then
        self.current = enumerator.current
        self.index = enumerator.index
        self._state = 1
        return true
    end
 
    self:finalize()
    return false
end
 
---Collects all sorter functions down the line.
---@param sorters table
---@return table
function SortableIterator:getEnumerableSorters(sorters)
    -- Get parent sorter first.
    if self._parentIterator ~= nil then
        self._parentIterator:getEnumerableSorters(sorters)
    end
    table.insert(sorters, self)
 
    return sorters
end
 
---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@return SortableIterator
function SortableIterator:createSortableIterator(keySelector, descendingSort)
    return SortableIterator.new(self._source, keySelector, descendingSort, self)
end
 
function SortableIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
        self._enumerator = nil
    end
    Iterator.finalize(self)
end
 
function SortableIterator:clone()
    return SortableIterator.new(self._source, self._keySelector, self._descendingSort, self._parentIterator)
end
end


return {
return {
     Enumerator = Enumerator,
     MapIterator = MapIterator,
     TableEnumerator = TableEnumerator,
     WhereIterator = WhereIterator,
     SelectEnumerator = SelectEnumerator,
     FlatMapIterator = FlatMapIterator,
     WhereEnumerator = WhereEnumerator,
     ConcatIterator = ConcatIterator,
     SelectManyEnumerator = SelectManyEnumerator
     AppendIterator = AppendIterator,
    UniqueIterator = UniqueIterator,
    DifferenceIterator = DifferenceIterator,
    UnionIterator = UnionIterator,
    IntersectIterator = IntersectIterator,
    ZipIterator = ZipIterator,
    GroupByIterator = GroupByIterator,
    GroupByResultIterator = GroupByResultIterator,
    SortableIterator = SortableIterator,
}
}

Latest revision as of 20:46, 26 July 2024

Documentation for this module may be created at Module:FunList/Iterators/doc

local Enumerable = require('Module:FunList/Enumerable')
local Lookup = require('Module:FunList/Lookup')
local TableEnumerator = require('Module:FunList/TableEnumerator')

---Creates a TableIterator Enumerable or returns the Enumerable if it already is one.
---We set the propper starting index here since this only applies for tables!
---@param source any
---@param isArray any
---@return Enumerator
local function getTableEnumerator(source, isArray)
    if source.getEnumerator == nil then
        assert(type(source) == 'table', 'sourceTable must be either an Enumerable or table.')
        return TableEnumerator.create(source, isArray)
    else
        return source:getEnumerator(isArray)
    end
end

-- Helper functions to create and manage a hashset structure.
-- These are used to distinguish elements for Union/Unique/Difference etc operations.
---@param source any
---@return table
local function sourceToSet(source)
    local set = {}
    if Enumerable.isEnumerable(source) then
        local enum = source:getEnumerator()
        while enum:moveNext() do
            set[enum.current] = true
        end
    else
        assert(type(source) == 'table', 'Source must be a table.')
        for _, v in pairs(source) do
            set[v] = true
        end
    end
    return set
end

---@param set table
---@param item any
---@return boolean
local function setAdd(set, item)
    if set[item] == nil then
        set[item] = true
        return true
    end
    return false
end

---@param set table
---@param item any
---@return boolean
local function setRemove(set, item)
    if set[item] == nil then
        return false
    end
    set[item] = nil
    return true
end

-- Defines the base class for all Iterator state machines. This is an Enumerable and Enumerator in one.
---@class Iterator : Enumerable, Enumerator
---@field current any
---@field index any
---@field _state integer
---@field _isArray boolean
local Iterator = setmetatable({}, { __index = Enumerable })
local Iterator_mt = {
	__index = Iterator,
	__pairs = Enumerable.getPairs,
	__ipairs = Enumerable.getiPairs
}

---@return Iterator
function Iterator.new()
	local self = setmetatable({}, Iterator_mt)
	self.current = nil
	self.index = nil
    self._state = -1
	-- Assume by default we are not dealing with a simple array
	self._isArray = false
	return self
end

-- This is normally implemented by the Enumerator abstract class.
-- But to prevent double inheritance, we cheat a bit and implement it outselves.
---@return boolean
function Iterator:moveNext()
    error('Abstract function must be overridden in derived class.')
end

---Returns this Enumerator or creates a new one if it has been used.
---@param isArray? boolean
---@return Enumerator
function Iterator:getEnumerator(isArray)
    -- The default _state is -1 which signifies a Iterator isn't used.
    local instance = (self._state == -1) and self or self:clone()
    instance._isArray = isArray or self._isArray
    instance._state = 0
    return instance
end

---@return Iterator
function Iterator:clone()
    error('Abstract function must be overridden in derived class.')
end

function Iterator:finalize()
    -- Signals invalid _state.
    self._state = -4
end

function Iterator.createIteratorSubclass()
    local derivedClass = setmetatable({}, { __index = Iterator })
    derivedClass.__index = derivedClass
    derivedClass.__pairs  = Enumerable.getPairs
    derivedClass.__ipairs = Enumerable.getiPairs
    return derivedClass
end

--Projects each element of a sequence into a new form.
---@class MapIterator : Iterator
---@field _source Enumerable
---@field _selector fun(value: any, index: any): any
---@field _position integer
---@field _enumerator Enumerator
local MapIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param selector fun(value: any, index: any): any
function MapIterator.new(source, selector)
	assert(source, 'Source cannot be nil')
	assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Iterator.new(), MapIterator)
    self._source = source
    self._selector = selector
    self._enumerator = nil
    self._position = 0
    return self
end

function MapIterator:moveNext()
    local state = self._state
    if state ~= 0 then
        if state ~= 1 then
            return false
        end
    else
        self._state = 1 state = 1
     	self._position = 0
		self._enumerator = self._source:getEnumerator(self._isArray)
    end

    local _enumerator = self._enumerator
    if _enumerator:moveNext() == true then
        -- 1 based index because Lua
        local pos = self._position + 1
        local index = _enumerator.index
        local current = self._selector(_enumerator.current, pos)
        assert(current, 'Selected value must be non-nil')
        self.index = index
        self.current = current
        self._position = pos
        return true
    end
    self:finalize()

	return false
end

function MapIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    Iterator.finalize(self)
end

function MapIterator:clone()
    return MapIterator.new(self._source, self._selector)
end

--Filters a sequence of values based on a predicate.
---@class WhereIterator : Iterator
---@field _source Enumerable
---@field _predicate fun(value: any, index: any): boolean
---@field _enumerator Enumerator
local WhereIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param predicate fun(value: any, index: any): boolean
function WhereIterator.new(source, predicate)
	assert(source, 'Source cannot be nil')
	assert(predicate, 'Predicate cannot be nil')
    local self = setmetatable(Iterator.new(), WhereIterator)
    self._source = source
    self._predicate = predicate
    self._enumerator = nil
    return self
end

function WhereIterator:moveNext()
    local state = self._state
    if state ~= 0 then
        if state ~= 1 then
            return false
        end
    else
		self._state = 1 state = 1
		self._enumerator = self._source:getEnumerator(self._isArray)
    end

    -- Cache some class lookups since these are likely
    -- to be accessed more than once per moveNext
    local enum = self._enumerator
    local nextFunc = enum.moveNext
    local predicate = self._predicate
    while nextFunc(enum) == true do
        local sourceElement = enum.current
        local sourceIndex = enum.index
        if predicate(sourceElement, sourceIndex) == true then
            self.index = sourceIndex
            self.current = sourceElement
            return true
        end
    end
    self:finalize()

	return false
end

function WhereIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    Iterator.finalize(self)
end

function WhereIterator:clone()
    return WhereIterator.new(self._source, self._predicate)
end

--Projects each element of a sequence to an Enumerable and flattens the resulting sequences into one sequence.
---@class FlatMapIterator : Iterator
---@field _source Enumerable
---@field _selector fun(value: any, index: integer): any
---@field _position integer
---@field _enumerator Enumerator
---@field _sourceEnumerator Enumerator
local FlatMapIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param selector fun(value: any, index: integer): any
function FlatMapIterator.new(source, selector)
	assert(source, 'Source cannot be nil')
	assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Iterator.new(), FlatMapIterator)
    self._source = source
    self._selector = selector
    self._position = 0
    self._enumerator = nil		 -- Iterator of the _source Enumerable
    self._sourceEnumerator = nil  -- Iterator of the nested Enumerable
    return self
end

function FlatMapIterator:moveNext()
    local state = self._state
    if state == -4 then
        return false
    end

    -- Setup _state
    if state == 0 then
        self._position = 0
        self._enumerator = self._source:getEnumerator(self._isArray)
        self._state = 3 state = 3 -- signal to get (_first) nested _enumerator
    end
    while true do
        -- Grab next value from nested _enumerator		
        if state == 4 then
            local sourceEnum = self._sourceEnumerator
            if sourceEnum:moveNext() then
                self.current = sourceEnum.current
                self.index = sourceEnum.index
                self._state, state = 4, 4 -- signal to get next item
                return true
            else
                -- Cleanup nested _enumerator
                self._sourceEnumerator:finalize()
                self._state = 3 state = 3 -- signal to get next _enumerator
            end
        end

		-- Grab nest nested _enumerator
        if state == 3 then
            if self._enumerator:moveNext() then
                local current = self._enumerator.current
                local pos = self._position + 1

                local sourceTable = self._selector(current, pos)
                -- Nested tables are never treated as arrays.
                self._sourceEnumerator = getTableEnumerator(sourceTable, false)
                self._position = pos
                self._state = 4 state = 4 -- signal to get next item
            else
            	-- _enumerator doesn't have any more nested enumerators.
                self:finalize()
                return false
            end
        end
    end
end

function FlatMapIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    if self._sourceEnumerator then
        self._sourceEnumerator:finalize()
    end

    Iterator.finalize(self)
end

function FlatMapIterator:clone()
	return FlatMapIterator.new(self._source, self._selector)
end

--Concatenates two sequences.
---@class ConcatIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _enumerator Enumerator
local ConcatIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
function ConcatIterator.new(first, second)
	assert(first, 'First cannot be nil')
	assert(second, 'Second cannot be nil')
    local self = setmetatable(Iterator.new(), ConcatIterator)
    self._first = first
    self._second = second
    self._enumerator = nil
    return self
end

-- Function to grab enumerators in order. We currently only have two so this
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
function ConcatIterator:getEnumerable(index)
	if index == 0 then
		return self._first
	elseif index == 1 then
		return self._second
	else
		return nil
	end
end

function ConcatIterator:moveNext()
    local state = self._state
    if state == -4 then
        return false
    end

	if state == 0 then
		self._enumerator = self:getEnumerable(state)
			:getEnumerator(self._isArray)
		self._state = 1 state = 1
	end

	if state > 0 then
		while true do
            local enum = self._enumerator
			if enum:moveNext() == true then
				self.index = enum.index
				self.current = enum.current
				return true
			end

            state = state + 1
            self._state = state
			local next = self:getEnumerable(state - 1)
			if next ~= nil then
                -- Cleanup previous _enumerator
                self._enumerator:finalize()
				self._enumerator = getTableEnumerator(next, self._isArray)
			else
                self:finalize()
				return false
			end
		end
	end

	return false
end

function ConcatIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function ConcatIterator:clone()
    return ConcatIterator.new(self._first, self._second)
end

--Appends a value to the end or start of the sequence.
---@class AppendIterator : Iterator
---@field _source Enumerable
---@field _item any
---@field _itemIndex any
---@field _append boolean
---@field _enumerator Enumerator
local AppendIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param item any
---@param itemIndex? any
---@param append? boolean
function AppendIterator.new(source, item, itemIndex, append)
	assert(source, 'Source cannot be nil')
	assert(item, 'Item cannot be nil')
    local self = setmetatable(Iterator.new(), AppendIterator)
    self._source = source
    self._item = item
    -- Index needs *some* value, otherwise iteration stops.
    if itemIndex == nil then self._itemIndex = item else self._itemIndex = itemIndex end
    if append == nil then self._append = true else self._append = append end
    self._enumerator = nil
    return self
end

function AppendIterator:moveNext()
    local state = self._state
    if state == 0 then
        self._state = 1 state = 1
        -- Put the item in front, as a prepend.
        if self._append == false then
            self.current = self._item
            self.index = self._itemIndex
            return true
        end
    end

    -- Grab the _source _enumerator
    if state == 1 then
        self._enumerator = self._source:getEnumerator(self._isArray)
        self._state = 2 state = 2
    end

    if state == 2 then
        local enum = self._enumerator
        if enum:moveNext() then
            self.current = enum.current
            self.index = enum.index
            return true
        else
            self:finalize()
        end

        if self._append == true then
            self.current = self._item
            self.index = self._itemIndex
            return true
        end
    end

    return false
end

function AppendIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function AppendIterator:clone()
    return AppendIterator.new(self._source, self._item, self._itemIndex, self._append)
end

--Returns unique (distinct) elements from a sequence according to a specified key selector function. 
---@class UniqueIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local UniqueIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param keySelector? fun(value: any, index: any): any
function UniqueIterator.new(source, keySelector)
	assert(source, 'Source cannot be nil')
    local self = setmetatable(Iterator.new(), UniqueIterator)
    self._source = source
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end

function UniqueIterator:moveNext()
    local state = self._state
    if state == -4 then
        return false
    end

    if state == 0 then
        self._enumerator = self._source:getEnumerator(self._isArray)
        -- If we have any items, create a hashtable. Otherwise abort.
        if self._enumerator:moveNext() == true then
            self._set = {}
            self._state = 1 state = 1
        else
            self:finalize()
            return false
        end
    end

    local enum = self._enumerator
    local set = self._set
    while true do
        if state == 1 then
            self._state = 2 state = 2
            local current = enum.current
            local index = enum.index
            local key = self._keySelector(current, index)
            if setAdd(key) == true then
                self.current = current
                self.index = index
                return true
            end
        end

        -- Try to grab a new item.
        if state == 2 then
            if enum:moveNext() == true then
                self._state = 1 state = 1
            else
                self:finalize()
                return false
            end
        end
    end
end

function UniqueIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function UniqueIterator:clone()
    return UniqueIterator.new(self._source, self._keySelector)
end

--Produces the set difference of two sequences according to a specified key selector function.
---@class DifferenceIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local DifferenceIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function DifferenceIterator.new(first, second, keySelector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), DifferenceIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end

function DifferenceIterator:moveNext()
    local state = self._state
    if state ~= 0 then
        if state ~= 1 then
            return false
        end
    else
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = sourceToSet(self._second)
        self._state = 1 state = 1
    end

    local enum = self._enumerator
    local nextFunc = enum.moveNext
    local keySelector = self._keySelector
    local set = self._set
    while nextFunc(enum) == true do
        local current = enum.current
        local index = enum.index
        local key = keySelector(current, index)
        if setAdd(key) == true then
            self.current = current
            self.index = index
            return true
        end
    end

    self:finalize()
    return false
end

function DifferenceIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
end

function DifferenceIterator:clone()
    return DifferenceIterator.new(self._first, self._second, self._keySelector)
end

--Produces the set union of two sequences according to a specified key selector function.
---@class UnionIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local UnionIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function UnionIterator.new(first, second, keySelector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), UnionIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end

local function enumerateSource(self, _state)
    while self._enumerator:moveNext() do
        local current = self._enumerator.current
        local index = self._enumerator.index
        if self._set:add(self._keySelector(current, index)) == true then
            self.current = self._enumerator.current
            self.index = self._enumerator.index
            self._state = _state
            return true
        end
    end
end

function UnionIterator:moveNext()
    if self._state == -4 then
        return false
    end

    if self._state == 0 then
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = {}
        self._state = 1
    end

    -- Process _first
    if self._state == 1 then
        if enumerateSource(self, 1) == true then
            return true
        end

        -- End _first enumeration
        self._state = 2
        self._enumerator:finalize()
        self._enumerator = getTableEnumerator(self._second, self._isArray)
    end

    -- Process _second
    if self._state == 2 then
        if enumerateSource(self, 2) == true then
            return true
        end

        -- End _first enumeration
        self._state = 3
    end

    self:finalize()
    return false
end

function UnionIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function UnionIterator:clone()
    return UnionIterator.new(self._first, self._second, self._keySelector)
end

--Produces the set intersection of two sequences according to a specified key selector function.
---@class IntersectIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local IntersectIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function IntersectIterator.new(first, second, keySelector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), IntersectIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end

function IntersectIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = sourceToSet(self._second)
        self._state = 1
    end
    while self._enumerator:moveNext() do
        local current = self._enumerator.current
        local index = self._enumerator.index
        local key = self._keySelector(current, index)
        if setRemove(self._set, key) == true then
            self.current = self._enumerator.current
            self.index = self._enumerator.index
            self._state = 1
            return true
        end
    end

    self:finalize()
    return false
end

function IntersectIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function IntersectIterator:clone()
    return IntersectIterator.new(self._first, self._second, self._keySelector)
end

--Applies a specified function to the corresponding elements of two sequences, producing a sequence of the results.
---@class ZipIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _resultSelector fun(left: any, right: any): any
---@field _enumerator1 Enumerator
---@field _enumerator2 Enumerator
local ZipIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
---@param resultSelector? fun(left: any, right: any): any
function ZipIterator.new(first, second, resultSelector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), ZipIterator)
    self._first = first
    self._second = second
    self._resultSelector = resultSelector or (function(a, b) return {a, b} end)
    self._enumerator1 = nil
    self._enumerator2 = nil
    return self
end

function ZipIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._enumerator1 = self._first:getEnumerator(self._isArray)
        self._enumerator2 = getTableEnumerator(self._second, self._isArray)
        self._state = -4
    end
    if self._enumerator1:moveNext() == true and self._enumerator2:moveNext() == true then
        self.current = self._resultSelector(self._enumerator1.current, self._enumerator2.current)
        self.index = self._enumerator1.index
        self._state = 1
        return true
    end

    self:finalize()
    return false
end

function ZipIterator:finalize()
    if self._enumerator1 then
        self._enumerator1:finalize()
    end
    if self._enumerator2 then
        self._enumerator2:finalize()
    end

    Iterator.finalize(self)
end

function ZipIterator:clone()
    return ZipIterator.new(self._first, self._second, self._resultSelector)
end

--Groups the elements of a sequence.
---@class GroupByIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
function GroupByIterator.new(source, keySelector, elementSelector)
	assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    local self = setmetatable(Iterator.new(), GroupByIterator)
    self._source = source
    self._keySelector = keySelector
    self._elementSelector = elementSelector or function(x) return x end
    self._enumerator = nil
    self._lookup = nil
    return self
end

function GroupByIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
        self._enumerator = self._lookup:getEnumerator()
        self._state = 1
    end

    local enum = self._enumerator
    if enum:moveNext() == true then
        self.current = enum.current
        self.index = enum.index
        self._state = 1
        return true
    end

    self:finalize()
    return false
end

function GroupByIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function GroupByIterator:clone()
    return GroupByIterator.new(self._source, self._keySelector, self._elementSelector)
end

--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class GroupByResultIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _resultSelector fun(key: any, grouping: Grouping): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByResultIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
---@param resultSelector fun(key: any, grouping: Grouping): any
function GroupByResultIterator.new(source, keySelector, elementSelector, resultSelector)
	assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    assert(resultSelector, 'keySelector cannot be nil')
    local self = setmetatable(Iterator.new(), GroupByResultIterator)
    self._source = source
    self._keySelector = keySelector
    self._elementSelector = elementSelector or function(x) return x end
    self._resultSelector = resultSelector
    self._enumerator = nil
    self._lookup = nil
    return self
end

function GroupByResultIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
        self._enumerator = self._lookup:getEnumerator()
        self._state = 1
    end

    local enumerator = self._enumerator
    if enumerator:moveNext() == true then
        -- Transform Grouping element for enumeration
        local gKey = enumerator.current.key
        local gElements = enumerator.current
        self.current = self._resultSelector(gKey, gElements)
        self.index = enumerator.index
        self._state = 1
        return true
    end

    self:finalize()
    return false
end

function GroupByResultIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function GroupByResultIterator:clone()
    return GroupByResultIterator.new(self._source, self._keySelector, self._elementSelector, self._resultSelector)
end


--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class SortableIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(value: any, index: any): any|?
---@field _descendingSort boolean
---@field _parentIterator SortableIterator
---@field _enumerator Enumerator
local SortableIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@param parent? SortableIterator
---@return SortableIterator
function SortableIterator.new(source, keySelector, descendingSort, parent)
	assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    assert(parent == nil or parent.getEnumerableSorters ~= nil, 'parent must be of type SortableIterator')
    if descendingSort == nil then
        descendingSort = false
    end

    local self = setmetatable(Iterator.new(), SortableIterator)
    self._source = source
    self._keySelector = keySelector
    self._descendingSort = descendingSort
    self._parentIterator = parent
    self._enumerator = nil
---@diagnostic disable-next-line: return-type-mismatch
    return self
end

local function createSortFunc(sorters)
    local sortCount = #sorters
    if sortCount > 1 then
        return function(a, b)
            for _, sorter in ipairs(sorters) do
                local selector = sorter._keySelector
                local descendingSort = sorter._descendingSort

                local aVal = selector(a)
                local bVal = selector(b)
                if aVal ~= bVal then
                    if descendingSort == true then
                        return aVal > bVal
                    else
                        return aVal < bVal
                    end
                end
            end
        end
    end

    -- Faster sort function if there's no nested sorts.
    local sorter = sorters[1]
    local selector = sorter._keySelector
    if sorter._descendingSort == true then
        return function(a, b) return selector(a) > selector(b) end
    else
        return function(a, b) return selector(a) < selector(b) end
    end
end

function SortableIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        -- Collect sorter meta information.
        local sorters = {}
        self:getEnumerableSorters(sorters)
        local sortFunc = createSortFunc(sorters)

        local enumerator = self._source:getEnumerator(self._isArray)
        local buffer = {}
        -- Process and sort all previous enumerators.
        while enumerator:moveNext() == true do
            table.insert(buffer, enumerator.current)
        end
        table.sort(buffer, sortFunc)
        self._enumerator = TableEnumerator.createForArray(buffer)
        self._state = 1
    end

    local enumerator = self._enumerator
    if enumerator:moveNext() == true then
        self.current = enumerator.current
        self.index = enumerator.index
        self._state = 1
        return true
    end

    self:finalize()
    return false
end

---Collects all sorter functions down the line.
---@param sorters table
---@return table
function SortableIterator:getEnumerableSorters(sorters)
    -- Get parent sorter first.
    if self._parentIterator ~= nil then
        self._parentIterator:getEnumerableSorters(sorters)
    end
    table.insert(sorters, self)

    return sorters
end

---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@return SortableIterator
function SortableIterator:createSortableIterator(keySelector, descendingSort)
    return SortableIterator.new(self._source, keySelector, descendingSort, self)
end

function SortableIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
        self._enumerator = nil
    end
    Iterator.finalize(self)
end

function SortableIterator:clone()
    return SortableIterator.new(self._source, self._keySelector, self._descendingSort, self._parentIterator)
end

return {
    MapIterator = MapIterator,
    WhereIterator = WhereIterator,
    FlatMapIterator = FlatMapIterator,
    ConcatIterator = ConcatIterator,
    AppendIterator = AppendIterator,
    UniqueIterator = UniqueIterator,
    DifferenceIterator = DifferenceIterator,
    UnionIterator = UnionIterator,
    IntersectIterator = IntersectIterator,
    ZipIterator = ZipIterator,
    GroupByIterator = GroupByIterator,
    GroupByResultIterator = GroupByResultIterator,
    SortableIterator = SortableIterator,
}