Module:FunList/Iterators: Difference between revisions

From Melvor Idle
No edit summary
No edit summary
Line 36: Line 36:


-- Hooks the moveNext function into the Lua 'pairs' function
-- Hooks the moveNext function into the Lua 'pairs' function
function Enumerator:overridePairs()
function Enumerator:overridePairs(startIndex)
self.index = nil
-- Get or create clean enumerator
self.current = nil
local enum = self:getEnumerator()
enum.index = startIndex
enum.current = nil
local function iterator(t, k)
local function iterator(t, k)
if self:moveNext() == true then
if enum:moveNext() == true then
return self.index, self.current
return enum.index, enum.current
end
end
return nil, nil
return nil, nil
end
end
return iterator, self, nil
return iterator, enum, startIndex
end
end



Revision as of 22:13, 21 July 2024

Documentation for this module may be created at Module:FunList/Iterators/doc

-- Helper Functions --
local function isType(obj, class)
    local mt = getmetatable(obj)
    while mt do
        if mt == class then
            return true
        end
        mt = getmetatable(mt)
    end
    return false
end

-- BASE ENUMERATOR CLASS --
--enumerable = {}
local Enumerator = {}
local Enumerator_mt = {
	__index = Enumerator,
	__pairs = function(t) return t:overridePairs() end,
	--__ipairs = function(t) return t:overrideiPairs() end
}

function Enumerator.new()
	local self = setmetatable({}, Enumerator_mt)
	self.current = nil
	self.index = nil
	return self
end

function Enumerator:moveNext()
	error('Not implemented in base class.')
end

function Enumerator:getEnumerator()
	error('Not implemented in base class.')
end

-- Hooks the moveNext function into the Lua 'pairs' function
function Enumerator:overridePairs(startIndex)
	-- Get or create clean enumerator
	local enum = self:getEnumerator()
	enum.index = startIndex
	enum.current = nil
	local function iterator(t, k)
		if enum:moveNext() == true then
			return enum.index, enum.current
		end
		return nil, nil
	end
	
	return iterator, enum, startIndex
end

-- TABLE ENUMERATOR CLASS --
-- This is essentially a wrapper for the table object, 
-- since it provides no state machine esque iteration out of the box
local TableEnumerator = setmetatable({}, { __index = Enumerator })
TableEnumerator.__index = TableEnumerator
TableEnumerator.__pairs = Enumerator_mt.__pairs

function TableEnumerator.new(tbl)
    local self = setmetatable(Enumerator.new(), TableEnumerator)
    self.tbl = tbl
    return self
end

function TableEnumerator:moveNext()
	if self.state == 0 then
		self.state = 1
	end
	
	-- Grab the next index for the internal table.
	local key = self.index
	key = next(self.tbl, key)
	self.index = key
	
	-- If the index exist, we have succesfuly moved to the next.
	if key ~= nil then
		self.current = self.tbl[key]
		return true
	end
	
	return false
end

function TableEnumerator:getEnumerator()
	if self.state == 0 then
		return self
	else
		return TableEnumerator.new(self.tbl)
	end
end

-- SELECT ENUMERATOR --
local SelectEnumerator = setmetatable({}, { __index = Enumerator })
SelectEnumerator.__index = SelectEnumerator
SelectEnumerator.__pairs = Enumerator_mt.__pairs

function SelectEnumerator.new(source, selector)
    local self = setmetatable(Enumerator.new(), SelectEnumerator)
    self.state = 0
    self.source = source
    self.selector = selector
    self.enumerator = nil
    self.position = 0
    return self	
end

function SelectEnumerator:moveNext()
	if self.state == 0 then
		self.state = 1
		self.position = 0
		self.enumerator = self.source:getEnumerator()
	end
	
	if self.state == 1 then
		local enumerator = self.enumerator
		if enumerator:moveNext() == true then
			local sourceElement = enumerator.current
			self.index = enumerator.index -- Preserve index
			self.position = self.position + 1
			self.current = self.selector(sourceElement, self.position)
			assert(self.current, 'Selected value must be non-nil')
			return true
		end
	end
	return false
end

function SelectEnumerator:getEnumerator()
	if self.state == 0 then
		return self
	else
		return SelectEnumerator.new(self.source, self.selector)
	end
end

-- WHERE ENUMERATOR --
local WhereEnumerator = setmetatable({}, { __index = Enumerator })
WhereEnumerator.__index = WhereEnumerator
WhereEnumerator.__pairs = Enumerator_mt.__pairs

function WhereEnumerator.new(source, predicate)
    local self = setmetatable(Enumerator.new(), WhereEnumerator)
    self.state = 0
    self.source = source
    self.predicate = predicate
    self.enumerator = nil
    return self	
end

function WhereEnumerator:moveNext()
	if self.state == 0 then
		self.state = 1
		self.position = 0
		self.enumerator = self.source:getEnumerator()
	end
	
	if self.state == 1 then
		local enumerator = self.enumerator
		while enumerator:moveNext() == true do
			local sourceElement = enumerator.current
			if self.predicate(sourceElement) == true then
				self.index = enumerator.index
				self.current = sourceElement
				return true
			end
		end
	end
	
	return false
end

function WhereEnumerator:getEnumerator()
	if self.state == 0 then
		return self
	else
		return WhereEnumerator.new(self.source, self.predicate)
	end
end

-- WHERE ENUMERATOR --
local SelectManyEnumerator = setmetatable({}, { __index = Enumerator })
SelectManyEnumerator.__index = SelectManyEnumerator
SelectManyEnumerator.__pairs = Enumerator_mt.__pairs

function SelectManyEnumerator.new(source, selector)
    local self = setmetatable(Enumerator.new(), SelectManyEnumerator)
    self.state = 0
    self.source = source
    self.selector = selector
    self.position = 0
    self.enumerator = nil		 -- Enumerator of the source Enumerable
    self.sourceEnumerator = nil  -- Enumerator of the nested Enumerable
    return self	
end

function SelectManyEnumerator:moveNext()
    while true do
        -- Setup state
        if self.state == 0 then
            self.position = 0
            self.enumerator = self.source:getEnumerator()
            self.state = -3 -- signal to get (first) nested enumerator
        end

        -- Grab next value from nested enumerator		
        if self.state == -4 then
            if self.sourceEnumerator:moveNext() then
                self.current = self.sourceEnumerator.current
                self.index = self.sourceEnumerator.index
                self.state = -4 -- signal to get next item
                return true
            else
                self.state = -3 -- signal to get next enumerator
            end
        end
		
		-- Grab nest nested enumerator
        if self.state == -3 then
            if self.enumerator:moveNext() then
                local current = self.enumerator.current
                self.position = self.position + 1

                local sourceTable = self.selector(current, self.position)
                if not isType(sourceTable, Enumerator) then
                	-- We need to turn the nested table into an enumerator
                	self.sourceEnumerator = TableEnumerator.new(sourceTable):getEnumerator()
                else
                	self.sourceEnumerator = sourceTable
                end
                self.state = -4 -- signal to get next item
            else
            	-- enumerator doesn't have any more nested enumerators.
                return false
            end
        end
    end
end


function SelectManyEnumerator:getEnumerator()
	if self.state == 0 then
		return self
	else
		return SelectManyEnumerator.new(self.source, self.predicate)
	end
end

return {
    Enumerator = Enumerator,
    TableEnumerator = TableEnumerator,
    SelectEnumerator = SelectEnumerator,
    WhereEnumerator = WhereEnumerator,
    SelectManyEnumerator = SelectManyEnumerator
}