Module:FunList/Iterators
From Melvor Idle
Documentation for this module may be created at Module:FunList/Iterators/doc
-- FORWARD DECLARED FUNCTIONS --
-- Checks if the provided object matches the provided type.
-- Returns true if object is of the provided type.
-- (object, type). Returns boolean
local isType
-- Returns a TableEnumerator from the provided object or
-- creates a new one if the object is not of type TableEnumerator
-- (object). Returns TableEnumerator
local getTableEnumerator
-- Attempts to add an object to the provided table as a hashset.
-- Returns True if the object was not already present.
-- (table, object). Returns True
local addToSet
-- CLASS DEFINITIONS --
-- BASE ENUMERATOR CLASS --
---@class Enumerator
---@field current any
---@field index any
---@field state integer
---@field isArray boolean
local Enumerator = {}
local Enumerator_mt = {
__index = Enumerator,
__pairs = function(t) return t:getPairs() end,
__ipairs = function(t) return t:getiPairs()
end
}
---@return Enumerator
function Enumerator.new()
local self = setmetatable({}, Enumerator_mt)
self.current = nil
self.index = nil
self.state = -1
-- Assume by default we are not dealing with a simple array
self.isArray = false
return self
end
---@return boolean
function Enumerator:moveNext()
error('Abstract function must be overridden in derived class.')
end
---@return Enumerator
function Enumerator:getEnumerator(isArray)
-- The default state is -1 which signifies a Enumerator isn't used.
local instance = (self.state == -1) and self or self:clone()
instance.isArray = isArray
instance.state = 0
return instance
end
---@return Enumerator
function Enumerator:clone()
error('Abstract function must be overridden in derived class.')
end
function Enumerator:finalize()
-- Signals invalid state.
self.state = -4
end
-- Hooks the moveNext function into the Lua 'pairs' function
local function overridePairs(enum, startIndex)
-- Get or create clean enumerator. This ensures the state is 0.
local new = enum:getEnumerator(startIndex == 0)
new.current = nil
new.index = startIndex
local function iterator(t, k)
if new:moveNext() == true then
return new.index, new.current
end
return nil, nil
end
return iterator, new, new.index
end
-- Manual override for iterating over the Enumerator using pairs()
function Enumerator:getPairs()
return overridePairs(self, nil)
end
-- Manual override for iterating over the Enumerator using ipairs()
function Enumerator:getiPairs()
return overridePairs(self, 0)
end
-- TABLE ENUMERATOR CLASS --
-- This is essentially a wrapper for the table object,
-- since it provides no state machine esque iteration out of the box
---@class TableEnumerator : Enumerator
---@field state integer
---@field tbl table
local TableEnumerator = setmetatable({}, { __index = Enumerator })
TableEnumerator.__index = TableEnumerator
TableEnumerator.__pairs = Enumerator_mt.__pairs
TableEnumerator.__ipairs = Enumerator_mt.__ipairs
function TableEnumerator.new(tbl)
local self = setmetatable(Enumerator.new(), TableEnumerator)
self.tbl = tbl or {} -- Allow creation of empty enumerable
return self
end
function TableEnumerator:moveNext()
if self.state == 0 then
self.state = 1
self.index = self.isArray and 0 or nil
end
if self.isArray == true then
-- Iterate using ipairs, starting from index 1
self.index = self.index + 1
self.current = self.tbl[self.index]
return self.current ~= nil
else
-- Iterate using pairs
local key = self.index
key = next(self.tbl, key)
self.index = key
if key ~= nil then
self.current = self.tbl[key]
return true
end
end
return false
end
function TableEnumerator:clone()
return TableEnumerator.new(self.tbl)
end
-- SELECT ENUMERATOR --
---@class MapEnumerator : Enumerator
---@field state integer
---@field source Enumerator
---@field selector Enumerator
local MapEnumerator = setmetatable({}, { __index = Enumerator })
MapEnumerator.__index = MapEnumerator
MapEnumerator.__pairs = Enumerator_mt.__pairs
MapEnumerator.__ipairs = Enumerator_mt.__ipairs
function MapEnumerator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
local self = setmetatable(Enumerator.new(), MapEnumerator)
self.source = source
self.selector = selector
self.enumerator = nil
self.position = 0
return self
end
function MapEnumerator:moveNext()
if self.state == 0 then
self.state = 1
self.position = 0
self.enumerator = self.source:getEnumerator(self.isArray)
end
if self.state == 1 then
local enumerator = self.enumerator
if enumerator:moveNext() == true then
local sourceElement = enumerator.current
self.index = enumerator.index -- Preserve index
self.position = self.position + 1
self.current = self.selector(sourceElement, self.position)
assert(self.current, 'Selected value must be non-nil')
return true
end
self:finalize()
end
return false
end
function MapEnumerator:finalize()
if self.enumerator then
self.enumerator:finalize()
end
Enumerator.finalize(self)
end
function MapEnumerator:clone()
return MapEnumerator.new(self.source, self.selector)
end
-- WHERE ENUMERATOR --
---@class WhereEnumerator : Enumerator
---@field source Enumerator
---@field predicate function
local WhereEnumerator = setmetatable({}, { __index = Enumerator })
WhereEnumerator.__index = WhereEnumerator
WhereEnumerator.__pairs = Enumerator_mt.__pairs
WhereEnumerator.__ipairs = Enumerator_mt.__ipairs
function WhereEnumerator.new(source, predicate)
assert(source, 'Source cannot be nil')
assert(predicate, 'Predicate cannot be nil')
local self = setmetatable(Enumerator.new(), WhereEnumerator)
self.source = source
self.predicate = predicate
self.enumerator = nil
return self
end
function WhereEnumerator:moveNext()
if self.state == 0 then
self.state = 1
self.position = 0
self.enumerator = self.source:getEnumerator(self.isArray)
end
if self.state == 1 then
local enumerator = self.enumerator
while enumerator:moveNext() == true do
local sourceElement = enumerator.current
local sourceIndex = enumerator.index
if self.predicate(sourceElement, sourceIndex) == true then
self.index = sourceIndex
self.current = sourceElement
return true
end
end
self:finalize()
end
return false
end
function WhereEnumerator:finalize()
if self.enumerator then
self.enumerator:finalize()
end
Enumerator.finalize(self)
end
function WhereEnumerator:clone()
return WhereEnumerator.new(self.source, self.predicate)
end
-- FLATMAP (SELECTMANY) ENUMERATOR --
---@class FlatMapEnumerator : Enumerator
---@field source Enumerator
---@field selector function
---@field position integer
local FlatMapEnumerator = setmetatable({}, { __index = Enumerator })
FlatMapEnumerator.__index = FlatMapEnumerator
FlatMapEnumerator.__pairs = Enumerator_mt.__pairs
FlatMapEnumerator.__ipairs = Enumerator_mt.__ipairs
function FlatMapEnumerator.new(source, selector)
assert(source, 'Source cannot be nil')
assert(selector, 'Selector cannot be nil')
local self = setmetatable(Enumerator.new(), FlatMapEnumerator)
self.source = source
self.selector = selector
self.position = 0
self.enumerator = nil -- Enumerator of the source Enumerable
self.sourceEnumerator = nil -- Enumerator of the nested Enumerable
return self
end
function FlatMapEnumerator:moveNext()
if self.state == -4 then
return false
end
-- Setup state
if self.state == 0 then
self.position = 0
self.enumerator = self.source:getEnumerator(self.isArray)
self.state = 3 -- signal to get (first) nested enumerator
end
while true do
-- Grab next value from nested enumerator
if self.state == 4 then
if self.sourceEnumerator:moveNext() then
self.current = self.sourceEnumerator.current
self.index = self.sourceEnumerator.index
self.state = 4 -- signal to get next item
return true
else
-- Cleanup nested enumerator
self.sourceEnumerator:finalize()
self.state = 3 -- signal to get next enumerator
end
end
-- Grab nest nested enumerator
if self.state == 3 then
if self.enumerator:moveNext() then
local current = self.enumerator.current
self.position = self.position + 1
local sourceTable = self.selector(current, self.position)
-- Nested tables are never treated as arrays.
self.sourceEnumerator = getTableEnumerator(sourceTable, false)
self.state = 4 -- signal to get next item
else
-- enumerator doesn't have any more nested enumerators.
self:finalize()
return false
end
end
end
end
function FlatMapEnumerator:finalize()
if self.enumerator then
self.enumerator:finalize()
end
if self.sourceEnumerator then
self.sourceEnumerator:finalize()
end
Enumerator.finalize(self)
end
function FlatMapEnumerator:clone()
return FlatMapEnumerator.new(self.source, self.selector)
end
-- CONCAT ENUMERATOR --
---@class ConcatEnumerator : Enumerator
---@field first Enumerator
---@field second Enumerator
local ConcatEnumerator = setmetatable({}, { __index = Enumerator })
ConcatEnumerator.__index = ConcatEnumerator
ConcatEnumerator.__pairs = Enumerator_mt.__pairs
ConcatEnumerator.__ipairs = Enumerator_mt.__ipairs
function ConcatEnumerator.new(first, second)
assert(first, 'First cannot be nil')
assert(second, 'Second cannot be nil')
local self = setmetatable(Enumerator.new(), ConcatEnumerator)
self.first = first
self.second = second
self.enumerator = nil
return self
end
-- Function to grab enumerators in order. We currently only have two so this
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
function ConcatEnumerator:getEnumerable(index)
if index == 0 then
return self.first
elseif index == 1 then
return self.second
else
return nil
end
end
function ConcatEnumerator:moveNext()
if self.state == -4 then
return false
end
if self.state == 0 then
self.enumerator = self:getEnumerable(self.state)
:getEnumerator(self.isArray)
self.state = 1
end
if self.state > 0 then
while true do
if self.enumerator:moveNext() == true then
self.index = self.enumerator.index
self.current = self.enumerator.current
return true
end
self.state = self.state + 1
local next = self:getEnumerable(self.state - 1)
if next ~= nil then
-- Cleanup previous enumerator
self.enumerator:finalize()
self.enumerator = getTableEnumerator(next, self.isArray)
else
self:finalize()
return false
end
end
end
return false
end
function ConcatEnumerator:finalize()
if self.enumerator then
self.enumerator:finalize()
end
Enumerator.finalize(self)
end
function ConcatEnumerator:clone()
return ConcatEnumerator.new(self.first, self.second)
end
-- APPEND ENUMERATOR --
---@class AppendEnumerator : Enumerator
---@field source Enumerator
---@field item any
---@field itemIndex any
---@field append boolean
---@field enumerator Enumerator
local AppendEnumerator = setmetatable({}, { __index = Enumerator })
AppendEnumerator.__index = AppendEnumerator
AppendEnumerator.__pairs = Enumerator_mt.__pairs
AppendEnumerator.__ipairs = Enumerator_mt.__ipairs
function AppendEnumerator.new(source, item, itemIndex, append)
assert(source, 'Source cannot be nil')
assert(item, 'Item cannot be nil')
local self = setmetatable(Enumerator.new(), AppendEnumerator)
self.source = source
self.item = item
-- Index needs *some* value, otherwise iteration stops.
if itemIndex == nil then self.itemIndex = item else self.itemIndex = itemIndex end
if append == nil then self.append = true else self.append = append end
self.enumerator = nil
return self
end
function AppendEnumerator:moveNext()
if self.state == 0 then
self.state = 1
-- Put the item in front, as a prepend.
if self.append == false then
self.current = self.item
self.index = self.itemIndex
return true
end
end
-- Grab the source enumerator
if self.state == 1 then
self.enumerator = self.source:getEnumerator(self.isArray)
self.state = 2
end
if self.state == 2 then
if self.enumerator:moveNext() then
self.current = self.enumerator.current
self.index = self.enumerator.index
return true
else
self:finalize()
end
if self.append == true then
self.current = self.item
self.index = self.itemIndex
return true
end
end
return false
end
function AppendEnumerator:finalize()
if self.enumerator then
self.enumerator:finalize()
end
Enumerator.finalize(self)
end
function AppendEnumerator:clone()
return AppendEnumerator.new(self.source, self.item, self.itemIndex, self.append)
end
-- UNIQUE (DISTINCT) ENUMERATOR --
---@class UniqueEnumerator : Enumerator
---@field source Enumerator
---@field selector function
---@field enumerator Enumerator
---@field set table
local UniqueEnumerator = setmetatable({}, { __index = Enumerator })
UniqueEnumerator.__index = UniqueEnumerator
UniqueEnumerator.__pairs = Enumerator_mt.__pairs
UniqueEnumerator.__ipairs = Enumerator_mt.__ipairs
function UniqueEnumerator.new(source, selector)
assert(source, 'Source cannot be nil')
local self = setmetatable(Enumerator.new(), UniqueEnumerator)
self.source = source
self.selector = selector
self.enumerator = nil
self.set = nil
return self
end
function UniqueEnumerator:moveNext()
if self.state == -4 then
return false
end
if self.state == 0 then
self.enumerator = self.source:getEnumerator(self.isArray)
-- If we have any items, create a hashtable. Otherwise abort.
if self.enumerator:moveNext() == true then
self.set = {}
self.state = 1
else
self:finalize()
return false
end
end
while true do
if self.state == 1 then
self.state = 2
local current = self.enumerator.current
local index = self.enumerator.index
-- Manipulate the item if we have a selector (DistinctBy)
if self.selector then
current = self.selector(current, index)
end
if addToSet(self.set, current) then
self.current = self.enumerator.current
self.index = self.enumerator.index
return true
end
end
-- Try to grab a new item.
if self.state == 2 then
if self.enumerator:moveNext() == true then
self.state = 1
else
self:finalize()
return false
end
end
end
end
function UniqueEnumerator:finalize()
if self.enumerator then
self.enumerator:finalize()
end
Enumerator.finalize(self)
end
function UniqueEnumerator:clone()
return UniqueEnumerator.new(self.source, self.selector)
end
-- EXCEPT ENUMERATOR --
---@class DifferenceEnumerator : Enumerator
---@field first Enumerator
---@field second any
---@field selector function
---@field enumerator Enumerator
---@field set table
local DifferenceEnumerator = setmetatable({}, { __index = Enumerator })
DifferenceEnumerator.__index = DifferenceEnumerator
DifferenceEnumerator.__pairs = Enumerator_mt.__pairs
DifferenceEnumerator.__ipairs = Enumerator_mt.__ipairs
function DifferenceEnumerator.new(first, second, selector)
assert(first, 'First table cannot be nil')
assert(second, 'Second table cannot be nil')
local self = setmetatable(Enumerator.new(), DifferenceEnumerator)
self.first = first
self.second = second
self.selector = selector
self.enumerator = nil
self.set = nil
return self
end
local function createSet(other)
local set = {}
if isType(other, Enumerator) then
for _, v in other:getPairs() do
set[v] = true
end
else
assert(type(other) == 'table', 'Source must be a table.')
for _, v in pairs(other) do
set[v] = true
end
end
return set
end
function DifferenceEnumerator:moveNext()
if self.state == -4 then
return false
end
if self.state == 0 then
self.enumerator = self.first:getEnumerator(self.isArray)
self.set = createSet(self.second)
self.state = 1
end
while self.enumerator:moveNext() do
local current = self.enumerator.current
local index = self.enumerator.index
if self.selector ~= nil then
current = self.selector(current, index)
end
if addToSet(self.set, current) then
self.current = self.enumerator.current
self.index = self.enumerator.index
self.state = 1
return true
end
end
self:finalize()
return false
end
function DifferenceEnumerator:finalize()
if self.enumerator then
self.enumerator:finalize()
end
end
function DifferenceEnumerator:clone()
return DifferenceEnumerator.new(self.first, self.second, self.selector)
end
-- UNION ENUMERATOR --
---@class UnionEnumerator : Enumerator
---@field first Enumerator
---@field second any
---@field selector function
---@field enumerator Enumerator
---@field set table
local UnionEnumerator = setmetatable({}, { __index = Enumerator })
UnionEnumerator.__index = UnionEnumerator
UnionEnumerator.__pairs = Enumerator_mt.__pairs
UnionEnumerator.__ipairs = Enumerator_mt.__ipairs
function UnionEnumerator.new(first, second, selector)
assert(first, 'First table cannot be nil')
assert(second, 'Second table cannot be nil')
local self = setmetatable(Enumerator.new(), UnionEnumerator)
self.first = first
self.second = second
self.selector = selector
self.enumerator = nil
self.set = nil
return self
end
local function enumerateSource(self, state)
while self.enumerator:moveNext() do
local current = self.enumerator.current
local index = self.enumerator.index
if self.selector ~= nil then
current = self.selector(current, index)
end
if addToSet(self.set, current) == true then
self.current = self.enumerator.current
self.index = self.enumerator.index
self.state = state
return true
end
end
end
function UnionEnumerator:moveNext()
if self.state == -4 then
return false
end
if self.state == 0 then
self.enumerator = self.first:getEnumerator(self.isArray)
self.set = {}
self.state = 1
end
-- Process first
if self.state == 1 then
if enumerateSource(self, 1) == true then
return true
end
-- End first enumeration
self.state = 2
self.enumerator:finalize()
self.enumerator = getTableEnumerator(self.second, self.isArray)
end
-- Process second
if self.state == 2 then
if enumerateSource(self, 2) == true then
return true
end
-- End first enumeration
self.state = 3
end
self:finalize()
return false
end
function UnionEnumerator:finalize()
if self.enumerator then
self.enumerator:finalize()
end
Enumerator.finalize(self)
end
function UnionEnumerator:clone()
return UnionEnumerator.new(self.first, self.second, self.selector)
end
--TODO:
-- INTERSECT
-- GROUPBY
-- ZIP
-- ORDER/SORTING???
--Optional:
---Take
---Skip
-- Define forward declared functions
isType = function(obj, class)
local mt = getmetatable(obj)
while mt do
if mt.__index == class then
return true
end
mt = getmetatable(mt)
end
end
getTableEnumerator = function(sourceTable, isArray)
if not isType(sourceTable, Enumerator) then
return TableEnumerator.new(sourceTable):getEnumerator(isArray)
else
return sourceTable
end
end
addToSet = function(set, item)
if set[item] == nil then
set[item] = true
return true
else
return false
end
end
return {
Enumerator = Enumerator,
TableEnumerator = TableEnumerator,
MapEnumerator = MapEnumerator,
WhereEnumerator = WhereEnumerator,
FlatMapEnumerator = FlatMapEnumerator,
ConcatEnumerator = ConcatEnumerator,
AppendEnumerator = AppendEnumerator,
UniqueEnumerator = UniqueEnumerator,
DifferenceEnumerator = DifferenceEnumerator,
UnionEnumerator = UnionEnumerator
}