Module:FunList/Iterators: Difference between revisions
From Melvor Idle
(Add SelectEnumerator) |
(bugfix) |
||
(34 intermediate revisions by the same user not shown) | |||
Line 1: | Line 1: | ||
-- | local Enumerable = require('Module:FunList/Enumerable') | ||
-- | local Lookup = require('Module:FunList/Lookup') | ||
local | local TableEnumerator = require('Module:FunList/TableEnumerator') | ||
local | |||
---Creates a TableIterator Enumerable or returns the Enumerable if it already is one. | |||
---We set the propper starting index here since this only applies for tables! | |||
---@param source any | |||
---@param isArray any | |||
---@return Enumerator | |||
local function getTableEnumerator(source, isArray) | |||
if source.getEnumerator == nil then | |||
assert(type(source) == 'table', 'sourceTable must be either an Enumerable or table.') | |||
return TableEnumerator.create(source, isArray) | |||
else | |||
return source:getEnumerator(isArray) | |||
end | |||
end | |||
-- Helper functions to create and manage a hashset structure. | |||
-- These are used to distinguish elements for Union/Unique/Difference etc operations. | |||
---@param source any | |||
---@return table | |||
local function sourceToSet(source) | |||
local set = {} | |||
if Enumerable.isEnumerable(source) then | |||
local enum = source:getEnumerator() | |||
while enum:moveNext() do | |||
set[enum.current] = true | |||
end | |||
else | |||
assert(type(source) == 'table', 'Source must be a table.') | |||
for _, v in pairs(source) do | |||
set[v] = true | |||
end | |||
end | |||
return set | |||
end | |||
---@param set table | |||
---@param item any | |||
---@return boolean | |||
local function setAdd(set, item) | |||
if set[item] == nil then | |||
set[item] = true | |||
return true | |||
end | |||
return false | |||
end | |||
---@param set table | |||
---@param item any | |||
---@return boolean | |||
local function setRemove(set, item) | |||
if set[item] == nil then | |||
return false | |||
end | |||
set[item] = nil | |||
return true | |||
end | |||
-- Defines the base class for all Iterator state machines. This is an Enumerable and Enumerator in one. | |||
---@class Iterator : Enumerable, Enumerator | |||
---@field current any | |||
---@field index any | |||
---@field _state integer | |||
---@field _isArray boolean | |||
local Iterator = setmetatable({}, { __index = Enumerable }) | |||
local Iterator_mt = { | |||
__index = Iterator, | |||
__pairs = Enumerable.getPairs, | |||
__ipairs = Enumerable.getiPairs | |||
} | } | ||
function | ---@return Iterator | ||
local self = setmetatable({}, | function Iterator.new() | ||
local self = setmetatable({}, Iterator_mt) | |||
self.current = nil | self.current = nil | ||
self.index = nil | self.index = nil | ||
self._state = -1 | |||
-- Assume by default we are not dealing with a simple array | |||
self._isArray = false | |||
return self | return self | ||
end | end | ||
function Enumerator: | -- This is normally implemented by the Enumerator abstract class. | ||
-- But to prevent double inheritance, we cheat a bit and implement it outselves. | |||
---@return boolean | |||
function Iterator:moveNext() | |||
error('Abstract function must be overridden in derived class.') | |||
end | |||
---Returns this Enumerator or creates a new one if it has been used. | |||
---@param isArray? boolean | |||
---@return Enumerator | |||
function Iterator:getEnumerator(isArray) | |||
-- The default _state is -1 which signifies a Iterator isn't used. | |||
local instance = (self._state == -1) and self or self:clone() | |||
instance._isArray = isArray or self._isArray | |||
instance._state = 0 | |||
return instance | |||
end | |||
---@return Iterator | |||
function Iterator:clone() | |||
error('Abstract function must be overridden in derived class.') | |||
end | |||
function Iterator:finalize() | |||
-- Signals invalid _state. | |||
self._state = -4 | |||
end | |||
function Iterator.createIteratorSubclass() | |||
local derivedClass = setmetatable({}, { __index = Iterator }) | |||
derivedClass.__index = derivedClass | |||
derivedClass.__pairs = Enumerable.getPairs | |||
derivedClass.__ipairs = Enumerable.getiPairs | |||
return derivedClass | |||
end | |||
--Projects each element of a sequence into a new form. | |||
---@class MapIterator : Iterator | |||
---@field _source Enumerable | |||
---@field _selector fun(value: any, index: any): any | |||
---@field _position integer | |||
---@field _enumerator Enumerator | |||
local MapIterator = Iterator.createIteratorSubclass() | |||
---@param source Enumerable | |||
---@param selector fun(value: any, index: any): any | |||
function MapIterator.new(source, selector) | |||
assert(source, 'Source cannot be nil') | |||
assert(selector, 'Selector cannot be nil') | |||
local self = setmetatable(Iterator.new(), MapIterator) | |||
self._source = source | |||
self._selector = selector | |||
self._enumerator = nil | |||
self._position = 0 | |||
return self | |||
end | |||
function MapIterator:moveNext() | |||
local state = self._state | |||
if state ~= 0 then | |||
if state ~= 1 then | |||
return false | |||
end | |||
else | |||
self._state = 1 state = 1 | |||
self._position = 0 | |||
self._enumerator = self._source:getEnumerator(self._isArray) | |||
end | |||
local _enumerator = self._enumerator | |||
if _enumerator:moveNext() == true then | |||
-- 1 based index because Lua | |||
local pos = self._position + 1 | |||
local index = _enumerator.index | |||
local current = self._selector(_enumerator.current, pos) | |||
assert(current, 'Selected value must be non-nil') | |||
self.index = index | |||
self.current = current | |||
self._position = pos | |||
return true | |||
end | |||
self:finalize() | |||
return false | |||
end | end | ||
function | function MapIterator:finalize() | ||
if self._enumerator then | |||
self._enumerator:finalize() | |||
end | |||
Iterator.finalize(self) | |||
end | end | ||
function MapIterator:clone() | |||
return MapIterator.new(self._source, self._selector) | |||
end | end | ||
-- | --Filters a sequence of values based on a predicate. | ||
-- | ---@class WhereIterator : Iterator | ||
-- | ---@field _source Enumerable | ||
---@field _predicate fun(value: any, index: any): boolean | |||
---@field _enumerator Enumerator | |||
local WhereIterator = Iterator.createIteratorSubclass() | |||
function | ---@param source Enumerable | ||
local self = setmetatable( | ---@param predicate fun(value: any, index: any): boolean | ||
self. | function WhereIterator.new(source, predicate) | ||
assert(source, 'Source cannot be nil') | |||
assert(predicate, 'Predicate cannot be nil') | |||
local self = setmetatable(Iterator.new(), WhereIterator) | |||
self._source = source | |||
self._predicate = predicate | |||
self._enumerator = nil | |||
return self | return self | ||
end | end | ||
function | function WhereIterator:moveNext() | ||
local state = self._state | |||
self.state = 1 | if state ~= 0 then | ||
if state ~= 1 then | |||
return false | |||
end | |||
else | |||
self._state = 1 state = 1 | |||
self._enumerator = self._source:getEnumerator(self._isArray) | |||
end | |||
-- Cache some class lookups since these are likely | |||
-- to be accessed more than once per moveNext | |||
local enum = self._enumerator | |||
local nextFunc = enum.moveNext | |||
local predicate = self._predicate | |||
while nextFunc(enum) == true do | |||
local sourceElement = enum.current | |||
local sourceIndex = enum.index | |||
if predicate(sourceElement, sourceIndex) == true then | |||
self.index = sourceIndex | |||
self.current = sourceElement | |||
return true | |||
end | |||
end | |||
self:finalize() | |||
return false | return false | ||
end | end | ||
function | function WhereIterator:finalize() | ||
if self.state == 0 then | if self._enumerator then | ||
return self | self._enumerator:finalize() | ||
end | |||
Iterator.finalize(self) | |||
end | |||
function WhereIterator:clone() | |||
return WhereIterator.new(self._source, self._predicate) | |||
end | |||
--Projects each element of a sequence to an Enumerable and flattens the resulting sequences into one sequence. | |||
---@class FlatMapIterator : Iterator | |||
---@field _source Enumerable | |||
---@field _selector fun(value: any, index: integer): any | |||
---@field _position integer | |||
---@field _enumerator Enumerator | |||
---@field _sourceEnumerator Enumerator | |||
local FlatMapIterator = Iterator.createIteratorSubclass() | |||
---@param source Enumerable | |||
---@param selector fun(value: any, index: integer): any | |||
function FlatMapIterator.new(source, selector) | |||
assert(source, 'Source cannot be nil') | |||
assert(selector, 'Selector cannot be nil') | |||
local self = setmetatable(Iterator.new(), FlatMapIterator) | |||
self._source = source | |||
self._selector = selector | |||
self._position = 0 | |||
self._enumerator = nil -- Iterator of the _source Enumerable | |||
self._sourceEnumerator = nil -- Iterator of the nested Enumerable | |||
return self | |||
end | |||
function FlatMapIterator:moveNext() | |||
local state = self._state | |||
if state == -4 then | |||
return false | |||
end | |||
-- Setup _state | |||
if state == 0 then | |||
self._position = 0 | |||
self._enumerator = self._source:getEnumerator(self._isArray) | |||
self._state = 3 state = 3 -- signal to get (_first) nested _enumerator | |||
end | |||
while true do | |||
-- Grab next value from nested _enumerator | |||
if state == 4 then | |||
local sourceEnum = self._sourceEnumerator | |||
if sourceEnum:moveNext() then | |||
self.current = sourceEnum.current | |||
self.index = sourceEnum.index | |||
self._state, state = 4, 4 -- signal to get next item | |||
return true | |||
else | |||
-- Cleanup nested _enumerator | |||
self._sourceEnumerator:finalize() | |||
self._state = 3 state = 3 -- signal to get next _enumerator | |||
end | |||
end | |||
-- Grab nest nested _enumerator | |||
if state == 3 then | |||
if self._enumerator:moveNext() then | |||
local current = self._enumerator.current | |||
local pos = self._position + 1 | |||
local sourceTable = self._selector(current, pos) | |||
-- Nested tables are never treated as arrays. | |||
self._sourceEnumerator = getTableEnumerator(sourceTable, false) | |||
self._position = pos | |||
self._state = 4 state = 4 -- signal to get next item | |||
else | |||
-- _enumerator doesn't have any more nested enumerators. | |||
self:finalize() | |||
return false | |||
end | |||
end | |||
end | |||
end | |||
function FlatMapIterator:finalize() | |||
if self._enumerator then | |||
self._enumerator:finalize() | |||
end | |||
if self._sourceEnumerator then | |||
self._sourceEnumerator:finalize() | |||
end | |||
Iterator.finalize(self) | |||
end | |||
function FlatMapIterator:clone() | |||
return FlatMapIterator.new(self._source, self._selector) | |||
end | |||
--Concatenates two sequences. | |||
---@class ConcatIterator : Iterator | |||
---@field _first Enumerable | |||
---@field _second any | |||
---@field _enumerator Enumerator | |||
local ConcatIterator = Iterator.createIteratorSubclass() | |||
---@param first Enumerable | |||
---@param second any | |||
function ConcatIterator.new(first, second) | |||
assert(first, 'First cannot be nil') | |||
assert(second, 'Second cannot be nil') | |||
local self = setmetatable(Iterator.new(), ConcatIterator) | |||
self._first = first | |||
self._second = second | |||
self._enumerator = nil | |||
return self | |||
end | |||
-- Function to grab enumerators in order. We currently only have two so this | |||
-- seems a bit redundant. But it would allow to add N-amount of enumerables. | |||
function ConcatIterator:getEnumerable(index) | |||
if index == 0 then | |||
return self._first | |||
elseif index == 1 then | |||
return self._second | |||
else | else | ||
return | return nil | ||
end | end | ||
end | end | ||
function ConcatIterator:moveNext() | |||
local | local state = self._state | ||
if state == -4 then | |||
return false | |||
end | |||
if state == 0 then | |||
self._enumerator = self:getEnumerable(state) | |||
:getEnumerator(self._isArray) | |||
self._state = 1 state = 1 | |||
end | |||
if state > 0 then | |||
while true do | |||
end | local enum = self._enumerator | ||
if enum:moveNext() == true then | |||
self.index = enum.index | |||
self.current = enum.current | |||
return true | |||
end | |||
state = state + 1 | |||
self._state = state | |||
local next = self:getEnumerable(state - 1) | |||
if next ~= nil then | |||
-- Cleanup previous _enumerator | |||
self._enumerator:finalize() | |||
self._enumerator = getTableEnumerator(next, self._isArray) | |||
else | |||
self:finalize() | |||
return false | |||
end | |||
self | |||
end | end | ||
end | end | ||
return false | return false | ||
end | end | ||
function | function ConcatIterator:finalize() | ||
if self.state == 0 then | if self._enumerator then | ||
self._enumerator:finalize() | |||
else | end | ||
end | Iterator.finalize(self) | ||
end | |||
function ConcatIterator:clone() | |||
return ConcatIterator.new(self._first, self._second) | |||
end | |||
--Appends a value to the end or start of the sequence. | |||
---@class AppendIterator : Iterator | |||
---@field _source Enumerable | |||
---@field _item any | |||
---@field _itemIndex any | |||
---@field _append boolean | |||
---@field _enumerator Enumerator | |||
local AppendIterator = Iterator.createIteratorSubclass() | |||
---@param source Enumerable | |||
---@param item any | |||
---@param itemIndex? any | |||
---@param append? boolean | |||
function AppendIterator.new(source, item, itemIndex, append) | |||
assert(source, 'Source cannot be nil') | |||
assert(item, 'Item cannot be nil') | |||
local self = setmetatable(Iterator.new(), AppendIterator) | |||
self._source = source | |||
self._item = item | |||
-- Index needs *some* value, otherwise iteration stops. | |||
if itemIndex == nil then self._itemIndex = item else self._itemIndex = itemIndex end | |||
if append == nil then self._append = true else self._append = append end | |||
self._enumerator = nil | |||
return self | |||
end | |||
function AppendIterator:moveNext() | |||
local state = self._state | |||
if state == 0 then | |||
self._state = 1 state = 1 | |||
-- Put the item in front, as a prepend. | |||
if self._append == false then | |||
self.current = self._item | |||
self.index = self._itemIndex | |||
return true | |||
end | |||
end | |||
-- Grab the _source _enumerator | |||
if state == 1 then | |||
self._enumerator = self._source:getEnumerator(self._isArray) | |||
self._state = 2 state = 2 | |||
end | |||
if state == 2 then | |||
local enum = self._enumerator | |||
if enum:moveNext() then | |||
self.current = enum.current | |||
self.index = enum.index | |||
return true | |||
else | |||
self:finalize() | |||
end | |||
if self._append == true then | |||
self.current = self._item | |||
self.index = self._itemIndex | |||
return true | |||
end | |||
end | |||
return false | |||
end | |||
function AppendIterator:finalize() | |||
if self._enumerator then | |||
self._enumerator:finalize() | |||
end | |||
Iterator.finalize(self) | |||
end | |||
function AppendIterator:clone() | |||
return AppendIterator.new(self._source, self._item, self._itemIndex, self._append) | |||
end | |||
--Returns unique (distinct) elements from a sequence according to a specified key selector function. | |||
---@class UniqueIterator : Iterator | |||
---@field _source Enumerable | |||
---@field _keySelector fun(value: any, index: any): any | |||
---@field _enumerator Enumerator | |||
---@field _set table | |||
local UniqueIterator = Iterator.createIteratorSubclass() | |||
---@param source Enumerable | |||
---@param keySelector? fun(value: any, index: any): any | |||
function UniqueIterator.new(source, keySelector) | |||
assert(source, 'Source cannot be nil') | |||
local self = setmetatable(Iterator.new(), UniqueIterator) | |||
self._source = source | |||
self._keySelector = keySelector or function(x) return x end | |||
self._enumerator = nil | |||
self._set = nil | |||
return self | |||
end | |||
function UniqueIterator:moveNext() | |||
local state = self._state | |||
if state == -4 then | |||
return false | |||
end | |||
if state == 0 then | |||
self._enumerator = self._source:getEnumerator(self._isArray) | |||
-- If we have any items, create a hashtable. Otherwise abort. | |||
if self._enumerator:moveNext() == true then | |||
self._set = {} | |||
self._state = 1 state = 1 | |||
else | |||
self:finalize() | |||
return false | |||
end | |||
end | |||
local enum = self._enumerator | |||
local set = self._set | |||
while true do | |||
if state == 1 then | |||
self._state = 2 state = 2 | |||
local current = enum.current | |||
local index = enum.index | |||
local key = self._keySelector(current, index) | |||
if setAdd(key) == true then | |||
self.current = current | |||
self.index = index | |||
return true | |||
end | |||
end | |||
-- Try to grab a new item. | |||
if state == 2 then | |||
if enum:moveNext() == true then | |||
self._state = 1 state = 1 | |||
else | |||
self:finalize() | |||
return false | |||
end | |||
end | |||
end | |||
end | |||
function UniqueIterator:finalize() | |||
if self._enumerator then | |||
self._enumerator:finalize() | |||
end | |||
Iterator.finalize(self) | |||
end | |||
function UniqueIterator:clone() | |||
return UniqueIterator.new(self._source, self._keySelector) | |||
end | |||
--Produces the set difference of two sequences according to a specified key selector function. | |||
---@class DifferenceIterator : Iterator | |||
---@field _first Enumerable | |||
---@field _second any | |||
---@field _keySelector fun(value: any, index: any): any | |||
---@field _enumerator Enumerator | |||
---@field _set table | |||
local DifferenceIterator = Iterator.createIteratorSubclass() | |||
---@param first Enumerable | |||
---@param second any | |||
---@param keySelector? fun(value: any, index: any): any | |||
function DifferenceIterator.new(first, second, keySelector) | |||
assert(first, 'First table cannot be nil') | |||
assert(second, 'Second table cannot be nil') | |||
local self = setmetatable(Iterator.new(), DifferenceIterator) | |||
self._first = first | |||
self._second = second | |||
self._keySelector = keySelector or function(x) return x end | |||
self._enumerator = nil | |||
self._set = nil | |||
return self | |||
end | |||
function DifferenceIterator:moveNext() | |||
local state = self._state | |||
if state ~= 0 then | |||
if state ~= 1 then | |||
return false | |||
end | |||
else | |||
self._enumerator = self._first:getEnumerator(self._isArray) | |||
self._set = sourceToSet(self._second) | |||
self._state = 1 state = 1 | |||
end | |||
local enum = self._enumerator | |||
local nextFunc = enum.moveNext | |||
local keySelector = self._keySelector | |||
local set = self._set | |||
while nextFunc(enum) == true do | |||
local current = enum.current | |||
local index = enum.index | |||
local key = keySelector(current, index) | |||
if setAdd(key) == true then | |||
self.current = current | |||
self.index = index | |||
return true | |||
end | |||
end | |||
self:finalize() | |||
return false | |||
end | |||
function DifferenceIterator:finalize() | |||
if self._enumerator then | |||
self._enumerator:finalize() | |||
end | |||
end | |||
function DifferenceIterator:clone() | |||
return DifferenceIterator.new(self._first, self._second, self._keySelector) | |||
end | |||
--Produces the set union of two sequences according to a specified key selector function. | |||
---@class UnionIterator : Iterator | |||
---@field _first Enumerable | |||
---@field _second any | |||
---@field _keySelector fun(value: any, index: any): any | |||
---@field _enumerator Enumerator | |||
---@field _set table | |||
local UnionIterator = Iterator.createIteratorSubclass() | |||
---@param first Enumerable | |||
---@param second any | |||
---@param keySelector? fun(value: any, index: any): any | |||
function UnionIterator.new(first, second, keySelector) | |||
assert(first, 'First table cannot be nil') | |||
assert(second, 'Second table cannot be nil') | |||
local self = setmetatable(Iterator.new(), UnionIterator) | |||
self._first = first | |||
self._second = second | |||
self._keySelector = keySelector or function(x) return x end | |||
self._enumerator = nil | |||
self._set = nil | |||
return self | |||
end | |||
local function enumerateSource(self, _state) | |||
while self._enumerator:moveNext() do | |||
local current = self._enumerator.current | |||
local index = self._enumerator.index | |||
if self._set:add(self._keySelector(current, index)) == true then | |||
self.current = self._enumerator.current | |||
self.index = self._enumerator.index | |||
self._state = _state | |||
return true | |||
end | |||
end | |||
end | |||
function UnionIterator:moveNext() | |||
if self._state == -4 then | |||
return false | |||
end | |||
if self._state == 0 then | |||
self._enumerator = self._first:getEnumerator(self._isArray) | |||
self._set = {} | |||
self._state = 1 | |||
end | |||
-- Process _first | |||
if self._state == 1 then | |||
if enumerateSource(self, 1) == true then | |||
return true | |||
end | |||
-- End _first enumeration | |||
self._state = 2 | |||
self._enumerator:finalize() | |||
self._enumerator = getTableEnumerator(self._second, self._isArray) | |||
end | |||
-- Process _second | |||
if self._state == 2 then | |||
if enumerateSource(self, 2) == true then | |||
return true | |||
end | |||
-- End _first enumeration | |||
self._state = 3 | |||
end | |||
self:finalize() | |||
return false | |||
end | |||
function UnionIterator:finalize() | |||
if self._enumerator then | |||
self._enumerator:finalize() | |||
end | |||
Iterator.finalize(self) | |||
end | |||
function UnionIterator:clone() | |||
return UnionIterator.new(self._first, self._second, self._keySelector) | |||
end | |||
--Produces the set intersection of two sequences according to a specified key selector function. | |||
---@class IntersectIterator : Iterator | |||
---@field _first Enumerable | |||
---@field _second any | |||
---@field _keySelector fun(value: any, index: any): any | |||
---@field _enumerator Enumerator | |||
---@field _set table | |||
local IntersectIterator = Iterator.createIteratorSubclass() | |||
---@param first Enumerable | |||
---@param second any | |||
---@param keySelector? fun(value: any, index: any): any | |||
function IntersectIterator.new(first, second, keySelector) | |||
assert(first, 'First table cannot be nil') | |||
assert(second, 'Second table cannot be nil') | |||
local self = setmetatable(Iterator.new(), IntersectIterator) | |||
self._first = first | |||
self._second = second | |||
self._keySelector = keySelector or function(x) return x end | |||
self._enumerator = nil | |||
self._set = nil | |||
return self | |||
end | |||
function IntersectIterator:moveNext() | |||
if self._state ~= 0 then | |||
if self._state ~= 1 then | |||
return false | |||
end | |||
else | |||
self._enumerator = self._first:getEnumerator(self._isArray) | |||
self._set = sourceToSet(self._second) | |||
self._state = 1 | |||
end | |||
while self._enumerator:moveNext() do | |||
local current = self._enumerator.current | |||
local index = self._enumerator.index | |||
local key = self._keySelector(current, index) | |||
if setRemove(self._set, key) == true then | |||
self.current = self._enumerator.current | |||
self.index = self._enumerator.index | |||
self._state = 1 | |||
return true | |||
end | |||
end | |||
self:finalize() | |||
return false | |||
end | |||
function IntersectIterator:finalize() | |||
if self._enumerator then | |||
self._enumerator:finalize() | |||
end | |||
Iterator.finalize(self) | |||
end | |||
function IntersectIterator:clone() | |||
return IntersectIterator.new(self._first, self._second, self._keySelector) | |||
end | |||
--Applies a specified function to the corresponding elements of two sequences, producing a sequence of the results. | |||
---@class ZipIterator : Iterator | |||
---@field _first Enumerable | |||
---@field _second any | |||
---@field _resultSelector fun(left: any, right: any): any | |||
---@field _enumerator1 Enumerator | |||
---@field _enumerator2 Enumerator | |||
local ZipIterator = Iterator.createIteratorSubclass() | |||
---@param first Enumerable | |||
---@param second any | |||
---@param resultSelector? fun(left: any, right: any): any | |||
function ZipIterator.new(first, second, resultSelector) | |||
assert(first, 'First table cannot be nil') | |||
assert(second, 'Second table cannot be nil') | |||
local self = setmetatable(Iterator.new(), ZipIterator) | |||
self._first = first | |||
self._second = second | |||
self._resultSelector = resultSelector or (function(a, b) return {a, b} end) | |||
self._enumerator1 = nil | |||
self._enumerator2 = nil | |||
return self | |||
end | |||
function ZipIterator:moveNext() | |||
if self._state ~= 0 then | |||
if self._state ~= 1 then | |||
return false | |||
end | |||
else | |||
self._enumerator1 = self._first:getEnumerator(self._isArray) | |||
self._enumerator2 = getTableEnumerator(self._second, self._isArray) | |||
self._state = -4 | |||
end | |||
if self._enumerator1:moveNext() == true and self._enumerator2:moveNext() == true then | |||
self.current = self._resultSelector(self._enumerator1.current, self._enumerator2.current) | |||
self.index = self._enumerator1.index | |||
self._state = 1 | |||
return true | |||
end | |||
self:finalize() | |||
return false | |||
end | |||
function ZipIterator:finalize() | |||
if self._enumerator1 then | |||
self._enumerator1:finalize() | |||
end | |||
if self._enumerator2 then | |||
self._enumerator2:finalize() | |||
end | |||
Iterator.finalize(self) | |||
end | |||
function ZipIterator:clone() | |||
return ZipIterator.new(self._first, self._second, self._resultSelector) | |||
end | |||
--Groups the elements of a sequence. | |||
---@class GroupByIterator : Iterator | |||
---@field _source Enumerable | |||
---@field _keySelector fun(param: any): any | |||
---@field _elementSelector fun(param: any): any | |||
---@field _enumerator Enumerator | |||
---@field _lookup Lookup | |||
local GroupByIterator = Iterator.createIteratorSubclass() | |||
---@param source Enumerable | |||
---@param keySelector fun(param: any): any | |||
---@param elementSelector? fun(param: any): any | |||
function GroupByIterator.new(source, keySelector, elementSelector) | |||
assert(source, 'source cannot be nil') | |||
assert(keySelector, 'keySelector cannot be nil') | |||
local self = setmetatable(Iterator.new(), GroupByIterator) | |||
self._source = source | |||
self._keySelector = keySelector | |||
self._elementSelector = elementSelector or function(x) return x end | |||
self._enumerator = nil | |||
self._lookup = nil | |||
return self | |||
end | |||
function GroupByIterator:moveNext() | |||
if self._state ~= 0 then | |||
if self._state ~= 1 then | |||
return false | |||
end | |||
else | |||
self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray) | |||
self._enumerator = self._lookup:getEnumerator() | |||
self._state = 1 | |||
end | |||
local enum = self._enumerator | |||
if enum:moveNext() == true then | |||
self.current = enum.current | |||
self.index = enum.index | |||
self._state = 1 | |||
return true | |||
end | |||
self:finalize() | |||
return false | |||
end | |||
function GroupByIterator:finalize() | |||
if self._enumerator then | |||
self._enumerator:finalize() | |||
end | |||
Iterator.finalize(self) | |||
end | |||
function GroupByIterator:clone() | |||
return GroupByIterator.new(self._source, self._keySelector, self._elementSelector) | |||
end | |||
--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function. | |||
---@class GroupByResultIterator : Iterator | |||
---@field _source Enumerable | |||
---@field _keySelector fun(param: any): any | |||
---@field _elementSelector fun(param: any): any | |||
---@field _resultSelector fun(key: any, grouping: Grouping): any | |||
---@field _enumerator Enumerator | |||
---@field _lookup Lookup | |||
local GroupByResultIterator = Iterator.createIteratorSubclass() | |||
---@param source Enumerable | |||
---@param keySelector fun(param: any): any | |||
---@param elementSelector? fun(param: any): any | |||
---@param resultSelector fun(key: any, grouping: Grouping): any | |||
function GroupByResultIterator.new(source, keySelector, elementSelector, resultSelector) | |||
assert(source, 'source cannot be nil') | |||
assert(keySelector, 'keySelector cannot be nil') | |||
assert(resultSelector, 'keySelector cannot be nil') | |||
local self = setmetatable(Iterator.new(), GroupByResultIterator) | |||
self._source = source | |||
self._keySelector = keySelector | |||
self._elementSelector = elementSelector or function(x) return x end | |||
self._resultSelector = resultSelector | |||
self._enumerator = nil | |||
self._lookup = nil | |||
return self | |||
end | |||
function GroupByResultIterator:moveNext() | |||
if self._state ~= 0 then | |||
if self._state ~= 1 then | |||
return false | |||
end | |||
else | |||
self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray) | |||
self._enumerator = self._lookup:getEnumerator() | |||
self._state = 1 | |||
end | |||
local enumerator = self._enumerator | |||
if enumerator:moveNext() == true then | |||
-- Transform Grouping element for enumeration | |||
local gKey = enumerator.current.key | |||
local gElements = enumerator.current | |||
self.current = self._resultSelector(gKey, gElements) | |||
self.index = enumerator.index | |||
self._state = 1 | |||
return true | |||
end | |||
self:finalize() | |||
return false | |||
end | |||
function GroupByResultIterator:finalize() | |||
if self._enumerator then | |||
self._enumerator:finalize() | |||
end | |||
Iterator.finalize(self) | |||
end | |||
function GroupByResultIterator:clone() | |||
return GroupByResultIterator.new(self._source, self._keySelector, self._elementSelector, self._resultSelector) | |||
end | |||
--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function. | |||
---@class SortableIterator : Iterator | |||
---@field _source Enumerable | |||
---@field _keySelector fun(value: any, index: any): any|? | |||
---@field _descendingSort boolean | |||
---@field _parentIterator SortableIterator | |||
---@field _enumerator Enumerator | |||
local SortableIterator = Iterator.createIteratorSubclass() | |||
---@param source Enumerable | |||
---@param keySelector fun(value: any, index: any): any | |||
---@param descendingSort? boolean | |||
---@param parent? SortableIterator | |||
---@return SortableIterator | |||
function SortableIterator.new(source, keySelector, descendingSort, parent) | |||
assert(source, 'source cannot be nil') | |||
assert(keySelector, 'keySelector cannot be nil') | |||
assert(parent == nil or parent.getEnumerableSorters ~= nil, 'parent must be of type SortableIterator') | |||
if descendingSort == nil then | |||
descendingSort = false | |||
end | |||
local self = setmetatable(Iterator.new(), SortableIterator) | |||
self._source = source | |||
self._keySelector = keySelector | |||
self._descendingSort = descendingSort | |||
self._parentIterator = parent | |||
self._enumerator = nil | |||
---@diagnostic disable-next-line: return-type-mismatch | |||
return self | |||
end | |||
local function createSortFunc(sorters) | |||
local sortCount = #sorters | |||
if sortCount > 1 then | |||
return function(a, b) | |||
for _, sorter in ipairs(sorters) do | |||
local selector = sorter._keySelector | |||
local descendingSort = sorter._descendingSort | |||
local aVal = selector(a) | |||
local bVal = selector(b) | |||
if aVal ~= bVal then | |||
if descendingSort == true then | |||
return aVal > bVal | |||
else | |||
return aVal < bVal | |||
end | |||
end | |||
end | |||
end | |||
end | |||
-- Faster sort function if there's no nested sorts. | |||
local sorter = sorters[1] | |||
local selector = sorter._keySelector | |||
if sorter._descendingSort == true then | |||
return function(a, b) return selector(a) > selector(b) end | |||
else | |||
return function(a, b) return selector(a) < selector(b) end | |||
end | |||
end | |||
function SortableIterator:moveNext() | |||
if self._state ~= 0 then | |||
if self._state ~= 1 then | |||
return false | |||
end | |||
else | |||
-- Collect sorter meta information. | |||
local sorters = {} | |||
self:getEnumerableSorters(sorters) | |||
local sortFunc = createSortFunc(sorters) | |||
local enumerator = self._source:getEnumerator(self._isArray) | |||
local buffer = {} | |||
-- Process and sort all previous enumerators. | |||
while enumerator:moveNext() == true do | |||
table.insert(buffer, enumerator.current) | |||
end | |||
table.sort(buffer, sortFunc) | |||
self._enumerator = TableEnumerator.createForArray(buffer) | |||
self._state = 1 | |||
end | |||
local enumerator = self._enumerator | |||
if enumerator:moveNext() == true then | |||
self.current = enumerator.current | |||
self.index = enumerator.index | |||
self._state = 1 | |||
return true | |||
end | |||
self:finalize() | |||
return false | |||
end | |||
---Collects all sorter functions down the line. | |||
---@param sorters table | |||
---@return table | |||
function SortableIterator:getEnumerableSorters(sorters) | |||
-- Get parent sorter first. | |||
if self._parentIterator ~= nil then | |||
self._parentIterator:getEnumerableSorters(sorters) | |||
end | |||
table.insert(sorters, self) | |||
return sorters | |||
end | |||
---@param keySelector fun(value: any, index: any): any | |||
---@param descendingSort? boolean | |||
---@return SortableIterator | |||
function SortableIterator:createSortableIterator(keySelector, descendingSort) | |||
return SortableIterator.new(self._source, keySelector, descendingSort, self) | |||
end | |||
function SortableIterator:finalize() | |||
if self._enumerator then | |||
self._enumerator:finalize() | |||
self._enumerator = nil | |||
end | |||
Iterator.finalize(self) | |||
end | |||
function SortableIterator:clone() | |||
return SortableIterator.new(self._source, self._keySelector, self._descendingSort, self._parentIterator) | |||
end | end | ||
return { | return { | ||
MapIterator = MapIterator, | |||
WhereIterator = WhereIterator, | |||
FlatMapIterator = FlatMapIterator, | |||
ConcatIterator = ConcatIterator, | |||
AppendIterator = AppendIterator, | |||
UniqueIterator = UniqueIterator, | |||
DifferenceIterator = DifferenceIterator, | |||
UnionIterator = UnionIterator, | |||
IntersectIterator = IntersectIterator, | |||
ZipIterator = ZipIterator, | |||
GroupByIterator = GroupByIterator, | |||
GroupByResultIterator = GroupByResultIterator, | |||
SortableIterator = SortableIterator, | |||
} | } |
Latest revision as of 20:46, 26 July 2024
Documentation for this module may be created at Module:FunList/Iterators/doc
local Enumerable = require('Module:FunList/Enumerable')
local Lookup = require('Module:FunList/Lookup')
local TableEnumerator = require('Module:FunList/TableEnumerator')
---Creates a TableIterator Enumerable or returns the Enumerable if it already is one.
---We set the propper starting index here since this only applies for tables!
---@param source any
---@param isArray any
---@return Enumerator
local function getTableEnumerator(source, isArray)
if source.getEnumerator == nil then
assert(type(source) == 'table', 'sourceTable must be either an Enumerable or table.')
return TableEnumerator.create(source, isArray)
else
return source:getEnumerator(isArray)
end
end
-- Helper functions to create and manage a hashset structure.
-- These are used to distinguish elements for Union/Unique/Difference etc operations.
---@param source any
---@return table
local function sourceToSet(source)
local set = {}
if Enumerable.isEnumerable(source) then
local enum = source:getEnumerator()
while enum:moveNext() do
set[enum.current] = true
end
else
assert(type(source) == 'table', 'Source must be a table.')
for _, v in pairs(source) do
set[v] = true
end
end
return set
end
---@param set table
---@param item any
---@return boolean
local function setAdd(set, item)
if set[item] == nil then
set[item] = true
return true
end
return false
end
---@param set table
---@param item any
---@return boolean
local function setRemove(set, item)
if set[item] == nil then
return false
end
set[item] = nil
return true
end
-- Defines the base class for all Iterator state machines. This is an Enumerable and Enumerator in one.
---@class Iterator : Enumerable, Enumerator
---@field current any
---@field index any
---@field _state integer
---@field _isArray boolean
local Iterator = setmetatable({}, { __index = Enumerable })
local Iterator_mt = {
__index = Iterator,
__pairs = Enumerable.getPairs,
__ipairs = Enumerable.getiPairs
}
---@return Iterator
function Iterator.new()
local self = setmetatable({}, Iterator_mt)
self.current = nil
self.index = nil
self._state = -1
-- Assume by default we are not dealing with a simple array
self._isArray = false
return self
end
-- This is normally implemented by the Enumerator abstract class.
-- But to prevent double inheritance, we cheat a bit and implement it outselves.
---@return boolean
function Iterator:moveNext()
error('Abstract function must be overridden in derived class.')
end
---Returns this Enumerator or creates a new one if it has been used.
---@param isArray? boolean
---@return Enumerator
function Iterator:getEnumerator(isArray)
-- The default _state is -1 which signifies a Iterator isn't used.
local instance = (self._state == -1) and self or self:clone()
instance._isArray = isArray or self._isArray
instance._state = 0
return instance
end
---@return Iterator
function Iterator:clone()
error('Abstract function must be overridden in derived class.')
end
function Iterator:finalize()
-- Signals invalid _state.
self._state = -4
end
function Iterator.createIteratorSubclass()
local derivedClass = setmetatable({}, { __index = Iterator })
derivedClass.__index = derivedClass
derivedClass.__pairs = Enumerable.getPairs
derivedClass.__ipairs = Enumerable.getiPairs
return derivedClass
end
--Projects each element of a sequence into a new form.
---@class MapIterator : Iterator
---@field _source Enumerable
---@field _selector fun(value: any, index: any): any
---@field _position integer
---@field _enumerator Enumerator
local MapIterator = Iterator.createIteratorSubclass()
---@param source Enumerable
---@param selector fun(value: any, index: any): any
function MapIterator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
local self = setmetatable(Iterator.new(), MapIterator)
self._source = source
self._selector = selector
self._enumerator = nil
self._position = 0
return self
end
function MapIterator:moveNext()
local state = self._state
if state ~= 0 then
if state ~= 1 then
return false
end
else
self._state = 1 state = 1
self._position = 0
self._enumerator = self._source:getEnumerator(self._isArray)
end
local _enumerator = self._enumerator
if _enumerator:moveNext() == true then
-- 1 based index because Lua
local pos = self._position + 1
local index = _enumerator.index
local current = self._selector(_enumerator.current, pos)
assert(current, 'Selected value must be non-nil')
self.index = index
self.current = current
self._position = pos
return true
end
self:finalize()
return false
end
function MapIterator:finalize()
if self._enumerator then
self._enumerator:finalize()
end
Iterator.finalize(self)
end
function MapIterator:clone()
return MapIterator.new(self._source, self._selector)
end
--Filters a sequence of values based on a predicate.
---@class WhereIterator : Iterator
---@field _source Enumerable
---@field _predicate fun(value: any, index: any): boolean
---@field _enumerator Enumerator
local WhereIterator = Iterator.createIteratorSubclass()
---@param source Enumerable
---@param predicate fun(value: any, index: any): boolean
function WhereIterator.new(source, predicate)
assert(source, 'Source cannot be nil')
assert(predicate, 'Predicate cannot be nil')
local self = setmetatable(Iterator.new(), WhereIterator)
self._source = source
self._predicate = predicate
self._enumerator = nil
return self
end
function WhereIterator:moveNext()
local state = self._state
if state ~= 0 then
if state ~= 1 then
return false
end
else
self._state = 1 state = 1
self._enumerator = self._source:getEnumerator(self._isArray)
end
-- Cache some class lookups since these are likely
-- to be accessed more than once per moveNext
local enum = self._enumerator
local nextFunc = enum.moveNext
local predicate = self._predicate
while nextFunc(enum) == true do
local sourceElement = enum.current
local sourceIndex = enum.index
if predicate(sourceElement, sourceIndex) == true then
self.index = sourceIndex
self.current = sourceElement
return true
end
end
self:finalize()
return false
end
function WhereIterator:finalize()
if self._enumerator then
self._enumerator:finalize()
end
Iterator.finalize(self)
end
function WhereIterator:clone()
return WhereIterator.new(self._source, self._predicate)
end
--Projects each element of a sequence to an Enumerable and flattens the resulting sequences into one sequence.
---@class FlatMapIterator : Iterator
---@field _source Enumerable
---@field _selector fun(value: any, index: integer): any
---@field _position integer
---@field _enumerator Enumerator
---@field _sourceEnumerator Enumerator
local FlatMapIterator = Iterator.createIteratorSubclass()
---@param source Enumerable
---@param selector fun(value: any, index: integer): any
function FlatMapIterator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
local self = setmetatable(Iterator.new(), FlatMapIterator)
self._source = source
self._selector = selector
self._position = 0
self._enumerator = nil -- Iterator of the _source Enumerable
self._sourceEnumerator = nil -- Iterator of the nested Enumerable
return self
end
function FlatMapIterator:moveNext()
local state = self._state
if state == -4 then
return false
end
-- Setup _state
if state == 0 then
self._position = 0
self._enumerator = self._source:getEnumerator(self._isArray)
self._state = 3 state = 3 -- signal to get (_first) nested _enumerator
end
while true do
-- Grab next value from nested _enumerator
if state == 4 then
local sourceEnum = self._sourceEnumerator
if sourceEnum:moveNext() then
self.current = sourceEnum.current
self.index = sourceEnum.index
self._state, state = 4, 4 -- signal to get next item
return true
else
-- Cleanup nested _enumerator
self._sourceEnumerator:finalize()
self._state = 3 state = 3 -- signal to get next _enumerator
end
end
-- Grab nest nested _enumerator
if state == 3 then
if self._enumerator:moveNext() then
local current = self._enumerator.current
local pos = self._position + 1
local sourceTable = self._selector(current, pos)
-- Nested tables are never treated as arrays.
self._sourceEnumerator = getTableEnumerator(sourceTable, false)
self._position = pos
self._state = 4 state = 4 -- signal to get next item
else
-- _enumerator doesn't have any more nested enumerators.
self:finalize()
return false
end
end
end
end
function FlatMapIterator:finalize()
if self._enumerator then
self._enumerator:finalize()
end
if self._sourceEnumerator then
self._sourceEnumerator:finalize()
end
Iterator.finalize(self)
end
function FlatMapIterator:clone()
return FlatMapIterator.new(self._source, self._selector)
end
--Concatenates two sequences.
---@class ConcatIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _enumerator Enumerator
local ConcatIterator = Iterator.createIteratorSubclass()
---@param first Enumerable
---@param second any
function ConcatIterator.new(first, second)
assert(first, 'First cannot be nil')
assert(second, 'Second cannot be nil')
local self = setmetatable(Iterator.new(), ConcatIterator)
self._first = first
self._second = second
self._enumerator = nil
return self
end
-- Function to grab enumerators in order. We currently only have two so this
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
function ConcatIterator:getEnumerable(index)
if index == 0 then
return self._first
elseif index == 1 then
return self._second
else
return nil
end
end
function ConcatIterator:moveNext()
local state = self._state
if state == -4 then
return false
end
if state == 0 then
self._enumerator = self:getEnumerable(state)
:getEnumerator(self._isArray)
self._state = 1 state = 1
end
if state > 0 then
while true do
local enum = self._enumerator
if enum:moveNext() == true then
self.index = enum.index
self.current = enum.current
return true
end
state = state + 1
self._state = state
local next = self:getEnumerable(state - 1)
if next ~= nil then
-- Cleanup previous _enumerator
self._enumerator:finalize()
self._enumerator = getTableEnumerator(next, self._isArray)
else
self:finalize()
return false
end
end
end
return false
end
function ConcatIterator:finalize()
if self._enumerator then
self._enumerator:finalize()
end
Iterator.finalize(self)
end
function ConcatIterator:clone()
return ConcatIterator.new(self._first, self._second)
end
--Appends a value to the end or start of the sequence.
---@class AppendIterator : Iterator
---@field _source Enumerable
---@field _item any
---@field _itemIndex any
---@field _append boolean
---@field _enumerator Enumerator
local AppendIterator = Iterator.createIteratorSubclass()
---@param source Enumerable
---@param item any
---@param itemIndex? any
---@param append? boolean
function AppendIterator.new(source, item, itemIndex, append)
assert(source, 'Source cannot be nil')
assert(item, 'Item cannot be nil')
local self = setmetatable(Iterator.new(), AppendIterator)
self._source = source
self._item = item
-- Index needs *some* value, otherwise iteration stops.
if itemIndex == nil then self._itemIndex = item else self._itemIndex = itemIndex end
if append == nil then self._append = true else self._append = append end
self._enumerator = nil
return self
end
function AppendIterator:moveNext()
local state = self._state
if state == 0 then
self._state = 1 state = 1
-- Put the item in front, as a prepend.
if self._append == false then
self.current = self._item
self.index = self._itemIndex
return true
end
end
-- Grab the _source _enumerator
if state == 1 then
self._enumerator = self._source:getEnumerator(self._isArray)
self._state = 2 state = 2
end
if state == 2 then
local enum = self._enumerator
if enum:moveNext() then
self.current = enum.current
self.index = enum.index
return true
else
self:finalize()
end
if self._append == true then
self.current = self._item
self.index = self._itemIndex
return true
end
end
return false
end
function AppendIterator:finalize()
if self._enumerator then
self._enumerator:finalize()
end
Iterator.finalize(self)
end
function AppendIterator:clone()
return AppendIterator.new(self._source, self._item, self._itemIndex, self._append)
end
--Returns unique (distinct) elements from a sequence according to a specified key selector function.
---@class UniqueIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local UniqueIterator = Iterator.createIteratorSubclass()
---@param source Enumerable
---@param keySelector? fun(value: any, index: any): any
function UniqueIterator.new(source, keySelector)
assert(source, 'Source cannot be nil')
local self = setmetatable(Iterator.new(), UniqueIterator)
self._source = source
self._keySelector = keySelector or function(x) return x end
self._enumerator = nil
self._set = nil
return self
end
function UniqueIterator:moveNext()
local state = self._state
if state == -4 then
return false
end
if state == 0 then
self._enumerator = self._source:getEnumerator(self._isArray)
-- If we have any items, create a hashtable. Otherwise abort.
if self._enumerator:moveNext() == true then
self._set = {}
self._state = 1 state = 1
else
self:finalize()
return false
end
end
local enum = self._enumerator
local set = self._set
while true do
if state == 1 then
self._state = 2 state = 2
local current = enum.current
local index = enum.index
local key = self._keySelector(current, index)
if setAdd(key) == true then
self.current = current
self.index = index
return true
end
end
-- Try to grab a new item.
if state == 2 then
if enum:moveNext() == true then
self._state = 1 state = 1
else
self:finalize()
return false
end
end
end
end
function UniqueIterator:finalize()
if self._enumerator then
self._enumerator:finalize()
end
Iterator.finalize(self)
end
function UniqueIterator:clone()
return UniqueIterator.new(self._source, self._keySelector)
end
--Produces the set difference of two sequences according to a specified key selector function.
---@class DifferenceIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local DifferenceIterator = Iterator.createIteratorSubclass()
---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function DifferenceIterator.new(first, second, keySelector)
assert(first, 'First table cannot be nil')
assert(second, 'Second table cannot be nil')
local self = setmetatable(Iterator.new(), DifferenceIterator)
self._first = first
self._second = second
self._keySelector = keySelector or function(x) return x end
self._enumerator = nil
self._set = nil
return self
end
function DifferenceIterator:moveNext()
local state = self._state
if state ~= 0 then
if state ~= 1 then
return false
end
else
self._enumerator = self._first:getEnumerator(self._isArray)
self._set = sourceToSet(self._second)
self._state = 1 state = 1
end
local enum = self._enumerator
local nextFunc = enum.moveNext
local keySelector = self._keySelector
local set = self._set
while nextFunc(enum) == true do
local current = enum.current
local index = enum.index
local key = keySelector(current, index)
if setAdd(key) == true then
self.current = current
self.index = index
return true
end
end
self:finalize()
return false
end
function DifferenceIterator:finalize()
if self._enumerator then
self._enumerator:finalize()
end
end
function DifferenceIterator:clone()
return DifferenceIterator.new(self._first, self._second, self._keySelector)
end
--Produces the set union of two sequences according to a specified key selector function.
---@class UnionIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local UnionIterator = Iterator.createIteratorSubclass()
---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function UnionIterator.new(first, second, keySelector)
assert(first, 'First table cannot be nil')
assert(second, 'Second table cannot be nil')
local self = setmetatable(Iterator.new(), UnionIterator)
self._first = first
self._second = second
self._keySelector = keySelector or function(x) return x end
self._enumerator = nil
self._set = nil
return self
end
local function enumerateSource(self, _state)
while self._enumerator:moveNext() do
local current = self._enumerator.current
local index = self._enumerator.index
if self._set:add(self._keySelector(current, index)) == true then
self.current = self._enumerator.current
self.index = self._enumerator.index
self._state = _state
return true
end
end
end
function UnionIterator:moveNext()
if self._state == -4 then
return false
end
if self._state == 0 then
self._enumerator = self._first:getEnumerator(self._isArray)
self._set = {}
self._state = 1
end
-- Process _first
if self._state == 1 then
if enumerateSource(self, 1) == true then
return true
end
-- End _first enumeration
self._state = 2
self._enumerator:finalize()
self._enumerator = getTableEnumerator(self._second, self._isArray)
end
-- Process _second
if self._state == 2 then
if enumerateSource(self, 2) == true then
return true
end
-- End _first enumeration
self._state = 3
end
self:finalize()
return false
end
function UnionIterator:finalize()
if self._enumerator then
self._enumerator:finalize()
end
Iterator.finalize(self)
end
function UnionIterator:clone()
return UnionIterator.new(self._first, self._second, self._keySelector)
end
--Produces the set intersection of two sequences according to a specified key selector function.
---@class IntersectIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _keySelector fun(value: any, index: any): any
---@field _enumerator Enumerator
---@field _set table
local IntersectIterator = Iterator.createIteratorSubclass()
---@param first Enumerable
---@param second any
---@param keySelector? fun(value: any, index: any): any
function IntersectIterator.new(first, second, keySelector)
assert(first, 'First table cannot be nil')
assert(second, 'Second table cannot be nil')
local self = setmetatable(Iterator.new(), IntersectIterator)
self._first = first
self._second = second
self._keySelector = keySelector or function(x) return x end
self._enumerator = nil
self._set = nil
return self
end
function IntersectIterator:moveNext()
if self._state ~= 0 then
if self._state ~= 1 then
return false
end
else
self._enumerator = self._first:getEnumerator(self._isArray)
self._set = sourceToSet(self._second)
self._state = 1
end
while self._enumerator:moveNext() do
local current = self._enumerator.current
local index = self._enumerator.index
local key = self._keySelector(current, index)
if setRemove(self._set, key) == true then
self.current = self._enumerator.current
self.index = self._enumerator.index
self._state = 1
return true
end
end
self:finalize()
return false
end
function IntersectIterator:finalize()
if self._enumerator then
self._enumerator:finalize()
end
Iterator.finalize(self)
end
function IntersectIterator:clone()
return IntersectIterator.new(self._first, self._second, self._keySelector)
end
--Applies a specified function to the corresponding elements of two sequences, producing a sequence of the results.
---@class ZipIterator : Iterator
---@field _first Enumerable
---@field _second any
---@field _resultSelector fun(left: any, right: any): any
---@field _enumerator1 Enumerator
---@field _enumerator2 Enumerator
local ZipIterator = Iterator.createIteratorSubclass()
---@param first Enumerable
---@param second any
---@param resultSelector? fun(left: any, right: any): any
function ZipIterator.new(first, second, resultSelector)
assert(first, 'First table cannot be nil')
assert(second, 'Second table cannot be nil')
local self = setmetatable(Iterator.new(), ZipIterator)
self._first = first
self._second = second
self._resultSelector = resultSelector or (function(a, b) return {a, b} end)
self._enumerator1 = nil
self._enumerator2 = nil
return self
end
function ZipIterator:moveNext()
if self._state ~= 0 then
if self._state ~= 1 then
return false
end
else
self._enumerator1 = self._first:getEnumerator(self._isArray)
self._enumerator2 = getTableEnumerator(self._second, self._isArray)
self._state = -4
end
if self._enumerator1:moveNext() == true and self._enumerator2:moveNext() == true then
self.current = self._resultSelector(self._enumerator1.current, self._enumerator2.current)
self.index = self._enumerator1.index
self._state = 1
return true
end
self:finalize()
return false
end
function ZipIterator:finalize()
if self._enumerator1 then
self._enumerator1:finalize()
end
if self._enumerator2 then
self._enumerator2:finalize()
end
Iterator.finalize(self)
end
function ZipIterator:clone()
return ZipIterator.new(self._first, self._second, self._resultSelector)
end
--Groups the elements of a sequence.
---@class GroupByIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByIterator = Iterator.createIteratorSubclass()
---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
function GroupByIterator.new(source, keySelector, elementSelector)
assert(source, 'source cannot be nil')
assert(keySelector, 'keySelector cannot be nil')
local self = setmetatable(Iterator.new(), GroupByIterator)
self._source = source
self._keySelector = keySelector
self._elementSelector = elementSelector or function(x) return x end
self._enumerator = nil
self._lookup = nil
return self
end
function GroupByIterator:moveNext()
if self._state ~= 0 then
if self._state ~= 1 then
return false
end
else
self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
self._enumerator = self._lookup:getEnumerator()
self._state = 1
end
local enum = self._enumerator
if enum:moveNext() == true then
self.current = enum.current
self.index = enum.index
self._state = 1
return true
end
self:finalize()
return false
end
function GroupByIterator:finalize()
if self._enumerator then
self._enumerator:finalize()
end
Iterator.finalize(self)
end
function GroupByIterator:clone()
return GroupByIterator.new(self._source, self._keySelector, self._elementSelector)
end
--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class GroupByResultIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(param: any): any
---@field _elementSelector fun(param: any): any
---@field _resultSelector fun(key: any, grouping: Grouping): any
---@field _enumerator Enumerator
---@field _lookup Lookup
local GroupByResultIterator = Iterator.createIteratorSubclass()
---@param source Enumerable
---@param keySelector fun(param: any): any
---@param elementSelector? fun(param: any): any
---@param resultSelector fun(key: any, grouping: Grouping): any
function GroupByResultIterator.new(source, keySelector, elementSelector, resultSelector)
assert(source, 'source cannot be nil')
assert(keySelector, 'keySelector cannot be nil')
assert(resultSelector, 'keySelector cannot be nil')
local self = setmetatable(Iterator.new(), GroupByResultIterator)
self._source = source
self._keySelector = keySelector
self._elementSelector = elementSelector or function(x) return x end
self._resultSelector = resultSelector
self._enumerator = nil
self._lookup = nil
return self
end
function GroupByResultIterator:moveNext()
if self._state ~= 0 then
if self._state ~= 1 then
return false
end
else
self._lookup = Lookup.new(self._source, self._keySelector, self._elementSelector, self._isArray)
self._enumerator = self._lookup:getEnumerator()
self._state = 1
end
local enumerator = self._enumerator
if enumerator:moveNext() == true then
-- Transform Grouping element for enumeration
local gKey = enumerator.current.key
local gElements = enumerator.current
self.current = self._resultSelector(gKey, gElements)
self.index = enumerator.index
self._state = 1
return true
end
self:finalize()
return false
end
function GroupByResultIterator:finalize()
if self._enumerator then
self._enumerator:finalize()
end
Iterator.finalize(self)
end
function GroupByResultIterator:clone()
return GroupByResultIterator.new(self._source, self._keySelector, self._elementSelector, self._resultSelector)
end
--Groups the elements of a sequence according to a specified key selector function and creates a result value from each group and its key. The elements of each group are projected by using a specified function.
---@class SortableIterator : Iterator
---@field _source Enumerable
---@field _keySelector fun(value: any, index: any): any|?
---@field _descendingSort boolean
---@field _parentIterator SortableIterator
---@field _enumerator Enumerator
local SortableIterator = Iterator.createIteratorSubclass()
---@param source Enumerable
---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@param parent? SortableIterator
---@return SortableIterator
function SortableIterator.new(source, keySelector, descendingSort, parent)
assert(source, 'source cannot be nil')
assert(keySelector, 'keySelector cannot be nil')
assert(parent == nil or parent.getEnumerableSorters ~= nil, 'parent must be of type SortableIterator')
if descendingSort == nil then
descendingSort = false
end
local self = setmetatable(Iterator.new(), SortableIterator)
self._source = source
self._keySelector = keySelector
self._descendingSort = descendingSort
self._parentIterator = parent
self._enumerator = nil
---@diagnostic disable-next-line: return-type-mismatch
return self
end
local function createSortFunc(sorters)
local sortCount = #sorters
if sortCount > 1 then
return function(a, b)
for _, sorter in ipairs(sorters) do
local selector = sorter._keySelector
local descendingSort = sorter._descendingSort
local aVal = selector(a)
local bVal = selector(b)
if aVal ~= bVal then
if descendingSort == true then
return aVal > bVal
else
return aVal < bVal
end
end
end
end
end
-- Faster sort function if there's no nested sorts.
local sorter = sorters[1]
local selector = sorter._keySelector
if sorter._descendingSort == true then
return function(a, b) return selector(a) > selector(b) end
else
return function(a, b) return selector(a) < selector(b) end
end
end
function SortableIterator:moveNext()
if self._state ~= 0 then
if self._state ~= 1 then
return false
end
else
-- Collect sorter meta information.
local sorters = {}
self:getEnumerableSorters(sorters)
local sortFunc = createSortFunc(sorters)
local enumerator = self._source:getEnumerator(self._isArray)
local buffer = {}
-- Process and sort all previous enumerators.
while enumerator:moveNext() == true do
table.insert(buffer, enumerator.current)
end
table.sort(buffer, sortFunc)
self._enumerator = TableEnumerator.createForArray(buffer)
self._state = 1
end
local enumerator = self._enumerator
if enumerator:moveNext() == true then
self.current = enumerator.current
self.index = enumerator.index
self._state = 1
return true
end
self:finalize()
return false
end
---Collects all sorter functions down the line.
---@param sorters table
---@return table
function SortableIterator:getEnumerableSorters(sorters)
-- Get parent sorter first.
if self._parentIterator ~= nil then
self._parentIterator:getEnumerableSorters(sorters)
end
table.insert(sorters, self)
return sorters
end
---@param keySelector fun(value: any, index: any): any
---@param descendingSort? boolean
---@return SortableIterator
function SortableIterator:createSortableIterator(keySelector, descendingSort)
return SortableIterator.new(self._source, keySelector, descendingSort, self)
end
function SortableIterator:finalize()
if self._enumerator then
self._enumerator:finalize()
self._enumerator = nil
end
Iterator.finalize(self)
end
function SortableIterator:clone()
return SortableIterator.new(self._source, self._keySelector, self._descendingSort, self._parentIterator)
end
return {
MapIterator = MapIterator,
WhereIterator = WhereIterator,
FlatMapIterator = FlatMapIterator,
ConcatIterator = ConcatIterator,
AppendIterator = AppendIterator,
UniqueIterator = UniqueIterator,
DifferenceIterator = DifferenceIterator,
UnionIterator = UnionIterator,
IntersectIterator = IntersectIterator,
ZipIterator = ZipIterator,
GroupByIterator = GroupByIterator,
GroupByResultIterator = GroupByResultIterator,
SortableIterator = SortableIterator,
}