Module:FunList/Iterators: Difference between revisions

no edit summary
m (Ricewind moved page Module:FunList/Enumerators to Module:FunList/Iterators without leaving a redirect)
No edit summary
Line 1: Line 1:
-- FORWARD DECLARED FUNCTIONS --
local Enumerable = require('Module:Enumerable')
local Lookup = require('Module:Lookup')
local TableEnumerator = require('Module:TableEnumerator')


-- Checks if the provided object matches the provided type.
---Creates a TableIterator Enumerable or returns the Enumerable if it already is one.
-- Returns true if object is of the provided type.
---We set the propper starting index here since this only applies for tables!
-- (object, type). Returns boolean
---@param source any
local isType
---@param isArray any
---@return Enumerator
local function getTableEnumerator(source, isArray)
    if source.getEnumerator == nil then
        assert(type(source) == 'table', 'sourceTable must be either an Enumerable or table.')
        return TableEnumerator.create(source, isArray)
    else
        return source:getEnumerator(isArray)
    end
end
 
-- Helper functions to create and manage a hashset structure.
-- These are used to distinguish elements for Union/Unique/Difference etc operations.
---@param source any
---@return table
local function sourceToSet(source)
    local set = {}
    if Enumerable.isEnumerable(source) then
        local enum = source:getEnumerator()
        while enum:moveNext() do
            set[enum.current] = true
        end
    else
        assert(type(source) == 'table', 'Source must be a table.')
        for _, v in pairs(source) do
            set[v] = true
        end
    end
    return set
end


-- Returns a TableEnumerator from the provided object or
---@param set table
-- creates a new one if the object is not of type TableEnumerator
---@param item any
-- (object). Returns TableEnumerator
---@return boolean
local getTableEnumerator
local function setAdd(set, item)
    if set[item] == nil then
        set[item] = true
        return true
    end
    return false
end


-- Attempts to add an object to the provided table as a hashset.
---@param set table
-- Returns True if the object was not already present.
---@param item any
-- (table, object). Returns True
---@return boolean
local addToSet
local function setRemove(set, item)
    if set[item] == nil then
        return false
    end
    set[item] = nil
    return true
end


-- CLASS DEFINITIONS --
-- Defines the base class for all Iterator state machines. This is an Enumerable and Enumerator in one.
-- BASE ENUMERATOR CLASS --
---@class Iterator : Enumerable, Enumerator
---@class Enumerator
---@field current any
---@field current any
---@field index any
---@field index any
---@field state integer
---@field _state integer
---@field isArray boolean
---@field _isArray boolean
local Enumerator = {}
local Iterator = setmetatable({}, { __index = Enumerable })
local Enumerator_mt = {
local Iterator_mt = {
__index = Enumerator,
__index = Iterator,
__pairs = function(t) return t:getPairs() end,
__pairs = Enumerable.getPairs,
__ipairs = function(t) return t:getiPairs()
__ipairs = Enumerable.getiPairs
    end
}
}


---@return Enumerator
---@return Iterator
function Enumerator.new()
function Iterator.new()
local self = setmetatable({}, Enumerator_mt)
local self = setmetatable({}, Iterator_mt)
self.current = nil
self.current = nil
self.index = nil
self.index = nil
     self.state = -1
     self._state = -1
-- Assume by default we are not dealing with a simple array
-- Assume by default we are not dealing with a simple array
self.isArray = false
self._isArray = false
return self
return self
end
end


-- This is normally implemented by the Enumerator abstract class.
-- But to prevent double inheritance, we cheat a bit and implement it outselves.
---@return boolean
---@return boolean
function Enumerator:moveNext()
function Iterator:moveNext()
     error('Abstract function must be overridden in derived class.')
     error('Abstract function must be overridden in derived class.')
end
end


---Returns this Enumerator or creates a new one if it has been used.
---@param isArray? boolean
---@return Enumerator
---@return Enumerator
function Enumerator:getEnumerator(isArray)
function Iterator:getEnumerator(isArray)
     -- The default state is -1 which signifies a Enumerator isn't used.
     -- The default _state is -1 which signifies a Iterator isn't used.
     local instance = (self.state == -1) and self or self:clone()
     local instance = (self._state == -1) and self or self:clone()
     instance.isArray = isArray
     instance._isArray = isArray or self._isArray
     instance.state = 0
     instance._state = 0
     return instance
     return instance
end
end


---@return Enumerator
---@return Iterator
function Enumerator:clone()
function Iterator:clone()
     error('Abstract function must be overridden in derived class.')
     error('Abstract function must be overridden in derived class.')
end
end


function Enumerator:finalize()
function Iterator:finalize()
     -- Signals invalid state.
     -- Signals invalid _state.
     self.state = -4
     self._state = -4
end
end


-- Hooks the moveNext function into the Lua 'pairs' function
function Iterator.createIteratorSubclass()
local function overridePairs(enum, startIndex)
    local derivedClass = setmetatable({}, { __index = Iterator })
-- Get or create clean enumerator. This ensures the state is 0.
    derivedClass.__index = derivedClass
local new = enum:getEnumerator(startIndex == 0)
    derivedClass.__pairs  = Enumerable.getPairs
new.current = nil
    derivedClass.__ipairs = Enumerable.getiPairs
new.index = startIndex
    return derivedClass
local function iterator(t, k)
if new:moveNext() == true then
return new.index, new.current
end
return nil, nil
end
 
return iterator, new, new.index
end
end


-- Manual override for iterating over the Enumerator using pairs()
--Projects each element of a sequence into a new form.
function Enumerator:getPairs()
---@class MapIterator : Iterator
    return overridePairs(self, nil)
---@field _source Enumerable
end
---@field _selector fun(value: any, index: any): any
 
---@field _position integer
-- Manual override for iterating over the Enumerator using ipairs()
---@field _enumerator Enumerator
function Enumerator:getiPairs()
local MapIterator = Iterator.createIteratorSubclass()
    return overridePairs(self, 0)
end
 
-- TABLE ENUMERATOR CLASS --
-- This is essentially a wrapper for the table object,
-- since it provides no state machine esque iteration out of the box
---@class TableEnumerator : Enumerator
---@field state integer
---@field tbl table
local TableEnumerator = setmetatable({}, { __index = Enumerator })
TableEnumerator.__index = TableEnumerator
TableEnumerator.__pairs =  Enumerator_mt.__pairs
TableEnumerator.__ipairs = Enumerator_mt.__ipairs
 
function TableEnumerator.new(tbl)
    local self = setmetatable(Enumerator.new(), TableEnumerator)
    self.tbl = tbl or {} -- Allow creation of empty enumerable


---@param source Enumerable
---@param selector fun(value: any, index: any): any
function MapIterator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Iterator.new(), MapIterator)
    self._source = source
    self._selector = selector
    self._enumerator = nil
    self._position = 0
     return self
     return self
end
end


function TableEnumerator:moveNext()
function MapIterator:moveNext()
     if self.state == 0 then
     local state = self._state
         self.state = 1
    if state ~= 0 then
self.index = self.isArray and 0 or nil
        if state ~= 1 then
            return false
        end
    else
         self._state = 1 state = 1
    self._position = 0
self._enumerator = self._source:getEnumerator(self._isArray)
     end
     end


     if self.isArray == true then
     local _enumerator = self._enumerator
         -- Iterate using ipairs, starting from index 1
    if _enumerator:moveNext() == true then
         self.index = self.index + 1
         -- 1 based index because Lua
         self.current = self.tbl[self.index]
         local pos = self._position + 1
         return self.current ~= nil
         local index = _enumerator.index
    else
         local current = self._selector(_enumerator.current, pos)
        -- Iterate using pairs
         assert(current, 'Selected value must be non-nil')
        local key = self.index
         self.index = index
         key = next(self.tbl, key)
         self.current = current
         self.index = key
        self._position = pos
         if key ~= nil then
        return true
            self.current = self.tbl[key]
            return true
        end
     end
     end
    self:finalize()


    return false
end
function TableEnumerator:clone()
    return TableEnumerator.new(self.tbl)
end
-- SELECT ENUMERATOR --
---@class MapEnumerator : Enumerator
---@field state integer
---@field source Enumerator
---@field selector Enumerator
local MapEnumerator = setmetatable({}, { __index = Enumerator })
MapEnumerator.__index = MapEnumerator
MapEnumerator.__pairs = Enumerator_mt.__pairs
MapEnumerator.__ipairs = Enumerator_mt.__ipairs
function MapEnumerator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Enumerator.new(), MapEnumerator)
    self.source = source
    self.selector = selector
    self.enumerator = nil
    self.position = 0
    return self
end
function MapEnumerator:moveNext()
if self.state == 0 then
self.state = 1
self.position = 0
self.enumerator = self.source:getEnumerator(self.isArray)
end
if self.state == 1 then
local enumerator = self.enumerator
if enumerator:moveNext() == true then
local sourceElement = enumerator.current
self.index = enumerator.index -- Preserve index
self.position = self.position + 1
self.current = self.selector(sourceElement, self.position)
assert(self.current, 'Selected value must be non-nil')
return true
        end
        self:finalize()
end
return false
return false
end
end


function MapEnumerator:finalize()
function MapIterator:finalize()
     if self.enumerator then
     if self._enumerator then
         self.enumerator:finalize()
         self._enumerator:finalize()
     end
     end
     Enumerator.finalize(self)
     Iterator.finalize(self)
end
end


function MapEnumerator:clone()
function MapIterator:clone()
     return MapEnumerator.new(self.source, self.selector)
     return MapIterator.new(self._source, self._selector)
end
end


-- WHERE ENUMERATOR --
--Filters a sequence of values based on a predicate.
---@class WhereEnumerator : Enumerator
---@class WhereIterator : Iterator
---@field source Enumerator
---@field _source Enumerable
---@field predicate function
---@field _predicate fun(value: any, index: any): boolean
local WhereEnumerator = setmetatable({}, { __index = Enumerator })
---@field _enumerator Enumerator
WhereEnumerator.__index = WhereEnumerator
local WhereIterator = Iterator.createIteratorSubclass()
WhereEnumerator.__pairs = Enumerator_mt.__pairs
WhereEnumerator.__ipairs = Enumerator_mt.__ipairs


function WhereEnumerator.new(source, predicate)
---@param source Enumerable
---@param predicate fun(value: any, index: any): boolean
function WhereIterator.new(source, predicate)
assert(source, 'Source cannot be nil')
assert(source, 'Source cannot be nil')
assert(predicate, 'Predicate cannot be nil')
assert(predicate, 'Predicate cannot be nil')
     local self = setmetatable(Enumerator.new(), WhereEnumerator)
     local self = setmetatable(Iterator.new(), WhereIterator)
     self.source = source
     self._source = source
     self.predicate = predicate
     self._predicate = predicate
     self.enumerator = nil
     self._enumerator = nil
     return self
     return self
end
end


function WhereEnumerator:moveNext()
function WhereIterator:moveNext()
if self.state == 0 then
    local state = self._state
self.state = 1
    if state ~= 0 then
self.position = 0
        if state ~= 1 then
self.enumerator = self.source:getEnumerator(self.isArray)
            return false
end
        end
    else
self._state = 1 state = 1
self._enumerator = self._source:getEnumerator(self._isArray)
    end


if self.state == 1 then
    -- Cache some class lookups since these are likely
local enumerator = self.enumerator
    -- to be accessed more than once per moveNext
while enumerator:moveNext() == true do
    local enum = self._enumerator
local sourceElement = enumerator.current
    local nextFunc = enum.moveNext
            local sourceIndex = enumerator.index
    local predicate = self._predicate
if self.predicate(sourceElement, sourceIndex) == true then
    while nextFunc(enum) == true do
self.index = sourceIndex
        local sourceElement = enum.current
self.current = sourceElement
        local sourceIndex = enum.index
return true
        if predicate(sourceElement, sourceIndex) == true then
end
            self.index = sourceIndex
end
            self.current = sourceElement
 
            return true
        self:finalize()
        end
end
    end
    self:finalize()


return false
return false
end
end


function WhereEnumerator:finalize()
function WhereIterator:finalize()
     if self.enumerator then
     if self._enumerator then
         self.enumerator:finalize()
         self._enumerator:finalize()
     end
     end
     Enumerator.finalize(self)
     Iterator.finalize(self)
end
end


function WhereEnumerator:clone()
function WhereIterator:clone()
     return WhereEnumerator.new(self.source, self.predicate)
     return WhereIterator.new(self._source, self._predicate)
end
end


-- FLATMAP (SELECTMANY) ENUMERATOR --
--Projects each element of a sequence to an Enumerable and flattens the resulting sequences into one sequence.
---@class FlatMapEnumerator : Enumerator
---@class FlatMapIterator : Iterator
---@field source Enumerator
---@field _source Enumerable
---@field selector function
---@field _selector fun(value: any, index: integer): any
---@field position integer
---@field _position integer
local FlatMapEnumerator = setmetatable({}, { __index = Enumerator })
---@field _enumerator Enumerator
FlatMapEnumerator.__index = FlatMapEnumerator
---@field _sourceEnumerator Enumerator
FlatMapEnumerator.__pairs = Enumerator_mt.__pairs
local FlatMapIterator = Iterator.createIteratorSubclass()
FlatMapEnumerator.__ipairs = Enumerator_mt.__ipairs


function FlatMapEnumerator.new(source, selector)
---@param source Enumerable
---@param selector fun(value: any, index: integer): any
function FlatMapIterator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
assert(selector, 'Selector cannot be nil')
     local self = setmetatable(Enumerator.new(), FlatMapEnumerator)
     local self = setmetatable(Iterator.new(), FlatMapIterator)
     self.source = source
     self._source = source
     self.selector = selector
     self._selector = selector
     self.position = 0
     self._position = 0
     self.enumerator = nil -- Enumerator of the source Enumerable
     self._enumerator = nil -- Iterator of the _source Enumerable
     self.sourceEnumerator = nil  -- Enumerator of the nested Enumerable
     self._sourceEnumerator = nil  -- Iterator of the nested Enumerable
     return self
     return self
end
end


function FlatMapEnumerator:moveNext()
function FlatMapIterator:moveNext()
     if self.state == -4 then
     local state = self._state
    if state == -4 then
         return false
         return false
     end
     end


     -- Setup state
     -- Setup _state
     if self.state == 0 then
     if state == 0 then
         self.position = 0
         self._position = 0
         self.enumerator = self.source:getEnumerator(self.isArray)
         self._enumerator = self._source:getEnumerator(self._isArray)
         self.state = 3 -- signal to get (first) nested enumerator
         self._state = 3 state = 3 -- signal to get (_first) nested _enumerator
     end
     end
     while true do
     while true do
         -- Grab next value from nested enumerator
         -- Grab next value from nested _enumerator
         if self.state == 4 then
         if state == 4 then
             if self.sourceEnumerator:moveNext() then
             local sourceEnum = self._sourceEnumerator
                 self.current = self.sourceEnumerator.current
            if sourceEnum:moveNext() then
                 self.index = self.sourceEnumerator.index
                 self.current = sourceEnum.current
                 self.state = 4 -- signal to get next item
                 self.index = sourceEnum.index
                 self._state, state = 4, 4 -- signal to get next item
                 return true
                 return true
             else
             else
                 -- Cleanup nested enumerator
                 -- Cleanup nested _enumerator
                 self.sourceEnumerator:finalize()
                 self._sourceEnumerator:finalize()
                 self.state = 3 -- signal to get next enumerator
                 self._state = 3 state = 3 -- signal to get next _enumerator
             end
             end
         end
         end


-- Grab nest nested enumerator
-- Grab nest nested _enumerator
         if self.state == 3 then
         if state == 3 then
             if self.enumerator:moveNext() then
             if self._enumerator:moveNext() then
                 local current = self.enumerator.current
                 local current = self._enumerator.current
                 self.position = self.position + 1
                 local pos = self._position + 1


                 local sourceTable = self.selector(current, self.position)
                 local sourceTable = self._selector(current, pos)
                 -- Nested tables are never treated as arrays.
                 -- Nested tables are never treated as arrays.
                 self.sourceEnumerator = getTableEnumerator(sourceTable, false)
                 self._sourceEnumerator = getTableEnumerator(sourceTable, false)
                 self.state = 4 -- signal to get next item
                 self._position = pos
                self._state = 4 state = 4 -- signal to get next item
             else
             else
             -- enumerator doesn't have any more nested enumerators.
             -- _enumerator doesn't have any more nested enumerators.
                 self:finalize()
                 self:finalize()
                 return false
                 return false
Line 316: Line 311:
end
end


function FlatMapEnumerator:finalize()
function FlatMapIterator:finalize()
     if self.enumerator then
     if self._enumerator then
         self.enumerator:finalize()
         self._enumerator:finalize()
     end
     end
     if self.sourceEnumerator then
     if self._sourceEnumerator then
         self.sourceEnumerator:finalize()
         self._sourceEnumerator:finalize()
     end
     end


     Enumerator.finalize(self)
     Iterator.finalize(self)
end
end


function FlatMapEnumerator:clone()
function FlatMapIterator:clone()
return FlatMapEnumerator.new(self.source, self.selector)
return FlatMapIterator.new(self._source, self._selector)
end
end


-- CONCAT ENUMERATOR --
--Concatenates two sequences.
---@class ConcatEnumerator : Enumerator
---@class ConcatIterator : Iterator
---@field first Enumerator
---@field _first Enumerable
---@field second Enumerator
---@field _second any
local ConcatEnumerator = setmetatable({}, { __index = Enumerator })
---@field _enumerator Enumerator
ConcatEnumerator.__index = ConcatEnumerator
local ConcatIterator = Iterator.createIteratorSubclass()
ConcatEnumerator.__pairs = Enumerator_mt.__pairs
ConcatEnumerator.__ipairs = Enumerator_mt.__ipairs


function ConcatEnumerator.new(first, second)
---@param first Enumerable
---@param second any
function ConcatIterator.new(first, second)
assert(first, 'First cannot be nil')
assert(first, 'First cannot be nil')
assert(second, 'Second cannot be nil')
assert(second, 'Second cannot be nil')
     local self = setmetatable(Enumerator.new(), ConcatEnumerator)
     local self = setmetatable(Iterator.new(), ConcatIterator)
     self.first = first
     self._first = first
     self.second = second
     self._second = second
     self.enumerator = nil
     self._enumerator = nil
     return self
     return self
end
end
Line 352: Line 347:
-- Function to grab enumerators in order. We currently only have two so this
-- Function to grab enumerators in order. We currently only have two so this
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
function ConcatEnumerator:getEnumerable(index)
function ConcatIterator:getEnumerable(index)
if index == 0 then
if index == 0 then
return self.first
return self._first
elseif index == 1 then
elseif index == 1 then
return self.second
return self._second
else
else
return nil
return nil
Line 362: Line 357:
end
end


function ConcatEnumerator:moveNext()
function ConcatIterator:moveNext()
     if self.state == -4 then
     local state = self._state
    if state == -4 then
         return false
         return false
     end
     end


if self.state == 0 then
if state == 0 then
self.enumerator = self:getEnumerable(self.state)
self._enumerator = self:getEnumerable(state)
:getEnumerator(self.isArray)
:getEnumerator(self._isArray)
self.state = 1
self._state = 1 state = 1
end
end


if self.state > 0 then
if state > 0 then
while true do
while true do
if self.enumerator:moveNext() == true then
            local enum = self._enumerator
self.index = self.enumerator.index
if enum:moveNext() == true then
self.current = self.enumerator.current
self.index = enum.index
self.current = enum.current
return true
return true
end
end


self.state = self.state + 1
            state = state + 1
local next = self:getEnumerable(self.state - 1)
            self._state = state
local next = self:getEnumerable(state - 1)
if next ~= nil then
if next ~= nil then
                 -- Cleanup previous enumerator
                 -- Cleanup previous _enumerator
                 self.enumerator:finalize()
                 self._enumerator:finalize()
self.enumerator = getTableEnumerator(next, self.isArray)
self._enumerator = getTableEnumerator(next, self._isArray)
else
else
                 self:finalize()
                 self:finalize()
Line 397: Line 395:
end
end


function ConcatEnumerator:finalize()
function ConcatIterator:finalize()
     if self.enumerator then
     if self._enumerator then
         self.enumerator:finalize()
         self._enumerator:finalize()
     end
     end


     Enumerator.finalize(self)
     Iterator.finalize(self)
end
end


function ConcatEnumerator:clone()
function ConcatIterator:clone()
     return ConcatEnumerator.new(self.first, self.second)
     return ConcatIterator.new(self._first, self._second)
end
end


-- APPEND ENUMERATOR --
--Appends a value to the end or start of the sequence.
---@class AppendEnumerator : Enumerator
---@class AppendIterator : Iterator
---@field source Enumerator
---@field _source Enumerable
---@field item any
---@field _item any
---@field itemIndex any
---@field _itemIndex any
---@field append boolean
---@field _append boolean
---@field enumerator Enumerator
---@field _enumerator Enumerator
local AppendEnumerator = setmetatable({}, { __index = Enumerator })
local AppendIterator = Iterator.createIteratorSubclass()
AppendEnumerator.__index = AppendEnumerator
AppendEnumerator.__pairs = Enumerator_mt.__pairs
AppendEnumerator.__ipairs = Enumerator_mt.__ipairs


function AppendEnumerator.new(source, item, itemIndex, append)
---@param source Enumerable
---@param item any
---@param itemIndex? any
---@param append? boolean
function AppendIterator.new(source, item, itemIndex, append)
assert(source, 'Source cannot be nil')
assert(source, 'Source cannot be nil')
assert(item, 'Item cannot be nil')
assert(item, 'Item cannot be nil')
     local self = setmetatable(Enumerator.new(), AppendEnumerator)
     local self = setmetatable(Iterator.new(), AppendIterator)
     self.source = source
     self._source = source
     self.item = item
     self._item = item
     -- Index needs *some* value, otherwise iteration stops.
     -- Index needs *some* value, otherwise iteration stops.
     if itemIndex == nil then self.itemIndex = item else self.itemIndex = itemIndex end
     if itemIndex == nil then self._itemIndex = item else self._itemIndex = itemIndex end
     if append == nil then self.append = true else self.append = append end
     if append == nil then self._append = true else self._append = append end
     self.enumerator = nil
     self._enumerator = nil
     return self
     return self
end
end


function AppendEnumerator:moveNext()
function AppendIterator:moveNext()
     if self.state == 0 then
     local state = self._state
         self.state = 1
    if state == 0 then
         self._state = 1 state = 1
         -- Put the item in front, as a prepend.
         -- Put the item in front, as a prepend.
         if self.append == false then
         if self._append == false then
             self.current = self.item
             self.current = self._item
             self.index = self.itemIndex
             self.index = self._itemIndex
             return true
             return true
         end
         end
     end
     end


     -- Grab the source enumerator
     -- Grab the _source _enumerator
     if self.state == 1 then
     if state == 1 then
         self.enumerator = self.source:getEnumerator(self.isArray)
         self._enumerator = self._source:getEnumerator(self._isArray)
         self.state = 2
         self._state = 2 state = 2
     end
     end


     if self.state == 2 then
     if state == 2 then
         if self.enumerator:moveNext() then
         local enum = self._enumerator
             self.current = self.enumerator.current
        if enum:moveNext() then
             self.index = self.enumerator.index
             self.current = enum.current
             self.index = enum.index
             return true
             return true
         else
         else
Line 460: Line 461:
         end
         end


         if self.append == true then
         if self._append == true then
             self.current = self.item
             self.current = self._item
             self.index = self.itemIndex
             self.index = self._itemIndex
             return true
             return true
         end
         end
Line 470: Line 471:
end
end


function AppendEnumerator:finalize()
function AppendIterator:finalize()
     if self.enumerator then
     if self._enumerator then
         self.enumerator:finalize()
         self._enumerator:finalize()
     end
     end


     Enumerator.finalize(self)
     Iterator.finalize(self)
end
end


function AppendEnumerator:clone()
function AppendIterator:clone()
     return AppendEnumerator.new(self.source, self.item, self.itemIndex, self.append)
     return AppendIterator.new(self._source, self._item, self._itemIndex, self._append)
end
end


-- UNIQUE (DISTINCT) ENUMERATOR --
--Returns unique (distinct) elements from a sequence according to a specified key selector function.
---@class UniqueEnumerator : Enumerator
---@class UniqueIterator : Iterator
---@field source Enumerator
---@field _source Enumerable
---@field selector function
---@field _keySelector fun(value: any, index: any): any
---@field enumerator Enumerator
---@field _enumerator Enumerator
---@field set table
---@field _set table
local UniqueEnumerator = setmetatable({}, { __index = Enumerator })
local UniqueIterator = Iterator.createIteratorSubclass()
UniqueEnumerator.__index = UniqueEnumerator
UniqueEnumerator.__pairs = Enumerator_mt.__pairs
UniqueEnumerator.__ipairs = Enumerator_mt.__ipairs


function UniqueEnumerator.new(source, selector)
---@param source Enumerable
---@param keySelector? fun(value: any, index: any): any
function UniqueIterator.new(source, keySelector)
assert(source, 'Source cannot be nil')
assert(source, 'Source cannot be nil')
     local self = setmetatable(Enumerator.new(), UniqueEnumerator)
     local self = setmetatable(Iterator.new(), UniqueIterator)
     self.source = source
     self._source = source
     self.selector = selector
     self._keySelector = keySelector or function(x) return x end
     self.enumerator = nil
     self._enumerator = nil
     self.set = nil
     self._set = nil
     return self
     return self
end
end


function UniqueEnumerator:moveNext()
function UniqueIterator:moveNext()
     if self.state == -4 then
     local state = self._state
    if state == -4 then
         return false
         return false
     end
     end


     if self.state == 0 then
     if state == 0 then
         self.enumerator = self.source:getEnumerator(self.isArray)
         self._enumerator = self._source:getEnumerator(self._isArray)
         -- If we have any items, create a hashtable. Otherwise abort.
         -- If we have any items, create a hashtable. Otherwise abort.
         if self.enumerator:moveNext() == true then
         if self._enumerator:moveNext() == true then
             self.set = {}
             self._set = {}
             self.state = 1
             self._state = 1 state = 1
         else
         else
             self:finalize()
             self:finalize()
Line 520: Line 521:
     end
     end


    local enum = self._enumerator
    local set = self._set
     while true do
     while true do
         if self.state == 1 then
         if state == 1 then
             self.state = 2
             self._state = 2 state = 2
             local current = self.enumerator.current
             local current = enum.current
             local index = self.enumerator.index
             local index = enum.index
             -- Manipulate the item if we have a selector (DistinctBy)
             local key = self._keySelector(current, index)
            if self.selector then
             if setAdd(key) == true then
                current = self.selector(current, index)
                 self.current = current
            end
                 self.index = index
 
             if addToSet(self.set, current) then
                 self.current = self.enumerator.current
                 self.index = self.enumerator.index
                 return true
                 return true
             end
             end
Line 538: Line 537:


         -- Try to grab a new item.
         -- Try to grab a new item.
         if self.state == 2 then
         if state == 2 then
             if self.enumerator:moveNext() == true then
             if enum:moveNext() == true then
                 self.state = 1
                 self._state = 1 state = 1
             else
             else
                 self:finalize()
                 self:finalize()
Line 549: Line 548:
end
end


function UniqueEnumerator:finalize()
function UniqueIterator:finalize()
     if self.enumerator then
     if self._enumerator then
         self.enumerator:finalize()
         self._enumerator:finalize()
     end
     end


     Enumerator.finalize(self)
     Iterator.finalize(self)
end
end


function UniqueEnumerator:clone()
function UniqueIterator:clone()
     return UniqueEnumerator.new(self.source, self.selector)
     return UniqueIterator.new(self._source, self._keySelector)
end
end


-- EXCEPT ENUMERATOR --
--Produces the set difference of two sequences according to a specified key selector function.
---@class DifferenceEnumerator : Enumerator
---@class DifferenceIterator : Iterator
---@field first Enumerator
---@field _first Enumerable
---@field second any
---@field _second any
---@field selector function
---@field _keySelector fun(value: any, index: any): any
---@field enumerator Enumerator
---@field _enumerator Enumerator
---@field set table
---@field _set table
local DifferenceEnumerator = setmetatable({}, { __index = Enumerator })
local DifferenceIterator = Iterator.createIteratorSubclass()
DifferenceEnumerator.__index = DifferenceEnumerator
DifferenceEnumerator.__pairs = Enumerator_mt.__pairs
DifferenceEnumerator.__ipairs = Enumerator_mt.__ipairs


function DifferenceEnumerator.new(first, second, selector)
---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function DifferenceIterator.new(first, second, keySelector)
assert(first, 'First table cannot be nil')
assert(first, 'First table cannot be nil')
     assert(second, 'Second table cannot be nil')
     assert(second, 'Second table cannot be nil')
     local self = setmetatable(Enumerator.new(), DifferenceEnumerator)
     local self = setmetatable(Iterator.new(), DifferenceIterator)
     self.first = first
     self._first = first
     self.second = second
     self._second = second
     self.selector = selector
     self._keySelector = keySelector or function(x) return x end
     self.enumerator = nil
     self._enumerator = nil
     self.set = nil
     self._set = nil
     return self
     return self
end
end


local function createSet(other)
function DifferenceIterator:moveNext()
     local set = {}
     local state = self._state
     if isType(other, Enumerator) then
     if state ~= 0 then
         for _, v in other:getPairs() do
         if state ~= 1 then
             set[v] = true
             return false
         end
         end
     else
     else
         assert(type(other) == 'table', 'Source must be a table.')
         self._enumerator = self._first:getEnumerator(self._isArray)
         for _, v in pairs(other) do
        self._set = sourceToSet(self._second)
             set[v] = true
        self._state = 1 state = 1
    end
 
    local enum = self._enumerator
    local nextFunc = enum.moveNext
    local keySelector = self._keySelector
    local set = self._set
    while nextFunc(enum) == true do
        local current = enum.current
         local index = enum.index
        local key = keySelector(current, index)
        if setAdd(key) == true then
            self.current = current
             self.index = index
            return true
         end
         end
     end
     end


     return set
    self:finalize()
     return false
end
 
function DifferenceIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
end
 
function DifferenceIterator:clone()
    return DifferenceIterator.new(self._first, self._second, self._keySelector)
end
 
--Produces the set union of two sequences according to a specified key selector function.
---@class UnionIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local UnionIterator = Iterator.createIteratorSubclass()
 
---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function UnionIterator.new(first, second, keySelector)
assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), UnionIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end
 
local function enumerateSource(self, _state)
    while self._enumerator:moveNext() do
        local current = self._enumerator.current
        local index = self._enumerator.index
        if self._set:add(self._keySelector(current, index)) == true then
            self.current = self._enumerator.current
            self.index = self._enumerator.index
            self._state = _state
            return true
        end
    end
end
end


function DifferenceEnumerator:moveNext()
function UnionIterator:moveNext()
     if self.state == -4 then
     if self._state == -4 then
         return false
         return false
     end
     end


     if self.state == 0 then
     if self._state == 0 then
         self.enumerator = self.first:getEnumerator(self.isArray)
         self._enumerator = self._first:getEnumerator(self._isArray)
         self.set = createSet(self.second)
         self._set = {}
         self.state = 1
         self._state = 1
     end
     end


     while self.enumerator:moveNext() do
     -- Process _first
        local current = self.enumerator.current
    if self._state == 1 then
        local index = self.enumerator.index
         if enumerateSource(self, 1) == true then
         if self.selector ~= nil then  
             return true
             current = self.selector(current, index)
         end
         end
         if addToSet(self.set, current) then
 
            self.current = self.enumerator.current
        -- End _first enumeration
            self.index = self.enumerator.index
        self._state = 2
            self.state = 1
         self._enumerator:finalize()
        self._enumerator = getTableEnumerator(self._second, self._isArray)
    end
 
    -- Process _second
    if self._state == 2 then
        if enumerateSource(self, 2) == true then
             return true
             return true
         end
         end
        -- End _first enumeration
        self._state = 3
     end
     end


Line 630: Line 699:
end
end


function DifferenceEnumerator:finalize()
function UnionIterator:finalize()
     if self.enumerator then
     if self._enumerator then
         self.enumerator:finalize()
         self._enumerator:finalize()
     end
     end
    Iterator.finalize(self)
end
end


function DifferenceEnumerator:clone()
function UnionIterator:clone()
     return DifferenceEnumerator.new(self.first, self.second, self.selector)
     return UnionIterator.new(self._first, self._second, self._keySelector)
end
end


-- UNION ENUMERATOR --
--Produces the set intersection of two sequences according to a specified key selector function.
---@class UnionEnumerator : Enumerator
---@class IntersectIterator : Iterator
---@field first Enumerator
---@field _first Enumerable
---@field second any
---@field _second any
---@field selector function
---@field _keySelector fun(value: any, index: any): any
---@field enumerator Enumerator
---@field _enumerator Enumerator
---@field set table
---@field _set table
local UnionEnumerator = setmetatable({}, { __index = Enumerator })
local IntersectIterator = Iterator.createIteratorSubclass()
UnionEnumerator.__index = UnionEnumerator
UnionEnumerator.__pairs = Enumerator_mt.__pairs
UnionEnumerator.__ipairs = Enumerator_mt.__ipairs


function UnionEnumerator.new(first, second, selector)
---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function IntersectIterator.new(first, second, keySelector)
assert(first, 'First table cannot be nil')
assert(first, 'First table cannot be nil')
     assert(second, 'Second table cannot be nil')
     assert(second, 'Second table cannot be nil')
     local self = setmetatable(Enumerator.new(), UnionEnumerator)
     local self = setmetatable(Iterator.new(), IntersectIterator)
     self.first = first
     self._first = first
     self.second = second
     self._second = second
     self.selector = selector
     self._keySelector = keySelector or function(x) return x end
     self.enumerator = nil
     self._enumerator = nil
     self.set = nil
     self._set = nil
     return self
     return self
end
end


local function enumerateSource(self, state)
function IntersectIterator:moveNext()
    while self.enumerator:moveNext() do
    if self._state ~= 0 then
        local current = self.enumerator.current
         if self._state ~= 1 then
        local index = self.enumerator.index
             return false
         if self.selector ~= nil then
             current = self.selector(current, index)
         end
         end
         if addToSet(self.set, current) == true then
    else
             self.current = self.enumerator.current
        self._enumerator = self._first:getEnumerator(self._isArray)
             self.index = self.enumerator.index
        self._set = sourceToSet(self._second)
             self.state = state
        self._state = 1
    end
    while self._enumerator:moveNext() do
        local current = self._enumerator.current
        local index = self._enumerator.index
        local key = self._keySelector(current, index)
         if setRemove(self._set, key) == true then
             self.current = self._enumerator.current
             self.index = self._enumerator.index
             self._state = 1
             return true
             return true
         end
         end
     end
     end
    self:finalize()
    return false
end
function IntersectIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    Iterator.finalize(self)
end
end


function UnionEnumerator:moveNext()
function IntersectIterator:clone()
     if self.state == -4 then
    return IntersectIterator.new(self._first, self._second, self._keySelector)
         return false
end
 
--Applies a specified function to the corresponding elements of two sequences, producing a sequence of the results.
---@class ZipIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _resultSelector fun(left: any, right: any): any
---@field _enumerator1 Enumerator
---@field _enumerator2 Enumerator
local ZipIterator = Iterator.createIteratorSubclass()
 
---@param first Enumerable
---@param second any
---@param resultSelector? fun(left: any, right: any): any
function ZipIterator.new(first, second, resultSelector)
assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), ZipIterator)
    self._first = first
    self._second = second
    self._resultSelector = resultSelector or (function(a, b) return {a, b} end)
    self._enumerator1 = nil
    self._enumerator2 = nil
    return self
end
 
function ZipIterator:moveNext()
     if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._enumerator1 = self._first:getEnumerator(self._isArray)
        self._enumerator2 = getTableEnumerator(self._second, self._isArray)
        self._state = -4
    end
    if self._enumerator1:moveNext() == true and self._enumerator2:moveNext() == true then
        self.current = self._resultSelector(self._enumerator1.current, self._enumerator2.current)
        self.index = self._enumerator1.index
        self._state = 1
         return true
     end
     end


     if self.state == 0 then
    self:finalize()
         self.enumerator = self.first:getEnumerator(self.isArray)
    return false
        self.set = {}
end
         self.state = 1
 
function ZipIterator:finalize()
     if self._enumerator1 then
         self._enumerator1:finalize()
    end
    if self._enumerator2 then
         self._enumerator2:finalize()
     end
     end


     -- Process first
     Iterator.finalize(self)
     if self.state == 1 then
end
         if enumerateSource(self, 1) == true then
 
             return true
function ZipIterator:clone()
    return ZipIterator.new(self._first, self._second, self._resultSelector)
end
 
--Groups the elements of a sequence.
---@class GroupByIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
function GroupByIterator.new(source, keySelector, elementSelector)
assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    local self = setmetatable(Iterator.new(), GroupByIterator)
    self._source = source
    self._keySelector = keySelector
    self._elementSelector = elementSelector or function(x) return x end
    self._enumerator = nil
    self._lookup = nil
    return self
end
 
function GroupByIterator:moveNext()
     if self._state ~= 0 then
         if self._state ~= 1 then
             return false
         end
         end
    else
        self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
        self._enumerator = self._lookup:getEnumerator()
        self._state = 1
    end


        -- End first enumeration
    local enum = self._enumerator
        self.state = 2
    if enum:moveNext() == true then
         self.enumerator:finalize()
         self.current = enum.current
         self.enumerator = getTableEnumerator(self.second, self.isArray)
         self.index = enum.index
        self._state = 1
        return true
     end
     end


     -- Process second
     self:finalize()
     if self.state == 2 then
    return false
         if enumerateSource(self, 2) == true then
end
             return true
 
function GroupByIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
 
    Iterator.finalize(self)
end
 
function GroupByIterator:clone()
    return GroupByIterator.new(self._source, self._keySelector, self._elementSelector)
end
 
--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class GroupByResultIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _resultSelector fun(key: any, grouping: Grouping): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByResultIterator = Iterator.createIteratorSubclass()
 
---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
---@param resultSelector fun(key: any, grouping: Grouping): any
function GroupByResultIterator.new(source, keySelector, elementSelector, resultSelector)
assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    assert(resultSelector, 'keySelector cannot be nil')
    local self = setmetatable(Iterator.new(), GroupByResultIterator)
    self._source = source
    self._keySelector = keySelector
    self._elementSelector = elementSelector or function(x) return x end
    self._resultSelector = resultSelector
    self._enumerator = nil
    self._lookup = nil
    return self
end
 
function GroupByResultIterator:moveNext()
     if self._state ~= 0 then
         if self._state ~= 1 then
             return false
         end
         end
    else
        self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
        self._enumerator = self._lookup:getEnumerator()
        self._state = 1
    end


         -- End first enumeration
    local enumerator = self._enumerator
         self.state = 3
    if enumerator:moveNext() == true then
         -- Transform Grouping element for enumeration
        local gKey = enumerator.current.key
        local gElements = enumerator.current
         self.current = self._resultSelector(gKey, gElements)
        self.index = enumerator.index
        self._state = 1
        return true
     end
     end


Line 717: Line 946:
end
end


function UnionEnumerator:finalize()
function GroupByResultIterator:finalize()
     if self.enumerator then
     if self._enumerator then
         self.enumerator:finalize()
         self._enumerator:finalize()
     end
     end


     Enumerator.finalize(self)
     Iterator.finalize(self)
end
end


function UnionEnumerator:clone()
function GroupByResultIterator:clone()
     return UnionEnumerator.new(self.first, self.second, self.selector)
     return GroupByResultIterator.new(self._source, self._keySelector, self._elementSelector, self._resultSelector)
end
end


--TODO:
-- INTERSECT
-- GROUPBY
-- ZIP


--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class SortableIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(value: any, index: any): any|?
---@field _descendingSort boolean
---@field _parentIterator SortableIterator
---@field _enumerator Enumerator
local SortableIterator = Iterator.createIteratorSubclass()
---@param source Enumerable
---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@param parent? SortableIterator
---@return SortableIterator
function SortableIterator.new(source, keySelector, descendingSort, parent)
assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    assert(parent == nil or parent.getEnumerableSorters ~= nil, 'parent must be of type SortableIterator')
    if descendingSort == nil then
        descendingSort = false
    end


-- ORDER/SORTING???
    local self = setmetatable(Iterator.new(), SortableIterator)
    self._source = source
    self._keySelector = keySelector
    self._descendingSort = descendingSort
    self._parentIterator = parent
    self._enumerator = nil
---@diagnostic disable-next-line: return-type-mismatch
    return self
end


--Optional:
local function createSortFunc(sorters)
---Take
    local sortCount = #sorters
---Skip
    if sortCount > 1 then
        return function(a, b)
            for _, sorter in ipairs(sorters) do
                local selector = sorter._keySelector
                local descendingSort = sorter._descendingSort


-- Define forward declared functions
                local aVal = selector(a)
isType = function(obj, class)
                local bVal = selector(b)
    local mt = getmetatable(obj)
                if aVal ~= bVal then
    while mt do
                    if descendingSort == true then
        if mt.__index == class then
                        return aVal > bVal
             return true
                    else
                        return aVal < bVal
                    end
                end
             end
         end
         end
         mt = getmetatable(mt)
    end
 
    -- Faster sort function if there's no nested sorts.
    local sorter = sorters[1]
    local selector = sorter._keySelector
    if sorter._descendingSort == true then
        return function(a, b) return selector(a) > selector(b) end
    else
         return function(a, b) return selector(a) < selector(b) end
     end
     end
end
end


getTableEnumerator = function(sourceTable, isArray)
function SortableIterator:moveNext()
     if not isType(sourceTable, Enumerator) then
     if self._state ~= 0 then
         return TableEnumerator.new(sourceTable):getEnumerator(isArray)
         if self._state ~= 1 then
            return false
        end
     else
     else
         return sourceTable
        -- Collect sorter meta information.
        local sorters = {}
        self:getEnumerableSorters(sorters)
        local sortFunc = createSortFunc(sorters)
 
        local enumerator = self._source:getEnumerator(self._isArray)
        local buffer = {}
        -- Process and sort all previous enumerators.
        while enumerator:moveNext() == true do
            table.insert(buffer, enumerator.current)
        end
        table.sort(buffer, sortFunc)
        self._enumerator = TableEnumerator.createForArray(buffer)
        self._state = 1
    end
 
    local enumerator = self._enumerator
    if enumerator:moveNext() == true then
        self.current = enumerator.current
        self.index = enumerator.current
        self._state = 1
         return true
     end
     end
    self:finalize()
    return false
end
end


addToSet = function(set, item)
---Collects all sorter functions down the line.
     if set[item] == nil then
---@param sorters table
         set[item] = true
---@return table
        return true
function SortableIterator:getEnumerableSorters(sorters)
    else
    -- Get parent sorter first.
        return false
     if self._parentIterator ~= nil then
         self._parentIterator:getEnumerableSorters(sorters)
     end
     end
    table.insert(sorters, self)
    return sorters
end
---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@return SortableIterator
function SortableIterator:createSortableIterator(keySelector, descendingSort)
    return SortableIterator.new(self._source, keySelector, descendingSort, self)
end
end


function SortableIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
        self._enumerator = nil
    end
    Iterator.finalize(self)
end


function SortableIterator:clone()
    return SortableIterator.new(self._source, self._keySelector, self._descendingSort, self._parentIterator)
end


return {
return {
     Enumerator = Enumerator,
     MapIterator = MapIterator,
     TableEnumerator = TableEnumerator,
     WhereIterator = WhereIterator,
     MapEnumerator = MapEnumerator,
     FlatMapIterator = FlatMapIterator,
     WhereEnumerator = WhereEnumerator,
     ConcatIterator = ConcatIterator,
     FlatMapEnumerator = FlatMapEnumerator,
     AppendIterator = AppendIterator,
     ConcatEnumerator = ConcatEnumerator,
     UniqueIterator = UniqueIterator,
     AppendEnumerator = AppendEnumerator,
     DifferenceIterator = DifferenceIterator,
     UniqueEnumerator = UniqueEnumerator,
     UnionIterator = UnionIterator,
     DifferenceEnumerator = DifferenceEnumerator,
     IntersectIterator = IntersectIterator,
     UnionEnumerator = UnionEnumerator
     ZipIterator = ZipIterator,
    GroupByIterator = GroupByIterator,
    GroupByResultIterator = GroupByResultIterator,
    SortableIterator = SortableIterator,
}
}
2,875

edits