Module:FunList/Iterators: Difference between revisions

From Melvor Idle
(Add ConcatEnumerator)
No edit summary
Line 12: Line 12:


-- BASE ENUMERATOR CLASS --
-- BASE ENUMERATOR CLASS --
--enumerable = {}
local Enumerator = {}
local Enumerator = {}
local Enumerator_mt = {
local Enumerator_mt = {
Line 40: Line 39:
function Enumerator:overridePairs(startIndex)
function Enumerator:overridePairs(startIndex)
-- Get or create clean enumerator. This ensures the state is 0.
-- Get or create clean enumerator. This ensures the state is 0.
local enum = self:getEnumerator(isArray)
local enum = self:getEnumerator(startIndex == 0)
enum.current = nil
enum.current = nil
enum.index = startIndex
enum.index = startIndex

Revision as of 00:17, 22 July 2024

Documentation for this module may be created at Module:FunList/Iterators/doc

-- Helper Functions --
local function isType(obj, class)
    local mt = getmetatable(obj)
    while mt do
        if mt == class then
            return true
        end
        mt = getmetatable(mt)
    end
    return false
end

-- BASE ENUMERATOR CLASS --
local Enumerator = {}
local Enumerator_mt = {
	__index = Enumerator,
	__pairs = function(t) return t:overridePairs() end,
	__ipairs = function(t) return t:overridePairs(0) end
}

function Enumerator.new()
	local self = setmetatable({}, Enumerator_mt)
	self.current = nil
	self.index = nil
	-- Assume by default we are not dealing with a simple array
	self.isArray = false
	return self
end

function Enumerator:moveNext()
	error('Not implemented in base class.')
end

function Enumerator:getEnumerator(isArray)
	error('Not implemented in base class.')
end

-- Hooks the moveNext function into the Lua 'pairs' function
function Enumerator:overridePairs(startIndex)
	-- Get or create clean enumerator. This ensures the state is 0.
	local enum = self:getEnumerator(startIndex == 0)
	enum.current = nil
	enum.index = startIndex
	local function iterator(t, k)
		if enum:moveNext() == true then
			return enum.index, enum.current
		end
		return nil, nil
	end
	
	return iterator, enum, enum.index
end

-- TABLE ENUMERATOR CLASS --
-- This is essentially a wrapper for the table object, 
-- since it provides no state machine esque iteration out of the box
local TableEnumerator = setmetatable({}, { __index = Enumerator })
TableEnumerator.__index = TableEnumerator
TableEnumerator.__pairs =  Enumerator_mt.__pairs
TableEnumerator.__ipairs = Enumerator_mt.__ipairs

function TableEnumerator.new(tbl)
    local self = setmetatable(Enumerator.new(), TableEnumerator)
    self.tbl = tbl or {} -- Allow creation of empty enumerable
    self.state = 0

    return self
end

function TableEnumerator:moveNext()
    if self.state == 0 then
        self.state = 1
		self.index = self.isArray and 0 or nil
    end

    if self.isArray == true then
        -- Iterate using ipairs, starting from index 1
        self.index = self.index + 1
        self.current = self.tbl[self.index]
        return self.current ~= nil
    else
        -- Iterate using pairs
        local key = self.index
        key = next(self.tbl, key)
        self.index = key
        if key ~= nil then
            self.current = self.tbl[key]
            return true
        end
    end

    return false
end

-- startIndex is used to determine if the underlying table should be treated
-- as an array or as a mixed table. It is ignored in the other enumerators as
-- they just call moveNext on the enumerator instead.
function TableEnumerator:getEnumerator(isArray)
    local instance = (self.state == 0) and self or TableEnumerator.new(self.tbl)
    instance.isArray = isArray
    return instance
end

-- SELECT ENUMERATOR --
local SelectEnumerator = setmetatable({}, { __index = Enumerator })
SelectEnumerator.__index = SelectEnumerator
SelectEnumerator.__pairs = Enumerator_mt.__pairs
SelectEnumerator.__ipairs = Enumerator_mt.__ipairs

function SelectEnumerator.new(source, selector)
	assert(source, 'Source cannot be nil')
	assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Enumerator.new(), SelectEnumerator)
    self.state = 0
    self.source = source
    self.selector = selector
    self.enumerator = nil
    self.position = 0
    return self	
end

function SelectEnumerator:moveNext()
	if self.state == 0 then
		self.state = 1
		self.position = 0
		self.enumerator = self.source:getEnumerator(self.isArray)
	end
	
	if self.state == 1 then
		local enumerator = self.enumerator
		if enumerator:moveNext() == true then
			local sourceElement = enumerator.current
			self.index = enumerator.index -- Preserve index
			self.position = self.position + 1
			self.current = self.selector(sourceElement, self.position)
			assert(self.current, 'Selected value must be non-nil')
			return true
		end
	end
	return false
end

function SelectEnumerator:getEnumerator(isArray)
    local instance = (self.state == 0) and self or SelectEnumerator.new(self.source, self.selector)
    instance.isArray = isArray
    return instance
end

-- WHERE ENUMERATOR --
local WhereEnumerator = setmetatable({}, { __index = Enumerator })
WhereEnumerator.__index = WhereEnumerator
WhereEnumerator.__pairs =  Enumerator_mt.__pairs
WhereEnumerator.__ipairs = Enumerator_mt.__ipairs

function WhereEnumerator.new(source, predicate)
	assert(source, 'Source cannot be nil')
	assert(predicate, 'Predicate cannot be nil')
    local self = setmetatable(Enumerator.new(), WhereEnumerator)
    self.state = 0
    self.source = source
    self.predicate = predicate
    self.enumerator = nil
    return self	
end

function WhereEnumerator:moveNext()
	if self.state == 0 then
		self.state = 1
		self.position = 0
		self.enumerator = self.source:getEnumerator(self.isArray)
	end
	
	if self.state == 1 then
		local enumerator = self.enumerator
		while enumerator:moveNext() == true do
			local sourceElement = enumerator.current
			if self.predicate(sourceElement) == true then
				self.index = enumerator.index
				self.current = sourceElement
				return true
			end
		end
	end
	
	return false
end

function WhereEnumerator:getEnumerator()
    local instance = (self.state == 0) and self or WhereEnumerator.new(self.source, self.predicate)
    instance.isArray = isArray
    return instance
end

-- SELECTMANY ENUMERATOR --
local SelectManyEnumerator = setmetatable({}, { __index = Enumerator })
SelectManyEnumerator.__index = SelectManyEnumerator
SelectManyEnumerator.__pairs = Enumerator_mt.__pairs
SelectManyEnumerator.__ipairs = Enumerator_mt.__ipairs

function SelectManyEnumerator.new(source, selector)
	assert(source, 'Source cannot be nil')
	assert(selector, 'Selector cannot be nil')
    local self = setmetatable(Enumerator.new(), SelectManyEnumerator)
    self.state = 0
    self.source = source
    self.selector = selector
    self.position = 0
    self.enumerator = nil		 -- Enumerator of the source Enumerable
    self.sourceEnumerator = nil  -- Enumerator of the nested Enumerable
    return self	
end

function SelectManyEnumerator:moveNext()
    while true do
        -- Setup state
        if self.state == 0 then
            self.position = 0
            self.enumerator = self.source:getEnumerator(self.isArray)
            self.state = -3 -- signal to get (first) nested enumerator
        end

        -- Grab next value from nested enumerator		
        if self.state == -4 then
            if self.sourceEnumerator:moveNext() then
                self.current = self.sourceEnumerator.current
                self.index = self.sourceEnumerator.index
                self.state = -4 -- signal to get next item
                return true
            else
                self.state = -3 -- signal to get next enumerator
            end
        end
		
		-- Grab nest nested enumerator
        if self.state == -3 then
            if self.enumerator:moveNext() then
                local current = self.enumerator.current
                self.position = self.position + 1

                local sourceTable = self.selector(current, self.position)
                if not isType(sourceTable, Enumerator) then
                	-- We need to turn the nested table into an enumerator
                	self.sourceEnumerator = TableEnumerator.new(sourceTable)
                		:getEnumerator(self.isArray)
                else
                	self.sourceEnumerator = sourceTable
                end
                self.state = -4 -- signal to get next item
            else
            	-- enumerator doesn't have any more nested enumerators.
                return false
            end
        end
    end
end


function SelectManyEnumerator:getEnumerator(isArray)
	local instance = (self.state == 0) and self or SelectManyEnumerator.new(self.source, self.selector)
    instance.isArray = isArray
    return instance
end

-- CONCAT ENUMERATOR --
local ConcatEnumerator = setmetatable({}, { __index = Enumerator })
ConcatEnumerator.__index = ConcatEnumerator
ConcatEnumerator.__pairs = Enumerator_mt.__pairs
ConcatEnumerator.__ipairs = Enumerator_mt.__ipairs

function ConcatEnumerator.new(first, second)
	assert(first, 'First cannot be nil')
	assert(second, 'Second cannot be nil')
    local self = setmetatable(Enumerator.new(), ConcatEnumerator)
    self.state = 0
    self.first = first
    self.second = second
    self.enumerator = nil
    return self	
end

-- Function to grab enumerators in order. We currently only have two so this
-- seems a bit redundant. But it would allow to add N-amount of enumerables.
function ConcatEnumerator:getEnumerable(index)
	if index == 0 then 
		return self.first
	elseif index == 1 then
		return self.second
	else
		return nil
	end
end

function ConcatEnumerator:moveNext()
	if self.state == 0 then
		self.enumerator = self:getEnumerable(self.state)
			:getEnumerator(self.isArray)
		self.state = 1
	end
	
	if self.state > 0 then
		while true do
			if self.enumerator:moveNext() == true then
				self.index = self.enumerator.index
				self.current = self.enumerator.current
				return true
			end

			self.state = self.state + 1
			local next = self:getEnumerable(self.state - 1)
			if next ~= nil then
				self.enumerator = next:getEnumerator(self.isArray)
			else
				return false
			end
		end
	end

	return false
end

function ConcatEnumerator:getEnumerator(isArray)
    local instance = (self.state == 0) and self or ConcatEnumerator.new(self.first, self.second)
    instance.isArray = isArray
    return instance
end

return {
    Enumerator = Enumerator,
    TableEnumerator = TableEnumerator,
    SelectEnumerator = SelectEnumerator,
    WhereEnumerator = WhereEnumerator,
    SelectManyEnumerator = SelectManyEnumerator,
    ConcatEnumerator = ConcatEnumerator
}