Module:FunList/Iterators: Difference between revisions

From Melvor Idle
No edit summary
(Add WhereEnumerator)
Line 25: Line 25:
-- Hooks the moveNext function into the Lua 'pairs' function
-- Hooks the moveNext function into the Lua 'pairs' function
function Enumerator:overridePairs()
function Enumerator:overridePairs()
mw.log('overriding pairs')
self.index = nil
self.index = nil
self.current = nil
self.current = nil
Line 95: Line 94:


function SelectEnumerator:moveNext()
function SelectEnumerator:moveNext()
mw.log('select movenext')
if self.state == 0 then
if self.state == 0 then
self.state = 1
self.state = 1
Line 121: Line 119:
else
else
return SelectEnumerator.new(self.source, self.selector)
return SelectEnumerator.new(self.source, self.selector)
end
end
-- WHERE ENUMERATOR --
local WhereEnumerator = setmetatable({}, { __index = Enumerator })
WhereEnumerator.__index = WhereEnumerator
WhereEnumerator.__pairs = Enumerator_mt.__pairs
function WhereEnumerator.new(source, predicate)
    local self = setmetatable(Enumerator.new(), WhereEnumerator)
    self.state = 0
    self.source = source
    self.predicate = predicate
    self.enumerator = nil
    return self
end
function WhereEnumerator:moveNext()
if self.state == 0 then
self.state = 1
self.position = 0
self.enumerator = self.source:getEnumerator()
end
if self.state == 1 then
local enumerator = self.enumerator
while enumerator:moveNext() == true do
local sourceElement = enumerator.current
if self.predicate(sourceElement) == true then
self.index = enumerator.index
self.current = sourceElement
return true
end
end
end
return false
end
function WhereEnumerator:getEnumerator()
if self.state == 0 then
return self
else
return WhereEnumerator.new(self.source, self.selector)
end
end
end
end
Line 127: Line 169:
     Enumerator = Enumerator,
     Enumerator = Enumerator,
     TableEnumerator = TableEnumerator,
     TableEnumerator = TableEnumerator,
     SelectEnumerator = SelectEnumerator
     SelectEnumerator = SelectEnumerator,
    WhereEnumerator = WhereEnumerator
}
}

Revision as of 18:34, 21 July 2024

Documentation for this module may be created at Module:FunList/Iterators/doc

-- BASE ENUMERATOR CLASS --
--enumerable = {}
local Enumerator = {}
local Enumerator_mt = {
	__index = Enumerator,
	__pairs = function(t) return t:overridePairs() end,
	--__ipairs = function(t) return t:overrideiPairs() end
}

function Enumerator.new()
	local self = setmetatable({}, Enumerator_mt)
	self.current = nil
	self.index = nil
	return self
end

function Enumerator:moveNext()
	error('Not implemented in base class.')
end

function Enumerator:getEnumerator()
	error('Not implemented in base class.')
end

-- Hooks the moveNext function into the Lua 'pairs' function
function Enumerator:overridePairs()
	self.index = nil
	self.current = nil
	local function iterator(t, k)
		if self:moveNext() == true then
			return self.index, self.current
		end
		return nil, nil
	end
	
	return iterator, self, nil
end

-- TABLE ENUMERATOR CLASS --
-- This is essentially a wrapper for the table object, 
-- since it provides no state machine esque iteration out of the box
local TableEnumerator = setmetatable({}, { __index = Enumerator })
TableEnumerator.__index = TableEnumerator
TableEnumerator.__pairs = Enumerator_mt.__pairs

function TableEnumerator.new(tbl)
    local self = setmetatable(Enumerator.new(), TableEnumerator)
    self.tbl = tbl
    return self
end

function TableEnumerator:moveNext()
	if self.state == 0 then
		self.state = 1
	end
	
	-- Grab the next index for the internal table.
	local key = self.index
	key = next(self.tbl, key)
	self.index = key
	
	-- If the index exist, we have succesfuly moved to the next.
	if key ~= nil then
		self.current = self.tbl[key]
		return true
	end
	
	return false
end

function TableEnumerator:getEnumerator()
	if self.state == 0 then
		return self
	else
		return TableEnumerator.new(self.tbl)
	end
		
end

-- SELECT ENUMERATOR --
local SelectEnumerator = setmetatable({}, { __index = Enumerator })
SelectEnumerator.__index = SelectEnumerator
SelectEnumerator.__pairs = Enumerator_mt.__pairs

function SelectEnumerator.new(source, selector)
    local self = setmetatable(Enumerator.new(), SelectEnumerator)
    self.state = 0
    self.source = source
    self.selector = selector
    self.enumerator = nil
    self.position = 0
    return self	
end

function SelectEnumerator:moveNext()
	if self.state == 0 then
		self.state = 1
		self.position = 0
		self.enumerator = self.source:getEnumerator()
	end
	
	if self.state == 1 then
		local enumerator = self.enumerator
		if enumerator:moveNext() == true then
			local sourceElement = enumerator.current
			self.index = enumerator.index -- Preserve index
			self.position = self.position + 1
			self.current = self.selector(sourceElement, self.position)
			assert(self.current, 'Selected value must be non-nil')
			return true
		end
	end
	return false
end

function SelectEnumerator:getEnumerator()
	if self.state == 0 then
		return self
	else
		return SelectEnumerator.new(self.source, self.selector)
	end
end

-- WHERE ENUMERATOR --
local WhereEnumerator = setmetatable({}, { __index = Enumerator })
WhereEnumerator.__index = WhereEnumerator
WhereEnumerator.__pairs = Enumerator_mt.__pairs

function WhereEnumerator.new(source, predicate)
    local self = setmetatable(Enumerator.new(), WhereEnumerator)
    self.state = 0
    self.source = source
    self.predicate = predicate
    self.enumerator = nil
    return self	
end

function WhereEnumerator:moveNext()
	if self.state == 0 then
		self.state = 1
		self.position = 0
		self.enumerator = self.source:getEnumerator()
	end
	
	if self.state == 1 then
		local enumerator = self.enumerator
		while enumerator:moveNext() == true do
			local sourceElement = enumerator.current
			if self.predicate(sourceElement) == true then
				self.index = enumerator.index
				self.current = sourceElement
				return true
			end
		end
	end
	
	return false
end

function WhereEnumerator:getEnumerator()
	if self.state == 0 then
		return self
	else
		return WhereEnumerator.new(self.source, self.selector)
	end
end

return {
    Enumerator = Enumerator,
    TableEnumerator = TableEnumerator,
    SelectEnumerator = SelectEnumerator,
    WhereEnumerator = WhereEnumerator
}