Module:FunList/Iterators

From Melvor Idle

Documentation for this module may be created at Module:FunList/Iterators/doc

-- FORWARD DECLARED FUNCTIONS --

-- Checks if the provided object matches the provided type.
-- Returns true if object is of the provided type.
-- (object, type). Returns boolean
local isType

-- Returns a TableEnumerator from the provided object or
-- creates a new one if the object is not of type TableEnumerator
-- (object). Returns TableEnumerator
local getTableEnumerator

-- Attempts to add an object to the provided table as a hashset.
-- Returns True if the object was not already present.
-- (table, object). Returns True
local addToSet

-- CLASS DEFINITIONS --
-- BASE ENUMERATOR CLASS --
---@class Enumerator
---@field current any
---@field index any
---@field state integer
---@field isArray boolean
local Enumerator = {}
local Enumerator_mt = {
	__index = Enumerator,
	__pairs = function(t) return t:getPairs() end,
	__ipairs = function(t) return t:getiPairs()
    end
}

---@return Enumerator
function Enumerator.new()
	local self = setmetatable({}, Enumerator_mt)
	self.current = nil
	self.index = nil
    self.state = -1
	-- Assume by default we are not dealing with a simple array
	self.isArray = false
	return self
end

---@return boolean
function Enumerator:moveNext()
    error('Abstract function must be overridden in derived class.')
end

---@return Enumerator
function Enumerator:getEnumerator(isArray)
    -- The default state is -1 which signifies a Enumerator isn't used.
    local instance = (self.state == -1) and self or self:clone()
    instance.isArray = isArray
    instance.state = 0
    return instance
end

---@return Enumerator
function Enumerator:clone()
    error('Abstract function must be overridden in derived class.')
end

function Enumerator:finalize()
    -- Signals invalid state.
    self.state = -4
end

-- Hooks the moveNext function into the Lua 'pairs' function
local function overridePairs(enum, startIndex)
	-- Get or create clean enumerator. This ensures the state is 0.
	local new = enum:getEnumerator(startIndex == 0)
	new.current = nil
	new.index = startIndex
	local function iterator(t, k)
		if new:moveNext() == true then
			return new.index, new.current
		end
		return nil, nil
	end

	return iterator, new, new.index
end

-- Manual override for iterating over the Enumerator using pairs()
function Enumerator:getPairs()
    return overridePairs(self, nil)
end

-- Manual override for iterating over the Enumerator using ipairs()
function Enumerator:getiPairs()
    return overridePairs(self, 0)
end

-- TABLE ENUMERATOR CLASS --
-- This is essentially a wrapper for the table object, 
-- since it provides no state machine esque iteration out of the box
---@class TableEnumerator : Enumerator
---@field state integer
---@field tbl table
local TableEnumerator = setmetatable({}, { __index = Enumerator })
TableEnumerator.__index = TableEnumerator
TableEnumerator.__pairs =  Enumerator_mt.__pairs
TableEnumerator.__ipairs = Enumerator_mt.__ipairs

function TableEnumerator.new(tbl)
    local self = setmetatable(Enumerator.new(), TableEnumerator)
    self.tbl = tbl or {} -- Allow creation of empty enumerable

    return self
end

function TableEnumerator:moveNext()
    if self.state == 0 then
        self.state = 1
		self.index = self.isArray and 0 or nil
    end

    if self.isArray == true then
        -- Iterate using ipairs, starting from index 1
        self.index = self.index + 1
        self.current = self.tbl[self.index]
        return self.current ~= nil
    else
        -- Iterate using pairs
        local key = self.index
        key = next(self.tbl, key)
        self.index = key
        if key ~= nil then
            self.current = self.tbl[key]
            return true
        end
    end

    return false
end

function TableEnumerator:clone()
    return TableEnumerator.new(self.tbl)
end

-- SELECT ENUMERATOR --
---@class MapEnumerator : Enumerator
---@field state integer
---@field source Enumerator
---@field selector Enumerator
local MapEnumerator = setmetatable({}, { __index = Enumerator })
MapEnumerator.__index = MapEnumerator
MapEnumerator.__pairs = Enumerator_mt.__pairs
MapEnumerator.__ipairs = Enumerator_mt.__ipairs

function MapEnumerator.new(source, selector)
	assert(source, 'Source cannot be nil')
	assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Enumerator.new(), MapEnumerator)
    self.source = source
    self.selector = selector
    self.enumerator = nil
    self.position = 0
    return self
end

function MapEnumerator:moveNext()
	if self.state == 0 then
		self.state = 1
		self.position = 0
		self.enumerator = self.source:getEnumerator(self.isArray)
	end

	if self.state == 1 then
		local enumerator = self.enumerator
		if enumerator:moveNext() == true then
			local sourceElement = enumerator.current
			self.index = enumerator.index -- Preserve index
			self.position = self.position + 1
			self.current = self.selector(sourceElement, self.position)
			assert(self.current, 'Selected value must be non-nil')
			return true
        end

        self:finalize()
	end
	return false
end

function MapEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
    Enumerator.finalize(self)
end

function MapEnumerator:clone()
    return MapEnumerator.new(self.source, self.selector)
end

-- WHERE ENUMERATOR --
---@class WhereEnumerator : Enumerator
---@field source Enumerator
---@field predicate function
local WhereEnumerator = setmetatable({}, { __index = Enumerator })
WhereEnumerator.__index = WhereEnumerator
WhereEnumerator.__pairs =  Enumerator_mt.__pairs
WhereEnumerator.__ipairs = Enumerator_mt.__ipairs

function WhereEnumerator.new(source, predicate)
	assert(source, 'Source cannot be nil')
	assert(predicate, 'Predicate cannot be nil')
    local self = setmetatable(Enumerator.new(), WhereEnumerator)
    self.source = source
    self.predicate = predicate
    self.enumerator = nil
    return self
end

function WhereEnumerator:moveNext()
	if self.state == 0 then
		self.state = 1
		self.position = 0
		self.enumerator = self.source:getEnumerator(self.isArray)
	end

	if self.state == 1 then
		local enumerator = self.enumerator
		while enumerator:moveNext() == true do
			local sourceElement = enumerator.current
            local sourceIndex = enumerator.index
			if self.predicate(sourceElement, sourceIndex) == true then
				self.index = sourceIndex
				self.current = sourceElement
				return true
			end
		end

        self:finalize()
	end

	return false
end

function WhereEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
    Enumerator.finalize(self)
end

function WhereEnumerator:clone()
    return WhereEnumerator.new(self.source, self.predicate)
end

-- FLATMAP (SELECTMANY) ENUMERATOR --
---@class FlatMapEnumerator : Enumerator
---@field source Enumerator
---@field selector function
---@field position integer
local FlatMapEnumerator = setmetatable({}, { __index = Enumerator })
FlatMapEnumerator.__index = FlatMapEnumerator
FlatMapEnumerator.__pairs = Enumerator_mt.__pairs
FlatMapEnumerator.__ipairs = Enumerator_mt.__ipairs

function FlatMapEnumerator.new(source, selector)
	assert(source, 'Source cannot be nil')
	assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Enumerator.new(), FlatMapEnumerator)
    self.source = source
    self.selector = selector
    self.position = 0
    self.enumerator = nil		 -- Enumerator of the source Enumerable
    self.sourceEnumerator = nil  -- Enumerator of the nested Enumerable
    return self
end

function FlatMapEnumerator:moveNext()
    if self.state == -4 then
        return false
    end

    -- Setup state
    if self.state == 0 then
        self.position = 0
        self.enumerator = self.source:getEnumerator(self.isArray)
        self.state = 3 -- signal to get (first) nested enumerator
    end
    while true do
        -- Grab next value from nested enumerator		
        if self.state == 4 then
            if self.sourceEnumerator:moveNext() then
                self.current = self.sourceEnumerator.current
                self.index = self.sourceEnumerator.index
                self.state = 4 -- signal to get next item
                return true
            else
                -- Cleanup nested enumerator
                self.sourceEnumerator:finalize()
                self.state = 3 -- signal to get next enumerator
            end
        end

		-- Grab nest nested enumerator
        if self.state == 3 then
            if self.enumerator:moveNext() then
                local current = self.enumerator.current
                self.position = self.position + 1

                local sourceTable = self.selector(current, self.position)
                -- Nested tables are never treated as arrays.
                self.sourceEnumerator = getTableEnumerator(sourceTable, false)
                self.state = 4 -- signal to get next item
            else
            	-- enumerator doesn't have any more nested enumerators.
                self:finalize()
                return false
            end
        end
    end
end

function FlatMapEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
    if self.sourceEnumerator then
        self.sourceEnumerator:finalize()
    end

    Enumerator.finalize(self)
end

function FlatMapEnumerator:clone()
	return FlatMapEnumerator.new(self.source, self.selector)
end

-- CONCAT ENUMERATOR --
---@class ConcatEnumerator : Enumerator
---@field first Enumerator
---@field second Enumerator
local ConcatEnumerator = setmetatable({}, { __index = Enumerator })
ConcatEnumerator.__index = ConcatEnumerator
ConcatEnumerator.__pairs = Enumerator_mt.__pairs
ConcatEnumerator.__ipairs = Enumerator_mt.__ipairs

function ConcatEnumerator.new(first, second)
	assert(first, 'First cannot be nil')
	assert(second, 'Second cannot be nil')
    local self = setmetatable(Enumerator.new(), ConcatEnumerator)
    self.first = first
    self.second = second
    self.enumerator = nil
    return self
end

-- Function to grab enumerators in order. We currently only have two so this
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
function ConcatEnumerator:getEnumerable(index)
	if index == 0 then
		return self.first
	elseif index == 1 then
		return self.second
	else
		return nil
	end
end

function ConcatEnumerator:moveNext()
    if self.state == -4 then
        return false
    end

	if self.state == 0 then
		self.enumerator = self:getEnumerable(self.state)
			:getEnumerator(self.isArray)
		self.state = 1
	end

	if self.state > 0 then
		while true do
			if self.enumerator:moveNext() == true then
				self.index = self.enumerator.index
				self.current = self.enumerator.current
				return true
			end

			self.state = self.state + 1
			local next = self:getEnumerable(self.state - 1)
			if next ~= nil then
                -- Cleanup previous enumerator
                self.enumerator:finalize()
				self.enumerator = getTableEnumerator(next, self.isArray)
			else
                self:finalize()
				return false
			end
		end
	end

	return false
end

function ConcatEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end

    Enumerator.finalize(self)
end

function ConcatEnumerator:clone()
    return ConcatEnumerator.new(self.first, self.second)
end

-- APPEND ENUMERATOR --
---@class AppendEnumerator : Enumerator
---@field source Enumerator
---@field item any
---@field itemIndex any
---@field append boolean
---@field enumerator Enumerator
local AppendEnumerator = setmetatable({}, { __index = Enumerator })
AppendEnumerator.__index = AppendEnumerator
AppendEnumerator.__pairs = Enumerator_mt.__pairs
AppendEnumerator.__ipairs = Enumerator_mt.__ipairs

function AppendEnumerator.new(source, item, itemIndex, append)
	assert(source, 'Source cannot be nil')
	assert(item, 'Item cannot be nil')
    local self = setmetatable(Enumerator.new(), AppendEnumerator)
    self.source = source
    self.item = item
    -- Index needs *some* value, otherwise iteration stops.
    if itemIndex == nil then self.itemIndex = item else self.itemIndex = itemIndex end
    if append == nil then self.append = true else self.append = append end
    self.enumerator = nil
    return self
end

function AppendEnumerator:moveNext()
    if self.state == 0 then
        self.state = 1
        -- Put the item in front, as a prepend.
        if self.append == false then
            self.current = self.item
            self.index = self.itemIndex
            return true
        end
    end

    -- Grab the source enumerator
    if self.state == 1 then
        self.enumerator = self.source:getEnumerator(self.isArray)
        self.state = 2
    end

    if self.state == 2 then
        if self.enumerator:moveNext() then
            self.current = self.enumerator.current
            self.index = self.enumerator.index
            return true
        else
            self:finalize()
        end

        if self.append == true then
            self.current = self.item
            self.index = self.itemIndex
            return true
        end
    end

    return false
end

function AppendEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end

    Enumerator.finalize(self)
end

function AppendEnumerator:clone()
    return AppendEnumerator.new(self.source, self.item, self.itemIndex, self.append)
end

-- UNIQUE (DISTINCT) ENUMERATOR --
---@class UniqueEnumerator : Enumerator
---@field source Enumerator
---@field selector function
---@field enumerator Enumerator
---@field set table
local UniqueEnumerator = setmetatable({}, { __index = Enumerator })
UniqueEnumerator.__index = UniqueEnumerator
UniqueEnumerator.__pairs = Enumerator_mt.__pairs
UniqueEnumerator.__ipairs = Enumerator_mt.__ipairs

function UniqueEnumerator.new(source, selector)
	assert(source, 'Source cannot be nil')
    local self = setmetatable(Enumerator.new(), UniqueEnumerator)
    self.source = source
    self.selector = selector
    self.enumerator = nil
    self.set = nil
    return self
end

function UniqueEnumerator:moveNext()
    if self.state == -4 then
        return false
    end

    if self.state == 0 then
        self.enumerator = self.source:getEnumerator(self.isArray)
        -- If we have any items, create a hashtable. Otherwise abort.
        if self.enumerator:moveNext() == true then
            self.set = {}
            self.state = 1
        else
            self:finalize()
            return false
        end
    end

    while true do
        if self.state == 1 then
            self.state = 2
            local current = self.enumerator.current
            local index = self.enumerator.index
            -- Manipulate the item if we have a selector (DistinctBy)
            if self.selector then
                current = self.selector(current, index)
            end

            if addToSet(self.set, current) then
                self.current = self.enumerator.current
                self.index = self.enumerator.index
                return true
            end
        end

        -- Try to grab a new item.
        if self.state == 2 then
            if self.enumerator:moveNext() == true then
                self.state = 1
            else
                self:finalize()
                return false
            end
        end
    end
end

function UniqueEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end

    Enumerator.finalize(self)
end

function UniqueEnumerator:clone()
    return UniqueEnumerator.new(self.source, self.selector)
end

-- EXCEPT ENUMERATOR --
---@class DifferenceEnumerator : Enumerator
---@field first Enumerator
---@field second any
---@field selector function
---@field enumerator Enumerator
---@field set table
local DifferenceEnumerator = setmetatable({}, { __index = Enumerator })
DifferenceEnumerator.__index = DifferenceEnumerator
DifferenceEnumerator.__pairs = Enumerator_mt.__pairs
DifferenceEnumerator.__ipairs = Enumerator_mt.__ipairs

function DifferenceEnumerator.new(first, second, selector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Enumerator.new(), DifferenceEnumerator)
    self.first = first
    self.second = second
    self.selector = selector
    self.enumerator = nil
    self.set = nil
    return self
end

local function createSet(other)
    local set = {}
    if isType(other, Enumerator) then
        for _, v in other:getPairs() do
            set[v] = true
        end
    else
        assert(type(other) == 'table', 'Source must be a table.')
        for _, v in pairs(other) do
            set[v] = true
        end
    end

    return set
end

function DifferenceEnumerator:moveNext()
    if self.state == -4 then
        return false
    end

    if self.state == 0 then
        self.enumerator = self.first:getEnumerator(self.isArray)
        self.set = createSet(self.second)
        self.state = 1
    end

    while self.enumerator:moveNext() do
        local current = self.enumerator.current
        local index = self.enumerator.index
        if self.selector ~= nil then 
            current = self.selector(current, index)
        end
        if addToSet(self.set, current) then
            self.current = self.enumerator.current
            self.index = self.enumerator.index
            self.state = 1
            return true
        end
    end

    self:finalize()
    return false
end

function DifferenceEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end
end

function DifferenceEnumerator:clone()
    return DifferenceEnumerator.new(self.first, self.second, self.selector)
end

-- UNION ENUMERATOR --
---@class UnionEnumerator : Enumerator
---@field first Enumerator
---@field second any
---@field selector function
---@field enumerator Enumerator
---@field set table
local UnionEnumerator = setmetatable({}, { __index = Enumerator })
UnionEnumerator.__index = UnionEnumerator
UnionEnumerator.__pairs = Enumerator_mt.__pairs
UnionEnumerator.__ipairs = Enumerator_mt.__ipairs

function UnionEnumerator.new(first, second, selector)
	assert(first, 'First table cannot be nil')
    assert(second, 'Second table cannot be nil')
    local self = setmetatable(Enumerator.new(), UnionEnumerator)
    self.first = first
    self.second = second
    self.selector = selector
    self.enumerator = nil
    self.set = nil
    return self
end

local function enumerateSource(self, state)
    while self.enumerator:moveNext() do
        local current = self.enumerator.current
        local index = self.enumerator.index
        if self.selector ~= nil then
            current = self.selector(current, index)
        end
        if addToSet(self.set, current) == true then
            self.current = self.enumerator.current
            self.index = self.enumerator.index
            self.state = state
            return true
        end
    end
end

function UnionEnumerator:moveNext()
    if self.state == -4 then
        return false
    end

    if self.state == 0 then
        self.enumerator = self.first:getEnumerator(self.isArray)
        self.set = {}
        self.state = 1
    end

    -- Process first
    if self.state == 1 then
        if enumerateSource(self, 1) == true then
            return true
        end

        -- End first enumeration
        self.state = 2
        self.enumerator:finalize()
        self.enumerator = getTableEnumerator(self.second, self.isArray)
    end

    -- Process second
    if self.state == 2 then
        if enumerateSource(self, 2) == true then
            return true
        end

        -- End first enumeration
        self.state = 3
    end

    self:finalize()
    return false
end

function UnionEnumerator:finalize()
    if self.enumerator then
        self.enumerator:finalize()
    end

    Enumerator.finalize(self)
end

function UnionEnumerator:clone()
    return UnionEnumerator.new(self.first, self.second, self.selector)
end

--TODO:
-- INTERSECT
-- GROUPBY
-- ZIP


-- ORDER/SORTING???

--Optional:
---Take
---Skip

-- Define forward declared functions
isType = function(obj, class)
    local mt = getmetatable(obj)
    while mt do
        if mt.__index == class then
            return true
        end
        mt = getmetatable(mt)
    end
end

getTableEnumerator = function(sourceTable, isArray)
    if not isType(sourceTable, Enumerator) then
        return TableEnumerator.new(sourceTable):getEnumerator(isArray)
    else
        return sourceTable
    end
end

addToSet = function(set, item)
    if set[item] == nil then
        set[item] = true
        return true
    else
        return false
    end
end



return {
    Enumerator = Enumerator,
    TableEnumerator = TableEnumerator,
    MapEnumerator = MapEnumerator,
    WhereEnumerator = WhereEnumerator,
    FlatMapEnumerator = FlatMapEnumerator,
    ConcatEnumerator = ConcatEnumerator,
    AppendEnumerator = AppendEnumerator,
    UniqueEnumerator = UniqueEnumerator,
    DifferenceEnumerator = DifferenceEnumerator,
    UnionEnumerator = UnionEnumerator
}