Module:FunList/Iterators

< Module:FunList
Revision as of 20:46, 26 July 2024 by Ricewind (talk | contribs) (bugfix)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)

Documentation for this module may be created at Module:FunList/Iterators/doc

local Enumerable = require('Module:FunList/Enumerable')
local Lookup = require('Module:FunList/Lookup')
local TableEnumerator = require('Module:FunList/TableEnumerator')

---Creates a TableIterator Enumerable or returns the Enumerable if it already is one.
---We set the propper starting index here since this only applies for tables!
---@param source any
---@param isArray any
---@return Enumerator
local function getTableEnumerator(source, isArray)
    if source.getEnumerator == nil then
        assert(type(source) == 'table', 'sourceTable must be either an Enumerable or table.')
        return TableEnumerator.create(source, isArray)
    else
        return source:getEnumerator(isArray)
    end
end

-- Helper functions to create and manage a hashset structure.
-- These are used to distinguish elements for Union/Unique/Difference etc operations.
---@param source any
---@return table
local function sourceToSet(source)
    local set = {}
    if Enumerable.isEnumerable(source) then
        local enum = source:getEnumerator()
        while enum:moveNext() do
            set[enum.current] = true
        end
    else
        assert(type(source) == 'table', 'Source must be a table.')
        for _, v in pairs(source) do
            set[v] = true
        end
    end
    return set
end

---@param set table
---@param item any
---@return boolean
local function setAdd(set, item)
    if set[item] == nil then
        set[item] = true
        return true
    end
    return false
end

---@param set table
---@param item any
---@return boolean
local function setRemove(set, item)
    if set[item] == nil then
        return false
    end
    set[item] = nil
    return true
end

-- Defines the base class for all Iterator state machines. This is an Enumerable and Enumerator in one.
---@class Iterator : Enumerable, Enumerator
---@field current any
---@field index any
---@field _state integer
---@field _isArray boolean
local Iterator = setmetatable({}, { __index = Enumerable })
local Iterator_mt = {
	__index = Iterator,
	__pairs = Enumerable.getPairs,
	__ipairs = Enumerable.getiPairs
}

---@return Iterator
function Iterator.new()
	local self = setmetatable({}, Iterator_mt)
	self.current = nil
	self.index = nil
    self._state = -1
	-- Assume by default we are not dealing with a simple array
	self._isArray = false
	return self
end

-- This is normally implemented by the Enumerator abstract class.
-- But to prevent double inheritance, we cheat a bit and implement it outselves.
---@return boolean
function Iterator:moveNext()
    error('Abstract function must be overridden in derived class.')
end

---Returns this Enumerator or creates a new one if it has been used.
---@param isArray? boolean
---@return Enumerator
function Iterator:getEnumerator(isArray)
    -- The default _state is -1 which signifies a Iterator isn't used.
    local instance = (self._state == -1) and self or self:clone()
    instance._isArray = isArray or self._isArray
    instance._state = 0
    return instance
end

---@return Iterator
function Iterator:clone()
    error('Abstract function must be overridden in derived class.')
end

function Iterator:finalize()
    -- Signals invalid _state.
    self._state = -4
end

function Iterator.createIteratorSubclass()
    local derivedClass = setmetatable({}, { __index = Iterator })
    derivedClass.__index = derivedClass
    derivedClass.__pairs  = Enumerable.getPairs
    derivedClass.__ipairs = Enumerable.getiPairs
    return derivedClass
end

--Projects each element of a sequence into a new form.
---@class MapIterator : Iterator
---@field _source Enumerable
---@field _selector fun(value: any, index: any): any
---@field _position integer
---@field _enumerator Enumerator
local MapIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param selector fun(value: any, index: any): any
function MapIterator.new(source, selector)
	assert(source, 'Source cannot be nil')
	assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Iterator.new(), MapIterator)
    self._source = source
    self._selector = selector
    self._enumerator = nil
    self._position = 0
    return self
end

function MapIterator:moveNext()
    local state = self._state
    if state ~= 0 then
        if state ~= 1 then
            return false
        end
    else
        self._state = 1 state = 1
     	self._position = 0
		self._enumerator = self._source:getEnumerator(self._isArray)
    end

    local _enumerator = self._enumerator
    if _enumerator:moveNext() == true then
        -- 1 based index because Lua
        local pos = self._position + 1
        local index = _enumerator.index
        local current = self._selector(_enumerator.current, pos)
        assert(current, 'Selected value must be non-nil')
        self.index = index
        self.current = current
        self._position = pos
        return true
    end
    self:finalize()

	return false
end

function MapIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    Iterator.finalize(self)
end

function MapIterator:clone()
    return MapIterator.new(self._source, self._selector)
end

--Filters a sequence of values based on a predicate.
---@class WhereIterator : Iterator
---@field _source Enumerable
---@field _predicate fun(value: any, index: any): boolean
---@field _enumerator Enumerator
local WhereIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param predicate fun(value: any, index: any): boolean
function WhereIterator.new(source, predicate)
	assert(source, 'Source cannot be nil')
	assert(predicate, 'Predicate cannot be nil')
    local self = setmetatable(Iterator.new(), WhereIterator)
    self._source = source
    self._predicate = predicate
    self._enumerator = nil
    return self
end

function WhereIterator:moveNext()
    local state = self._state
    if state ~= 0 then
        if state ~= 1 then
            return false
        end
    else
		self._state = 1 state = 1
		self._enumerator = self._source:getEnumerator(self._isArray)
    end

    -- Cache some class lookups since these are likely
    -- to be accessed more than once per moveNext
    local enum = self._enumerator
    local nextFunc = enum.moveNext
    local predicate = self._predicate
    while nextFunc(enum) == true do
        local sourceElement = enum.current
        local sourceIndex = enum.index
        if predicate(sourceElement, sourceIndex) == true then
            self.index = sourceIndex
            self.current = sourceElement
            return true
        end
    end
    self:finalize()

	return false
end

function WhereIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    Iterator.finalize(self)
end

function WhereIterator:clone()
    return WhereIterator.new(self._source, self._predicate)
end

--Projects each element of a sequence to an Enumerable and flattens the resulting sequences into one sequence.
---@class FlatMapIterator : Iterator
---@field _source Enumerable
---@field _selector fun(value: any, index: integer): any
---@field _position integer
---@field _enumerator Enumerator
---@field _sourceEnumerator Enumerator
local FlatMapIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param selector fun(value: any, index: integer): any
function FlatMapIterator.new(source, selector)
	assert(source, 'Source cannot be nil')
	assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Iterator.new(), FlatMapIterator)
    self._source = source
    self._selector = selector
    self._position = 0
    self._enumerator = nil		 -- Iterator of the _source Enumerable
    self._sourceEnumerator = nil  -- Iterator of the nested Enumerable
    return self
end

function FlatMapIterator:moveNext()
    local state = self._state
    if state == -4 then
        return false
    end

    -- Setup _state
    if state == 0 then
        self._position = 0
        self._enumerator = self._source:getEnumerator(self._isArray)
        self._state = 3 state = 3 -- signal to get (_first) nested _enumerator
    end
    while true do
        -- Grab next value from nested _enumerator		
        if state == 4 then
            local sourceEnum = self._sourceEnumerator
            if sourceEnum:moveNext() then
                self.current = sourceEnum.current
                self.index = sourceEnum.index
                self._state, state = 4, 4 -- signal to get next item
                return true
            else
                -- Cleanup nested _enumerator
                self._sourceEnumerator:finalize()
                self._state = 3 state = 3 -- signal to get next _enumerator
            end
        end

		-- Grab nest nested _enumerator
        if state == 3 then
            if self._enumerator:moveNext() then
                local current = self._enumerator.current
                local pos = self._position + 1

                local sourceTable = self._selector(current, pos)
                -- Nested tables are never treated as arrays.
                self._sourceEnumerator = getTableEnumerator(sourceTable, false)
                self._position = pos
                self._state = 4 state = 4 -- signal to get next item
            else
            	-- _enumerator doesn't have any more nested enumerators.
                self:finalize()
                return false
            end
        end
    end
end

function FlatMapIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
    if self._sourceEnumerator then
        self._sourceEnumerator:finalize()
    end

    Iterator.finalize(self)
end

function FlatMapIterator:clone()
	return FlatMapIterator.new(self._source, self._selector)
end

--Concatenates two sequences.
---@class ConcatIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _enumerator Enumerator
local ConcatIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
function ConcatIterator.new(first, second)
	assert(first, 'First cannot be nil')
	assert(second, 'Second cannot be nil')
    local self = setmetatable(Iterator.new(), ConcatIterator)
    self._first = first
    self._second = second
    self._enumerator = nil
    return self
end

-- Function to grab enumerators in order. We currently only have two so this
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
function ConcatIterator:getEnumerable(index)
	if index == 0 then
		return self._first
	elseif index == 1 then
		return self._second
	else
		return nil
	end
end

function ConcatIterator:moveNext()
    local state = self._state
    if state == -4 then
        return false
    end

	if state == 0 then
		self._enumerator = self:getEnumerable(state)
			:getEnumerator(self._isArray)
		self._state = 1 state = 1
	end

	if state > 0 then
		while true do
            local enum = self._enumerator
			if enum:moveNext() == true then
				self.index = enum.index
				self.current = enum.current
				return true
			end

            state = state + 1
            self._state = state
			local next = self:getEnumerable(state - 1)
			if next ~= nil then
                -- Cleanup previous _enumerator
                self._enumerator:finalize()
				self._enumerator = getTableEnumerator(next, self._isArray)
			else
                self:finalize()
				return false
			end
		end
	end

	return false
end

function ConcatIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function ConcatIterator:clone()
    return ConcatIterator.new(self._first, self._second)
end

--Appends a value to the end or start of the sequence.
---@class AppendIterator : Iterator
---@field _source Enumerable
---@field _item any
---@field _itemIndex any
---@field _append boolean
---@field _enumerator Enumerator
local AppendIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param item any
---@param itemIndex? any
---@param append? boolean
function AppendIterator.new(source, item, itemIndex, append)
	assert(source, 'Source cannot be nil')
	assert(item, 'Item cannot be nil')
    local self = setmetatable(Iterator.new(), AppendIterator)
    self._source = source
    self._item = item
    -- Index needs *some* value, otherwise iteration stops.
    if itemIndex == nil then self._itemIndex = item else self._itemIndex = itemIndex end
    if append == nil then self._append = true else self._append = append end
    self._enumerator = nil
    return self
end

function AppendIterator:moveNext()
    local state = self._state
    if state == 0 then
        self._state = 1 state = 1
        -- Put the item in front, as a prepend.
        if self._append == false then
            self.current = self._item
            self.index = self._itemIndex
            return true
        end
    end

    -- Grab the _source _enumerator
    if state == 1 then
        self._enumerator = self._source:getEnumerator(self._isArray)
        self._state = 2 state = 2
    end

    if state == 2 then
        local enum = self._enumerator
        if enum:moveNext() then
            self.current = enum.current
            self.index = enum.index
            return true
        else
            self:finalize()
        end

        if self._append == true then
            self.current = self._item
            self.index = self._itemIndex
            return true
        end
    end

    return false
end

function AppendIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function AppendIterator:clone()
    return AppendIterator.new(self._source, self._item, self._itemIndex, self._append)
end

--Returns unique (distinct) elements from a sequence according to a specified key selector function. 
---@class UniqueIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local UniqueIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param keySelector? fun(value: any, index: any): any
function UniqueIterator.new(source, keySelector)
	assert(source, 'Source cannot be nil')
    local self = setmetatable(Iterator.new(), UniqueIterator)
    self._source = source
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end

function UniqueIterator:moveNext()
    local state = self._state
    if state == -4 then
        return false
    end

    if state == 0 then
        self._enumerator = self._source:getEnumerator(self._isArray)
        -- If we have any items, create a hashtable. Otherwise abort.
        if self._enumerator:moveNext() == true then
            self._set = {}
            self._state = 1 state = 1
        else
            self:finalize()
            return false
        end
    end

    local enum = self._enumerator
    local set = self._set
    while true do
        if state == 1 then
            self._state = 2 state = 2
            local current = enum.current
            local index = enum.index
            local key = self._keySelector(current, index)
            if setAdd(key) == true then
                self.current = current
                self.index = index
                return true
            end
        end

        -- Try to grab a new item.
        if state == 2 then
            if enum:moveNext() == true then
                self._state = 1 state = 1
            else
                self:finalize()
                return false
            end
        end
    end
end

function UniqueIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function UniqueIterator:clone()
    return UniqueIterator.new(self._source, self._keySelector)
end

--Produces the set difference of two sequences according to a specified key selector function.
---@class DifferenceIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local DifferenceIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function DifferenceIterator.new(first, second, keySelector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), DifferenceIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end

function DifferenceIterator:moveNext()
    local state = self._state
    if state ~= 0 then
        if state ~= 1 then
            return false
        end
    else
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = sourceToSet(self._second)
        self._state = 1 state = 1
    end

    local enum = self._enumerator
    local nextFunc = enum.moveNext
    local keySelector = self._keySelector
    local set = self._set
    while nextFunc(enum) == true do
        local current = enum.current
        local index = enum.index
        local key = keySelector(current, index)
        if setAdd(key) == true then
            self.current = current
            self.index = index
            return true
        end
    end

    self:finalize()
    return false
end

function DifferenceIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end
end

function DifferenceIterator:clone()
    return DifferenceIterator.new(self._first, self._second, self._keySelector)
end

--Produces the set union of two sequences according to a specified key selector function.
---@class UnionIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local UnionIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function UnionIterator.new(first, second, keySelector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), UnionIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end

local function enumerateSource(self, _state)
    while self._enumerator:moveNext() do
        local current = self._enumerator.current
        local index = self._enumerator.index
        if self._set:add(self._keySelector(current, index)) == true then
            self.current = self._enumerator.current
            self.index = self._enumerator.index
            self._state = _state
            return true
        end
    end
end

function UnionIterator:moveNext()
    if self._state == -4 then
        return false
    end

    if self._state == 0 then
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = {}
        self._state = 1
    end

    -- Process _first
    if self._state == 1 then
        if enumerateSource(self, 1) == true then
            return true
        end

        -- End _first enumeration
        self._state = 2
        self._enumerator:finalize()
        self._enumerator = getTableEnumerator(self._second, self._isArray)
    end

    -- Process _second
    if self._state == 2 then
        if enumerateSource(self, 2) == true then
            return true
        end

        -- End _first enumeration
        self._state = 3
    end

    self:finalize()
    return false
end

function UnionIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function UnionIterator:clone()
    return UnionIterator.new(self._first, self._second, self._keySelector)
end

--Produces the set intersection of two sequences according to a specified key selector function.
---@class IntersectIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local IntersectIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function IntersectIterator.new(first, second, keySelector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), IntersectIterator)
    self._first = first
    self._second = second
    self._keySelector = keySelector or function(x) return x end
    self._enumerator = nil
    self._set = nil
    return self
end

function IntersectIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._enumerator = self._first:getEnumerator(self._isArray)
        self._set = sourceToSet(self._second)
        self._state = 1
    end
    while self._enumerator:moveNext() do
        local current = self._enumerator.current
        local index = self._enumerator.index
        local key = self._keySelector(current, index)
        if setRemove(self._set, key) == true then
            self.current = self._enumerator.current
            self.index = self._enumerator.index
            self._state = 1
            return true
        end
    end

    self:finalize()
    return false
end

function IntersectIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function IntersectIterator:clone()
    return IntersectIterator.new(self._first, self._second, self._keySelector)
end

--Applies a specified function to the corresponding elements of two sequences, producing a sequence of the results.
---@class ZipIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _resultSelector fun(left: any, right: any): any
---@field _enumerator1 Enumerator
---@field _enumerator2 Enumerator
local ZipIterator = Iterator.createIteratorSubclass()

---@param first Enumerable
---@param second any
---@param resultSelector? fun(left: any, right: any): any
function ZipIterator.new(first, second, resultSelector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Iterator.new(), ZipIterator)
    self._first = first
    self._second = second
    self._resultSelector = resultSelector or (function(a, b) return {a, b} end)
    self._enumerator1 = nil
    self._enumerator2 = nil
    return self
end

function ZipIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._enumerator1 = self._first:getEnumerator(self._isArray)
        self._enumerator2 = getTableEnumerator(self._second, self._isArray)
        self._state = -4
    end
    if self._enumerator1:moveNext() == true and self._enumerator2:moveNext() == true then
        self.current = self._resultSelector(self._enumerator1.current, self._enumerator2.current)
        self.index = self._enumerator1.index
        self._state = 1
        return true
    end

    self:finalize()
    return false
end

function ZipIterator:finalize()
    if self._enumerator1 then
        self._enumerator1:finalize()
    end
    if self._enumerator2 then
        self._enumerator2:finalize()
    end

    Iterator.finalize(self)
end

function ZipIterator:clone()
    return ZipIterator.new(self._first, self._second, self._resultSelector)
end

--Groups the elements of a sequence.
---@class GroupByIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
function GroupByIterator.new(source, keySelector, elementSelector)
	assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    local self = setmetatable(Iterator.new(), GroupByIterator)
    self._source = source
    self._keySelector = keySelector
    self._elementSelector = elementSelector or function(x) return x end
    self._enumerator = nil
    self._lookup = nil
    return self
end

function GroupByIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
        self._enumerator = self._lookup:getEnumerator()
        self._state = 1
    end

    local enum = self._enumerator
    if enum:moveNext() == true then
        self.current = enum.current
        self.index = enum.index
        self._state = 1
        return true
    end

    self:finalize()
    return false
end

function GroupByIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function GroupByIterator:clone()
    return GroupByIterator.new(self._source, self._keySelector, self._elementSelector)
end

--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class GroupByResultIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _resultSelector fun(key: any, grouping: Grouping): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByResultIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
---@param resultSelector fun(key: any, grouping: Grouping): any
function GroupByResultIterator.new(source, keySelector, elementSelector, resultSelector)
	assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    assert(resultSelector, 'keySelector cannot be nil')
    local self = setmetatable(Iterator.new(), GroupByResultIterator)
    self._source = source
    self._keySelector = keySelector
    self._elementSelector = elementSelector or function(x) return x end
    self._resultSelector = resultSelector
    self._enumerator = nil
    self._lookup = nil
    return self
end

function GroupByResultIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
        self._enumerator = self._lookup:getEnumerator()
        self._state = 1
    end

    local enumerator = self._enumerator
    if enumerator:moveNext() == true then
        -- Transform Grouping element for enumeration
        local gKey = enumerator.current.key
        local gElements = enumerator.current
        self.current = self._resultSelector(gKey, gElements)
        self.index = enumerator.index
        self._state = 1
        return true
    end

    self:finalize()
    return false
end

function GroupByResultIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
    end

    Iterator.finalize(self)
end

function GroupByResultIterator:clone()
    return GroupByResultIterator.new(self._source, self._keySelector, self._elementSelector, self._resultSelector)
end


--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class SortableIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(value: any, index: any): any|?
---@field _descendingSort boolean
---@field _parentIterator SortableIterator
---@field _enumerator Enumerator
local SortableIterator = Iterator.createIteratorSubclass()

---@param source Enumerable
---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@param parent? SortableIterator
---@return SortableIterator
function SortableIterator.new(source, keySelector, descendingSort, parent)
	assert(source, 'source cannot be nil')
    assert(keySelector, 'keySelector cannot be nil')
    assert(parent == nil or parent.getEnumerableSorters ~= nil, 'parent must be of type SortableIterator')
    if descendingSort == nil then
        descendingSort = false
    end

    local self = setmetatable(Iterator.new(), SortableIterator)
    self._source = source
    self._keySelector = keySelector
    self._descendingSort = descendingSort
    self._parentIterator = parent
    self._enumerator = nil
---@diagnostic disable-next-line: return-type-mismatch
    return self
end

local function createSortFunc(sorters)
    local sortCount = #sorters
    if sortCount > 1 then
        return function(a, b)
            for _, sorter in ipairs(sorters) do
                local selector = sorter._keySelector
                local descendingSort = sorter._descendingSort

                local aVal = selector(a)
                local bVal = selector(b)
                if aVal ~= bVal then
                    if descendingSort == true then
                        return aVal > bVal
                    else
                        return aVal < bVal
                    end
                end
            end
        end
    end

    -- Faster sort function if there's no nested sorts.
    local sorter = sorters[1]
    local selector = sorter._keySelector
    if sorter._descendingSort == true then
        return function(a, b) return selector(a) > selector(b) end
    else
        return function(a, b) return selector(a) < selector(b) end
    end
end

function SortableIterator:moveNext()
    if self._state ~= 0 then
        if self._state ~= 1 then
            return false
        end
    else
        -- Collect sorter meta information.
        local sorters = {}
        self:getEnumerableSorters(sorters)
        local sortFunc = createSortFunc(sorters)

        local enumerator = self._source:getEnumerator(self._isArray)
        local buffer = {}
        -- Process and sort all previous enumerators.
        while enumerator:moveNext() == true do
            table.insert(buffer, enumerator.current)
        end
        table.sort(buffer, sortFunc)
        self._enumerator = TableEnumerator.createForArray(buffer)
        self._state = 1
    end

    local enumerator = self._enumerator
    if enumerator:moveNext() == true then
        self.current = enumerator.current
        self.index = enumerator.index
        self._state = 1
        return true
    end

    self:finalize()
    return false
end

---Collects all sorter functions down the line.
---@param sorters table
---@return table
function SortableIterator:getEnumerableSorters(sorters)
    -- Get parent sorter first.
    if self._parentIterator ~= nil then
        self._parentIterator:getEnumerableSorters(sorters)
    end
    table.insert(sorters, self)

    return sorters
end

---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@return SortableIterator
function SortableIterator:createSortableIterator(keySelector, descendingSort)
    return SortableIterator.new(self._source, keySelector, descendingSort, self)
end

function SortableIterator:finalize()
    if self._enumerator then
        self._enumerator:finalize()
        self._enumerator = nil
    end
    Iterator.finalize(self)
end

function SortableIterator:clone()
    return SortableIterator.new(self._source, self._keySelector, self._descendingSort, self._parentIterator)
end

return {
    MapIterator = MapIterator,
    WhereIterator = WhereIterator,
    FlatMapIterator = FlatMapIterator,
    ConcatIterator = ConcatIterator,
    AppendIterator = AppendIterator,
    UniqueIterator = UniqueIterator,
    DifferenceIterator = DifferenceIterator,
    UnionIterator = UnionIterator,
    IntersectIterator = IntersectIterator,
    ZipIterator = ZipIterator,
    GroupByIterator = GroupByIterator,
    GroupByResultIterator = GroupByResultIterator,
    SortableIterator = SortableIterator,
}