Module:FunList/Sandbox: Difference between revisions

From Melvor Idle
No edit summary
(Add function to specify iterator used for iterating table)
 
(6 intermediate revisions by the same user not shown)
Line 1: Line 1:
local p = {}
--------------------------------------------------------------------------------
-- Helper functions for wrapping and unwrapping the context (state) object.
--------------------------------------------------------------------------------
-- We define the most common wrapContext functions, since at lower parameter count
-- manual passing of the additional parameters is faster than using table.unpack


local Debug = require('module:Debug')
-- Wraps a context object in a factory function with one additional parameter.
local FL = require('module:FunList/Enumerators')
---@param enumerator Enumerator3
---@return fun():table
local wrapContext0 = function(enumerator)
    local nextFunc, context = enumerator.nextFunc, enumerator.contextFactory
    return function() return { 0, nextFunc, context } end
end
 
-- Wraps a context object in a factory function with one additional parameter.
---@param enumerator Enumerator3
---@param param1 any
---@return fun():table
local wrapContext1 = function(enumerator, param1)
    local nextFunc, context = enumerator.nextFunc, enumerator.contextFactory
    return function() return { 0, nextFunc, context, param1 } end
end
 
-- Wraps a context object in a factory function with two additional parameters.
---@param enumerator Enumerator3
---@param param1 any
---@param param2 any
---@return fun():table
local wrapContext2 = function(enumerator, param1, param2)
    local nextFunc, context = enumerator.nextFunc, enumerator.contextFactory
    return function() return { 0, nextFunc, context, param1, param2 } end
end
 
-- Wraps a context object in a factory function with N additional parameters.
---@param enumerator Enumerator3
---@param ... unknown
---@return fun():table
local wrapContextN = function(enumerator, ...)
    local nextFunc, context = enumerator.nextFunc, enumerator.contextFactory
    local varargs = {...}
    return function() return { 0, nextFunc, context, table.unpack(varargs) } end
end


function p.test()
local ipairsFunc = ipairs({0})
local tbl = {}
local pairsFunc = pairs({'a'})


for i = 1, 10 do
local getIterator
    tbl[i] = string.char(string.byte('a') + i - 1) -- Calculate the character
 
-- Sets the state to the context, and returns the state
---@param context table
---@param state integer
---@return integer
local setState = function(context, state)
    context[1] = state
    return state
end
 
-- Replaces the context factory with the context returned from the factory.
---@param context table
---@return table?
local initialiseContext = function(context)
    local contextFactory = context[3]
    if contextFactory == nil then
        return nil
    else
        -- Call the factory method and replace the factory method with the
        -- unpacked context so its state is available for mutation.
        local newContext = contextFactory()
        context[3] = newContext
        return newContext
    end
end
 
-- Grabs the next item from a raw iterator and stores replaces the index with the next index.
-- This is used as a shortcut for advancing nested iterators / tables.
---@param iterator table
---@return any
---@return any
local moveIterator = function(iterator)
    local iteratorFunction, context, index = iterator[1], iterator[2], iterator[3]
    index, context = iteratorFunction(context, index)
    iterator[3] = index
    return index, context
end
 
--------------------------------------------------------------------------------
-- Helper functions for hashset behaviour
--------------------------------------------------------------------------------
---@param set table
---@param item any
---@return boolean
local addToSet = function(set, item)
    if set[item] == nil then
        set[item] = true
        return true
    end
    return false
end
 
---@param set table
---@param item any
---@return boolean
local removeFromSet = function(set, item)
    if set[item] == nil then
        return false
    end
    set[item] = nil
    return true
end
 
--------------------------------------------------------------------------------
-- Enumerator object definition.
--------------------------------------------------------------------------------
--- @class Enumerator3
--- @field nextFunc fun(context: any, index: any):any, any
--- @field index any
--- @field contextFactory fun(): table
local Enumerator3 = {
    __call = function(self, param, state)
        return self.nextFunc(self.contextFactory(), state)
    end;
    __pairs = function(self)
        return self.nextFunc, self.contextFactory(), self.index
    end;
    __ipairs = function(self)
        return self.nextFunc, self.contextFactory(), self.index
    end;
}
Enumerator3.__index = Enumerator3
 
---@param nextFunc fun(context: any, index: any):any, any
---@param index any
---@param contextFactory fun(): any
---@return Enumerator3
function Enumerator3.new(nextFunc, index, contextFactory)
    return setmetatable({
        nextFunc = nextFunc,
        index = index,
        contextFactory = contextFactory
    }, Enumerator3)
end
 
---@param obj any
---@return Enumerator3
function Enumerator3.create(obj)
    local func, current, index = getIterator(obj)
    return Enumerator3.new(func, index, function() return current end)
end
 
function Enumerator3.createWith(obj, iterator)
    local func, current, index = iterator(obj)
    return Enumerator3.new(func, index, function() return current end)
end
 
--------------------------------------------------------------------------------
-- Helper function for setting up iterator.
--------------------------------------------------------------------------------
getIterator = function(obj, context, index)
    assert(obj ~= nil, 'invalid iterator')
    if type(obj) == 'table' then
        local mt = getmetatable(obj);
        if mt ~= nil then
            if mt == Enumerator3 then
                return obj.nextFunc, obj.contextFactory(), obj.index
            elseif mt.__ipairs ~= nil then
                return mt.__ipairs(obj)
            elseif mt.__pairs ~= nil then
                return mt.__pairs(obj)
            end
        end
        if obj._isIterator == true then
            return obj[1], obj[2](), obj[3]
        end
        if #obj > 0 then
            return ipairs(obj)
        else
            return pairs(obj)
        end
    elseif (type(obj) == 'function') then
        -- This is likely a contextFactory. Unwrap it if it's the case.
        if type(context) == 'function' then context = context() end
        return obj, context, index
    end
 
    error(string.format('object %s of type "%s" is not iterable', obj, type(obj)))
end
 
--------------------------------------------------------------------------------
-- Chainable functions
--------------------------------------------------------------------------------
 
-- WHERE
local whereFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
 
    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        -- Unpack context
        nextContext = initialiseContext(context)
        state = setState(context, 1)
    end
 
    local predicate = context[4]
    local current
 
    while true do
        index, current = nextFunc(nextContext, index)
        if index == nil then
            setState(context, -4)
            return nil, nil
        end
        if predicate(current, index) == true then
            return index, current
        end
    end
end
 
---Filters a sequence of values based on a predicate.
---@param predicate fun(item: any, index: any): boolean
---@return Enumerator3
function Enumerator3:where(predicate)
    assert(predicate, 'Predicate cannot be nil')
    return Enumerator3.new(whereFunc, self.index, wrapContext1(self, predicate))
end
 
 
-- MAP / SELECT
local mapFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
 
    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        -- Initialise
        nextContext = initialiseContext(context)
        state = setState(context, 1)
    end
 
    local selector = context[4]
    local position = context[5] + 1
    local current
 
    index, current = nextFunc(nextContext, index)
    if index ~= nil then
        current = selector(current, position)
        assert(current, 'Selected value must be non-nil')
        context[5] = position
        return index, current
    end
 
    setState(context, -4)
end
 
---Projects each element of a sequence into a new form.
---@param selector fun(item: any, position: integer): any
---@return Enumerator3
function Enumerator3:map(selector)
    assert(selector, 'Selector cannot be nil')
    return Enumerator3.new(mapFunc, self.index, wrapContext2(self, selector, 0))
end
 
 
-- FLATMAP / SELECTMANY
local flatMapFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Predicate
    -- 5: Position
    -- 6: Parent Index
    -- 7: Nested Iterator
    local current
 
    if state == -4 then
        return nil, nil
    end
 
    -- Initialise
    if state == 0 then
        nextContext = initialiseContext(context)
        state = setState(context, 3) -- signal to get (first) nested iterator
        context[5] = 1 -- set position to 1
    end
 
    while true do
        -- Grab next value from nested enumerator
        if state == 4 then
            local childIterator = context[7]
            index, current = moveIterator(childIterator)
            if index ~= nil then
                return index, current
            else
                state = setState(context, 3)
            end
        end
 
        if state == 3 then
            local selector = context[4]
            local position = context[5]
            index = context[6]
            -- Get the next item from the iterator
            index, current = nextFunc(nextContext, index)
            if index ~= nil then
                -- Grab the child iterator via the selector function.
                local nestedObject = selector(current, position)
                assert(type(nestedObject) == 'table', 'Selected object must be a table.')
                context[5] = position + 1
                context[6] = index
                -- Setup the child iterator
                context[7] = { getIterator(nestedObject) }
                state = setState(context, 4)
            else
            -- iterator doesn't have any more nested iterators.
                state = setState(context, -4)
                return nil, nil
            end
        end
    end
end
 
---Projects each element of a sequence and flattens the resulting sequences into one sequence.
---@param selector fun(item: any, position: integer): table
---@return Enumerator3
function Enumerator3:flatMap(selector)
    assert(selector, 'Selector cannot be nil')
    return Enumerator3.new(flatMapFunc, self.index, wrapContext2(self, selector, 0))
end
 
 
-- CONCAT
local concatFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Second (collection/table)
 
    if state == -4 then
        return false
    end
 
if state == 0 then
        nextContext = { getIterator(nextFunc, nextContext, index) }
state = setState(context, 1)
end
end
tbl['TextOK'] = 999
 
    while true do
local enum = FL.TableEnumerator.new(tbl)
        local current
local enum2 = FL.TableEnumerator.new({1,2,3,4,5,6})
        index, current = moveIterator(nextContext)
        if index ~= nil then
local concatEnum = FL.ConcatEnumerator.new(enum, enum2)
            context[3] = nextContext
local selectEnum = FL.SelectEnumerator.new(concatEnum, function(x) return x .. x end)
            return index, current
        end
 
for k, v in ipairs(selectEnum) do
        -- Incremental state can be used to chain together infinite iterators instead of just 2.
Debug.log(v)
        -- Though, currently only two are supported.
        state = setState(context, state + 1)
        if state > 2 then
            setState(context, -4)
            return nil, nil
        end
        nextContext = { getIterator(context[4]) }
    end
end
 
---Concatenates two sequences.
---@param second table
---@return Enumerator3
function Enumerator3:concat(second)
assert(second, 'Second cannot be nil')
    return Enumerator3.new(concatFunc, self.index, wrapContext1(self, second))
end
 
 
-- APPEND
local appendFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Item
    -- 5: ItemIndex
    -- 6: Append / Prepend
 
    -- This state is used for returning the prepend item as well as
    -- initialising the iterator.
    if state == 0 then
        -- Iterator gets wrapped so the index is stored locally.
        nextContext = { getIterator(nextFunc, nextContext, index) }
        context[3] = nextContext
        state = setState(context, 1)
 
        local append = context[6]
        if append == false then
            local item, itemIndex = context[4], context[5]
            return itemIndex, item
        end
    end
 
    if state == 1 then
        local current
        index, current = moveIterator(nextContext)
        if index ~= nil then
            return index, current
        else
            state = setState(context, -4)
        end
 
        local item, itemIndex = context[4], context[5]
        local append = context[6]
        if append == true then
            return itemIndex, item
        end
    end
 
    return nil, nil
end
 
---Appends a value to the end of the sequence.
---@param item any
---@param index? any
---@return Enumerator3
function Enumerator3:append(item, index)
    if index == nil then index = 'nil' end -- Index needs *some* value
 
    return Enumerator3.new(appendFunc, self.index, wrapContextN(self, item, index, true))
end
 
---Prepends a value to the start of the sequence.
---@param item any
---@param index? any
---@return Enumerator3
function Enumerator3:prepend(item, index)
    if index == nil then index = 'nil' end -- Index needs *some* value
 
    return Enumerator3.new(appendFunc, self.index, wrapContextN(self, item, index, false))
end
 
 
-- UNIQUE / DISTINCT
local uniqueFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    local keySelector = context[4]
    -- 5: Hashset for uniqueness
    local current
 
    if state == -4 then
        return false
    end
 
    if state == 0 then
        nextContext = initialiseContext(context)
        -- If we have any items, create a hashtable. Otherwise abort.
        -- We check this by moving to the next item in the iterator and checking
        -- if it's not nil. We then set the state to 1 which immediately uses the current/index
        -- without having to move to a new item.
        index, current = nextFunc(nextContext, index)
        if index ~= nil then
            context[5] = {}
            state = 1
        else
            setState(context, -4)
            return nil, nil
        end
    end
 
    -- State is only updated locally, because it is potentially set multiple times during one iteration.
    -- Persist state only when returned.
    local set = context[5]
    while true do
        -- Try to grab a new item.
        if state == 2 then
            index, current = nextFunc(nextContext, index)
            if index ~= nil then
                state = 1
            else
                setState(context, -4)
                return nil, nil
            end
        end
 
        if state == 1 then
            state = 2
            local key = keySelector and keySelector(current, index) or current
            if addToSet(set, key) == true then
                setState(context, state)
                return index, current
            end
        end
    end
end
 
---Returns unique (distinct) elements from a sequence according to a specified key selector function.
---@param keySelector? fun(current: any, index: any): any
---@return Enumerator3
function Enumerator3:unique(keySelector)
    return Enumerator3.new(uniqueFunc, self.index, wrapContext1(self, keySelector))
end
 
 
-- DIFFERENCE
local differenceFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Second collection
    local keySelector = context[5]
    -- 6: Hashset for uniqueness
    local current
 
    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        nextContext = initialiseContext(context)
        context[6] = Enumerator3.toSet(context[4], keySelector)
        setState(context, 1)
    end
 
    local set = context[6]
    while true do
        index, current = nextFunc(nextContext, index)
        if index ~= nil then
            local key = keySelector and keySelector(current, index) or current
            if addToSet(set, key) == true then
                return index, current
            end
        else
            setState(context, -4)
            return nil, nil
        end
    end
end
 
---Produces the set difference of two sequences according to a specified key selector function.
---@param second table
---@param keySelector? fun(current: any, index: any): any
---@return Enumerator3
function Enumerator3:difference(second, keySelector)
    assert(second, 'Second collection cannot be nil.')
 
    return Enumerator3.new(differenceFunc, self.index, wrapContext2(self, second, keySelector))
end
 
 
-- UNION
local unionFunc = function (context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Second collection
    local keySelector = context[5]
    -- 6: Hashset for uniqueness
    local current
 
    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        nextContext = { getIterator(nextFunc, nextContext, index) }
        context[3] = nextContext
        context[6] = {}
        state = setState(context, 1)
    end
 
    local set = context[6]
    while state > 0 do
        -- Iteration state (exhaust the current iterator)
        while true do
            index, current = moveIterator(nextContext)
            if index == nil then break end
 
            local key = keySelector and keySelector(current, index) or current
            if addToSet(set, key) == true then
                return index, current
            end
        end
 
        -- If the state is 2 at this point, the second iterator has
        -- already been initialised and exhausted.
        if state == 2 then
            setState(context, -4)
            return nil, nil
        end
 
        -- Grab second iterator, first one is empty
        nextContext = { getIterator(context[4]) }
        context[3] = nextContext
        state = setState(context, 2)
    end
end
 
---Produces the set union of two sequences according to a specified key selector function.
---@param second table
---@param keySelector? fun(current: any, index: any): any
---@return Enumerator3
function Enumerator3:union(second, keySelector)
    assert(second, 'Second collection cannot be nil.')
 
    return Enumerator3.new(unionFunc, self.index, wrapContext2(self, second, keySelector))
end
 
-- INTERSECT
local intersectFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Second collection
    local keySelector = context[5]
    -- 6: Hashset for uniqueness
    local current
 
    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        nextContext = initialiseContext(context)
        context[6] = Enumerator3.toSet(context[4], keySelector)
        state = setState(context, 1)
    end
 
    local set = context[6]
    while true do
        index, current = nextFunc(nextContext, index)
        if index == nil then
            setState(context, -4)
            return nil, nil
        end
 
        local key = keySelector and keySelector(current, index) or current
        if removeFromSet(set, key) == true then
            return index, current
        end
    end
end
 
---Produces the set intersection of two sequences according to a specified key selector function.
---@param second table
---@param keySelector? fun(current: any, index: any): any
function Enumerator3:intersect(second, keySelector)
    assert(second, 'Second collection cannot be nil.')
 
    return Enumerator3.new(intersectFunc, self.index, wrapContext2(self, second, keySelector))
end
 
 
-- ZIP
local zipFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Second collection
    local resultSelector = context[5]
 
    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        -- First iterator gets stored in context, second in context[4]
        nextContext = initialiseContext(context)
        context[4] = { getIterator(context[4]) }
        state = setState(context, 1)
    end
 
    local second = context[4]
    local indexA, currentA = nextFunc(nextContext, index)
    local indexB, currentB = moveIterator(second)
 
    if indexA ~= nil and indexB ~= nil then
        local result = resultSelector and resultSelector(currentA, currentB) or { currentA, currentB }
        return indexA, result
    end
 
    setState(context, -4)
    return nil, nil
end
 
---Applies a specified function to the corresponding elements of two sequences, producing a sequence of the results.
---@param second table
---@param resultSelector? fun(itema: any, itemb: any): any
---@return Enumerator3
function Enumerator3:zip(second, resultSelector)
    assert(second, 'Second collection cannot be nil.')
 
    return Enumerator3.new(zipFunc, self.index, wrapContext2(self, second, resultSelector))
end
 
 
-- GROUPBY / GROUPBYRESULT
local groupByFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: keySelector
    -- 5: elementSelector
    local resultSelector = context[6]
 
    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        -- Pack iterator info into a structure getIterator recognises
        local iterator = { nextFunc, nextContext, index, _isIterator = true }
 
        local keySelector = context[4]
        local elementSelector = context[5]
        local lookup = Enumerator3.toLookup(iterator, keySelector, elementSelector)
 
        -- Overwrite iterator with new iterator.
        nextFunc, nextContext, index = getIterator(lookup)
        context[2] = nextFunc
        context[3] = nextContext
        state = setState(context, 1)
    end
 
    local key, elements = nextFunc(nextContext, index)
    if key ~= nil then
        local result = resultSelector and resultSelector(key, elements) or elements
        return key, result
    end
 
    setState(context, -4)
    return nil, nil
end
 
---Groups the elements of a sequence.
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
---@return Enumerator3
function Enumerator3:groupBy(keySelector, elementSelector)
    assert(keySelector, 'Key Selector cannot be nil.')
 
    return Enumerator3.new(groupByFunc, self.index, wrapContext2(self, keySelector, elementSelector))
end
 
function Enumerator3:groupByResult(keySelector, elementSelector, resultSelector)
    assert(keySelector, 'Key Selector cannot be nil.')
    assert(resultSelector, 'Result Selctor cannot be nil.')
 
    return Enumerator3.new(groupByFunc, self.index, wrapContextN(self, keySelector, elementSelector, resultSelector))
end
 
-- SORTING
-- This function is verbose on purpose. It is beneficial to spend more time returning a sort function
-- than create a sort function that has to make additional checks besides sorting.
local getSortFunction = function(sortSelectors)
    -- Special case where sort isn't chained.
    if #sortSelectors == 1 then
        local sortSelector = sortSelectors[1]
        local acendingSort = sortSelector[1]
        local sortSelector = sortSelector[2]
        if sortSelector ~= nil then
            if acendingSort == true then
                return function(a, b) return sortSelector(a) < sortSelector(b) end
            else
                return function(a, b) return sortSelector(a) > sortSelector(b) end
            end
        else
            if acendingSort == true then
                return function(a, b) return a < b end
            else
                return function(a, b) return a > b end
            end
        end
    end
 
    -- Return chained sorting.
    return function(a, b)
for _, sortSelector in pairs(sortSelectors) do
local selector = sortSelector[2]
local sortAscending = sortSelector[1]
 
local leftVal = selector(a)
local rightVal = selector(b)
if leftVal ~= rightVal then
if sortAscending then
return leftVal < rightVal
else
return leftVal > rightVal
end
end
end
end
end
--Debug.log(enum)
end
end


return p
local sortFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: sort selectors
    local current
 
    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        -- Create and sort buffer.
        nextContext = initialiseContext(context)
        local buffer = {}
        for _, value in nextFunc, nextContext, index do
            table.insert(buffer, value)
        end
 
        local sortFunc = getSortFunction(context[4])
        table.sort(buffer, sortFunc)
 
        nextFunc, nextContext, index = getIterator(buffer)
        context[2] = nextFunc
        context[3] = nextContext
        setState(context, 1)
    end
 
    index, current = nextFunc(nextContext, index)
    if index ~= nil then
        return index, current
    end
 
    setState(context, -4)
    return nil, nil
end
 
---Boilerplate function to setup various sorting Enumerator3
---@param self Enumerator3
---@param isDecending boolean
---@param selector? fun(current: any): any
---@return Enumerator3
local function getSorter(self, isDecending, selector)
    local selectors = { { isDecending, selector } }
    return Enumerator3.new(sortFunc, self.index, wrapContext1(self, selectors))
end
 
---Boilerplate function to chain one sort to the previous
---@param self table
---@param isDecending any
---@param selector any
local function getChainedSorter(self, isDecending, selector)
    assert(self.nextFunc == sortFunc, 'Previous function must be a "sort" or "order" variant.')
 
    -- Since the context isn't accessible for the child, we unwrap the context first by calling the factory
    -- we can then access the context to add a new sortselector to it.
    local selectors = self.contextFactory()[4]
    table.insert(selectors, { isDecending, selector })
 
    -- Wrap context into a factory again.
    self.contextFactory = wrapContext1(self, selectors)
    return self
end
 
function Enumerator3:sort()
    return getSorter(self, true)
end
 
function Enumerator3:sortDescending()
    return getSorter(self, false)
end
 
function Enumerator3:sortBy(selector)
    return getSorter(self, true, selector)
end
 
function Enumerator3:sortByDescending(selector)
    return getSorter(self, false, selector)
end
 
function Enumerator3:thenBy(selector)
    return getChainedSorter(self, true, selector)
end
 
function Enumerator3:thenByDecending(selector)
    return getChainedSorter(self, false, selector)
end
 
--------------------------------------------------------------------------------
-- Equality functions
--------------------------------------------------------------------------------
local equals = function(a, b)
    -- Start with a quick Lua equals comparison
    local result = a == b
    if result then
        return true
    end
 
    -- If this is false, we check for additional equality, such as sequence equality.
    local atype = type(a)
    local btype = type(b)
    if atype ~= btype then
        return false
    end
 
    if atype == 'table' then
        return Enumerator3.sequenceEquals(a, b)
    else
        return result
    end
end
 
---comment
---@param self Enumerator3
---@param item any
---@param comparer? fun(a: any, b: any): boolean
---@return integer
function Enumerator3.indexOf(self, item, comparer)
    local index = 1
    local cmp = comparer or equals
    for _, v in getIterator(self) do
        if cmp(v, item) == true then
            return index
        end
        index = index + 1
    end
    return -1
end
 
---Determines whether the sequence contains the provided item according to an equality comparer.
---@param self table
---@param item any
---@param comparer? fun(a: any, b: any): boolean
---@return boolean
function Enumerator3.contains(self, item, comparer)
    local cmp = comparer or equals
    for _, v in getIterator(self) do
        if cmp(v, item) == true then
            return true
        end
    end
    return false
end
 
---Determines whether two sequences are equal according to an equality comparer.
---@param self table
---@param other table
---@param comparer? fun(a: any, b: any): boolean
---@return boolean
function Enumerator3.sequenceEquals(self, other, comparer)
    local cmp = comparer or equals
 
    local next, context, index, current = getIterator(other)
 
    for _, v in getIterator(self) do
        index, current = next(context, index)
        if cmp(v, current) == false then
            return false
        end
    end
 
    -- We have exhausted the first iterator. For the sequences to be equal, the second iterator also
    -- has to be exhausted. We know this is the case if the last returned index is also nil.
    return index == nil
end
 
--------------------------------------------------------------------------------
-- Sequence search functions
--------------------------------------------------------------------------------
local function getFirstItem(self, defaultValue)
    local firstItem = nil
 
    local nextFunc, context, index = getIterator(self)
    -- Shortcut for arrays
    if nextFunc == ipairsFunc then
        firstItem = context[1]
    else
        _, firstItem = nextFunc(context, index)
    end
 
    return firstItem or defaultValue
end
 
local function getLastItem(self, defaultValue)
    local nextFunc, context, index = getIterator(self)
    local lastItem = nil
 
    if nextFunc == ipairsFunc then
        lastItem = context[#context]
    else
        -- Run iterator to the end
        local k, v = index, _
        repeat
            k, v = nextFunc(context, k)
        until k == nil
        lastItem = v
    end
 
    return lastItem or defaultValue
end
 
---Returns the first item in the sequence, or a default value if there are no items in the sequence.
---@param self table
---@param predicate? fun(value: any, key: any): boolean
---@param defaultValue? any
---@return any
function Enumerator3.firstOrDefault(self, predicate, defaultValue)
    if predicate == nil then
        return getFirstItem(self, defaultValue)
    end
 
    local nextFunc, context, index = getIterator(self)
    for k, v in nextFunc, context, index do
        if predicate(v, k) == true then
            return v
        end
    end
 
    return defaultValue
end
 
---Returns the first item in the sequence, or an error if there are no items in the sequence.
---@param self table
---@param predicate? fun(value: any, key: any): boolean
---@return any
function Enumerator3.first(self, predicate)
    local result = Enumerator3.firstOrDefault(self, predicate, nil)
    if result == nil then
        error('Sequence contains no (matching) elements.')
    else
        return result
    end
end
 
---Returns the last item in the sequence, or a default value if there are no items in the sequence.
---@param self table
---@param predicate? fun(value: any, key: any): boolean
---@param defaultValue? any
---@return any
function Enumerator3.lastOrDefault(self, predicate, defaultValue)
    if predicate == nil then
        return getLastItem(self, defaultValue)
    end
 
    local nextFunc, context, index = getIterator(self)
    local lastMatch = nil
    for k, v in nextFunc, context, index do
        if predicate(v, k) == true then
            lastMatch = v
        end
    end
 
    return lastMatch or defaultValue
end
 
---Returns the last item in the sequence, or an error if there are no items in the sequence.
---@param self table
---@param predicate? fun(value: any, key: any): boolean
---@return any
function Enumerator3.last(self, predicate)
    local result = Enumerator3.lastOrDefault(self, predicate, nil)
    if result == nil then
        error('Sequence contains no (matching) elements.')
    else
        return result
    end
end
 
---Returns the length of this collection.
---@param self table
---@return integer
function Enumerator3.count(self)
    local nextFunc, context, index = getIterator(self)
    -- Shortcut for arrays
    if nextFunc == ipairsFunc then
        return #context
    end
 
    local len = 0
    for _, _ in nextFunc, context, index do
        len = len + 1
    end
    return len
end
 
---Returns the number of elements in a sequence that match a predicate.
---@param self Enumerator3
---@param predicate fun(value: any, key: any): any
---@return integer
function Enumerator3.countBy(self, predicate)
    assert(predicate, 'Predicate cannot be nil.')
    local count = 0
    for k, v in getIterator(self) do
        if predicate(v, k) == true then
            count = count + 1
        end
    end
    return count
end
 
--- Determines whether all elements of a sequence satisfy a condition.
---@param self Enumerator3
---@param predicate fun(current: any, index: any): boolean
---@return boolean
function Enumerator3.all(self, predicate)
    assert(predicate)
    for key, value in getIterator(self) do
        if predicate(value, key) == false then
            return false
        end
    end
    return true
end
 
---Determines whether any element of a sequence exists or satisfies a condition.
---@param self Enumerator3
---@param predicate? fun(current: any, index: any): boolean
---@return boolean
function Enumerator3.any(self, predicate)
    local nextFunc, context, index = getIterator(self)
    -- Check without predicate (just need a first item to exist)
    if predicate == nil then
        -- Array shortcut.
        if nextFunc == ipairsFunc then
            return #context > 0
        else
            -- Check if we can iterate once.
            local inx, _ = nextFunc(context, index)
            return inx ~= nil
        end
    -- Check with a predicate. We need one item that matches the predicate.
    else
        for key, value in nextFunc, context, index do
            if predicate(value, key) == true then
                return true
            end
        end
    end
 
    return false
end
 
--------------------------------------------------------------------------------
-- Iteration via function
--------------------------------------------------------------------------------
function Enumerator3.foreach(self, func)
    assert(type(func) == 'function')
 
    for _, current in getIterator(self) do
        func(current)
    end
    return nil
end
 
function Enumerator3.ipairs(self)
    return getIterator(self)
end
 
function Enumerator3.pairs(self)
    return getIterator(self)
end
 
--------------------------------------------------------------------------------
-- Reduction
--------------------------------------------------------------------------------
 
---Applies an accumulator function over a sequence. The specified seed value is used as the initial accumulator value, and the specified function is used to select the result value.
---@param self table
---@param func fun(accumulate: any, next: any): any
---@param seed? any
---@param resultSelector? fun(result: any): any
---@return any
function Enumerator3.aggregate(self, func, seed, resultSelector)
    assert(func, 'Aggregate function cannot be nil.')
 
    local result = seed
    local next, context, index = getIterator(self)
 
    -- We need a starting element.
    if result == nil then
        index, result = next(context, index)
        if index == nil then
            error('Sequence contains no elements')
        end
    end
 
    for _, current in next, context, index do
        result = func(result, current)
    end
 
    if resultSelector ~= nil then
        result = resultSelector(result)
    end
 
    return result
end
 
--------------------------------------------------------------------------------
-- Arithmetic functions
--------------------------------------------------------------------------------
local maxCmp = function(a, b)
    if a < b then return true end
end
 
local minCmp = function(a, b)
    if a > b then return true end
end
 
local minmaxBy = function(self, selector, comparison)
    local nextFunc, context, index = getIterator(self)
 
    -- Grab the first item
    local key, value = nextFunc(context, index)
    local maxVal = selector(value, key)
    local maxItem = value
 
    local curr = nil
    for k, v in nextFunc, context, key do
        curr = selector(v, k)
        if comparison(maxVal, curr) == true then
            maxVal = curr
            maxItem = v
        end
    end
 
    return maxItem
end
 
local minmax = function(self, comparison)
    local nextFunc, context, index = getIterator(self)
 
    -- Grab the first item
    local key, value = nextFunc(context, index)
    local maxVal = value
 
    for _, v in nextFunc, context, key do
        if comparison(maxVal, v) == true then
            maxVal = v
        end
    end
 
    return maxVal
end
 
---Returns the maximum value in a sequence of values.
---@param self Enumerator3
---@return number
function Enumerator3.max(self)
    return minmax(self, maxCmp)
end
 
---Returns the item with the highest value in a sequence.
---@param self Enumerator3
---@param selector fun(item: any, index: any): number
---@return unknown
function Enumerator3.maxBy(self, selector)
    assert(selector)
    return minmaxBy(self, selector, maxCmp)
end
 
---Returns the minimum value in a sequence of values.
---@param self Enumerator3
---@return number
function Enumerator3.min(self)
    return minmax(self, minCmp)
end
 
---Returns the item with the lowest value in a sequence.
---@param self Enumerator3
---@param selector fun(item: any, index: any): number
---@return unknown
function Enumerator3.minBy(self, selector)
    assert(selector)
    return minmaxBy(self, selector, minCmp)
end
 
---Returns the sum value of a sequence of values.
---@param self Enumerator3
---@return number
function Enumerator3.sum(self)
    local total = 0
    for _, v in getIterator(self) do
        total = total + v
    end
    return total
end
 
--------------------------------------------------------------------------------
-- Enumerator functions that result in a new collection.
--------------------------------------------------------------------------------
 
---Creates a set from a table|Enumerator3 with a specified key for uniqueness.
---@param self table
---@param keySelector? fun(current:any, index: any): any
---@return table
function Enumerator3.toSet(self, keySelector)
    local set = {}
    for inx, current in getIterator(self) do
        if keySelector == nil then
            set[current] = true
        else
            local key = keySelector(current, inx)
            set[key] = true
        end
    end
    return set
end
 
---Creates an array from a table|Enumerator3
---@param self table
---@return table
function Enumerator3.toTable(self)
    local tab = {}
    for _, value in getIterator(self) do
        table.insert(tab, value)
    end
    return tab
end
 
---Creates a lookup from a table|Enumerator3
---This is a { key, { elements } } structure
---@param self table
---@return table
function Enumerator3.toLookup(self, keySelector, elementSelector)
    local lookup = {}
 
    for inx, value in getIterator(self) do
        -- Retrieve grouping by key
        local key = keySelector(value, inx)
        assert(key, 'Key cannot be nil.')
        local grouping = lookup[key]
 
        -- Create new grouping if it doesn't exist.
        if grouping == nil then
            grouping = {}
            lookup[key] = grouping
        end
 
        -- Add element to grouping
        local element = elementSelector and elementSelector(value, inx) or value
        table.insert(grouping, element)
    end
 
    return lookup
end
 
---Creates a dictionary (key, value) structure based on a keyselector and an elementselector
---@param self Enumerator3
---@param keySelector? fun(value: any, key: any): any
---@param elementSelector? fun(value: any, key: any): any
---@return table
function Enumerator3.toDictionary(self, keySelector, elementSelector)
    local dictionary = {}
 
    for key, value in getIterator(self) do
        if keySelector ~= nil then
            key = keySelector(value, key)
        end
        if elementSelector ~= nil then
            value = elementSelector(value, key)
        end
 
        dictionary[key] = value
    end
    return dictionary
end
 
---Creates a deep copy of the provided object.
---@param self any
---@return any
function Enumerator3.deepCopy(self)
    local orig_type = type(self)
    local copy
    if orig_type == 'table' then
        copy = {}
        for k, v in getIterator(self) do
            copy[Enumerator3.deepCopy(k)] = Enumerator3.deepCopy(v)
        end
    else
        copy = self
    end
    return copy
end
 
--------------------------------------------------------------------------------
-- Function aliasses
--------------------------------------------------------------------------------
Enumerator3.select = Enumerator3.map
Enumerator3.selectMany = Enumerator3.flatMap
Enumerator3.distinct = Enumerator3.unique
Enumerator3.except = Enumerator3.difference
Enumerator3.order = Enumerator3.sort
Enumerator3.orderDecending = Enumerator3.sortDescending
Enumerator3.orderBy = Enumerator3.sortBy
Enumerator3.orderByDecending = Enumerator3.sortByDescending
 
Enumerator3.each = Enumerator3.foreach
Enumerator3.for_each = Enumerator3.foreach
Enumerator3.length = Enumerator3.count
Enumerator3.lengthBy = Enumerator3.countBy
Enumerator3.has = Enumerator3.contains
Enumerator3.fold = Enumerator3.aggregate
Enumerator3.reduce = Enumerator3.aggregate
 
return Enumerator3
 
--------------------------------------------------------------------------------
-- Explanation on how this module functions / how it can be extended
--------------------------------------------------------------------------------
 
----------------------------------------------------
-- Chaining functions and the Enumerator3 class
----------------------------------------------------
--[[
The Enumerator3 class holds metadata for chaining functions and the context for the operation. Specifically, it holds three pieces of information:
- The function to be executed next in the sequence.
- Index of the Iterator Function: (which is 0 for ipairs and nil for pairs)
- The Context Factory (A function that generates a fresh context)
 
The Enumerator3 class ONLY serves as a storage for metadata with the sole purpose of chaining functions. This is also the reason why all chainable functions are added to Enumerator3. When a chainable function gets called, a new Enumerator3 instance gets created. The current (self) Enumerator3 metadata is passed to the newly created instance to form a chain (via the context object). The Enumerator3 class is ONLY used when functions are chained, NOT when executed. This is also why this module is as fast as it is. There are no calls made to retrieve data from the Enumerator3 instances when iterating. The context object is exclusively used for this purpose.
 
When the chain of functions gets evaluated, the following happens.
1. The next function in the chain is retrieved.
2. The context object for the next function is created via the factory.
3. The index for the next function is retrieved.
4. The next function is invoked using the context and index. (Usually in a loop, until the index returns nil)
 
The above 4 steps happen for every function in the chain.
]]
 
----------------------------------------------------
-- The context object and context factory
----------------------------------------------------
--[[
Since some functions require data to be remembered for consecutive "nextFunc" calls, a context object is required. This can be seen, for example, for the map function. Where the map function stores an index (position) and passes it to the selector function.
 
However, if the Enumerator3 object is reused or chained with a different function, the context from the previous iteration persists, leading to incorrect behavior. For example, in the map function, the index remains at the last iteration's position instead of restarting at 0. To prevent this, the context isn't passed directly to Enumerator3. Instead, it utilizes a function, known as a context factory, to generate a fresh context each time the function chain is evaluated.
 
This is achieved via the wrapContext function. The context always contains the same fields in the same order:
[1] State: Represents the current state of the function.
[2] Next Function: The subsequent function in the chain.
[3] Next Context (Factory): The factory function for creating the next context.
 
Depending on the specific function used, additional fields may be added to the context object. For example, the map function appends:
- A selector function as the 4th field.
- An indexer as the 5th field.
The sequence in which fields are added is crucial because it determines how they must be retrieved later. The fields added by a chainable function must align precisely with the fields expected by its implementation function.
 
For example in the map (chainable) function we see:
"Enumerator3.new(mapFunc, self.index, wrapContext2(self, selector, 0))"
 
The wrapContext function adds the initial three fields (state, next function, next context factory), with the selector as the 4th field and the indexer (defaulting to 0) as the 5th field.
 
The implementing function of map retrieves the fields in the exact order they were added:
Example: The implementation of map retrieves all of these fields in the same order when called:
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    local selector = context[4]
    local position = context[5]
 
 
Maintaining this order is critical for ensuring that the context is correctly interpreted and that each function receives the data it expects.
]]
 
 
----------------------------------------------------
-- The state of the implementing function
----------------------------------------------------
--[[
When an implementing function is invoked for the first time, the context object must be created using the provided factory function. During this initial call, the following occurs:
 
- Check Initial State: The state is checked to see if it is 0.
- Replace Context Factory: If the state is 0, the context factory function is replaced with the context itself via the initialiseContext function.
- Set to Active State: The state is then set to 1, signifying that the function is now in its active iteration phase.
 
On subsequent calls to the implementing function, the initialization step (state 0) is bypassed, and the function proceeds directly with the iteration step. Some implementing functions may require multiple states to handle more complex behaviors.
 
When the state becomes -4, it signifies that the iteration has reached its end. At this point, the function returns nil to indicate to the caller that there are no further elements to process.
]]
 
----------------------------------------------------
-- Expanding upon the module
----------------------------------------------------
--[[
If a new chainable function gets added (a function that returns an Enumerator3 instance) the following should be ensured:
1. A chainable function and an implementation function pair must be implemented.
2. The chainable function must be added to the Enumerator3 class to allow chaining.
3. The chainable function must pass a context factory to the Enumerator3 constructor, with the parameters in the same order as the implementing function uses them.
4. The implementing function must retrieve the required fields in the same order as the context factory specifies.
5. The implementing function must create a new context via the factory function and store the context for future iterations.
6. The state field can be used to track the state of the implementing function and change its behaviour. For instance, having an initialisation state and an iteration state.
 
Simple functions such as the map/mapFunc and where/whereFunc can be used as a reference.
 
 
 
If a new result function gets added (a function that iterates over the function chain and returns a direct result) the following should be ensured:
1. Retrieve the nextFunc, contextFactory and index from the current object (self).
2. Create the context by calling the contextFactory.
3. Iterate over the chain by calling nextFunc using the context and index.
 
If the index returned from nextFunc is nil, it means that the iteration has come to an end.
]]

Latest revision as of 16:06, 15 August 2024

This module allows for various complex querying and evaluating of a table or otherwise collection or object that is iterable. This means every Lua type that provides an ipairs or pairs iterator.

The functions under Chainable Functions can be chained back to back to create a complex query. The underlying collection isn't evaluated unless some function is used that evaluates the function chain. This can be one of the many functions not listed under Chainable Functions, or simply iterating over the function chain with ipairs/pairs.

To start create a new Enumerator3 object via Enumerator3.create. Functions can then be appended to this with the semicolon (:). Example:

local tbl = {}
local count = Enumerator3.create(tbl)
    :select(function(x) return x.number end)
    :countBy(function(x) return x % 2 == 0 end)

The above snippet counts all numbers in the table under the 'number' field that are divisible by two. The countBy function evaluates all functions in the chain and returns the count.

All functions that have a "self" parameter can be used with any collection and doesn't need to be chained. For example, the maxBy function can be used to get the person with the highest age:

local tbl = {}
local oldestPerson = Enumerator3.maxBy(tbl, function(person) return person.age end)

Enumerator3 Initialisation

create

---@param obj any
---@return Enumerator3
function Enumerator3.create(obj)
---@param obj any
---@param iterator fun (such as ipairs / pairs)
---@return Enumerator3
function Enumerator3.createWith(obj, iterator)

ipairs

---@return fun():table
function Enumerator3:ipairs()

Chainable Functions

where

---Filters a sequence of values based on a predicate.
---@param predicate fun(item: any, index: any): boolean
---@return Enumerator3
function Enumerator3:where(predicate)

map

---Projects each element of a sequence into a new form.
---@param selector fun(item: any, position: integer): any
---@return Enumerator3
function Enumerator3:map(selector)

flatMap

---Projects each element of a sequence and flattens the resulting sequences into one sequence.
---@param selector fun(item: any, position: integer): table
---@return Enumerator3
function Enumerator3:flatMap(selector)

concat

---Concatenates two sequences.
---@param second table
---@return Enumerator3
function Enumerator3:concat(second)

append

---Appends a value to the end of the sequence.
---@param item any
---@param index? any
---@return Enumerator3
function Enumerator3:append(item, index)

prepend

---Prepends a value to the start of the sequence.
---@param item any
---@param index? any
---@return Enumerator3
function Enumerator3:prepend(item, index)

unique

---Returns unique (distinct) elements from a sequence according to a specified key selector function.
---@param keySelector? fun(current: any, index: any): any
---@return Enumerator3 
function Enumerator3:unique(keySelector)

difference

---Produces the set difference of two sequences according to a specified key selector function.
---@param second table
---@param keySelector? fun(current: any, index: any): any
---@return Enumerator3
function Enumerator3:difference(second, keySelector)

union

---Produces the set union of two sequences according to a specified key selector function.
---@param second table
---@param keySelector? fun(current: any, index: any): any
---@return Enumerator3
function Enumerator3:union(second, keySelector)

intersect

---Produces the set intersection of two sequences according to a specified key selector function.
---@param second table
---@param keySelector? fun(current: any, index: any): any
function Enumerator3:intersect(second, keySelector)

zip

---Applies a specified function to the corresponding elements of two sequences, producing a sequence of the results.
---@param second table
---@param resultSelector? fun(itema: any, itemb: any): any
---@return Enumerator3
function Enumerator3:zip(second, resultSelector)

groupBy

---Groups the elements of a sequence.
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
---@return Enumerator3
function Enumerator3:groupBy(keySelector, elementSelector)

groupByResult

---Groups the elements of a sequence.
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
---@param resultSelector fun(key: any, elements: any): any
---@return Enumerator3
function Enumerator3:groupByResult(keySelector, elementSelector, resultSelector)

sort

function Enumerator3:sort()

sortDescending

function Enumerator3:sortDescending()

sortBy

function Enumerator3:sortBy(selector)

sortByDescending

function Enumerator3:sortByDescending(selector)

thenBy

function Enumerator3:thenBy(selector)

thenByDecending

function Enumerator3:thenByDecending(selector)

Equality checks

indexOf

---Determines the index of the first occurrence of a specified item.
---@param self Enumerator3
---@param item any
---@param comparer? fun(a: any, b: any): boolean
---@return integer
function Enumerator3.indexOf(self, item, comparer)

contains

---Determines whether the sequence contains the provided item according to an equality comparer.
---@param self table
---@param item any
---@param comparer? fun(a: any, b: any): boolean
---@return boolean
function Enumerator3.contains(self, item, comparer)

sequenceEquals

---Determines whether two sequences are equal according to an equality comparer.
---@param self table
---@param other table
---@param comparer? fun(a: any, b: any): boolean
---@return boolean
function Enumerator3.sequenceEquals(self, other, comparer)

Sequence Search Functions

firstOrDefault

---Returns the first item in the sequence, or a default value if there are no items in the sequence.
---@param self table
---@param predicate? fun(value: any, key: any): boolean
---@param defaultValue? any
---@return any
function Enumerator3.firstOrDefault(self, predicate, defaultValue)

first

---Returns the first item in the sequence, or an error if there are no items in the sequence.
---@param self table
---@param predicate? fun(value: any, key: any): boolean
---@return any
function Enumerator3.first(self, predicate)

lastOrDefault

---Returns the last item in the sequence, or a default value if there are no items in the sequence.
---@param self table
---@param predicate? fun(value: any, key: any): boolean
---@param defaultValue? any
---@return any
function Enumerator3.lastOrDefault(self, predicate, defaultValue)

last

---Returns the last item in the sequence, or an error if there are no items in the sequence.
---@param self table
---@param predicate? fun(value: any, key: any): boolean
---@return any
function Enumerator3.last(self, predicate)

count

---Returns the length of this collection.
---@param self table
---@return integer
function Enumerator3.count(self)

countBy

---Returns the number of elements in a sequence that match a predicate.
---@param self Enumerator3
---@param predicate fun(value: any, key: any): any
---@return integer
function Enumerator3.countBy(self, predicate)

all

---Determines whether all elements of a sequence satisfy a condition.
---@param self Enumerator3
---@param predicate fun(current: any, index: any): boolean
---@return boolean
function Enumerator3.all(self, oredicate)

any

---Determines whether any element of a sequence exists or satisfies a condition.
---@param self Enumerator3
---@param predicate? fun(current: any, index: any): boolean
---@return boolean
function Enumerator3.any(self, predicate)

Iteration

foreach

---Iterates over the sequence, applying a specified function to each element.
---@param self Enumerator3
---@param func fun(current: any): any
---@return nil
function Enumerator3.foreach(self, func)

Reduction

aggregate

---Applies an accumulator function over a sequence. The specified seed value is used as the initial accumulator value, and the specified function is used to select the result value.
---@param self table
---@param func fun(accumulate: any, next: any): any
---@param seed? any
---@param resultSelector? fun(result: any): any
---@return any
function Enumerator3.aggregate(self, func, seed, resultSelector)

Arithmetic Functions

max

---Returns the maximum value in a sequence of values.
---@param self Enumerator3
---@return number
function Enumerator3.max(self)

maxBy

---Returns the item with the highest value in a sequence.
---@param self Enumerator3
---@param selector fun(item: any, index: any): number
---@return unknown
function Enumerator3.maxBy(self, selector)

min

---Returns the minimum value in a sequence of values.
---@param self Enumerator3
---@return number
function Enumerator3.min(self)

minBy

---Returns the item with the lowest value in a sequence.
---@param self Enumerator3
---@param selector fun(item: any, index: any): number
---@return unknown
function Enumerator3.minBy(self, selector)

sum

---Returns the sum value of a sequence of values.
---@param self Enumerator3
---@return number
function Enumerator3.sum(self)

Collection transform functions

toSet

---Creates a set from a table|Enumerator3 with a specified key for uniqueness.
---@param self table
---@param keySelector? fun(current: any, index: any): any
---@return table
function Enumerator3.toSet(self, keySelector)

toTable

---Creates an array from a table|Enumerator3
---@param self table
---@return table
function Enumerator3.toTable(self)

toLookup

---Creates a lookup from a table|Enumerator3
---This is a { key, { elements } } structure
---@param self table
---@param keySelector fun(value: any, index: any): any
---@param elementSelector? fun(value: any, index: any): any
---@return table
function Enumerator3.toLookup(self, keySelector, elementSelector)

toDictionary

---Creates a dictionary (key, value) structure based on a keyselector and an elementselector
---@param self Enumerator3
---@param keySelector? fun(value: any, key: any): any
---@param elementSelector? fun(value: any, key: any): any
---@return table
function Enumerator3.toDictionary(self, keySelector, elementSelector)

deepCopy

---Creates a deep copy of the provided object.
---@param self any
---@return any
function Enumerator3.deepCopy(self)

Function aliases

These aliases behave the same as the functions listed above. They only have a different name.

select

Alias for Enumerator3.map

selectMany

Alias for Enumerator3.flatMap

distinct

Alias for Enumerator3.unique

except

Alias for Enumerator3.difference

order

Alias for Enumerator3.sort

orderDecending

Alias for Enumerator3.sortDescending

orderBy

Alias for Enumerator3.sortBy

orderByDecending

Alias for Enumerator3.sortByDescending

each

Alias for Enumerator3.foreach

for_each

Alias for Enumerator3.foreach

length

Alias for Enumerator3.count

lengthBy

Alias for Enumerator3.countBy

has

Alias for Enumerator3.contains

fold

Alias for Enumerator3.aggregate

reduce

Alias for Enumerator3.aggregate

Module Explanation

--------------------------------------------------------------------------------
-- Explanation on how this module functions / how it can be extended
--------------------------------------------------------------------------------

----------------------------------------------------
-- Chaining functions and the Enumerator3 class
----------------------------------------------------
--[[
The Enumerator3 class holds metadata for chaining functions and the context for the operation. Specifically, it holds three pieces of information:
- The function to be executed next in the sequence.
- Index of the Iterator Function: (which is 0 for ipairs and nil for pairs)
- The Context Factory (A function that generates a fresh context)

The Enumerator3 class ONLY serves as a storage for metadata with the sole purpose of chaining functions. This is also the reason why all chainable functions are added to Enumerator3. When a chainable function gets called, a new Enumerator3 instance gets created. The current (self) Enumerator3 metadata is passed to the newly created instance to form a chain (via the context object). The Enumerator3 class is ONLY used when functions are chained, NOT when executed. This is also why this module is as fast as it is. There are no calls made to retrieve data from the Enumerator3 instances when iterating. The context object is exclusively used for this purpose.

When the chain of functions gets evaluated, the following happens.
1. The next function in the chain is retrieved.
2. The context object for the next function is created via the factory.
3. The index for the next function is retrieved.
4. The next function is invoked using the context and index. (Usually in a loop, until the index returns nil)

The above 4 steps happen for every function in the chain.
]]

----------------------------------------------------
-- The context object and context factory
----------------------------------------------------
--[[
Since some functions require data to be remembered for consecutive "nextFunc" calls, a context object is required. This can be seen, for example, for the map function. Where the map function stores an index (position) and passes it to the selector function. 

However, if the Enumerator3 object is reused or chained with a different function, the context from the previous iteration persists, leading to incorrect behavior. For example, in the map function, the index remains at the last iteration's position instead of restarting at 0. To prevent this, the context isn't passed directly to Enumerator3. Instead, it utilizes a function, known as a context factory, to generate a fresh context each time the function chain is evaluated.

This is achieved via the wrapContext function. The context always contains the same fields in the same order:
[1] State: Represents the current state of the function.
[2] Next Function: The subsequent function in the chain.
[3] Next Context (Factory): The factory function for creating the next context.

Depending on the specific function used, additional fields may be added to the context object. For example, the map function appends:
- A selector function as the 4th field.
- An indexer as the 5th field.
The sequence in which fields are added is crucial because it determines how they must be retrieved later. The fields added by a chainable function must align precisely with the fields expected by its implementation function.

For example in the map (chainable) function we see:
"Enumerator3.new(mapFunc, self.index, wrapContext2(self, selector, 0))"

The wrapContext function adds the initial three fields (state, next function, next context factory), with the selector as the 4th field and the indexer (defaulting to 0) as the 5th field.

The implementing function of map retrieves the fields in the exact order they were added:
Example: The implementation of map retrieves all of these fields in the same order when called:
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    local selector = context[4]
    local position = context[5]


Maintaining this order is critical for ensuring that the context is correctly interpreted and that each function receives the data it expects.
]]


----------------------------------------------------
-- The state of the implementing function
----------------------------------------------------
--[[
When an implementing function is invoked for the first time, the context object must be created using the provided factory function. During this initial call, the following occurs:

- Check Initial State: The state is checked to see if it is 0.
- Replace Context Factory: If the state is 0, the context factory function is replaced with the context itself via the initialiseContext function.
- Set to Active State: The state is then set to 1, signifying that the function is now in its active iteration phase.

On subsequent calls to the implementing function, the initialization step (state 0) is bypassed, and the function proceeds directly with the iteration step. Some implementing functions may require multiple states to handle more complex behaviors. 

When the state becomes -4, it signifies that the iteration has reached its end. At this point, the function returns nil to indicate to the caller that there are no further elements to process. 
]]

----------------------------------------------------
-- Expanding upon the module
----------------------------------------------------
--[[
If a new chainable function gets added (a function that returns an Enumerator3 instance) the following should be ensured:
1. A chainable function and an implementation function pair must be implemented.
2. The chainable function must be added to the Enumerator3 class to allow chaining.
3. The chainable function must pass a context factory to the Enumerator3 constructor, with the parameters in the same order as the implementing function uses them.
4. The implementing function must retrieve the required fields in the same order as the context factory specifies.
5. The implementing function must create a new context via the factory function and store the context for future iterations.
6. The state field can be used to track the state of the implementing function and change its behaviour. For instance, having an initialisation state and an iteration state.

Simple functions such as the map/mapFunc and where/whereFunc can be used as a reference.



If a new result function gets added (a function that iterates over the function chain and returns a direct result) the following should be ensured:
1. Retrieve the nextFunc, contextFactory and index from the current object (self).
2. Create the context by calling the contextFactory.
3. Iterate over the chain by calling nextFunc using the context and index.

If the index returned from nextFunc is nil, it means that the iteration has come to an end.
]]

--------------------------------------------------------------------------------
-- Helper functions for wrapping and unwrapping the context (state) object.
--------------------------------------------------------------------------------
-- We define the most common wrapContext functions, since at lower parameter count
-- manual passing of the additional parameters is faster than using table.unpack

-- Wraps a context object in a factory function with one additional parameter.
---@param enumerator Enumerator3
---@return fun():table
local wrapContext0 = function(enumerator)
    local nextFunc, context = enumerator.nextFunc, enumerator.contextFactory
    return function() return { 0, nextFunc, context } end
end

-- Wraps a context object in a factory function with one additional parameter.
---@param enumerator Enumerator3
---@param param1 any
---@return fun():table
local wrapContext1 = function(enumerator, param1)
    local nextFunc, context = enumerator.nextFunc, enumerator.contextFactory
    return function() return { 0, nextFunc, context, param1 } end
end

-- Wraps a context object in a factory function with two additional parameters.
---@param enumerator Enumerator3
---@param param1 any
---@param param2 any
---@return fun():table
local wrapContext2 = function(enumerator, param1, param2)
    local nextFunc, context = enumerator.nextFunc, enumerator.contextFactory
    return function() return { 0, nextFunc, context, param1, param2 } end
end

-- Wraps a context object in a factory function with N additional parameters.
---@param enumerator Enumerator3
---@param ... unknown
---@return fun():table
local wrapContextN = function(enumerator, ...)
    local nextFunc, context = enumerator.nextFunc, enumerator.contextFactory
    local varargs = {...}
    return function() return { 0, nextFunc, context, table.unpack(varargs) } end
end

local ipairsFunc = ipairs({0})
local pairsFunc = pairs({'a'})

local getIterator

-- Sets the state to the context, and returns the state
---@param context table
---@param state integer
---@return integer
local setState = function(context, state)
    context[1] = state
    return state
end

-- Replaces the context factory with the context returned from the factory.
---@param context table
---@return table?
local initialiseContext = function(context)
    local contextFactory = context[3]
    if contextFactory == nil then
        return nil
    else
        -- Call the factory method and replace the factory method with the
        -- unpacked context so its state is available for mutation.
        local newContext = contextFactory()
        context[3] = newContext
        return newContext
    end
end

-- Grabs the next item from a raw iterator and stores replaces the index with the next index.
-- This is used as a shortcut for advancing nested iterators / tables.
---@param iterator table
---@return any
---@return any
local moveIterator = function(iterator)
    local iteratorFunction, context, index = iterator[1], iterator[2], iterator[3]
    index, context = iteratorFunction(context, index)
    iterator[3] = index
    return index, context
end

--------------------------------------------------------------------------------
-- Helper functions for hashset behaviour
--------------------------------------------------------------------------------
---@param set table
---@param item any
---@return boolean
local addToSet = function(set, item)
    if set[item] == nil then
        set[item] = true
        return true
    end
    return false
end

---@param set table
---@param item any
---@return boolean
local removeFromSet = function(set, item)
    if set[item] == nil then
        return false
    end
    set[item] = nil
    return true
end

--------------------------------------------------------------------------------
-- Enumerator object definition.
--------------------------------------------------------------------------------
--- @class Enumerator3
--- @field nextFunc fun(context: any, index: any):any, any
--- @field index any
--- @field contextFactory fun(): table
local Enumerator3 = {
    __call = function(self, param, state)
        return self.nextFunc(self.contextFactory(), state)
    end;
    __pairs = function(self)
        return self.nextFunc, self.contextFactory(), self.index
    end;
    __ipairs = function(self)
        return self.nextFunc, self.contextFactory(), self.index
    end;
}
Enumerator3.__index = Enumerator3

---@param nextFunc fun(context: any, index: any):any, any
---@param index any
---@param contextFactory fun(): any
---@return Enumerator3
function Enumerator3.new(nextFunc, index, contextFactory)
    return setmetatable({
        nextFunc = nextFunc,
        index = index,
        contextFactory = contextFactory
    }, Enumerator3)
end

---@param obj any
---@return Enumerator3
function Enumerator3.create(obj)
    local func, current, index = getIterator(obj)
    return Enumerator3.new(func, index, function() return current end)
end

function Enumerator3.createWith(obj, iterator)
    local func, current, index = iterator(obj)
    return Enumerator3.new(func, index, function() return current end)
end

--------------------------------------------------------------------------------
-- Helper function for setting up iterator.
--------------------------------------------------------------------------------
getIterator = function(obj, context, index)
    assert(obj ~= nil, 'invalid iterator')
    if type(obj) == 'table' then
        local mt = getmetatable(obj);
        if mt ~= nil then
            if mt == Enumerator3 then
                return obj.nextFunc, obj.contextFactory(), obj.index
            elseif mt.__ipairs ~= nil then
                return mt.__ipairs(obj)
            elseif mt.__pairs ~= nil then
                return mt.__pairs(obj)
            end
        end
        if obj._isIterator == true then
            return obj[1], obj[2](), obj[3]
        end
        if #obj > 0 then
            return ipairs(obj)
        else
            return pairs(obj)
        end
    elseif (type(obj) == 'function') then
        -- This is likely a contextFactory. Unwrap it if it's the case.
        if type(context) == 'function' then context = context() end
        return obj, context, index
    end

    error(string.format('object %s of type "%s" is not iterable', obj, type(obj)))
end

--------------------------------------------------------------------------------
-- Chainable functions
--------------------------------------------------------------------------------

-- WHERE
local whereFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]

    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        -- Unpack context
        nextContext = initialiseContext(context)
        state = setState(context, 1)
    end

    local predicate = context[4]
    local current

    while true do
        index, current = nextFunc(nextContext, index)
        if index == nil then
            setState(context, -4)
            return nil, nil
        end
        if predicate(current, index) == true then
            return index, current
        end
    end
end

---Filters a sequence of values based on a predicate.
---@param predicate fun(item: any, index: any): boolean
---@return Enumerator3
function Enumerator3:where(predicate)
    assert(predicate, 'Predicate cannot be nil')
    return Enumerator3.new(whereFunc, self.index, wrapContext1(self, predicate))
end


-- MAP / SELECT
local mapFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]

    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        -- Initialise
        nextContext = initialiseContext(context)
        state = setState(context, 1)
    end

    local selector = context[4]
    local position = context[5] + 1
    local current

    index, current = nextFunc(nextContext, index)
    if index ~= nil then
        current = selector(current, position)
        assert(current, 'Selected value must be non-nil')
        context[5] = position
        return index, current
    end

    setState(context, -4)
end

---Projects each element of a sequence into a new form.
---@param selector fun(item: any, position: integer): any
---@return Enumerator3
function Enumerator3:map(selector)
    assert(selector, 'Selector cannot be nil')
    return Enumerator3.new(mapFunc, self.index, wrapContext2(self, selector, 0))
end


-- FLATMAP / SELECTMANY
local flatMapFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Predicate
    -- 5: Position
    -- 6: Parent Index
    -- 7: Nested Iterator
    local current

    if state == -4 then
        return nil, nil
    end

    -- Initialise
    if state == 0 then
        nextContext = initialiseContext(context)
        state = setState(context, 3) -- signal to get (first) nested iterator
        context[5] = 1 -- set position to 1
    end

    while true do
        -- Grab next value from nested enumerator		
        if state == 4 then
            local childIterator = context[7]
            index, current = moveIterator(childIterator)
            if index ~= nil then
                return index, current
            else
                state = setState(context, 3)
            end
        end

        if state == 3 then
            local selector = context[4]
            local position = context[5]
            index = context[6]
            -- Get the next item from the iterator
            index, current = nextFunc(nextContext, index)
            if index ~= nil then
                -- Grab the child iterator via the selector function.
                local nestedObject = selector(current, position)
                assert(type(nestedObject) == 'table', 'Selected object must be a table.')
                context[5] = position + 1
                context[6] = index
                -- Setup the child iterator
                context[7] = { getIterator(nestedObject) }
                state = setState(context, 4)
            else
            	-- iterator doesn't have any more nested iterators.
                state = setState(context, -4)
                return nil, nil
            end
        end
    end
end

---Projects each element of a sequence and flattens the resulting sequences into one sequence.
---@param selector fun(item: any, position: integer): table
---@return Enumerator3
function Enumerator3:flatMap(selector)
    assert(selector, 'Selector cannot be nil')
    return Enumerator3.new(flatMapFunc, self.index, wrapContext2(self, selector, 0))
end


-- CONCAT
local concatFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Second (collection/table)

    if state == -4 then
        return false
    end

	if state == 0 then
        nextContext = { getIterator(nextFunc, nextContext, index) }
		state = setState(context, 1)
	end

    while true do
        local current
        index, current = moveIterator(nextContext)
        if index ~= nil then
            context[3] = nextContext
            return index, current
        end

        -- Incremental state can be used to chain together infinite iterators instead of just 2.
        -- Though, currently only two are supported.
        state = setState(context, state + 1)
        if state > 2 then
            setState(context, -4)
            return nil, nil
        end
        nextContext = { getIterator(context[4]) }
    end
end

---Concatenates two sequences.
---@param second table
---@return Enumerator3
function Enumerator3:concat(second)
	assert(second, 'Second cannot be nil')
    return Enumerator3.new(concatFunc, self.index, wrapContext1(self, second))
end


-- APPEND
local appendFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Item
    -- 5: ItemIndex
    -- 6: Append / Prepend

    -- This state is used for returning the prepend item as well as
    -- initialising the iterator.
    if state == 0 then
        -- Iterator gets wrapped so the index is stored locally.
        nextContext = { getIterator(nextFunc, nextContext, index) }
        context[3] = nextContext
        state = setState(context, 1)

        local append = context[6]
        if append == false then
            local item, itemIndex = context[4], context[5]
            return itemIndex, item
        end
    end

    if state == 1 then
        local current
        index, current = moveIterator(nextContext)
        if index ~= nil then
            return index, current
        else
            state = setState(context, -4)
        end

        local item, itemIndex = context[4], context[5]
        local append = context[6]
        if append == true then
            return itemIndex, item
        end
    end

    return nil, nil
end

---Appends a value to the end of the sequence.
---@param item any
---@param index? any
---@return Enumerator3
function Enumerator3:append(item, index)
    if index == nil then index = 'nil' end -- Index needs *some* value

    return Enumerator3.new(appendFunc, self.index, wrapContextN(self, item, index, true))
end

---Prepends a value to the start of the sequence.
---@param item any
---@param index? any
---@return Enumerator3
function Enumerator3:prepend(item, index)
    if index == nil then index = 'nil' end -- Index needs *some* value

    return Enumerator3.new(appendFunc, self.index, wrapContextN(self, item, index, false))
end


-- UNIQUE / DISTINCT
local uniqueFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    local keySelector = context[4]
    -- 5: Hashset for uniqueness
    local current

    if state == -4 then
        return false
    end

    if state == 0 then
        nextContext = initialiseContext(context)
        -- If we have any items, create a hashtable. Otherwise abort.
        -- We check this by moving to the next item in the iterator and checking
        -- if it's not nil. We then set the state to 1 which immediately uses the current/index
        -- without having to move to a new item.
        index, current = nextFunc(nextContext, index)
        if index ~= nil then
            context[5] = {}
            state = 1
        else
            setState(context, -4)
            return nil, nil
        end
    end

    -- State is only updated locally, because it is potentially set multiple times during one iteration.
    -- Persist state only when returned.
    local set = context[5]
    while true do
        -- Try to grab a new item.
        if state == 2 then
            index, current = nextFunc(nextContext, index)
            if index ~= nil then
                state = 1
            else
                setState(context, -4)
                return nil, nil
            end
        end

        if state == 1 then
            state = 2
            local key = keySelector and keySelector(current, index) or current
            if addToSet(set, key) == true then
                setState(context, state)
                return index, current
            end
        end
    end
end

---Returns unique (distinct) elements from a sequence according to a specified key selector function.
---@param keySelector? fun(current: any, index: any): any
---@return Enumerator3 
function Enumerator3:unique(keySelector)
    return Enumerator3.new(uniqueFunc, self.index, wrapContext1(self, keySelector))
end


-- DIFFERENCE
local differenceFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Second collection
    local keySelector = context[5]
    -- 6: Hashset for uniqueness
    local current

    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        nextContext = initialiseContext(context)
        context[6] = Enumerator3.toSet(context[4], keySelector)
        setState(context, 1)
    end

    local set = context[6]
    while true do
        index, current = nextFunc(nextContext, index)
        if index ~= nil then
            local key = keySelector and keySelector(current, index) or current
            if addToSet(set, key) == true then
                return index, current
            end
        else
            setState(context, -4)
            return nil, nil
        end
    end
end

---Produces the set difference of two sequences according to a specified key selector function.
---@param second table
---@param keySelector? fun(current: any, index: any): any
---@return Enumerator3 
function Enumerator3:difference(second, keySelector)
    assert(second, 'Second collection cannot be nil.')

    return Enumerator3.new(differenceFunc, self.index, wrapContext2(self, second, keySelector))
end


-- UNION
local unionFunc = function (context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Second collection
    local keySelector = context[5]
    -- 6: Hashset for uniqueness
    local current

    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        nextContext = { getIterator(nextFunc, nextContext, index) }
        context[3] = nextContext
        context[6] = {}
        state = setState(context, 1)
    end

    local set = context[6]
    while state > 0 do
        -- Iteration state (exhaust the current iterator)
        while true do
            index, current = moveIterator(nextContext)
            if index == nil then break end

            local key = keySelector and keySelector(current, index) or current
            if addToSet(set, key) == true then
                return index, current
            end
        end

        -- If the state is 2 at this point, the second iterator has 
        -- already been initialised and exhausted.
        if state == 2 then
            setState(context, -4)
            return nil, nil
        end

        -- Grab second iterator, first one is empty
        nextContext = { getIterator(context[4]) }
        context[3] = nextContext
        state = setState(context, 2)
    end
end

---Produces the set union of two sequences according to a specified key selector function.
---@param second table
---@param keySelector? fun(current: any, index: any): any
---@return Enumerator3
function Enumerator3:union(second, keySelector)
    assert(second, 'Second collection cannot be nil.')

    return Enumerator3.new(unionFunc, self.index, wrapContext2(self, second, keySelector))
end

-- INTERSECT
local intersectFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Second collection
    local keySelector = context[5]
    -- 6: Hashset for uniqueness
    local current

    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        nextContext = initialiseContext(context)
        context[6] = Enumerator3.toSet(context[4], keySelector)
        state = setState(context, 1)
    end

    local set = context[6]
    while true do
        index, current = nextFunc(nextContext, index)
        if index == nil then
            setState(context, -4)
            return nil, nil
        end

        local key = keySelector and keySelector(current, index) or current
        if removeFromSet(set, key) == true then
            return index, current
        end
    end
end

---Produces the set intersection of two sequences according to a specified key selector function.
---@param second table
---@param keySelector? fun(current: any, index: any): any
function Enumerator3:intersect(second, keySelector)
    assert(second, 'Second collection cannot be nil.')

    return Enumerator3.new(intersectFunc, self.index, wrapContext2(self, second, keySelector))
end


-- ZIP
local zipFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: Second collection
    local resultSelector = context[5]

    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        -- First iterator gets stored in context, second in context[4]
        nextContext = initialiseContext(context)
        context[4] = { getIterator(context[4]) }
        state = setState(context, 1)
    end

    local second = context[4]
    local indexA, currentA = nextFunc(nextContext, index)
    local indexB, currentB = moveIterator(second)

    if indexA ~= nil and indexB ~= nil then
        local result = resultSelector and resultSelector(currentA, currentB) or { currentA, currentB }
        return indexA, result
    end

    setState(context, -4)
    return nil, nil
end

---Applies a specified function to the corresponding elements of two sequences, producing a sequence of the results.
---@param second table
---@param resultSelector? fun(itema: any, itemb: any): any
---@return Enumerator3
function Enumerator3:zip(second, resultSelector)
    assert(second, 'Second collection cannot be nil.')

    return Enumerator3.new(zipFunc, self.index, wrapContext2(self, second, resultSelector))
end


-- GROUPBY / GROUPBYRESULT
local groupByFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: keySelector
    -- 5: elementSelector
    local resultSelector = context[6]

    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        -- Pack iterator info into a structure getIterator recognises
        local iterator = { nextFunc, nextContext, index, _isIterator = true }

        local keySelector = context[4]
        local elementSelector = context[5]
        local lookup = Enumerator3.toLookup(iterator, keySelector, elementSelector)

        -- Overwrite iterator with new iterator.
        nextFunc, nextContext, index = getIterator(lookup)
        context[2] = nextFunc
        context[3] = nextContext
        state = setState(context, 1)
    end

    local key, elements = nextFunc(nextContext, index)
    if key ~= nil then
        local result = resultSelector and resultSelector(key, elements) or elements
        return key, result
    end

    setState(context, -4)
    return nil, nil
end

---Groups the elements of a sequence.
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
---@return Enumerator3
function Enumerator3:groupBy(keySelector, elementSelector)
    assert(keySelector, 'Key Selector cannot be nil.')

    return Enumerator3.new(groupByFunc, self.index, wrapContext2(self, keySelector, elementSelector))
end

function Enumerator3:groupByResult(keySelector, elementSelector, resultSelector)
    assert(keySelector, 'Key Selector cannot be nil.')
    assert(resultSelector, 'Result Selctor cannot be nil.')

    return Enumerator3.new(groupByFunc, self.index, wrapContextN(self, keySelector, elementSelector, resultSelector))
end

-- SORTING
-- This function is verbose on purpose. It is beneficial to spend more time returning a sort function
-- than create a sort function that has to make additional checks besides sorting.
local getSortFunction = function(sortSelectors)
    -- Special case where sort isn't chained.
    if #sortSelectors == 1 then
        local sortSelector = sortSelectors[1]
        local acendingSort = sortSelector[1]
        local sortSelector = sortSelector[2]
        if sortSelector ~= nil then
            if acendingSort == true then
                return function(a, b) return sortSelector(a) < sortSelector(b) end
            else
                return function(a, b) return sortSelector(a) > sortSelector(b) end
            end
        else
            if acendingSort == true then
                return function(a, b) return a < b end
            else
                return function(a, b) return a > b end
            end
        end
    end

    -- Return chained sorting.
    return function(a, b)
		for _, sortSelector in pairs(sortSelectors) do
			local selector = sortSelector[2]
			local sortAscending = sortSelector[1]

			local leftVal = selector(a)
			local rightVal = selector(b)
			if leftVal ~= rightVal then
				if sortAscending then
					return leftVal < rightVal
				else
					return leftVal > rightVal
				end
			end
		end
	end
end

local sortFunc = function(context, index)
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    -- 4: sort selectors
    local current

    if state ~= 0 then
        if state ~= 1 then
            return nil, nil
        end
    else
        -- Create and sort buffer.
        nextContext = initialiseContext(context)
        local buffer = {}
        for _, value in nextFunc, nextContext, index do
            table.insert(buffer, value)
        end

        local sortFunc = getSortFunction(context[4])
        table.sort(buffer, sortFunc)

        nextFunc, nextContext, index = getIterator(buffer)
        context[2] = nextFunc
        context[3] = nextContext
        setState(context, 1)
    end

    index, current = nextFunc(nextContext, index)
    if index ~= nil then
        return index, current
    end

    setState(context, -4)
    return nil, nil
end

---Boilerplate function to setup various sorting Enumerator3
---@param self Enumerator3
---@param isDecending boolean
---@param selector? fun(current: any): any
---@return Enumerator3
local function getSorter(self, isDecending, selector)
    local selectors = { { isDecending, selector } }
    return Enumerator3.new(sortFunc, self.index, wrapContext1(self, selectors))
end

---Boilerplate function to chain one sort to the previous
---@param self table
---@param isDecending any
---@param selector any
local function getChainedSorter(self, isDecending, selector)
    assert(self.nextFunc == sortFunc, 'Previous function must be a "sort" or "order" variant.')

    -- Since the context isn't accessible for the child, we unwrap the context first by calling the factory
    -- we can then access the context to add a new sortselector to it.
    local selectors = self.contextFactory()[4]
    table.insert(selectors, { isDecending, selector })

    -- Wrap context into a factory again.
    self.contextFactory = wrapContext1(self, selectors)
    return self
end

function Enumerator3:sort()
    return getSorter(self, true)
end

function Enumerator3:sortDescending()
    return getSorter(self, false)
end

function Enumerator3:sortBy(selector)
    return getSorter(self, true, selector)
end

function Enumerator3:sortByDescending(selector)
    return getSorter(self, false, selector)
end

function Enumerator3:thenBy(selector)
    return getChainedSorter(self, true, selector)
end

function Enumerator3:thenByDecending(selector)
    return getChainedSorter(self, false, selector)
end

--------------------------------------------------------------------------------
-- Equality functions
--------------------------------------------------------------------------------
local equals = function(a, b)
    -- Start with a quick Lua equals comparison
    local result = a == b
    if result then
        return true
    end

    -- If this is false, we check for additional equality, such as sequence equality.
    local atype = type(a)
    local btype = type(b)
    if atype ~= btype then
        return false
    end

    if atype == 'table' then
        return Enumerator3.sequenceEquals(a, b)
    else
        return result
    end
end

---comment
---@param self Enumerator3
---@param item any
---@param comparer? fun(a: any, b: any): boolean
---@return integer
function Enumerator3.indexOf(self, item, comparer)
    local index = 1
    local cmp = comparer or equals
    for _, v in getIterator(self) do
        if cmp(v, item) == true then
            return index
        end
        index = index + 1
    end
    return -1
end

---Determines whether the sequence contains the provided item according to an equality comparer.
---@param self table
---@param item any
---@param comparer? fun(a: any, b: any): boolean
---@return boolean
function Enumerator3.contains(self, item, comparer)
    local cmp = comparer or equals
    for _, v in getIterator(self) do
        if cmp(v, item) == true then
            return true
        end
    end
    return false
end

---Determines whether two sequences are equal according to an equality comparer.
---@param self table
---@param other table
---@param comparer? fun(a: any, b: any): boolean
---@return boolean
function Enumerator3.sequenceEquals(self, other, comparer)
    local cmp = comparer or equals

    local next, context, index, current = getIterator(other)

    for _, v in getIterator(self) do
        index, current = next(context, index)
        if cmp(v, current) == false then
            return false
        end
    end

    -- We have exhausted the first iterator. For the sequences to be equal, the second iterator also
    -- has to be exhausted. We know this is the case if the last returned index is also nil.
    return index == nil
end

--------------------------------------------------------------------------------
-- Sequence search functions
--------------------------------------------------------------------------------
local function getFirstItem(self, defaultValue)
    local firstItem = nil

    local nextFunc, context, index = getIterator(self)
    -- Shortcut for arrays
    if nextFunc == ipairsFunc then
        firstItem = context[1]
    else
        _, firstItem = nextFunc(context, index)
    end

    return firstItem or defaultValue
end

local function getLastItem(self, defaultValue)
    local nextFunc, context, index = getIterator(self)
    local lastItem = nil

    if nextFunc == ipairsFunc then
        lastItem = context[#context]
    else
        -- Run iterator to the end
        local k, v = index, _
        repeat
            k, v = nextFunc(context, k)
        until k == nil
        lastItem = v
    end

    return lastItem or defaultValue
end

---Returns the first item in the sequence, or a default value if there are no items in the sequence.
---@param self table
---@param predicate? fun(value: any, key: any): boolean
---@param defaultValue? any
---@return any
function Enumerator3.firstOrDefault(self, predicate, defaultValue)
    if predicate == nil then
        return getFirstItem(self, defaultValue)
    end

    local nextFunc, context, index = getIterator(self)
    for k, v in nextFunc, context, index do
        if predicate(v, k) == true then
            return v
        end
    end

    return defaultValue
end

---Returns the first item in the sequence, or an error if there are no items in the sequence.
---@param self table
---@param predicate? fun(value: any, key: any): boolean
---@return any
function Enumerator3.first(self, predicate)
    local result = Enumerator3.firstOrDefault(self, predicate, nil)
    if result == nil then
        error('Sequence contains no (matching) elements.')
    else
        return result
    end
end

---Returns the last item in the sequence, or a default value if there are no items in the sequence.
---@param self table
---@param predicate? fun(value: any, key: any): boolean
---@param defaultValue? any
---@return any
function Enumerator3.lastOrDefault(self, predicate, defaultValue)
    if predicate == nil then
        return getLastItem(self, defaultValue)
    end

    local nextFunc, context, index = getIterator(self)
    local lastMatch = nil
    for k, v in nextFunc, context, index do
        if predicate(v, k) == true then
            lastMatch = v
        end
    end

    return lastMatch or defaultValue
end

---Returns the last item in the sequence, or an error if there are no items in the sequence.
---@param self table
---@param predicate? fun(value: any, key: any): boolean
---@return any
function Enumerator3.last(self, predicate)
    local result = Enumerator3.lastOrDefault(self, predicate, nil)
    if result == nil then
        error('Sequence contains no (matching) elements.')
    else
        return result
    end
end

---Returns the length of this collection.
---@param self table
---@return integer
function Enumerator3.count(self)
    local nextFunc, context, index = getIterator(self)
    -- Shortcut for arrays
    if nextFunc == ipairsFunc then
        return #context
    end

    local len = 0
    for _, _ in nextFunc, context, index do
        len = len + 1
    end
    return len
end

---Returns the number of elements in a sequence that match a predicate.
---@param self Enumerator3
---@param predicate fun(value: any, key: any): any
---@return integer
function Enumerator3.countBy(self, predicate)
    assert(predicate, 'Predicate cannot be nil.')
    local count = 0
    for k, v in getIterator(self) do
        if predicate(v, k) == true then
            count = count + 1
        end
    end
    return count
end

--- Determines whether all elements of a sequence satisfy a condition.
---@param self Enumerator3
---@param predicate fun(current: any, index: any): boolean
---@return boolean
function Enumerator3.all(self, predicate)
    assert(predicate)
    for key, value in getIterator(self) do
        if predicate(value, key) == false then
            return false
        end
    end
    return true
end

---Determines whether any element of a sequence exists or satisfies a condition.
---@param self Enumerator3
---@param predicate? fun(current: any, index: any): boolean
---@return boolean
function Enumerator3.any(self, predicate)
    local nextFunc, context, index = getIterator(self)
    -- Check without predicate (just need a first item to exist)
    if predicate == nil then
        -- Array shortcut.
        if nextFunc == ipairsFunc then
            return #context > 0
        else
            -- Check if we can iterate once.
            local inx, _ = nextFunc(context, index)
            return inx ~= nil
        end
    -- Check with a predicate. We need one item that matches the predicate.
    else
        for key, value in nextFunc, context, index do
            if predicate(value, key) == true then
                return true
            end
        end
    end

    return false
end

--------------------------------------------------------------------------------
-- Iteration via function
--------------------------------------------------------------------------------
function Enumerator3.foreach(self, func)
    assert(type(func) == 'function')

    for _, current in getIterator(self) do
        func(current)
    end
    return nil
end

function Enumerator3.ipairs(self)
    return getIterator(self)
end

function Enumerator3.pairs(self)
    return getIterator(self)
end

--------------------------------------------------------------------------------
-- Reduction
--------------------------------------------------------------------------------

---Applies an accumulator function over a sequence. The specified seed value is used as the initial accumulator value, and the specified function is used to select the result value.
---@param self table
---@param func fun(accumulate: any, next: any): any
---@param seed? any
---@param resultSelector? fun(result: any): any
---@return any
function Enumerator3.aggregate(self, func, seed, resultSelector)
    assert(func, 'Aggregate function cannot be nil.')

    local result = seed
    local next, context, index = getIterator(self)

    -- We need a starting element.
    if result == nil then
        index, result = next(context, index)
        if index == nil then
            error('Sequence contains no elements')
        end
    end

    for _, current in next, context, index do
        result = func(result, current)
    end

    if resultSelector ~= nil then
        result = resultSelector(result)
    end

    return result
end

--------------------------------------------------------------------------------
-- Arithmetic functions
--------------------------------------------------------------------------------
local maxCmp = function(a, b)
    if a < b then return true end
end

local minCmp = function(a, b)
    if a > b then return true end
end

local minmaxBy = function(self, selector, comparison)
    local nextFunc, context, index = getIterator(self)

    -- Grab the first item
    local key, value = nextFunc(context, index)
    local maxVal = selector(value, key)
    local maxItem = value

    local curr = nil
    for k, v in nextFunc, context, key do
        curr = selector(v, k)
        if comparison(maxVal, curr) == true then
            maxVal = curr
            maxItem = v
        end
    end

    return maxItem
end

local minmax = function(self, comparison)
    local nextFunc, context, index = getIterator(self)

    -- Grab the first item
    local key, value = nextFunc(context, index)
    local maxVal = value

    for _, v in nextFunc, context, key do
        if comparison(maxVal, v) == true then
            maxVal = v
        end
    end

    return maxVal
end

---Returns the maximum value in a sequence of values.
---@param self Enumerator3
---@return number
function Enumerator3.max(self)
    return minmax(self, maxCmp)
end

---Returns the item with the highest value in a sequence.
---@param self Enumerator3
---@param selector fun(item: any, index: any): number
---@return unknown
function Enumerator3.maxBy(self, selector)
    assert(selector)
    return minmaxBy(self, selector, maxCmp)
end

---Returns the minimum value in a sequence of values.
---@param self Enumerator3
---@return number
function Enumerator3.min(self)
    return minmax(self, minCmp)
end

---Returns the item with the lowest value in a sequence.
---@param self Enumerator3
---@param selector fun(item: any, index: any): number
---@return unknown
function Enumerator3.minBy(self, selector)
    assert(selector)
    return minmaxBy(self, selector, minCmp)
end

---Returns the sum value of a sequence of values.
---@param self Enumerator3
---@return number
function Enumerator3.sum(self)
    local total = 0
    for _, v in getIterator(self) do
        total = total + v
    end
    return total
end

--------------------------------------------------------------------------------
-- Enumerator functions that result in a new collection.
--------------------------------------------------------------------------------

---Creates a set from a table|Enumerator3 with a specified key for uniqueness.
---@param self table
---@param keySelector? fun(current:any, index: any): any
---@return table
function Enumerator3.toSet(self, keySelector)
    local set = {}
    for inx, current in getIterator(self) do
        if keySelector == nil then
            set[current] = true
        else
            local key = keySelector(current, inx)
            set[key] = true
        end
    end
    return set
end

---Creates an array from a table|Enumerator3
---@param self table
---@return table
function Enumerator3.toTable(self)
    local tab = {}
    for _, value in getIterator(self) do
        table.insert(tab, value)
    end
    return tab
end

---Creates a lookup from a table|Enumerator3
---This is a { key, { elements } } structure
---@param self table
---@return table
function Enumerator3.toLookup(self, keySelector, elementSelector)
    local lookup = {}

    for inx, value in getIterator(self) do
        -- Retrieve grouping by key
        local key = keySelector(value, inx)
        assert(key, 'Key cannot be nil.')
        local grouping = lookup[key]

        -- Create new grouping if it doesn't exist.
        if grouping == nil then
            grouping = {}
            lookup[key] = grouping
        end

        -- Add element to grouping
        local element = elementSelector and elementSelector(value, inx) or value
        table.insert(grouping, element)
    end

    return lookup
end

---Creates a dictionary (key, value) structure based on a keyselector and an elementselector
---@param self Enumerator3
---@param keySelector? fun(value: any, key: any): any
---@param elementSelector? fun(value: any, key: any): any
---@return table
function Enumerator3.toDictionary(self, keySelector, elementSelector)
    local dictionary = {}

    for key, value in getIterator(self) do
        if keySelector ~= nil then
            key = keySelector(value, key)
        end
        if elementSelector ~= nil then
            value = elementSelector(value, key)
        end

        dictionary[key] = value
    end
    return dictionary
end

---Creates a deep copy of the provided object.
---@param self any
---@return any
function Enumerator3.deepCopy(self)
    local orig_type = type(self)
    local copy
    if orig_type == 'table' then
        copy = {}
        for k, v in getIterator(self) do
            copy[Enumerator3.deepCopy(k)] = Enumerator3.deepCopy(v)
        end
    else
        copy = self
    end
    return copy
end

--------------------------------------------------------------------------------
-- Function aliasses
--------------------------------------------------------------------------------
Enumerator3.select = Enumerator3.map
Enumerator3.selectMany = Enumerator3.flatMap
Enumerator3.distinct = Enumerator3.unique
Enumerator3.except = Enumerator3.difference
Enumerator3.order = Enumerator3.sort
Enumerator3.orderDecending = Enumerator3.sortDescending
Enumerator3.orderBy = Enumerator3.sortBy
Enumerator3.orderByDecending = Enumerator3.sortByDescending

Enumerator3.each = Enumerator3.foreach
Enumerator3.for_each = Enumerator3.foreach
Enumerator3.length = Enumerator3.count
Enumerator3.lengthBy = Enumerator3.countBy
Enumerator3.has = Enumerator3.contains
Enumerator3.fold = Enumerator3.aggregate
Enumerator3.reduce = Enumerator3.aggregate

return Enumerator3

--------------------------------------------------------------------------------
-- Explanation on how this module functions / how it can be extended
--------------------------------------------------------------------------------

----------------------------------------------------
-- Chaining functions and the Enumerator3 class
----------------------------------------------------
--[[
The Enumerator3 class holds metadata for chaining functions and the context for the operation. Specifically, it holds three pieces of information:
- The function to be executed next in the sequence.
- Index of the Iterator Function: (which is 0 for ipairs and nil for pairs)
- The Context Factory (A function that generates a fresh context)

The Enumerator3 class ONLY serves as a storage for metadata with the sole purpose of chaining functions. This is also the reason why all chainable functions are added to Enumerator3. When a chainable function gets called, a new Enumerator3 instance gets created. The current (self) Enumerator3 metadata is passed to the newly created instance to form a chain (via the context object). The Enumerator3 class is ONLY used when functions are chained, NOT when executed. This is also why this module is as fast as it is. There are no calls made to retrieve data from the Enumerator3 instances when iterating. The context object is exclusively used for this purpose.

When the chain of functions gets evaluated, the following happens.
1. The next function in the chain is retrieved.
2. The context object for the next function is created via the factory.
3. The index for the next function is retrieved.
4. The next function is invoked using the context and index. (Usually in a loop, until the index returns nil)

The above 4 steps happen for every function in the chain.
]]

----------------------------------------------------
-- The context object and context factory
----------------------------------------------------
--[[
Since some functions require data to be remembered for consecutive "nextFunc" calls, a context object is required. This can be seen, for example, for the map function. Where the map function stores an index (position) and passes it to the selector function. 

However, if the Enumerator3 object is reused or chained with a different function, the context from the previous iteration persists, leading to incorrect behavior. For example, in the map function, the index remains at the last iteration's position instead of restarting at 0. To prevent this, the context isn't passed directly to Enumerator3. Instead, it utilizes a function, known as a context factory, to generate a fresh context each time the function chain is evaluated.

This is achieved via the wrapContext function. The context always contains the same fields in the same order:
[1] State: Represents the current state of the function.
[2] Next Function: The subsequent function in the chain.
[3] Next Context (Factory): The factory function for creating the next context.

Depending on the specific function used, additional fields may be added to the context object. For example, the map function appends:
- A selector function as the 4th field.
- An indexer as the 5th field.
The sequence in which fields are added is crucial because it determines how they must be retrieved later. The fields added by a chainable function must align precisely with the fields expected by its implementation function.

For example in the map (chainable) function we see:
"Enumerator3.new(mapFunc, self.index, wrapContext2(self, selector, 0))"

The wrapContext function adds the initial three fields (state, next function, next context factory), with the selector as the 4th field and the indexer (defaulting to 0) as the 5th field.

The implementing function of map retrieves the fields in the exact order they were added:
Example: The implementation of map retrieves all of these fields in the same order when called:
    local state = context[1]
    local nextFunc = context[2]
    local nextContext = context[3]
    local selector = context[4]
    local position = context[5]


Maintaining this order is critical for ensuring that the context is correctly interpreted and that each function receives the data it expects.
]]


----------------------------------------------------
-- The state of the implementing function
----------------------------------------------------
--[[
When an implementing function is invoked for the first time, the context object must be created using the provided factory function. During this initial call, the following occurs:

- Check Initial State: The state is checked to see if it is 0.
- Replace Context Factory: If the state is 0, the context factory function is replaced with the context itself via the initialiseContext function.
- Set to Active State: The state is then set to 1, signifying that the function is now in its active iteration phase.

On subsequent calls to the implementing function, the initialization step (state 0) is bypassed, and the function proceeds directly with the iteration step. Some implementing functions may require multiple states to handle more complex behaviors. 

When the state becomes -4, it signifies that the iteration has reached its end. At this point, the function returns nil to indicate to the caller that there are no further elements to process. 
]]

----------------------------------------------------
-- Expanding upon the module
----------------------------------------------------
--[[
If a new chainable function gets added (a function that returns an Enumerator3 instance) the following should be ensured:
1. A chainable function and an implementation function pair must be implemented.
2. The chainable function must be added to the Enumerator3 class to allow chaining.
3. The chainable function must pass a context factory to the Enumerator3 constructor, with the parameters in the same order as the implementing function uses them.
4. The implementing function must retrieve the required fields in the same order as the context factory specifies.
5. The implementing function must create a new context via the factory function and store the context for future iterations.
6. The state field can be used to track the state of the implementing function and change its behaviour. For instance, having an initialisation state and an iteration state.

Simple functions such as the map/mapFunc and where/whereFunc can be used as a reference.



If a new result function gets added (a function that iterates over the function chain and returns a direct result) the following should be ensured:
1. Retrieve the nextFunc, contextFactory and index from the current object (self).
2. Create the context by calling the contextFactory.
3. Iterate over the chain by calling nextFunc using the context and index.

If the index returned from nextFunc is nil, it means that the iteration has come to an end.
]]