Skip to content

Instantly share code, notes, and snippets.

@TheGreatSageEqualToHeaven
Created April 24, 2026 21:35
Show Gist options
  • Select an option

  • Save TheGreatSageEqualToHeaven/2725b1ab9ffb456d3499a97342324636 to your computer and use it in GitHub Desktop.

Select an option

Save TheGreatSageEqualToHeaven/2725b1ab9ffb456d3499a97342324636 to your computer and use it in GitHub Desktop.
evil test
--!nonstrict
-- This file is an intentionally small Luau-shaped mirror of the Ast and Compiler
-- projects. It keeps familiar public structures while emitting Luau bytecode
-- directly in Luau, so it can run in Studio without Lune or native compiler
-- bindings. The CLI self-test can optionally use Lune to load the produced
-- bytecode, but bytecode emission itself is pure Luau.
local Luau = {}
Luau._VERSION = "lcompile.luau"
local LBF
local function shallowCopy(source)
local result = {}
if source then
for key, value in pairs(source) do
result[key] = value
end
end
return result
end
local function append(list, value)
list[#list + 1] = value
return value
end
local function startsWith(value, prefix)
return string.sub(value, 1, #prefix) == prefix
end
local function trimStackTrace(message)
local newline = string.find(message, "\n", 1, true)
return newline and string.sub(message, 1, newline - 1) or message
end
local function byte(value, index)
return string.byte(value, index or 1) or 0
end
local function isAlphaCode(code)
return (code >= 65 and code <= 90) or (code >= 97 and code <= 122)
end
local function isDigitCode(code)
return code >= 48 and code <= 57
end
local function isHexDigitCode(code)
return isDigitCode(code) or (code >= 65 and code <= 70) or (code >= 97 and code <= 102)
end
local function isNameStartCode(code)
return isAlphaCode(code) or code == 95
end
local function isNameContinueCode(code)
return isNameStartCode(code) or isDigitCode(code)
end
local function isSpaceCode(code)
return code == 9 or code == 10 or code == 11 or code == 12 or code == 13 or code == 32
end
local function clampByte(value)
if value < 0 then
return 0
elseif value > 255 then
return 255
else
return value
end
end
local function tryRequire(name)
local ok, result = pcall(require, name)
if ok then
return result
end
return nil, result
end
-- Positions and locations intentionally use the same zero-based line/column
-- convention as the C++ parser.
local Position = {}
Position.__index = Position
function Position.new(line, column)
return setmetatable({
line = line or 0,
column = column or 0,
}, Position)
end
function Position:clone()
return Position.new(self.line, self.column)
end
function Position:offset(columns)
return Position.new(self.line, self.column + columns)
end
Luau.Position = Position
local Location = {}
Location.__index = Location
function Location.new(beginPosition, endPosition)
return setmetatable({
begin = beginPosition or Position.new(),
["end"] = endPosition or beginPosition or Position.new(),
}, Location)
end
function Location.empty()
local position = Position.new()
return Location.new(position, position)
end
function Location:extend(other)
return Location.new(self.begin, other["end"])
end
function Location:contains(position)
if position.line < self.begin.line or position.line > self["end"].line then
return false
end
if position.line == self.begin.line and position.column < self.begin.column then
return false
end
if position.line == self["end"].line and position.column > self["end"].column then
return false
end
return true
end
Luau.Location = Location
local function lineCount(source)
local count = 1
for index = 1, #source do
local code = byte(source, index)
if code == 10 then
count += 1
end
end
return count
end
local function positionFromOffset(source, targetOffset)
local line = 0
local column = 0
local index = 1
while index < targetOffset and index <= #source do
local code = byte(source, index)
if code == 13 then
if byte(source, index + 1) == 10 then
index += 1
end
line += 1
column = 0
elseif code == 10 then
line += 1
column = 0
else
column += 1
end
index += 1
end
return Position.new(line, column)
end
local CompileOptions = {}
CompileOptions.__index = CompileOptions
local defaultCompileOptions = {
optimizationLevel = 1,
debugLevel = 1,
typeInfoLevel = 0,
coverageLevel = 0,
vectorLib = nil,
vectorCtor = nil,
vectorType = nil,
mutableGlobals = nil,
userdataTypes = nil,
librariesWithKnownMembers = nil,
libraryMemberTypeCb = nil,
libraryMemberConstantCb = nil,
disabledBuiltins = nil,
}
function CompileOptions.new(options)
local result = shallowCopy(defaultCompileOptions)
if options then
for key, value in pairs(options) do
result[key] = value
end
if result.optimization ~= nil and result.optimizationLevel == nil then
result.optimizationLevel = result.optimization
end
if result.debug ~= nil and result.debugLevel == nil then
result.debugLevel = result.debug
end
if result.coverage ~= nil and result.coverageLevel == nil then
result.coverageLevel = result.coverage
end
end
return setmetatable(result, CompileOptions)
end
function CompileOptions:toBackendOptions()
local result = {
optimizationLevel = self.optimizationLevel,
debugLevel = self.debugLevel,
typeInfoLevel = self.typeInfoLevel,
coverageLevel = self.coverageLevel,
}
if self.vectorLib ~= nil then
result.vectorLib = self.vectorLib
end
if self.vectorCtor ~= nil then
result.vectorCtor = self.vectorCtor
end
if self.vectorType ~= nil then
result.vectorType = self.vectorType
end
if self.mutableGlobals ~= nil then
result.mutableGlobals = self.mutableGlobals
end
if self.userdataTypes ~= nil then
result.userdataTypes = self.userdataTypes
end
if self.librariesWithKnownMembers ~= nil then
result.librariesWithKnownMembers = self.librariesWithKnownMembers
end
if self.disabledBuiltins ~= nil then
result.disabledBuiltins = self.disabledBuiltins
end
return result
end
Luau.CompileOptions = CompileOptions
local ParseOptions = {}
ParseOptions.__index = ParseOptions
local defaultParseOptions = {
captureComments = true,
allowDeclarationSyntax = true,
storeCstData = false,
}
function ParseOptions.new(options)
local result = shallowCopy(defaultParseOptions)
if options then
for key, value in pairs(options) do
result[key] = value
end
end
return setmetatable(result, ParseOptions)
end
Luau.ParseOptions = ParseOptions
local ParseError = {}
ParseError.__index = ParseError
function ParseError.new(location, message)
return setmetatable({
location = location,
message = message,
}, ParseError)
end
function ParseError:what()
return self.message
end
function ParseError:getLocation()
return self.location
end
Luau.ParseError = ParseError
local CompileError = {}
CompileError.__index = CompileError
function CompileError.new(location, message)
return setmetatable({
location = location,
message = message,
}, CompileError)
end
function CompileError:what()
return self.message
end
function CompileError:getLocation()
return self.location
end
Luau.CompileError = CompileError
local Allocator = {}
Allocator.__index = Allocator
function Allocator.new()
return setmetatable({
nodes = {},
}, Allocator)
end
function Allocator:alloc(node)
append(self.nodes, node)
return node
end
Luau.Allocator = Allocator
local Lexeme = {}
Lexeme.__index = Lexeme
Lexeme.Type = {
Eof = 0,
Char_END = 256,
Equal = 257,
LessEqual = 258,
GreaterEqual = 259,
NotEqual = 260,
Dot2 = 261,
Dot3 = 262,
SkinnyArrow = 263,
DoubleColon = 264,
FloorDiv = 265,
InterpStringBegin = 266,
InterpStringMid = 267,
InterpStringEnd = 268,
InterpStringSimple = 269,
AddAssign = 270,
SubAssign = 271,
MulAssign = 272,
DivAssign = 273,
FloorDivAssign = 274,
ModAssign = 275,
PowAssign = 276,
ConcatAssign = 277,
RawString = 278,
QuotedString = 279,
Number = 280,
Name = 281,
Comment = 282,
BlockComment = 283,
Attribute = 284,
AttributeOpen = 285,
BrokenString = 286,
BrokenComment = 287,
BrokenUnicode = 288,
BrokenInterpDoubleBrace = 289,
Error = 290,
Reserved_BEGIN = 291,
ReservedAnd = 291,
ReservedBreak = 292,
ReservedDo = 293,
ReservedElse = 294,
ReservedElseif = 295,
ReservedEnd = 296,
ReservedFalse = 297,
ReservedFor = 298,
ReservedFunction = 299,
ReservedIf = 300,
ReservedIn = 301,
ReservedLocal = 302,
ReservedNil = 303,
ReservedNot = 304,
ReservedOr = 305,
ReservedRepeat = 306,
ReservedReturn = 307,
ReservedThen = 308,
ReservedTrue = 309,
ReservedUntil = 310,
ReservedWhile = 311,
Reserved_END = 312,
}
Lexeme.TypeName = {}
for name, value in pairs(Lexeme.Type) do
Lexeme.TypeName[value] = name
end
local reservedTypes = {
["and"] = Lexeme.Type.ReservedAnd,
["break"] = Lexeme.Type.ReservedBreak,
["do"] = Lexeme.Type.ReservedDo,
["else"] = Lexeme.Type.ReservedElse,
["elseif"] = Lexeme.Type.ReservedElseif,
["end"] = Lexeme.Type.ReservedEnd,
["false"] = Lexeme.Type.ReservedFalse,
["for"] = Lexeme.Type.ReservedFor,
["function"] = Lexeme.Type.ReservedFunction,
["if"] = Lexeme.Type.ReservedIf,
["in"] = Lexeme.Type.ReservedIn,
["local"] = Lexeme.Type.ReservedLocal,
["nil"] = Lexeme.Type.ReservedNil,
["not"] = Lexeme.Type.ReservedNot,
["or"] = Lexeme.Type.ReservedOr,
["repeat"] = Lexeme.Type.ReservedRepeat,
["return"] = Lexeme.Type.ReservedReturn,
["then"] = Lexeme.Type.ReservedThen,
["true"] = Lexeme.Type.ReservedTrue,
["until"] = Lexeme.Type.ReservedUntil,
["while"] = Lexeme.Type.ReservedWhile,
}
function Lexeme.new(location, lexemeType, data, length)
return setmetatable({
type = lexemeType,
location = location,
data = data,
name = data,
length = length or (type(data) == "string" and #data or 0),
}, Lexeme)
end
function Lexeme:getLength()
return self.length or 0
end
function Lexeme:toString()
if self.type == Lexeme.Type.Eof then
return "<eof>"
end
if self.type > 0 and self.type < Lexeme.Type.Char_END then
return string.char(self.type)
end
if self.type == Lexeme.Type.Name then
return self.name or self.data or "<name>"
end
if self.type == Lexeme.Type.Number then
return self.data or "<number>"
end
if self.type == Lexeme.Type.QuotedString or self.type == Lexeme.Type.RawString then
return self.data or "<string>"
end
return Lexeme.TypeName[self.type] or "<lexeme>"
end
Luau.Lexeme = Lexeme
local AstNameTable = {}
AstNameTable.__index = AstNameTable
function AstNameTable.new()
local result = setmetatable({
data = {},
reserved = reservedTypes,
}, AstNameTable)
for name, lexemeType in pairs(reservedTypes) do
result:addStatic(name, lexemeType)
end
return result
end
function AstNameTable:addStatic(name, lexemeType)
local entry = {
value = name,
type = lexemeType or Lexeme.Type.Name,
}
self.data[name] = entry
return name
end
function AstNameTable:getOrAddWithType(name, length)
local key = string.sub(name, 1, length or #name)
local entry = self.data[key]
if not entry then
entry = {
value = key,
type = reservedTypes[key] or Lexeme.Type.Name,
}
self.data[key] = entry
end
return entry.value, entry.type
end
function AstNameTable:getWithType(name, length)
local key = string.sub(name, 1, length or #name)
local entry = self.data[key]
if entry then
return entry.value, entry.type
end
return nil, Lexeme.Type.Name
end
function AstNameTable:getOrAdd(name, length)
local value = self:getOrAddWithType(name, length)
return value
end
function AstNameTable:get(name)
local entry = self.data[name]
return entry and entry.value or nil
end
Luau.AstNameTable = AstNameTable
local Lexer = {}
Lexer.__index = Lexer
function Lexer.new(buffer, bufferSize, names, startPosition)
local source = bufferSize and string.sub(buffer, 1, bufferSize) or buffer
local position = startPosition or Position.new()
return setmetatable({
buffer = source,
bufferSize = #source,
offset = 1,
line = position.line,
lineOffset = 1 - position.column,
lexeme = Lexeme.new(Location.empty(), Lexeme.Type.Eof),
prevLocation = Location.empty(),
names = names or AstNameTable.new(),
skipComments = true,
readNames = true,
braceStack = {},
}, Lexer)
end
function Lexer.setSkipComments(self, skip)
self.skipComments = skip
end
function Lexer.setReadNames(self, read)
self.readNames = read
end
function Lexer.current(self)
return self.lexeme
end
function Lexer.previousLocation(self)
return self.prevLocation
end
function Lexer.getOffset(self)
return self.offset
end
function Lexer.position(self)
return Position.new(self.line, self.offset - self.lineOffset)
end
function Lexer.peekch(self, lookahead)
local index = self.offset + (lookahead or 0)
if index > self.bufferSize then
return ""
end
return string.sub(self.buffer, index, index)
end
function Lexer.peekCode(self, lookahead)
return byte(self.buffer, self.offset + (lookahead or 0))
end
function Lexer.consume(self)
self.offset += 1
end
function Lexer.consumeAny(self)
local code = self:peekCode()
if code == 13 then
self.offset += 1
if self:peekCode() == 10 then
self.offset += 1
end
self.line += 1
self.lineOffset = self.offset
elseif code == 10 then
self.offset += 1
self.line += 1
self.lineOffset = self.offset
elseif code ~= 0 then
self.offset += 1
end
end
function Lexer.nextline(self)
while self.offset <= self.bufferSize do
local code = self:peekCode()
self:consumeAny()
if code == 10 or code == 13 then
break
end
end
end
function Lexer.makeLocation(self, startPosition)
return Location.new(startPosition, self:position())
end
function Lexer.skipLongSeparator(self)
local start = self:peekch()
local count = 0
if start ~= "[" and start ~= "]" then
return -1
end
local index = 1
while self:peekch(index) == "=" do
count += 1
index += 1
end
if self:peekch(index) == start then
return count
end
return -1
end
function Lexer.readLongString(self, startPosition, separator, okType, brokenType)
local openLength = separator + 2
for _ = 1, openLength do
self:consumeAny()
end
if self:peekCode() == 10 or self:peekCode() == 13 then
self:consumeAny()
end
local dataStart = self.offset
while self.offset <= self.bufferSize do
if self:peekch() == "]" then
local found = true
for index = 1, separator do
if self:peekch(index) ~= "=" then
found = false
break
end
end
if found and self:peekch(separator + 1) == "]" then
local dataEnd = self.offset - 1
for _ = 1, openLength do
self:consumeAny()
end
local data = string.sub(self.buffer, dataStart, dataEnd)
return Lexeme.new(self:makeLocation(startPosition), okType, data, #data)
end
end
self:consumeAny()
end
local data = string.sub(self.buffer, dataStart)
return Lexeme.new(self:makeLocation(startPosition), brokenType, data, #data)
end
function Lexer.readBackslashInString(self)
self:consumeAny()
if self.offset > self.bufferSize then
return
end
local code = self:peekCode()
if code == 13 or code == 10 then
self:consumeAny()
elseif code == byte("z") then
self:consume()
while self.offset <= self.bufferSize and isSpaceCode(self:peekCode()) do
self:consumeAny()
end
elseif code == byte("u") and self:peekch(1) == "{" then
self:consume()
self:consume()
while self.offset <= self.bufferSize do
local ch = self:peekch()
self:consume()
if ch == "}" then
break
end
end
else
self:consume()
end
end
local escapeMap = {
a = "\a",
b = "\b",
f = "\f",
n = "\n",
r = "\r",
t = "\t",
v = "\v",
["\\"] = "\\",
['"'] = '"',
["'"] = "'",
}
local function utf8Char(codepoint)
if codepoint < 0x80 then
return string.char(codepoint)
elseif codepoint < 0x800 then
return string.char(
0xC0 + math.floor(codepoint / 0x40),
0x80 + codepoint % 0x40
)
elseif codepoint < 0x10000 then
return string.char(
0xE0 + math.floor(codepoint / 0x1000),
0x80 + math.floor(codepoint / 0x40) % 0x40,
0x80 + codepoint % 0x40
)
else
return string.char(
0xF0 + math.floor(codepoint / 0x40000),
0x80 + math.floor(codepoint / 0x1000) % 0x40,
0x80 + math.floor(codepoint / 0x40) % 0x40,
0x80 + codepoint % 0x40
)
end
end
function Lexer.fixupQuotedString(data)
local output = {}
local index = 1
while index <= #data do
local ch = string.sub(data, index, index)
if ch ~= "\\" then
append(output, ch)
index += 1
else
local escape = string.sub(data, index + 1, index + 1)
if escape == "" then
append(output, "\\")
index += 1
elseif escapeMap[escape] then
append(output, escapeMap[escape])
index += 2
elseif escape == "z" then
index += 2
while index <= #data and isSpaceCode(byte(data, index)) do
index += 1
end
elseif isDigitCode(byte(escape)) then
local digits = escape
local digitIndex = index + 2
while #digits < 3 and isDigitCode(byte(data, digitIndex)) do
digits ..= string.sub(data, digitIndex, digitIndex)
digitIndex += 1
end
append(output, string.char(clampByte(tonumber(digits) or 0)))
index = digitIndex
elseif escape == "x" and isHexDigitCode(byte(data, index + 2)) and isHexDigitCode(byte(data, index + 3)) then
local hex = string.sub(data, index + 2, index + 3)
append(output, string.char(tonumber(hex, 16) or 0))
index += 4
elseif escape == "u" and string.sub(data, index + 2, index + 2) == "{" then
local close = string.find(data, "}", index + 3, true)
if close then
local hex = string.sub(data, index + 3, close - 1)
local codepoint = tonumber(hex, 16)
if codepoint and codepoint <= 0x10FFFF then
append(output, utf8Char(codepoint))
index = close + 1
else
append(output, "u")
index += 2
end
else
append(output, "u")
index += 2
end
else
append(output, escape)
index += 2
end
end
end
return table.concat(output)
end
function Lexer.fixupMultilineString(data)
if startsWith(data, "\r\n") then
return string.sub(data, 3)
elseif startsWith(data, "\n") or startsWith(data, "\r") then
return string.sub(data, 2)
else
return data
end
end
function Lexer.readQuotedString(self)
local quote = self:peekch()
local startPosition = self:position()
self:consume()
local dataStart = self.offset
local pieces = {}
while self.offset <= self.bufferSize do
local ch = self:peekch()
local code = self:peekCode()
if ch == quote then
append(pieces, string.sub(self.buffer, dataStart, self.offset - 1))
self:consume()
local raw = table.concat(pieces)
local fixed = Lexer.fixupQuotedString(raw)
return Lexeme.new(self:makeLocation(startPosition), Lexeme.Type.QuotedString, fixed, #fixed)
elseif ch == "\\" then
append(pieces, string.sub(self.buffer, dataStart, self.offset - 1))
local escapeStart = self.offset
self:readBackslashInString()
append(pieces, string.sub(self.buffer, escapeStart, self.offset - 1))
dataStart = self.offset
elseif code == 10 or code == 13 then
append(pieces, string.sub(self.buffer, dataStart, self.offset - 1))
local raw = table.concat(pieces)
return Lexeme.new(self:makeLocation(startPosition), Lexeme.Type.BrokenString, raw, #raw)
else
self:consume()
end
end
append(pieces, string.sub(self.buffer, dataStart))
local raw = table.concat(pieces)
return Lexeme.new(self:makeLocation(startPosition), Lexeme.Type.BrokenString, raw, #raw)
end
function Lexer.readInterpolatedString(self)
local startPosition = self:position()
local dataStart = self.offset + 1
local hasExpression = false
self:consume()
local skipExpression
local function skipBody()
while self.offset <= self.bufferSize do
local ch = self:peekch()
local code = self:peekCode()
if ch == "\\" then
self:readBackslashInString()
elseif ch == "`" then
self:consume()
return true
elseif ch == "{" then
hasExpression = true
self:consume()
skipExpression()
elseif code == 10 or code == 13 then
self:consumeAny()
else
self:consume()
end
end
return false
end
skipExpression = function()
local depth = 1
while self.offset <= self.bufferSize and depth > 0 do
local ch = self:peekch()
if ch == "'" or ch == '"' then
self:readQuotedString()
elseif ch == "`" then
self:consume()
skipBody()
elseif ch == "{" then
depth += 1
self:consume()
elseif ch == "}" then
depth -= 1
self:consume()
else
self:consumeAny()
end
end
end
if skipBody() then
local data = string.sub(self.buffer, dataStart, self.offset - 2)
return Lexeme.new(self:makeLocation(startPosition), hasExpression and Lexeme.Type.InterpStringEnd or Lexeme.Type.InterpStringSimple, data, #data)
end
local data = string.sub(self.buffer, dataStart)
return Lexeme.new(self:makeLocation(startPosition), Lexeme.Type.BrokenString, data, #data)
end
function Lexer.readName(self)
local start = self.offset
while isNameContinueCode(self:peekCode()) do
self:consume()
end
local text = string.sub(self.buffer, start, self.offset - 1)
return self.names:getOrAddWithType(text, #text)
end
function Lexer.readNumber(self, startPosition, startOffset)
local sawExponent = false
if self:peekch() == "0" and (self:peekch(1) == "x" or self:peekch(1) == "X") then
self:consume()
self:consume()
while isHexDigitCode(self:peekCode()) or self:peekch() == "." or self:peekch() == "_" do
self:consume()
end
if self:peekch() == "p" or self:peekch() == "P" then
self:consume()
if self:peekch() == "+" or self:peekch() == "-" then
self:consume()
end
while isDigitCode(self:peekCode()) or self:peekch() == "_" do
self:consume()
end
end
elseif self:peekch() == "0" and (self:peekch(1) == "b" or self:peekch(1) == "B") then
self:consume()
self:consume()
while self:peekch() == "0" or self:peekch() == "1" or self:peekch() == "_" do
self:consume()
end
else
while self.offset <= self.bufferSize do
local ch = self:peekch()
local code = self:peekCode()
if isDigitCode(code) or ch == "_" then
self:consume()
elseif ch == "." and self:peekch(1) ~= "." then
self:consume()
elseif (ch == "e" or ch == "E") and not sawExponent then
sawExponent = true
self:consume()
if self:peekch() == "+" or self:peekch() == "-" then
self:consume()
end
else
break
end
end
end
while isNameContinueCode(self:peekCode()) do
self:consume()
end
local text = string.sub(self.buffer, startOffset, self.offset - 1)
return Lexeme.new(self:makeLocation(startPosition), Lexeme.Type.Number, text, #text)
end
function Lexer.readCommentBody(self)
local startPosition = self:position()
self:consume()
self:consume()
if self:peekch() == "[" then
local separator = self:skipLongSeparator()
if separator >= 0 then
return self:readLongString(startPosition, separator, Lexeme.Type.BlockComment, Lexeme.Type.BrokenComment)
end
end
local dataStart = self.offset
while self.offset <= self.bufferSize do
local code = self:peekCode()
if code == 10 or code == 13 then
break
end
self:consume()
end
local data = string.sub(self.buffer, dataStart, self.offset - 1)
return Lexeme.new(self:makeLocation(startPosition), Lexeme.Type.Comment, data, #data)
end
function Lexer.readNext(self)
while self.offset <= self.bufferSize do
local code = self:peekCode()
if isSpaceCode(code) then
self:consumeAny()
else
break
end
end
local startPosition = self:position()
if self.offset > self.bufferSize then
return Lexeme.new(Location.new(startPosition, startPosition), Lexeme.Type.Eof)
end
local ch = self:peekch()
local nextCh = self:peekch(1)
local startOffset = self.offset
if isNameStartCode(self:peekCode()) then
local name, lexemeType = self:readName()
return Lexeme.new(self:makeLocation(startPosition), lexemeType, name, #name)
end
if isDigitCode(self:peekCode()) or (ch == "." and isDigitCode(self:peekCode(1))) then
return self:readNumber(startPosition, startOffset)
end
if ch == "'" or ch == '"' then
return self:readQuotedString()
end
if ch == "`" then
return self:readInterpolatedString()
end
if ch == "[" then
local separator = self:skipLongSeparator()
if separator >= 0 then
return self:readLongString(startPosition, separator, Lexeme.Type.RawString, Lexeme.Type.BrokenString)
end
end
if ch == "-" and nextCh == "-" then
return self:readCommentBody()
end
local two = ch .. nextCh
local three = two .. self:peekch(2)
if three == "..." then
self:consume()
self:consume()
self:consume()
return Lexeme.new(self:makeLocation(startPosition), Lexeme.Type.Dot3, three, 3)
end
local compoundTypes = {
["=="] = Lexeme.Type.Equal,
["<="] = Lexeme.Type.LessEqual,
[">="] = Lexeme.Type.GreaterEqual,
["~="] = Lexeme.Type.NotEqual,
[".."] = Lexeme.Type.Dot2,
["->"] = Lexeme.Type.SkinnyArrow,
["::"] = Lexeme.Type.DoubleColon,
["//"] = Lexeme.Type.FloorDiv,
["+="] = Lexeme.Type.AddAssign,
["-="] = Lexeme.Type.SubAssign,
["*="] = Lexeme.Type.MulAssign,
["/="] = Lexeme.Type.DivAssign,
["%="] = Lexeme.Type.ModAssign,
["^="] = Lexeme.Type.PowAssign,
["..="] = Lexeme.Type.ConcatAssign,
}
if three == "..=" then
self:consume()
self:consume()
self:consume()
return Lexeme.new(self:makeLocation(startPosition), Lexeme.Type.ConcatAssign, three, 3)
end
if two == "//" and self:peekch(2) == "=" then
self:consume()
self:consume()
self:consume()
return Lexeme.new(self:makeLocation(startPosition), Lexeme.Type.FloorDivAssign, "//=", 3)
end
local compound = compoundTypes[two]
if compound then
self:consume()
self:consume()
return Lexeme.new(self:makeLocation(startPosition), compound, two, 2)
end
if ch == "@" then
self:consume()
return Lexeme.new(self:makeLocation(startPosition), Lexeme.Type.AttributeOpen, ch, 1)
end
self:consume()
return Lexeme.new(self:makeLocation(startPosition), byte(ch), ch, 1)
end
function Lexer.next(self, skipComments, updatePrevLocation)
local skip = if skipComments == nil then self.skipComments else skipComments
local update = if updatePrevLocation == nil then true else updatePrevLocation
local result = self:readNext()
while skip and (result.type == Lexeme.Type.Comment or result.type == Lexeme.Type.BlockComment) do
result = self:readNext()
end
if update then
self.prevLocation = self.lexeme.location
end
self.lexeme = result
return result
end
function Lexer.lookahead(self)
local state = {
offset = self.offset,
line = self.line,
lineOffset = self.lineOffset,
lexeme = self.lexeme,
prevLocation = self.prevLocation,
}
local result = self:next()
self.offset = state.offset
self.line = state.line
self.lineOffset = state.lineOffset
self.lexeme = state.lexeme
self.prevLocation = state.prevLocation
return result
end
function Lexer.isReserved(word)
return reservedTypes[word] ~= nil
end
function Lexer.peekBraceStackTop(self)
return self.braceStack[#self.braceStack]
end
Luau.Lexer = Lexer
local Ast = {}
local rttiCounter = 0
local function astKind(name)
rttiCounter += 1
Ast[name] = rttiCounter
return rttiCounter
end
local function astNode(kind, location, fields)
local result = fields or {}
result.classIndex = Ast[kind] or astKind(kind)
result.kind = kind
result.location = location or Location.empty()
return result
end
Ast.Kind = Ast
function Ast.Name(value)
return {
value = value,
}
end
function Ast.Array(values)
return {
data = values or {},
size = values and #values or 0,
}
end
function Ast.Local(name, location, shadow, functionDepth, loopDepth, annotation, isConst)
return {
name = name,
location = location or Location.empty(),
shadow = shadow,
functionDepth = functionDepth or 0,
loopDepth = loopDepth or 0,
isConst = isConst == true,
annotation = annotation,
}
end
function Ast.StatBlock(location, body)
return astNode("AstStatBlock", location, {
body = body or {},
})
end
function Ast.StatExpr(location, expression)
return astNode("AstStatExpr", location, {
expression = expression,
})
end
function Ast.StatLocal(location, vars, values)
return astNode("AstStatLocal", location, {
vars = vars or {},
values = values or {},
})
end
function Ast.StatReturn(location, list)
return astNode("AstStatReturn", location, {
list = list or {},
})
end
function Ast.ExprConstantNil(location)
return astNode("AstExprConstantNil", location)
end
function Ast.ExprConstantBool(location, value)
return astNode("AstExprConstantBool", location, {
value = value == true,
})
end
function Ast.ExprConstantNumber(location, value, text)
return astNode("AstExprConstantNumber", location, {
value = value,
text = text,
})
end
function Ast.ExprConstantString(location, value, quoteStyle)
return astNode("AstExprConstantString", location, {
value = value,
quoteStyle = quoteStyle,
})
end
function Ast.ExprLocal(location, local_)
return astNode("AstExprLocal", location, {
local_ = local_,
})
end
function Ast.ExprGlobal(location, name)
return astNode("AstExprGlobal", location, {
name = name,
})
end
Luau.Ast = Ast
local function makeLiteralFromLexeme(lexeme)
if lexeme.type == Lexeme.Type.ReservedNil then
return Ast.ExprConstantNil(lexeme.location)
elseif lexeme.type == Lexeme.Type.ReservedTrue then
return Ast.ExprConstantBool(lexeme.location, true)
elseif lexeme.type == Lexeme.Type.ReservedFalse then
return Ast.ExprConstantBool(lexeme.location, false)
elseif lexeme.type == Lexeme.Type.Number then
local cleaned = string.gsub(lexeme.data or "", "_", "")
return Ast.ExprConstantNumber(lexeme.location, tonumber(cleaned), lexeme.data)
elseif lexeme.type == Lexeme.Type.QuotedString or lexeme.type == Lexeme.Type.RawString or lexeme.type == Lexeme.Type.InterpStringSimple then
return Ast.ExprConstantString(lexeme.location, lexeme.data or "", nil)
elseif lexeme.type == Lexeme.Type.Name then
return Ast.ExprGlobal(lexeme.location, lexeme.name)
end
return nil
end
local Parser = {}
Parser.__index = Parser
function Parser.new(buffer, bufferSize, names, allocator, options)
return setmetatable({
buffer = bufferSize and string.sub(buffer, 1, bufferSize) or buffer,
names = names or AstNameTable.new(),
allocator = allocator or Allocator.new(),
options = ParseOptions.new(options),
}, Parser)
end
function Parser:parseChunk()
local lexer = Lexer.new(self.buffer, #self.buffer, self.names)
lexer:setSkipComments(false)
local body = {}
local comments = {}
local tokens = {}
while true do
local lexeme = lexer:next(false, true)
append(tokens, lexeme)
if lexeme.type == Lexeme.Type.Eof then
break
elseif lexeme.type == Lexeme.Type.Comment or lexeme.type == Lexeme.Type.BlockComment or lexeme.type == Lexeme.Type.BrokenComment then
append(comments, {
type = lexeme.type,
location = lexeme.location,
content = lexeme.data or "",
})
else
local expression = makeLiteralFromLexeme(lexeme)
if expression then
append(body, Ast.StatExpr(lexeme.location, expression))
end
end
end
local rootLocation = Location.new(Position.new(0, 0), positionFromOffset(self.buffer, #self.buffer + 1))
local root = self.allocator:alloc(Ast.StatBlock(rootLocation, body))
return root, tokens, comments
end
function Parser.parse(buffer, bufferSize, names, allocator, options)
local parser = Parser.new(buffer, bufferSize, names, allocator, options)
local root, tokens, comments = parser:parseChunk()
return {
root = root,
lines = lineCount(parser.buffer),
tokens = tokens,
hotcomments = {},
errors = {},
commentLocations = comments,
cstNodeMap = {},
}
end
function Parser.parseExpr(buffer, bufferSize, names, allocator, options)
local result = Parser.parse(buffer, bufferSize, names, allocator, options)
local expression = nil
if result.root and result.root.body and result.root.body[1] then
expression = result.root.body[1].expression
end
return {
root = expression,
lines = result.lines,
hotcomments = result.hotcomments,
errors = result.errors,
commentLocations = result.commentLocations,
cstNodeMap = result.cstNodeMap,
}
end
Luau.Parser = Parser
local BytecodeBuilder = {}
BytecodeBuilder.__index = BytecodeBuilder
BytecodeBuilder.DumpFlags = {
Dump_Code = 1,
Dump_Lines = 2,
Dump_Source = 4,
Dump_Locals = 8,
Dump_Remarks = 16,
Dump_Types = 32,
Dump_Constants = 64,
}
function BytecodeBuilder.new()
return setmetatable({
bytecode = "",
functions = {},
strings = {},
constants = {},
mainFunction = 0,
dumpFlags = 0,
source = nil,
}, BytecodeBuilder)
end
function BytecodeBuilder:getBytecode()
return self.bytecode
end
function BytecodeBuilder:setBytecode(bytecodeBlob)
self.bytecode = bytecodeBlob
end
function BytecodeBuilder:setMainFunction(functionId)
self.mainFunction = functionId
end
function BytecodeBuilder:setDumpFlags(flags)
self.dumpFlags = flags
end
function BytecodeBuilder:setDumpSource(source)
self.source = source
end
function BytecodeBuilder:finalize(bytecodeBlob)
if bytecodeBlob ~= nil then
self.bytecode = bytecodeBlob
end
return self.bytecode
end
function BytecodeBuilder.getError(message)
return "\0" .. message
end
function BytecodeBuilder.isError(bytecodeBlob)
return #bytecodeBlob > 0 and byte(bytecodeBlob, 1) == 0
end
function BytecodeBuilder.getVersion(bytecodeBlob)
if bytecodeBlob and #bytecodeBlob > 0 then
return byte(bytecodeBlob, 1)
end
return nil
end
function BytecodeBuilder.getTypeEncodingVersion(bytecodeBlob)
if bytecodeBlob and #bytecodeBlob > 1 then
return byte(bytecodeBlob, 2)
end
return nil
end
Luau.BytecodeBuilder = BytecodeBuilder
local BytecodeReader = {}
function BytecodeReader.summary(bytecodeBlob)
local version = BytecodeBuilder.getVersion(bytecodeBlob)
if version == 0 then
return {
version = 0,
error = string.sub(bytecodeBlob, 2),
size = #bytecodeBlob,
}
end
return {
version = version,
typeVersion = BytecodeBuilder.getTypeEncodingVersion(bytecodeBlob),
size = #bytecodeBlob,
}
end
Luau.BytecodeReader = BytecodeReader
local LOP = {
NOP = 0,
BREAK = 1,
LOADNIL = 2,
LOADB = 3,
LOADN = 4,
LOADK = 5,
MOVE = 6,
GETGLOBAL = 7,
SETGLOBAL = 8,
GETUPVAL = 9,
SETUPVAL = 10,
CLOSEUPVALS = 11,
GETIMPORT = 12,
GETTABLE = 13,
SETTABLE = 14,
GETTABLEKS = 15,
SETTABLEKS = 16,
GETTABLEN = 17,
SETTABLEN = 18,
NEWCLOSURE = 19,
NAMECALL = 20,
CALL = 21,
RETURN = 22,
JUMP = 23,
JUMPBACK = 24,
JUMPIF = 25,
JUMPIFNOT = 26,
JUMPIFEQ = 27,
JUMPIFLE = 28,
JUMPIFLT = 29,
JUMPIFNOTEQ = 30,
JUMPIFNOTLE = 31,
JUMPIFNOTLT = 32,
ADD = 33,
SUB = 34,
MUL = 35,
DIV = 36,
MOD = 37,
POW = 38,
ADDK = 39,
SUBK = 40,
MULK = 41,
DIVK = 42,
MODK = 43,
POWK = 44,
AND = 45,
OR = 46,
ANDK = 47,
ORK = 48,
CONCAT = 49,
NOT = 50,
MINUS = 51,
LENGTH = 52,
NEWTABLE = 53,
DUPTABLE = 54,
SETLIST = 55,
FORNPREP = 56,
FORNLOOP = 57,
FORGLOOP = 58,
FORGPREP_INEXT = 59,
FASTCALL3 = 60,
FORGPREP_NEXT = 61,
NATIVECALL = 62,
GETVARARGS = 63,
DUPCLOSURE = 64,
PREPVARARGS = 65,
LOADKX = 66,
JUMPX = 67,
FASTCALL = 68,
COVERAGE = 69,
CAPTURE = 70,
SUBRK = 71,
DIVRK = 72,
FASTCALL1 = 73,
FASTCALL2 = 74,
FASTCALL2K = 75,
FORGPREP = 76,
JUMPXEQKNIL = 77,
JUMPXEQKB = 78,
JUMPXEQKN = 79,
JUMPXEQKS = 80,
IDIV = 81,
IDIVK = 82,
GETUDATAKS = 83,
SETUDATAKS = 84,
NAMECALLUDATA = 85,
}
Luau.OpCode = LOP
local LBC_CONSTANT_NIL = 0
local LBC_CONSTANT_BOOLEAN = 1
local LBC_CONSTANT_NUMBER = 2
local LBC_CONSTANT_STRING = 3
local LBC_CONSTANT_IMPORT = 4
local LBC_CONSTANT_TABLE = 5
local LBC_CONSTANT_CLOSURE = 6
local LBC_CONSTANT_VECTOR = 7
local LBC_CONSTANT_TABLE_WITH_CONSTANTS = 8
local LBC_CONSTANT_INTEGER = 9
local _LBC_TYPE_NIL = 0
local LBC_TYPE_BOOLEAN = 1
local LBC_TYPE_NUMBER = 2
local LBC_TYPE_STRING = 3
local LBC_TYPE_TABLE = 4
local _LBC_TYPE_FUNCTION = 5
local LBC_TYPE_THREAD = 6
local LBC_TYPE_USERDATA = 7
local LBC_TYPE_VECTOR = 8
local LBC_TYPE_BUFFER = 9
local LBC_TYPE_INTEGER = 10
local LBC_TYPE_ANY = 15
local LBC_TYPE_OPTIONAL_BIT = 128
local simpleTypeNameToBytecode = {
["nil"] = _LBC_TYPE_NIL,
boolean = LBC_TYPE_BOOLEAN,
number = LBC_TYPE_NUMBER,
string = LBC_TYPE_STRING,
table = LBC_TYPE_TABLE,
thread = LBC_TYPE_THREAD,
userdata = LBC_TYPE_USERDATA,
vector = LBC_TYPE_VECTOR,
buffer = LBC_TYPE_BUFFER,
integer = LBC_TYPE_INTEGER,
any = LBC_TYPE_ANY,
}
local TWO_16 = 65536
local TWO_23 = 8388608
local _TWO_24 = 16777216
local TWO_32 = 4294967296
local TWO_52 = 4503599627370496
local MIN_NORMAL_FLOAT = 2 ^ -126
local MIN_SUBNORMAL_FLOAT = 2 ^ -149
local MIN_NORMAL_DOUBLE = 2 ^ -1022
local MIN_SUBNORMAL_DOUBLE = 2 ^ -1074
local function writeByteValue(out, value)
append(out, string.char(value % 256))
end
local function writeInt32(out, value)
value %= TWO_32
local b0 = value % 256
value = (value - b0) / 256
local b1 = value % 256
value = (value - b1) / 256
local b2 = value % 256
value = (value - b2) / 256
local b3 = value % 256
append(out, string.char(b0, b1, b2, b3))
end
local function writeVarIntValue(out, value)
repeat
local byteValue = value % 128
value = (value - byteValue) / 128
if value > 0 then
byteValue += 128
end
writeByteValue(out, byteValue)
until value == 0
end
local function writeVarInt64(out, hi, lo)
repeat
local byteValue = lo % 128
lo = (lo - byteValue) / 128 + (hi % 128) * 33554432
hi = (hi - hi % 128) / 128
if hi > 0 or lo > 0 then
byteValue += 128
end
writeByteValue(out, byteValue)
until hi == 0 and lo == 0
end
local function roundToEven(value)
local base = math.floor(value)
local fraction = value - base
if fraction > 0.5 then
return base + 1
elseif fraction < 0.5 then
return base
end
return base % 2 == 0 and base or base + 1
end
local function frexpFallback(value)
if value == 0 then
return 0, 0
end
local exponent = math.floor(math.log(value) / math.log(2)) + 1
local mantissa = value / 2 ^ exponent
while mantissa < 0.5 do
mantissa *= 2
exponent -= 1
end
while mantissa >= 1 do
mantissa /= 2
exponent += 1
end
return mantissa, exponent
end
local function packFloat(value)
local sign = 0
if value < 0 or value == 0 and 1 / value == -math.huge then
sign = 1
value = -value
end
local bits
if value ~= value then
bits = 0x7fc00000
elseif value == math.huge then
bits = 0x7f800000
elseif value == 0 then
bits = 0
elseif value < MIN_NORMAL_FLOAT then
local fraction = roundToEven(value / MIN_SUBNORMAL_FLOAT)
if fraction >= TWO_23 then
bits = TWO_23
else
bits = fraction
end
else
local exponent = math.floor(math.log(value) / math.log(2))
local significand = value / 2 ^ exponent
local fraction = roundToEven((significand - 1) * TWO_23)
if fraction >= TWO_23 then
exponent += 1
fraction = 0
end
local expBits = exponent + 127
if expBits >= 255 then
bits = 0x7f800000
else
bits = expBits * TWO_23 + fraction
end
end
return string.char(
bits % 256,
math.floor(bits / 256) % 256,
math.floor(bits / 65536) % 256,
math.floor(bits / 16777216) % 128 + sign * 128
)
end
local function packDouble(value)
local sign = 0
if value < 0 or value == 0 and 1 / value == -math.huge then
sign = 1
value = -value
end
local high
local low
if value ~= value then
high = 0x7ff80000
low = 0
elseif value == math.huge then
high = 0x7ff00000
low = 0
elseif value == 0 then
high = 0
low = 0
elseif value < MIN_NORMAL_DOUBLE then
local fraction = math.floor(value / MIN_SUBNORMAL_DOUBLE + 0.5)
if fraction >= TWO_52 then
low = 0
high = 0x00100000
else
low = fraction % TWO_32
high = math.floor(fraction / TWO_32)
end
else
local mantissa, exponent
if math.frexp then
mantissa, exponent = math.frexp(value)
else
mantissa, exponent = frexpFallback(value)
end
local biased = exponent + 1022
local fraction = math.floor((mantissa * 2 - 1) * TWO_52 + 0.5)
if fraction >= TWO_52 then
fraction = 0
biased += 1
end
low = fraction % TWO_32
high = math.floor(fraction / TWO_32) + biased * 1048576
end
if sign == 1 then
high += 2147483648
end
local out = {}
writeInt32(out, low)
writeInt32(out, high)
return table.concat(out)
end
local function encodeABC(op, a, b, c)
return op + a * 256 + b * 65536 + c * 16777216
end
local function encodeAD(op, a, d)
if d < 0 then
d += TWO_16
end
return op + a * 256 + (d % TWO_16) * 65536
end
local function stringHash(value)
local hash = #value
for index = #value, 1, -1 do
hash = (bit32.bxor(hash, hash * 32 + math.floor(hash / 4) + byte(value, index))) % TWO_32
end
return hash
end
local function encodeTableSize(value)
if value <= 0 then
return 0
end
local power = 1
local log = 0
while power < value do
power *= 2
log += 1
end
return log + 1
end
local function floorLog2(value)
local result = 0
while value >= 2 do
value = math.floor(value / 2)
result += 1
end
return result
end
local PureBytecodeBuilder = {}
PureBytecodeBuilder.__index = PureBytecodeBuilder
function PureBytecodeBuilder.new(options)
return setmetatable({
options = CompileOptions.new(options),
strings = {},
stringMap = {},
functions = {},
mainFunction = 0,
hasIntegerConstant = false,
}, PureBytecodeBuilder)
end
function PureBytecodeBuilder:addString(value)
local existing = self.stringMap[value]
if existing then
return existing
end
local index = #self.strings + 1
self.strings[index] = value
self.stringMap[value] = index
return index
end
function PureBytecodeBuilder:createFunction(numparams, isvararg)
return {
maxstacksize = numparams or 0,
numparams = numparams or 0,
numupvalues = 0,
isvararg = isvararg == true,
flags = 0,
insns = {},
lines = {},
constants = {},
constantMap = {},
protos = {},
protoMap = {},
debuglinedefined = 1,
debugname = 0,
debugLocals = {},
debugUpvals = {},
typeInfo = "",
typedUpvals = {},
typedLocals = {},
}
end
function PureBytecodeBuilder:addFunction(func)
local index = #self.functions
self.functions[index + 1] = func
return index
end
function PureBytecodeBuilder:addChildFunction(parent, childId)
local existing = parent.protoMap[childId]
if existing ~= nil then
return existing
end
local index = #parent.protos
parent.protos[index + 1] = childId
parent.protoMap[childId] = index
return index
end
local function constantKey(tag, value, extra)
return tostring(tag) .. "\0" .. tostring(value) .. "\0" .. tostring(extra)
end
function PureBytecodeBuilder:addConstant(func, tag, value, extra)
local key = constantKey(tag, value, extra)
local existing = func.constantMap[key]
if existing ~= nil then
return existing
end
local index = #func.constants
func.constants[index + 1] = {
tag = tag,
value = value,
extra = extra,
}
func.constantMap[key] = index
return index
end
function PureBytecodeBuilder:addConstantNil(func)
return self:addConstant(func, LBC_CONSTANT_NIL, 0)
end
function PureBytecodeBuilder:addConstantBoolean(func, value)
return self:addConstant(func, LBC_CONSTANT_BOOLEAN, value and 1 or 0)
end
function PureBytecodeBuilder:addConstantNumber(func, value)
return self:addConstant(func, LBC_CONSTANT_NUMBER, value)
end
function PureBytecodeBuilder:addConstantVector(func, value)
local packed = packFloat(value.x) .. packFloat(value.y) .. packFloat(value.z) .. packFloat(value.w)
local existing = func.constantMap[constantKey(LBC_CONSTANT_VECTOR, packed, "")]
if existing ~= nil then
return existing
end
local index = #func.constants
func.constants[index + 1] = {
tag = LBC_CONSTANT_VECTOR,
value = value,
packed = packed,
}
func.constantMap[constantKey(LBC_CONSTANT_VECTOR, packed, "")] = index
return index
end
function PureBytecodeBuilder:addConstantInteger(func, value)
self.hasIntegerConstant = true
local key = constantKey(LBC_CONSTANT_INTEGER, value.key, "")
local existing = func.constantMap[key]
if existing ~= nil then
return existing
end
local index = #func.constants
func.constants[index + 1] = {
tag = LBC_CONSTANT_INTEGER,
value = value,
}
func.constantMap[key] = index
return index
end
function PureBytecodeBuilder:addConstantString(func, value)
local stringId = self:addString(value)
return self:addConstant(func, LBC_CONSTANT_STRING, stringId)
end
function PureBytecodeBuilder:addConstantTable(func, keys, valueConstants, hasConstants)
local parts = { hasConstants and "tablec" or "table" }
for index, key in ipairs(keys) do
append(parts, tostring(key))
if hasConstants then
append(parts, tostring(valueConstants[index] or -1))
end
end
local key = table.concat(parts, "\0")
local existing = func.constantMap[key]
if existing ~= nil then
return existing
end
local index = #func.constants
local storedKeys = {}
for _, keyConstant in ipairs(keys) do
append(storedKeys, keyConstant)
end
if hasConstants then
local storedConstants = {}
for keyIndex = 1, #keys do
storedConstants[keyIndex] = valueConstants[keyIndex] or -1
end
func.constants[index + 1] = {
tag = LBC_CONSTANT_TABLE_WITH_CONSTANTS,
value = {
keys = storedKeys,
constants = storedConstants,
},
}
func.constantMap[key] = index
return index
end
func.constants[index + 1] = {
tag = LBC_CONSTANT_TABLE,
value = storedKeys,
}
func.constantMap[key] = index
return index
end
function PureBytecodeBuilder:addConstantClosure(func, childId)
return self:addConstant(func, LBC_CONSTANT_CLOSURE, childId)
end
function PureBytecodeBuilder:getImportId(path)
local count = #path
local importId = bit32.lshift(count, 30)
if count >= 1 then
importId = bit32.bor(importId, bit32.lshift(path[1] - 1, 20))
end
if count >= 2 then
importId = bit32.bor(importId, bit32.lshift(path[2] - 1, 10))
end
if count >= 3 then
importId = bit32.bor(importId, path[3] - 1)
end
return importId
end
function PureBytecodeBuilder:addConstantImport(func, path)
local ids = {}
for _, name in ipairs(path) do
append(ids, self:addConstantString(func, name) + 1)
end
local importId = self:getImportId(ids)
return self:addConstant(func, LBC_CONSTANT_IMPORT, importId)
end
function PureBytecodeBuilder:emit(func, word, line)
append(func.insns, word % TWO_32)
append(func.lines, math.max(1, line or 1))
end
function PureBytecodeBuilder:emitABC(func, op, a, b, c, line)
if op == nil or a == nil or b == nil or c == nil then
error(string.format("invalid ABC instruction op=%s a=%s b=%s c=%s", tostring(op), tostring(a), tostring(b), tostring(c)), 2)
end
self:emit(func, encodeABC(op, a, b, c), line)
end
function PureBytecodeBuilder:emitAD(func, op, a, d, line)
if op == nil or a == nil or d == nil then
error(string.format("invalid AD instruction op=%s a=%s d=%s", tostring(op), tostring(a), tostring(d)), 2)
end
self:emit(func, encodeAD(op, a, d), line)
end
function PureBytecodeBuilder:emitAux(func, aux, line)
self:emit(func, aux, line)
end
function PureBytecodeBuilder:label(func)
return #func.insns
end
function PureBytecodeBuilder:patchJump(func, at, target)
local offset = target - at - 1
if offset < -32768 or offset > 32767 then
error("jump offset out of range", 2)
end
local word = func.insns[at + 1]
func.insns[at + 1] = word % TWO_16 + (offset < 0 and offset + TWO_16 or offset) * TWO_16
end
function PureBytecodeBuilder:pushDebugLocal(func, name, reg, startpc, endpc, debugDepth, debugOrder)
if (self.options.debugLevel or 1) < 2 or reg == nil then
return
end
append(func.debugLocals, {
name = self:addString(name),
reg = reg,
startpc = startpc,
endpc = endpc,
debugDepth = debugDepth or 0,
debugOrder = debugOrder or #func.debugLocals,
})
end
function PureBytecodeBuilder:pushDebugUpval(func, name)
if (self.options.debugLevel or 1) < 2 then
return
end
append(func.debugUpvals, {
name = self:addString(name),
})
end
function PureBytecodeBuilder:pushLocalTypeInfo(func, typeId, reg, startpc, endpc)
if typeId == nil or typeId == LBC_TYPE_ANY or reg == nil or startpc == nil or endpc == nil then
return
end
append(func.typedLocals, {
type = typeId,
reg = reg,
startpc = startpc,
endpc = endpc,
})
end
function PureBytecodeBuilder:writeLineInfo(out, func)
local lineCount = #func.lines
local span = _TWO_24
local offset = 1
while offset <= lineCount do
local nextIndex = offset
local minLine = func.lines[offset]
local maxLine = func.lines[offset]
while nextIndex <= lineCount and nextIndex < offset + span do
local line = func.lines[nextIndex]
minLine = math.min(minLine, line)
maxLine = math.max(maxLine, line)
if maxLine - minLine > 255 then
break
end
nextIndex += 1
end
if nextIndex <= lineCount and nextIndex - offset < span then
span = 2 ^ floorLog2(nextIndex - offset)
end
offset += span
end
local baselines = {}
offset = 1
while offset <= lineCount do
local minLine = func.lines[offset]
local nextIndex = offset + 1
while nextIndex <= lineCount and nextIndex < offset + span do
minLine = math.min(minLine, func.lines[nextIndex])
nextIndex += 1
end
append(baselines, minLine)
offset += span
end
local logspan = floorLog2(span)
writeByteValue(out, logspan)
local lastOffset = 0
for index, line in ipairs(func.lines) do
local baseline = baselines[math.floor((index - 1) / span) + 1]
local delta = line - baseline
writeByteValue(out, delta - lastOffset)
lastOffset = delta
end
local lastLine = 0
for _, baseline in ipairs(baselines) do
writeInt32(out, baseline - lastLine)
lastLine = baseline
end
end
local foldableJumpOps = {
[LOP.JUMP] = true,
[LOP.JUMPIF] = true,
[LOP.JUMPIFNOT] = true,
[LOP.JUMPIFEQ] = true,
[LOP.JUMPIFLE] = true,
[LOP.JUMPIFLT] = true,
[LOP.JUMPIFNOTEQ] = true,
[LOP.JUMPIFNOTLE] = true,
[LOP.JUMPIFNOTLT] = true,
[LOP.FORNPREP] = true,
[LOP.FORNLOOP] = true,
[LOP.FORGLOOP] = true,
[LOP.FORGPREP_INEXT] = true,
[LOP.FORGPREP_NEXT] = true,
[LOP.FORGPREP] = true,
[LOP.JUMPBACK] = true,
[LOP.JUMPXEQKNIL] = true,
[LOP.JUMPXEQKB] = true,
[LOP.JUMPXEQKN] = true,
[LOP.JUMPXEQKS] = true,
}
local function insnD(word)
local d = math.floor(word / TWO_16) % TWO_16
if d >= 32768 then
d -= TWO_16
end
return d
end
function PureBytecodeBuilder:foldJumps(func)
for index, word in ipairs(func.insns) do
local op = word % 256
if foldableJumpOps[op] then
local jumpLabel = index - 1
local targetLabel = jumpLabel + 1 + insnD(word)
local targetInsn = func.insns[targetLabel + 1]
if targetInsn then
while targetInsn % 256 == LOP.JUMP and insnD(targetInsn) >= 0 do
targetLabel += 1 + insnD(targetInsn)
targetInsn = func.insns[targetLabel + 1]
if not targetInsn then
break
end
end
if targetInsn then
local offset = targetLabel - jumpLabel - 1
if op == LOP.JUMP and targetInsn % 256 == LOP.RETURN then
func.insns[index] = targetInsn
elseif offset >= -32768 and offset <= 32767 then
func.insns[index] = word % TWO_16 + (offset < 0 and offset + TWO_16 or offset) * TWO_16
end
end
end
end
end
end
function PureBytecodeBuilder:writeFunction(out, func)
writeByteValue(out, func.maxstacksize)
writeByteValue(out, func.numparams)
writeByteValue(out, func.numupvalues)
writeByteValue(out, func.isvararg and 1 or 0)
writeByteValue(out, func.flags or 0)
if func.typeInfo ~= "" or #func.typedUpvals > 0 or #func.typedLocals > 0 then
local typeOut = {}
writeVarIntValue(typeOut, #func.typeInfo)
writeVarIntValue(typeOut, #func.typedUpvals)
writeVarIntValue(typeOut, #func.typedLocals)
append(typeOut, func.typeInfo)
for _, upvalue in ipairs(func.typedUpvals) do
writeByteValue(typeOut, upvalue.type)
end
for _, localInfo in ipairs(func.typedLocals) do
writeByteValue(typeOut, localInfo.type)
writeByteValue(typeOut, localInfo.reg)
writeVarIntValue(typeOut, localInfo.startpc)
writeVarIntValue(typeOut, localInfo.endpc - localInfo.startpc)
end
local typeBlob = table.concat(typeOut)
writeVarIntValue(out, #typeBlob)
append(out, typeBlob)
else
writeVarIntValue(out, 0)
end
writeVarIntValue(out, #func.insns)
for _, word in ipairs(func.insns) do
writeInt32(out, word)
end
writeVarIntValue(out, #func.constants)
for _, constant in ipairs(func.constants) do
writeByteValue(out, constant.tag)
if constant.tag == LBC_CONSTANT_BOOLEAN then
writeByteValue(out, constant.value)
elseif constant.tag == LBC_CONSTANT_NUMBER then
append(out, packDouble(constant.value))
elseif constant.tag == LBC_CONSTANT_VECTOR then
append(out, constant.packed)
elseif constant.tag == LBC_CONSTANT_INTEGER then
writeByteValue(out, constant.value.negative and 1 or 0)
writeVarInt64(out, constant.value.hi, constant.value.lo)
elseif constant.tag == LBC_CONSTANT_STRING then
writeVarIntValue(out, constant.value)
elseif constant.tag == LBC_CONSTANT_IMPORT then
writeInt32(out, constant.value)
elseif constant.tag == LBC_CONSTANT_TABLE then
writeVarIntValue(out, #constant.value)
for _, key in ipairs(constant.value) do
writeVarIntValue(out, key)
end
elseif constant.tag == LBC_CONSTANT_TABLE_WITH_CONSTANTS then
writeVarIntValue(out, #constant.value.keys)
for index, key in ipairs(constant.value.keys) do
writeVarIntValue(out, key)
writeInt32(out, constant.value.constants[index] or -1)
end
elseif constant.tag == LBC_CONSTANT_CLOSURE then
writeVarIntValue(out, constant.value)
end
end
writeVarIntValue(out, #func.protos)
for _, proto in ipairs(func.protos) do
writeVarIntValue(out, proto)
end
writeVarIntValue(out, func.debuglinedefined or 1)
if (self.options.debugLevel or 1) >= 1 then
writeVarIntValue(out, func.debugname or 0)
else
writeVarIntValue(out, 0)
end
if self.options.debugLevel and self.options.debugLevel > 0 and #func.lines == #func.insns and #func.lines > 0 then
writeByteValue(out, 1)
self:writeLineInfo(out, func)
else
writeByteValue(out, 0)
end
if #func.debugLocals > 0 or #func.debugUpvals > 0 then
writeByteValue(out, 1)
writeVarIntValue(out, #func.debugLocals)
for _, localInfo in ipairs(func.debugLocals) do
writeVarIntValue(out, localInfo.name)
writeVarIntValue(out, localInfo.startpc)
writeVarIntValue(out, localInfo.endpc)
writeByteValue(out, localInfo.reg)
end
writeVarIntValue(out, #func.debugUpvals)
for _, upvalue in ipairs(func.debugUpvals) do
writeVarIntValue(out, upvalue.name)
end
else
writeByteValue(out, 0)
end
end
function PureBytecodeBuilder:remapStringsForSerialization()
local oldStrings = self.strings
local oldToNew = {}
local newStrings = {}
local function useString(oldId)
if not oldId or oldId == 0 then
return 0
end
local mapped = oldToNew[oldId]
if mapped then
return mapped
end
mapped = #newStrings + 1
oldToNew[oldId] = mapped
newStrings[mapped] = oldStrings[oldId]
return mapped
end
local function normalizeDebugLocalOrder(func)
if #func.debugLocals < 2 then
return
end
local instructionCount = #func.insns
local reordered = {}
local index = 1
while index <= #func.debugLocals do
local endpc = func.debugLocals[index].endpc
local nextIndex = index + 1
while nextIndex <= #func.debugLocals and func.debugLocals[nextIndex].endpc == endpc do
nextIndex += 1
end
local count = nextIndex - index
if count > 1 and endpc < instructionCount then
local group = {}
for groupIndex = index, nextIndex - 1 do
local localInfo = func.debugLocals[groupIndex]
append(group, {
info = localInfo,
order = groupIndex,
})
end
table.sort(group, function(left, right)
local leftDepth = left.info.debugDepth or 0
local rightDepth = right.info.debugDepth or 0
if leftDepth ~= rightDepth then
return leftDepth > rightDepth
end
local leftOrder = left.info.debugOrder or left.order
local rightOrder = right.info.debugOrder or right.order
if leftOrder ~= rightOrder then
return leftOrder < rightOrder
end
if left.info.startpc == right.info.startpc then
return left.order < right.order
end
return left.info.startpc < right.info.startpc
end)
for _, item in ipairs(group) do
append(reordered, item.info)
end
else
for groupIndex = index, nextIndex - 1 do
append(reordered, func.debugLocals[groupIndex])
end
end
index = nextIndex
end
func.debugLocals = reordered
end
if (self.options.debugLevel or 1) >= 2 then
for _, func in ipairs(self.functions) do
normalizeDebugLocalOrder(func)
end
local function useConstantString(func, constantId)
local constant = func.constants[constantId + 1]
if not constant then
return
end
if constant.tag == LBC_CONSTANT_STRING then
useString(constant.value)
elseif constant.tag == LBC_CONSTANT_TABLE then
for _, keyConstantId in ipairs(constant.value) do
useConstantString(func, keyConstantId)
end
elseif constant.tag == LBC_CONSTANT_TABLE_WITH_CONSTANTS then
for index, keyConstantId in ipairs(constant.value.keys) do
useConstantString(func, keyConstantId)
local valueConstantId = constant.value.constants[index]
if valueConstantId and valueConstantId >= 0 then
useConstantString(func, valueConstantId)
end
end
end
end
local function useImportStrings(func, aux)
local count = math.floor(aux / 1073741824)
local ids = {
math.floor(aux / 1048576) % 1024,
math.floor(aux / 1024) % 1024,
aux % 1024,
}
for index = 1, count do
useConstantString(func, ids[index])
end
end
local function useInstructionStrings(func, pc)
local word = func.insns[pc]
if not word then
return false
end
local op = word % 256
local c = math.floor(word / 16777216) % 256
local d = math.floor(word / 65536) % 65536
if d >= 32768 then
d -= 65536
end
if op == LOP.LOADK then
useConstantString(func, d)
elseif op == LOP.LOADKX then
useConstantString(func, func.insns[pc + 1] or 0)
return true
elseif op == LOP.GETGLOBAL or op == LOP.SETGLOBAL or op == LOP.GETTABLEKS or op == LOP.SETTABLEKS or op == LOP.NAMECALL then
useConstantString(func, func.insns[pc + 1] or 0)
return true
elseif op == LOP.GETIMPORT then
useImportStrings(func, func.insns[pc + 1] or 0)
return true
elseif op == LOP.DUPTABLE then
useConstantString(func, d)
elseif op == LOP.ANDK or op == LOP.ORK then
useConstantString(func, c)
elseif op == LOP.FASTCALL2K then
useConstantString(func, func.insns[pc + 1] or 0)
return true
elseif op == LOP.JUMPXEQKS then
local constantId = (func.insns[pc + 1] or 0) % 16777216
useConstantString(func, constantId)
return true
elseif op == LOP.JUMPXEQKNIL or op == LOP.JUMPXEQKB or op == LOP.JUMPXEQKN then
return true
elseif op == LOP.JUMPIFEQ or op == LOP.JUMPIFLE or op == LOP.JUMPIFLT or op == LOP.JUMPIFNOTEQ or op == LOP.JUMPIFNOTLE or op == LOP.JUMPIFNOTLT or op == LOP.NEWTABLE or op == LOP.SETLIST or op == LOP.FORGLOOP or op == LOP.FASTCALL2 or op == LOP.FASTCALL3 then
return true
end
return false
end
for _, func in ipairs(self.functions) do
local localsByEnd = {}
local finalLocals = {}
local instructionCount = #func.insns
for _, localInfo in ipairs(func.debugLocals) do
if localInfo.endpc < instructionCount then
local bucket = localsByEnd[localInfo.endpc]
if not bucket then
bucket = {}
localsByEnd[localInfo.endpc] = bucket
end
append(bucket, localInfo)
else
append(finalLocals, localInfo)
end
end
local pc = 1
while pc <= instructionCount do
local earlyLocals = localsByEnd[pc - 1]
if earlyLocals then
for _, localInfo in ipairs(earlyLocals) do
useString(localInfo.name)
end
end
local hasAux = useInstructionStrings(func, pc)
pc += hasAux and 2 or 1
end
local earlyLocals = localsByEnd[instructionCount]
if earlyLocals then
for _, localInfo in ipairs(earlyLocals) do
useString(localInfo.name)
end
end
useString(func.debugname)
for _, upvalue in ipairs(func.debugUpvals) do
useString(upvalue.name)
end
for _, localInfo in ipairs(finalLocals) do
useString(localInfo.name)
end
for _, constant in ipairs(func.constants) do
if constant.tag == LBC_CONSTANT_STRING then
useString(constant.value)
end
end
end
for _, func in ipairs(self.functions) do
for _, constant in ipairs(func.constants) do
if constant.tag == LBC_CONSTANT_STRING then
constant.value = useString(constant.value)
end
end
func.debugname = useString(func.debugname)
for _, localInfo in ipairs(func.debugLocals) do
localInfo.name = useString(localInfo.name)
end
for _, upvalue in ipairs(func.debugUpvals) do
upvalue.name = useString(upvalue.name)
end
end
else
for _, func in ipairs(self.functions) do
for _, constant in ipairs(func.constants) do
if constant.tag == LBC_CONSTANT_STRING then
constant.value = useString(constant.value)
end
end
func.debugname = useString(func.debugname)
for _, localInfo in ipairs(func.debugLocals) do
localInfo.name = useString(localInfo.name)
end
for _, upvalue in ipairs(func.debugUpvals) do
upvalue.name = useString(upvalue.name)
end
end
end
for oldId = 1, #oldStrings do
useString(oldId)
end
self.strings = newStrings
self.stringMap = {}
for index, value in ipairs(newStrings) do
self.stringMap[value] = index
end
end
function PureBytecodeBuilder:finalize()
local out = {}
if (self.options.optimizationLevel or 1) >= 1 then
for _, func in ipairs(self.functions) do
self:foldJumps(func)
end
end
self:remapStringsForSerialization()
writeByteValue(out, 9)
writeByteValue(out, 3)
writeVarIntValue(out, #self.strings)
for _, value in ipairs(self.strings) do
writeVarIntValue(out, #value)
append(out, value)
end
writeByteValue(out, 0)
writeVarIntValue(out, #self.functions)
for _, func in ipairs(self.functions) do
self:writeFunction(out, func)
end
writeVarIntValue(out, self.mainFunction)
return table.concat(out)
end
local PureParser = {}
PureParser.__index = PureParser
local char = function(value)
return byte(value)
end
function PureParser.new(source)
local names = AstNameTable.new()
local lexer = Lexer.new(source, #source, names)
lexer:setSkipComments(true)
local self = setmetatable({
source = source,
lexer = lexer,
current = nil,
typeAliases = {},
typeGenerics = nil,
}, PureParser)
self:advance()
return self
end
function PureParser:line()
return self.current and self.current.location.begin.line + 1 or 1
end
function PureParser:error(message)
error(string.format(":%d: %s", self:line(), message), 0)
end
function PureParser:advance()
self.current = self.lexer:next(true, true)
return self.current
end
function PureParser:check(tokenType)
return self.current.type == tokenType
end
function PureParser:accept(tokenType)
if self:check(tokenType) then
local token = self.current
self:advance()
return token
end
return nil
end
function PureParser:expect(tokenType, message)
local token = self:accept(tokenType)
if not token then
self:error(message or ("expected " .. tostring(tokenType)))
end
return token
end
local function node(kind, fields, line)
fields = fields or {}
fields.kind = kind
fields.line = line or 1
fields.endLine = fields.endLine or fields.line
return fields
end
local ENABLE_INTEGER_LITERALS = true
local function integerKey(negative, hi, lo)
return (negative and "-" or "+") .. tostring(hi) .. ":" .. tostring(lo)
end
local function parseUnsignedInteger64(digits, base)
local hi = 0
local lo = 0
if digits == "" then
return nil
end
for index = 1, #digits do
local ch = string.sub(digits, index, index)
local digit
if ch >= "0" and ch <= "9" then
digit = byte(ch) - byte("0")
elseif ch >= "a" and ch <= "f" then
digit = byte(ch) - byte("a") + 10
elseif ch >= "A" and ch <= "F" then
digit = byte(ch) - byte("A") + 10
else
return nil
end
if digit >= base then
return nil
end
local wideLo = lo * base + digit
lo = wideLo % TWO_32
local carry = (wideLo - lo) / TWO_32
local wideHi = hi * base + carry
if wideHi >= TWO_32 then
return nil
end
hi = wideHi
end
return hi, lo
end
local function unsignedToSignedInteger(hi, lo, base)
local negative = false
if base ~= 10 and hi >= 0x80000000 then
negative = true
hi = TWO_32 - 1 - hi
lo = TWO_32 - 1 - lo
lo += 1
if lo >= TWO_32 then
lo -= TWO_32
hi += 1
end
elseif base == 10 and hi >= 0x80000000 then
return nil
end
return {
negative = negative,
hi = hi,
lo = lo,
key = integerKey(negative, hi, lo),
}
end
local function negateIntegerLiteralValue(value)
if value.hi == 0 and value.lo == 0 then
return {
negative = false,
hi = 0,
lo = 0,
key = integerKey(false, 0, 0),
}
end
if value.negative and value.hi == 0x80000000 and value.lo == 0 then
return {
negative = true,
hi = value.hi,
lo = value.lo,
key = value.key,
}
end
return {
negative = not value.negative,
hi = value.hi,
lo = value.lo,
key = integerKey(not value.negative, value.hi, value.lo),
}
end
local function parseIntegerLiteral(text)
local cleaned = string.gsub(text or "0", "_", "")
if string.sub(cleaned, -1) ~= "i" then
return nil
end
local body = string.sub(cleaned, 1, #cleaned - 1)
local base = 10
if string.sub(body, 1, 2) == "0x" or string.sub(body, 1, 2) == "0X" then
base = 16
body = string.sub(body, 3)
elseif string.sub(body, 1, 2) == "0b" or string.sub(body, 1, 2) == "0B" then
base = 2
body = string.sub(body, 3)
end
local hi, lo = parseUnsignedInteger64(body, base)
if hi == nil then
return nil
end
return unsignedToSignedInteger(hi, lo, base)
end
local function parseNumberLiteral(text)
local cleaned = string.gsub(text or "0", "_", "")
if string.sub(cleaned, -1) == "i" then
if not ENABLE_INTEGER_LITERALS then
return nil
end
local value = parseIntegerLiteral(cleaned)
if not value then
return nil
end
return {
kind = "Integer",
value = value,
}
end
if string.sub(cleaned, 1, 2) == "0b" or string.sub(cleaned, 1, 2) == "0B" then
if #cleaned <= 2 then
return nil
end
local value = 0
for index = 3, #cleaned do
local ch = string.sub(cleaned, index, index)
if ch ~= "0" and ch ~= "1" then
return nil
end
value = value * 2 + (ch == "1" and 1 or 0)
end
return {
kind = "Number",
value = value,
}
end
local value = tonumber(cleaned)
if value == nil then
return nil
end
return {
kind = "Number",
value = value,
}
end
local binaryInfo = {
[Lexeme.Type.ReservedOr] = { "or", 1, false },
[Lexeme.Type.ReservedAnd] = { "and", 2, false },
[Lexeme.Type.Equal] = { "==", 3, false },
[Lexeme.Type.NotEqual] = { "~=", 3, false },
[Lexeme.Type.LessEqual] = { "<=", 3, false },
[Lexeme.Type.GreaterEqual] = { ">=", 3, false },
[char("<")] = { "<", 3, false },
[char(">")] = { ">", 3, false },
[Lexeme.Type.Dot2] = { "..", 4, true },
[char("+")] = { "+", 5, false },
[char("-")] = { "-", 5, false },
[char("*")] = { "*", 6, false },
[char("/")] = { "/", 6, false },
[Lexeme.Type.FloorDiv] = { "//", 6, false },
[char("%")] = { "%", 6, false },
[char("^")] = { "^", 8, true },
}
function PureParser:tokenName(token)
token = token or self.current
if token.type == Lexeme.Type.Eof then
return "<eof>"
end
if token.type > 0 and token.type < Lexeme.Type.Char_END then
return "'" .. string.char(token.type) .. "'"
end
if token.type == Lexeme.Type.Name then
return token.name or token.data or "<name>"
end
return token:toString()
end
function PureParser:parseName(message)
if not self:check(Lexeme.Type.Name) then
self:error((message or "Expected identifier") .. ", got " .. self:tokenName())
end
local token = self.current
self:advance()
return token.name or token.data
end
function PureParser:isName(value)
return self:check(Lexeme.Type.Name) and (self.current.name == value or self.current.data == value)
end
function PureParser:consumeBalanced(openType, closeType)
self:expect(openType)
local depth = 1
while depth > 0 do
if self:check(Lexeme.Type.Eof) then
self:error("Expected closing delimiter in type annotation")
elseif self:check(openType) then
depth += 1
elseif self:check(closeType) then
depth -= 1
end
self:advance()
end
end
function PureParser:skipGenericTypeList()
if not self:accept(char("<")) then
return nil
end
local depth = 1
local names = {}
local expectName = true
while depth > 0 do
if self:check(Lexeme.Type.Eof) then
self:error("Expected '>' after generic parameters")
elseif self:check(char("<")) then
depth += 1
self:advance()
elseif self:check(char(">")) then
depth -= 1
self:advance()
expectName = false
elseif depth == 1 and expectName and self:check(Lexeme.Type.Name) then
append(names, self.current.name or self.current.data)
expectName = false
self:advance()
elseif depth == 1 and self:check(char(",")) then
expectName = true
self:advance()
else
self:advance()
end
end
return names
end
local function addOptionalBytecodeType(typeId)
if typeId ~= nil and typeId ~= _LBC_TYPE_NIL and typeId ~= LBC_TYPE_ANY and typeId < LBC_TYPE_OPTIONAL_BIT then
return typeId + LBC_TYPE_OPTIONAL_BIT
end
return typeId
end
local function splitOptionalBytecodeType(typeId)
if typeId ~= nil and typeId >= LBC_TYPE_OPTIONAL_BIT then
return typeId - LBC_TYPE_OPTIONAL_BIT, true
end
return typeId, false
end
local function unionBytecodeTypes(leftType, rightType)
if leftType == nil then
return rightType
elseif rightType == nil then
return leftType
end
local leftBase, leftOptional = splitOptionalBytecodeType(leftType)
local rightBase, rightOptional = splitOptionalBytecodeType(rightType)
local optional = leftOptional or rightOptional
if leftBase == _LBC_TYPE_NIL then
leftBase = rightBase
optional = true
elseif rightBase == _LBC_TYPE_NIL then
optional = true
elseif leftBase ~= rightBase then
return LBC_TYPE_ANY
end
if optional then
return addOptionalBytecodeType(leftBase)
end
return leftBase
end
function PureParser:parseTypeAtom()
if self:check(Lexeme.Type.Name)
or self:check(Lexeme.Type.ReservedNil)
or self:check(Lexeme.Type.ReservedTrue)
or self:check(Lexeme.Type.ReservedFalse)
or self:check(Lexeme.Type.QuotedString)
or self:check(Lexeme.Type.RawString)
or self:check(Lexeme.Type.Number)
then
local name = nil
local typeId = nil
if self:check(Lexeme.Type.ReservedNil) then
name = "nil"
elseif self:check(Lexeme.Type.ReservedTrue) or self:check(Lexeme.Type.ReservedFalse) then
typeId = LBC_TYPE_BOOLEAN
elseif self:check(Lexeme.Type.QuotedString) or self:check(Lexeme.Type.RawString) then
typeId = LBC_TYPE_STRING
elseif self:check(Lexeme.Type.Number) then
typeId = LBC_TYPE_NUMBER
else
name = self.current.name or self.current.data
end
self:advance()
local hasPrefix = false
while self:accept(char(".")) do
hasPrefix = true
self:parseName()
end
if self:check(char("<")) then
self:skipGenericTypeList()
end
if name == "typeof" and self:check(char("(")) then
self:consumeBalanced(char("("), char(")"))
return LBC_TYPE_ANY
end
if typeId ~= nil then
return typeId
end
if hasPrefix then
return LBC_TYPE_ANY
end
if self.typeGenerics and self.typeGenerics[name] then
return LBC_TYPE_ANY
end
if self.typeAliases[name] ~= nil then
return self.typeAliases[name]
end
return simpleTypeNameToBytecode[name] or LBC_TYPE_USERDATA
elseif self:check(char("{")) then
self:consumeBalanced(char("{"), char("}"))
return LBC_TYPE_TABLE
elseif self:check(char("(")) then
self:advance()
local result = LBC_TYPE_ANY
local multiple = false
local function parseFunctionTypeSlot()
if self:check(Lexeme.Type.Name) and self.lexer:lookahead().type == char(":") then
self:advance()
self:advance()
return self:parseType()
elseif self:accept(Lexeme.Type.Dot3) then
if self:accept(char(":")) then
return self:parseType()
elseif self:check(Lexeme.Type.Name) or self:check(char("(")) or self:check(char("{")) then
return self:parseTypeAtom()
end
return LBC_TYPE_ANY
end
return self:parseType()
end
if not self:check(char(")")) then
result = parseFunctionTypeSlot()
while self:accept(char(",")) do
multiple = true
parseFunctionTypeSlot()
end
end
self:expect(char(")"), "Expected ')' in type annotation")
if self:accept(Lexeme.Type.SkinnyArrow) then
self:parseType()
return _LBC_TYPE_FUNCTION
end
return multiple and LBC_TYPE_ANY or result
elseif self:accept(Lexeme.Type.Dot3) then
if self:check(Lexeme.Type.Name) or self:check(char("(")) or self:check(char("{")) then
return self:parseTypeAtom()
end
elseif self:accept(Lexeme.Type.ReservedFunction) then
if self:check(char("(")) then
self:consumeBalanced(char("("), char(")"))
end
if self:accept(Lexeme.Type.SkinnyArrow) or self:accept(char(":")) then
self:parseType()
end
return _LBC_TYPE_FUNCTION
else
self:advance()
end
return LBC_TYPE_ANY
end
function PureParser:parseType()
local typeId = self:parseTypeAtom()
while true do
if self:accept(char("?")) then
typeId = addOptionalBytecodeType(typeId)
elseif self:accept(Lexeme.Type.Dot3) then
-- variadic type pack suffix
elseif self:accept(char("|")) then
typeId = unionBytecodeTypes(typeId, self:parseType())
elseif self:accept(char("&")) then
self:parseType()
typeId = LBC_TYPE_ANY
elseif self:accept(Lexeme.Type.SkinnyArrow) then
self:parseType()
typeId = _LBC_TYPE_FUNCTION
else
break
end
end
return typeId
end
function PureParser:skipTypeAtom()
self:parseTypeAtom()
end
function PureParser:skipType()
self:parseType()
end
function PureParser:skipTypeAnnotation()
if self:accept(char(":")) then
return self:parseType()
end
return nil
end
function PureParser:skipAttributes()
local attributes = nil
while self:check(Lexeme.Type.AttributeOpen) or self:check(Lexeme.Type.Attribute) do
local line = self:line()
if not attributes then
attributes = {
line = line,
native = false,
}
end
self:advance()
if self:accept(char("[")) then
local depth = 1
while depth > 0 do
if self:check(Lexeme.Type.Eof) then
self:error("Expected ']' after attribute list")
elseif self:check(char("[")) then
depth += 1
self:advance()
elseif self:check(char("]")) then
depth -= 1
self:advance()
else
if self:check(Lexeme.Type.Name) and (self.current.name or self.current.data) == "native" then
attributes.native = true
end
self:advance()
end
end
else
local name = self:parseName("Expected attribute name")
if name == "native" then
attributes.native = true
end
end
end
return attributes
end
function PureParser:skipTypeFunctionBody()
local depth = 1
while depth > 0 do
if self:check(Lexeme.Type.Eof) then
self:error("Expected 'end' after type function")
elseif self:check(Lexeme.Type.ReservedFunction) or self:check(Lexeme.Type.ReservedDo) or self:check(Lexeme.Type.ReservedIf) or self:check(Lexeme.Type.ReservedFor) or self:check(Lexeme.Type.ReservedWhile) then
depth += 1
self:advance()
elseif self:check(Lexeme.Type.ReservedRepeat) then
depth += 1
self:advance()
elseif self:check(Lexeme.Type.ReservedEnd) or self:check(Lexeme.Type.ReservedUntil) then
depth -= 1
self:advance()
else
self:advance()
end
end
end
function PureParser:parseTypeDeclaration(line)
if self:accept(Lexeme.Type.ReservedFunction) then
self:parseName("Expected type function name")
if self:check(char("(")) then
self:consumeBalanced(char("("), char(")"))
end
if self:accept(char("=")) then
-- Legacy/experimental forms can use an equals before the body.
end
self:skipTypeFunctionBody()
return node("Nop", nil, line)
end
self:parseName("Expected type name")
if self:check(char("<")) then
self:skipGenericTypeList()
end
self:expect(char("="), "Expected '=' after type name")
self:skipType()
self:accept(char(";"))
return node("Nop", nil, line)
end
function PureParser:skipExplicitTypeArguments()
if not (self:check(char("<")) and self.lexer:lookahead().type == char("<")) then
return false
end
self:advance()
self:advance()
local depth = 1
while depth > 0 do
if self:check(Lexeme.Type.Eof) then
self:error("Expected '>>' after type arguments")
elseif self:check(char("<")) and self.lexer:lookahead().type == char("<") then
depth += 1
self:advance()
self:advance()
elseif self:check(char(">")) and self.lexer:lookahead().type == char(">") then
depth -= 1
self:advance()
self:advance()
else
self:advance()
end
end
return true
end
function PureParser:parseFunctionBody(line, attributes)
local oldTypeGenerics = self.typeGenerics
if self:check(char("<")) then
local genericNames = self:skipGenericTypeList()
local typeGenerics = {}
if oldTypeGenerics then
for name in pairs(oldTypeGenerics) do
typeGenerics[name] = true
end
end
for _, name in ipairs(genericNames or {}) do
typeGenerics[name] = true
end
self.typeGenerics = typeGenerics
end
self:expect(char("("), "Expected '(' when parsing function")
local params = {}
local paramTypes = {}
local isvararg = false
if not self:check(char(")")) then
repeat
if self:accept(Lexeme.Type.Dot3) then
isvararg = true
self:skipTypeAnnotation()
break
end
append(params, self:parseName())
paramTypes[#params] = self:skipTypeAnnotation()
until not self:accept(char(","))
end
self:expect(char(")"), "Expected ')' after function parameters")
local returnType = self:skipTypeAnnotation()
local body = self:parseBlock({
[Lexeme.Type.ReservedEnd] = true,
})
local endLine = self:line()
self:expect(Lexeme.Type.ReservedEnd, "Expected 'end' after function body")
self.typeGenerics = oldTypeGenerics
return node("Function", {
params = params,
paramTypes = paramTypes,
returnType = returnType,
isvararg = isvararg,
body = body,
endLine = endLine,
native = attributes and attributes.native or false,
}, line)
end
function PureParser:parseIfExpression(line)
local clauses = {}
local condition = self:parseExpression()
self:expect(Lexeme.Type.ReservedThen, "Expected 'then' after if expression condition")
append(clauses, {
condition = condition,
value = self:parseExpression(),
})
while self:accept(Lexeme.Type.ReservedElseif) do
condition = self:parseExpression()
self:expect(Lexeme.Type.ReservedThen, "Expected 'then' after elseif expression condition")
append(clauses, {
condition = condition,
value = self:parseExpression(),
})
end
self:expect(Lexeme.Type.ReservedElse, "Expected 'else' after if expression")
return node("IfExpr", {
clauses = clauses,
elseValue = self:parseExpression(),
}, line)
end
function PureParser:setExprLine(expr, line)
expr.line = line or expr.line
if expr.kind == "Field" then
self:setExprLine(expr.object, line)
elseif expr.kind == "Index" then
self:setExprLine(expr.object, line)
self:setExprLine(expr.index, line)
elseif expr.kind == "Call" then
self:setExprLine(expr.callee, line)
for _, arg in ipairs(expr.args) do
self:setExprLine(arg, line)
end
elseif expr.kind == "MethodCall" then
self:setExprLine(expr.object, line)
for _, arg in ipairs(expr.args) do
self:setExprLine(arg, line)
end
elseif expr.kind == "Table" then
for _, entry in ipairs(expr.entries) do
if entry.key then
self:setExprLine(entry.key, line)
end
if entry.value then
self:setExprLine(entry.value, line)
end
end
elseif expr.kind == "Un" then
self:setExprLine(expr.expr, line)
elseif expr.kind == "Bin" then
self:setExprLine(expr.left, line)
self:setExprLine(expr.right, line)
elseif expr.kind == "Function" then
expr.line = line or expr.line
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
self:setExprLine(clause.condition, line)
self:setExprLine(clause.value, line)
end
self:setExprLine(expr.elseValue, line)
elseif expr.kind == "InterpString" then
for _, value in ipairs(expr.expressions) do
self:setExprLine(value, line)
end
end
end
function PureParser:offsetExprLine(expr, delta)
if not expr or delta == 0 then
return
end
if expr.line then
expr.line += delta
end
if expr.endLine then
expr.endLine += delta
end
if expr.opLine then
expr.opLine += delta
end
if expr.kind == "Field" then
self:offsetExprLine(expr.object, delta)
elseif expr.kind == "Index" then
self:offsetExprLine(expr.object, delta)
self:offsetExprLine(expr.index, delta)
elseif expr.kind == "Call" then
self:offsetExprLine(expr.callee, delta)
for _, arg in ipairs(expr.args) do
self:offsetExprLine(arg, delta)
end
elseif expr.kind == "MethodCall" then
self:offsetExprLine(expr.object, delta)
for _, arg in ipairs(expr.args) do
self:offsetExprLine(arg, delta)
end
elseif expr.kind == "Table" then
for _, entry in ipairs(expr.entries) do
self:offsetExprLine(entry.key, delta)
self:offsetExprLine(entry.value, delta)
end
elseif expr.kind == "Un" then
self:offsetExprLine(expr.expr, delta)
elseif expr.kind == "Bin" then
self:offsetExprLine(expr.left, delta)
self:offsetExprLine(expr.right, delta)
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
self:offsetExprLine(clause.condition, delta)
self:offsetExprLine(clause.value, delta)
end
self:offsetExprLine(expr.elseValue, delta)
elseif expr.kind == "InterpString" then
for _, value in ipairs(expr.expressions) do
self:offsetExprLine(value, delta)
end
end
end
function PureParser:parseExpressionFromSource(source, line)
local parser = PureParser.new(source)
local expr = parser:parseExpression()
parser:expect(Lexeme.Type.Eof, "Expected end of interpolation expression")
parser:offsetExprLine(expr, (line or 1) - 1)
return expr
end
function PureParser:parseInterpolatedString(data, line)
local strings = {}
local expressions = {}
local index = 1
local textStart = 1
local scanBracedExpression
local function scanQuoted(startIndex, quote)
local scanIndex = startIndex
while scanIndex <= #data do
local ch = string.sub(data, scanIndex, scanIndex)
if ch == "\\" then
scanIndex += 2
elseif quote == "`" and ch == "{" then
scanIndex = scanBracedExpression(scanIndex + 1)
elseif ch == quote then
return scanIndex + 1
else
scanIndex += 1
end
end
return scanIndex
end
scanBracedExpression = function(startIndex)
local scanIndex = startIndex
local depth = 1
while scanIndex <= #data and depth > 0 do
local ch = string.sub(data, scanIndex, scanIndex)
if ch == "\\" then
scanIndex += 2
elseif ch == "'" or ch == '"' or ch == "`" then
scanIndex = scanQuoted(scanIndex + 1, ch)
elseif ch == "{" then
depth += 1
scanIndex += 1
elseif ch == "}" then
depth -= 1
scanIndex += 1
else
scanIndex += 1
end
end
return scanIndex
end
local function appendText(untilIndex)
local text = untilIndex >= textStart and string.sub(data, textStart, untilIndex) or ""
append(strings, Lexer.fixupQuotedString(text))
end
while index <= #data do
local ch = string.sub(data, index, index)
if ch == "\\" then
index += 2
elseif ch == "{" then
appendText(index - 1)
local exprStart = index + 1
index = scanBracedExpression(exprStart)
local exprSource = string.sub(data, exprStart, index - 2)
local _, prefixNewlines = string.gsub(string.sub(data, 1, exprStart - 1), "\n", "")
append(expressions, self:parseExpressionFromSource(exprSource, line + prefixNewlines))
textStart = index
else
index += 1
end
end
appendText(#data)
return node("InterpString", {
strings = strings,
expressions = expressions,
}, line)
end
function PureParser:parsePrimary()
local token = self.current
local line = self:line()
local attributes = self:skipAttributes()
if attributes then
line = attributes.line or line
if self:accept(Lexeme.Type.ReservedFunction) then
return self:parseFunctionBody(line, attributes)
end
self:error("Expected 'function' declaration after attribute")
end
if self:accept(Lexeme.Type.ReservedNil) then
return node("Nil", nil, line)
elseif self:accept(Lexeme.Type.ReservedTrue) then
return node("Bool", { value = true }, line)
elseif self:accept(Lexeme.Type.ReservedFalse) then
return node("Bool", { value = false }, line)
elseif self:accept(Lexeme.Type.Number) then
local parsed = parseNumberLiteral(token.data)
if parsed == nil then
self:error("Malformed number")
end
return node(parsed.kind, {
value = parsed.value,
text = token.data,
}, line)
elseif self:accept(Lexeme.Type.QuotedString) or self:accept(Lexeme.Type.RawString) then
return node("String", { value = token.data or "" }, line)
elseif self:accept(Lexeme.Type.InterpStringSimple) then
return node("String", { value = Lexer.fixupQuotedString(token.data or "") }, line)
elseif self:accept(Lexeme.Type.InterpStringEnd) then
return self:parseInterpolatedString(token.data or "", line)
elseif self:accept(Lexeme.Type.Name) then
return node("Name", { name = token.name or token.data }, line)
elseif self:accept(char("(")) then
local expr = self:parseExpression()
self:expect(char(")"), "Expected ')' after expression")
if expr.kind == "Call" or expr.kind == "MethodCall" or expr.kind == "Vararg" then
return node("SingleResult", {
expr = expr,
}, line)
end
expr.parenthesized = true
return expr
elseif self:accept(char("{")) then
local entries = {}
if not self:check(char("}")) then
repeat
if self:check(Lexeme.Type.Name) and self.lexer:lookahead().type == char("=") then
local key = self:parseName()
self:expect(char("="), "Expected '=' after table key")
append(entries, {
kind = "field",
key = key,
value = self:parseExpression(),
})
elseif self:accept(char("[")) then
local key = self:parseExpression()
self:expect(char("]"), "Expected ']' after table key")
self:expect(char("="), "Expected '=' after table key")
append(entries, {
kind = "index",
key = key,
value = self:parseExpression(),
})
else
append(entries, {
kind = "array",
value = self:parseExpression(),
})
end
until not (self:accept(char(",")) or self:accept(char(";"))) or self:check(char("}"))
end
self:expect(char("}"), "Expected '}' after table constructor")
return node("Table", { entries = entries }, line)
elseif self:accept(Lexeme.Type.ReservedFunction) then
return self:parseFunctionBody(line)
elseif self:accept(Lexeme.Type.ReservedIf) then
return self:parseIfExpression(line)
elseif self:accept(Lexeme.Type.Dot3) then
return node("Vararg", nil, line)
end
self:error("Expected expression")
return nil
end
function PureParser:parseArgs()
local args = {}
if self:accept(char("(")) then
if not self:check(char(")")) then
repeat
append(args, self:parseExpression())
until not self:accept(char(","))
end
self:expect(char(")"), "Expected ')' after call arguments")
elseif self:check(Lexeme.Type.QuotedString) or self:check(Lexeme.Type.RawString) or self:check(Lexeme.Type.InterpStringSimple) or self:check(Lexeme.Type.InterpStringEnd) then
append(args, self:parsePrimary())
elseif self:check(char("{")) then
append(args, self:parsePrimary())
else
self:error("Expected call arguments")
end
return args
end
function PureParser:parsePrefix()
local canHaveSuffix = self:check(Lexeme.Type.Name) or self:check(char("("))
local expr = self:parsePrimary()
if not canHaveSuffix then
return expr
end
while true do
local line = self:line()
if self:accept(char(".")) then
expr = node("Field", {
object = expr,
field = self:parseName(),
}, line)
elseif self:accept(char("[")) then
local index = self:parseExpression()
self:expect(char("]"), "Expected ']' after index")
expr = node("Index", {
object = expr,
index = index,
}, line)
elseif self:accept(char(":")) then
local method = self:parseName()
self:skipExplicitTypeArguments()
expr = node("MethodCall", {
object = expr,
method = method,
args = self:parseArgs(),
}, line)
elseif self:skipExplicitTypeArguments() then
expr = node("Instantiate", {
expr = expr,
}, line)
elseif self:check(char("(")) or self:check(Lexeme.Type.QuotedString) or self:check(Lexeme.Type.RawString) or self:check(Lexeme.Type.InterpStringSimple) or self:check(Lexeme.Type.InterpStringEnd) or self:check(char("{")) then
expr = node("Call", {
callee = expr,
args = self:parseArgs(),
}, line)
else
break
end
end
return expr
end
function PureParser:parseUnary()
local line = self:line()
if self:accept(Lexeme.Type.ReservedNot) then
local expr = self:parseSubExpression(7)
return node("Un", { op = "not", expr = expr, endLine = expr.endLine or expr.line }, line)
elseif self:accept(char("-")) then
local expr = self:parseSubExpression(7)
return node("Un", { op = "-", expr = expr, endLine = expr.endLine or expr.line }, line)
elseif self:accept(char("#")) then
local expr = self:parseSubExpression(7)
return node("Un", { op = "#", expr = expr, endLine = expr.endLine or expr.line }, line)
end
return self:parsePrefix()
end
function PureParser:parseSubExpression(limit)
local left = self:parseUnary()
while true do
local info = binaryInfo[self.current.type]
if not info or info[2] < limit then
break
end
local op = info[1]
local precedence = info[2]
local rightAssoc = info[3]
local line = left.line or self:line()
local opLine = self:line()
self:advance()
local right = self:parseSubExpression(precedence + (rightAssoc and 0 or 1))
left = node("Bin", {
op = op,
left = left,
right = right,
opLine = opLine,
endLine = right.endLine or right.line,
}, line)
end
return left
end
function PureParser:parseExpression()
return self:parseSubExpression(1)
end
function PureParser:parseExpressionList()
local values = {
self:parseExpression(),
}
while self:accept(char(",")) do
append(values, self:parseExpression())
end
return values
end
function PureParser:parseFunctionName()
local expr = node("Name", { name = self:parseName() }, self:line())
while self:accept(char(".")) do
expr = node("Field", {
object = expr,
field = self:parseName(),
}, self:line())
end
local selfName = nil
if self:accept(char(":")) then
selfName = self:parseName()
expr = node("Field", {
object = expr,
field = selfName,
}, self:line())
end
return expr, selfName
end
function PureParser:parseStatement()
local attributes = self:skipAttributes()
local line = attributes and attributes.line or self:line()
if self:accept(char(";")) then
return node("Nop", nil, line)
elseif self:accept(Lexeme.Type.ReservedLocal) then
if self:accept(Lexeme.Type.ReservedFunction) then
local name = self:parseName()
local func = self:parseFunctionBody(line, attributes)
return node("LocalFunction", {
name = name,
value = func,
}, line)
end
local names = {
self:parseName("Expected identifier when parsing variable name"),
}
self:skipTypeAnnotation()
while self:accept(char(",")) do
append(names, self:parseName("Expected identifier when parsing variable name"))
self:skipTypeAnnotation()
end
local values = {}
if self:accept(char("=")) then
values = self:parseExpressionList()
end
return node("Local", {
names = names,
values = values,
}, line)
elseif self:accept(Lexeme.Type.ReservedReturn) then
local values = {}
if not self:check(Lexeme.Type.Eof) and not self:check(Lexeme.Type.ReservedEnd) and not self:check(Lexeme.Type.ReservedElse) and not self:check(Lexeme.Type.ReservedElseif) and not self:check(Lexeme.Type.ReservedUntil) and not self:check(char(";")) then
values = self:parseExpressionList()
end
self:accept(char(";"))
return node("Return", { values = values }, line)
elseif self:accept(Lexeme.Type.ReservedDo) then
local body = self:parseBlock({
[Lexeme.Type.ReservedEnd] = true,
})
self:expect(Lexeme.Type.ReservedEnd, "Expected 'end' after block")
return node("Do", { body = body }, line)
elseif self:accept(Lexeme.Type.ReservedIf) then
local clauses = {}
local condition = self:parseExpression()
self:expect(Lexeme.Type.ReservedThen, "Expected 'then' after if condition")
append(clauses, {
condition = condition,
body = self:parseBlock({
[Lexeme.Type.ReservedElseif] = true,
[Lexeme.Type.ReservedElse] = true,
[Lexeme.Type.ReservedEnd] = true,
}),
})
while self:accept(Lexeme.Type.ReservedElseif) do
condition = self:parseExpression()
self:expect(Lexeme.Type.ReservedThen, "Expected 'then' after elseif condition")
append(clauses, {
condition = condition,
body = self:parseBlock({
[Lexeme.Type.ReservedElseif] = true,
[Lexeme.Type.ReservedElse] = true,
[Lexeme.Type.ReservedEnd] = true,
}),
})
end
local elseBody = nil
if self:accept(Lexeme.Type.ReservedElse) then
elseBody = self:parseBlock({
[Lexeme.Type.ReservedEnd] = true,
})
end
self:expect(Lexeme.Type.ReservedEnd, "Expected 'end' after if statement")
return node("If", {
clauses = clauses,
elseBody = elseBody,
}, line)
elseif self:accept(Lexeme.Type.ReservedWhile) then
local condition = self:parseExpression()
self:expect(Lexeme.Type.ReservedDo, "Expected 'do' after while condition")
local body = self:parseBlock({
[Lexeme.Type.ReservedEnd] = true,
})
self:expect(Lexeme.Type.ReservedEnd, "Expected 'end' after while body")
return node("While", {
condition = condition,
body = body,
}, line)
elseif self:accept(Lexeme.Type.ReservedFor) then
local names = {
self:parseName("Expected identifier when parsing variable name"),
}
if self:accept(char("=")) then
local from = self:parseExpression()
self:expect(char(","), "Expected ',' after for loop start")
local to = self:parseExpression()
local step = nil
if self:accept(char(",")) then
step = self:parseExpression()
end
self:expect(Lexeme.Type.ReservedDo, "Expected 'do' after for loop range")
local body = self:parseBlock({
[Lexeme.Type.ReservedEnd] = true,
})
self:expect(Lexeme.Type.ReservedEnd, "Expected 'end' after for loop body")
return node("ForNumeric", {
name = names[1],
from = from,
to = to,
step = step,
body = body,
}, line)
end
if self:accept(char(",")) then
repeat
append(names, self:parseName("Expected identifier when parsing variable name"))
until not self:accept(char(","))
end
self:expect(Lexeme.Type.ReservedIn, "Expected 'in' after for loop variables")
local values = self:parseExpressionList()
self:expect(Lexeme.Type.ReservedDo, "Expected 'do' after for loop iterator")
local body = self:parseBlock({
[Lexeme.Type.ReservedEnd] = true,
})
self:expect(Lexeme.Type.ReservedEnd, "Expected 'end' after for loop body")
return node("ForIn", {
names = names,
values = values,
body = body,
}, line)
elseif self:accept(Lexeme.Type.ReservedRepeat) then
local body = self:parseBlock({
[Lexeme.Type.ReservedUntil] = true,
})
self:expect(Lexeme.Type.ReservedUntil, "Expected 'until' after repeat body")
return node("Repeat", {
body = body,
condition = self:parseExpression(),
}, line)
elseif self:accept(Lexeme.Type.ReservedFunction) then
local target, selfName = self:parseFunctionName()
local func = self:parseFunctionBody(line, attributes)
if selfName then
table.insert(func.params, 1, "self")
end
return node("FunctionStat", {
target = target,
value = func,
}, line)
elseif self:accept(Lexeme.Type.ReservedBreak) then
return node("Break", nil, line)
elseif self:accept(Lexeme.Type.DoubleColon) then
self:parseName("Expected label name")
self:expect(Lexeme.Type.DoubleColon, "Expected '::' after label")
return node("Nop", nil, line)
end
local first = self:parsePrefix()
if first.kind == "Call" or first.kind == "MethodCall" then
return node("CallStat", { expr = first }, line)
end
local targets = { first }
while self:accept(char(",")) do
append(targets, self:parsePrefix())
end
local assignOp = nil
if self:accept(char("=")) then
assignOp = "="
elseif self:accept(Lexeme.Type.AddAssign) then
assignOp = "+"
elseif self:accept(Lexeme.Type.SubAssign) then
assignOp = "-"
elseif self:accept(Lexeme.Type.MulAssign) then
assignOp = "*"
elseif self:accept(Lexeme.Type.DivAssign) then
assignOp = "/"
elseif self:accept(Lexeme.Type.ModAssign) then
assignOp = "%"
elseif self:accept(Lexeme.Type.PowAssign) then
assignOp = "^"
elseif self:accept(Lexeme.Type.FloorDivAssign) then
assignOp = "//"
elseif self:accept(Lexeme.Type.ConcatAssign) then
assignOp = ".."
else
if first.kind == "Name" and first.name == "type" then
return self:parseTypeDeclaration(line)
elseif first.kind == "Name" and first.name == "export" and self:isName("type") then
self:advance()
return self:parseTypeDeclaration(line)
elseif first.kind == "Name" and first.name == "continue" then
return node("Continue", nil, line)
end
self:error("Expected assignment or function call")
end
return node("Assign", {
targets = targets,
values = self:parseExpressionList(),
op = assignOp,
}, line)
end
function PureParser:parseBlock(stop)
local body = {}
while not self:check(Lexeme.Type.Eof) and not stop[self.current.type] do
append(body, self:parseStatement())
end
return node("Block", { body = body }, self:line())
end
function PureParser:parse()
local block = self:parseBlock({})
self:expect(Lexeme.Type.Eof, "Expected end of file")
return block
end
local PureCompiler = {}
PureCompiler.__index = PureCompiler
local LPF_NATIVE_COLD = 2
local LPF_NATIVE_FUNCTION = 4
local arithmeticOps = {
["+"] = LOP.ADD,
["-"] = LOP.SUB,
["*"] = LOP.MUL,
["/"] = LOP.DIV,
["%"] = LOP.MOD,
["^"] = LOP.POW,
["//"] = LOP.IDIV,
}
local arithmeticKOps = {
["+"] = LOP.ADDK,
["-"] = LOP.SUBK,
["*"] = LOP.MULK,
["/"] = LOP.DIVK,
["%"] = LOP.MODK,
["^"] = LOP.POWK,
["//"] = LOP.IDIVK,
}
local compareOps = {
["=="] = true,
["~="] = true,
["<"] = true,
["<="] = true,
[">"] = true,
[">="] = true,
}
function PureCompiler.new(builder)
return setmetatable({
builder = builder,
getfenvUsed = false,
setfenvUsed = false,
hasNativeFunction = false,
compilingFunctionStack = {},
}, PureCompiler)
end
function PureCompiler:newContext(func, parent)
return {
compiler = self,
builder = self.builder,
func = func,
parent = parent,
locals = {},
localList = {},
upvalues = {},
upvalueMap = {},
writeNames = {},
readNames = {},
globalWrites = parent and parent.globalWrites or {},
functionDepth = parent and (parent.functionDepth + 1) or 0,
loopDepth = 0,
tableArrayHints = {},
nextReg = 0,
maxReg = 0,
loopBreaks = {},
loopContinues = {},
hasLoops = false,
protectedTop = 0,
loopCloseStarts = {},
loopKinds = {},
debugScopeDepth = 0,
debugLocalOrder = 0,
debugConstantUpvalueMap = {},
debugConstantUpvalues = {},
}
end
function PureCompiler:useReg(ctx, reg)
ctx.maxReg = math.max(ctx.maxReg, reg + 1)
ctx.func.maxstacksize = math.max(ctx.func.maxstacksize, ctx.maxReg)
if ctx.func.maxstacksize > 254 then
error("register limit exceeded", 2)
end
end
function PureCompiler:reserve(ctx, count)
local reg = ctx.nextReg
ctx.nextReg += count
self:useReg(ctx, ctx.nextReg - 1)
return reg
end
function PureCompiler:addLocal(ctx, name, written)
local reg = self:reserve(ctx, 1)
local localInfo = {
name = name,
reg = reg,
written = written == true or ctx.writeNames and ctx.writeNames[name] == true or false,
}
return self:declareLocal(ctx, localInfo)
end
function PureCompiler:addLocalAt(ctx, name, reg, written)
local localInfo = {
name = name,
reg = reg,
written = written == true or ctx.writeNames and ctx.writeNames[name] == true or false,
}
self:useReg(ctx, reg)
return self:declareLocal(ctx, localInfo)
end
function PureCompiler:findLocal(ctx, name)
return ctx.locals[name]
end
function PureCompiler:collectWriteNamesFromExpr(expr, writes)
if expr.kind == "Field" then
self:collectWriteNamesFromExpr(expr.object, writes)
elseif expr.kind == "Index" then
self:collectWriteNamesFromExpr(expr.object, writes)
self:collectWriteNamesFromExpr(expr.index, writes)
elseif expr.kind == "Call" then
self:collectWriteNamesFromExpr(expr.callee, writes)
for _, arg in ipairs(expr.args) do
self:collectWriteNamesFromExpr(arg, writes)
end
elseif expr.kind == "MethodCall" then
self:collectWriteNamesFromExpr(expr.object, writes)
for _, arg in ipairs(expr.args) do
self:collectWriteNamesFromExpr(arg, writes)
end
elseif expr.kind == "Table" then
for _, entry in ipairs(expr.entries) do
if entry.key then
self:collectWriteNamesFromExpr(entry.key, writes)
end
if entry.value then
self:collectWriteNamesFromExpr(entry.value, writes)
end
end
elseif expr.kind == "Un" then
self:collectWriteNamesFromExpr(expr.expr, writes)
elseif expr.kind == "Bin" then
self:collectWriteNamesFromExpr(expr.left, writes)
self:collectWriteNamesFromExpr(expr.right, writes)
elseif expr.kind == "SingleResult" then
self:collectWriteNamesFromExpr(expr.expr, writes)
elseif expr.kind == "InterpString" then
for _, value in ipairs(expr.expressions) do
self:collectWriteNamesFromExpr(value, writes)
end
elseif expr.kind == "Function" then
self:collectWriteNames(expr.body, writes)
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
self:collectWriteNamesFromExpr(clause.condition, writes)
self:collectWriteNamesFromExpr(clause.value, writes)
end
self:collectWriteNamesFromExpr(expr.elseValue, writes)
end
end
function PureCompiler:collectWriteNames(block, writes)
writes = writes or {}
for _, stat in ipairs(block.body) do
if stat.kind == "Local" then
for _, value in ipairs(stat.values) do
self:collectWriteNamesFromExpr(value, writes)
end
elseif stat.kind == "LocalFunction" then
self:collectWriteNames(stat.value.body, writes)
elseif stat.kind == "Assign" then
for _, target in ipairs(stat.targets) do
if target.kind == "Name" then
writes[target.name] = true
else
self:collectWriteNamesFromExpr(target, writes)
end
end
for _, value in ipairs(stat.values) do
self:collectWriteNamesFromExpr(value, writes)
end
elseif stat.kind == "CallStat" then
self:collectWriteNamesFromExpr(stat.expr, writes)
elseif stat.kind == "Return" then
for _, value in ipairs(stat.values) do
self:collectWriteNamesFromExpr(value, writes)
end
elseif stat.kind == "Do" then
self:collectWriteNames(stat.body, writes)
elseif stat.kind == "If" then
for _, clause in ipairs(stat.clauses) do
self:collectWriteNamesFromExpr(clause.condition, writes)
self:collectWriteNames(clause.body, writes)
end
if stat.elseBody then
self:collectWriteNames(stat.elseBody, writes)
end
elseif stat.kind == "While" then
self:collectWriteNamesFromExpr(stat.condition, writes)
self:collectWriteNames(stat.body, writes)
elseif stat.kind == "Repeat" then
self:collectWriteNames(stat.body, writes)
self:collectWriteNamesFromExpr(stat.condition, writes)
elseif stat.kind == "ForNumeric" then
self:collectWriteNamesFromExpr(stat.from, writes)
self:collectWriteNamesFromExpr(stat.to, writes)
if stat.step then
self:collectWriteNamesFromExpr(stat.step, writes)
end
self:collectWriteNames(stat.body, writes)
elseif stat.kind == "ForIn" then
for _, value in ipairs(stat.values) do
self:collectWriteNamesFromExpr(value, writes)
end
self:collectWriteNames(stat.body, writes)
elseif stat.kind == "FunctionStat" then
if stat.target.kind == "Name" then
writes[stat.target.name] = true
else
self:collectWriteNamesFromExpr(stat.target, writes)
end
self:collectWriteNames(stat.value.body, writes)
end
end
return writes
end
function PureCompiler:newWriteAnalysisScope(parent, functionDepth, loopDepth)
return {
parent = parent,
locals = {},
functionDepth = functionDepth or (parent and parent.functionDepth or 0),
loopDepth = loopDepth or (parent and parent.loopDepth or 0),
}
end
function PureCompiler:resolveWriteSymbol(scope, name)
local current = scope
while current do
local symbol = current.locals[name]
if symbol then
return symbol
end
current = current.parent
end
return nil
end
function PureCompiler:declareWriteSymbol(scope, name)
local symbol = {
name = name,
written = false,
functionDepth = scope.functionDepth or 0,
loopDepth = scope.loopDepth or 0,
}
scope.locals[name] = symbol
return symbol
end
function PureCompiler:markWriteName(scope, globals, name)
local symbol = self:resolveWriteSymbol(scope, name)
if symbol then
symbol.written = true
else
globals[name] = true
end
end
function PureCompiler:analyzeWriteTarget(target, scope, globals)
if target.kind == "Name" then
target.localSymbol = self:resolveWriteSymbol(scope, target.name)
self:markWriteName(scope, globals, target.name)
elseif target.kind == "Field" then
self:analyzeWriteExpr(target.object, scope, globals)
elseif target.kind == "Index" then
self:analyzeWriteExpr(target.object, scope, globals)
self:analyzeWriteExpr(target.index, scope, globals)
end
end
function PureCompiler:analyzeWriteExpr(expr, scope, globals)
if expr.kind == "Name" then
expr.localSymbol = self:resolveWriteSymbol(scope, expr.name)
elseif expr.kind == "Field" then
self:analyzeWriteExpr(expr.object, scope, globals)
elseif expr.kind == "Index" then
self:analyzeWriteExpr(expr.object, scope, globals)
self:analyzeWriteExpr(expr.index, scope, globals)
elseif expr.kind == "Call" then
self:analyzeWriteExpr(expr.callee, scope, globals)
for _, arg in ipairs(expr.args) do
self:analyzeWriteExpr(arg, scope, globals)
end
elseif expr.kind == "MethodCall" then
self:analyzeWriteExpr(expr.object, scope, globals)
for _, arg in ipairs(expr.args) do
self:analyzeWriteExpr(arg, scope, globals)
end
elseif expr.kind == "Table" then
for _, entry in ipairs(expr.entries) do
if entry.key then
self:analyzeWriteExpr(entry.key, scope, globals)
end
if entry.value then
self:analyzeWriteExpr(entry.value, scope, globals)
end
end
elseif expr.kind == "Un" then
self:analyzeWriteExpr(expr.expr, scope, globals)
elseif expr.kind == "Bin" then
self:analyzeWriteExpr(expr.left, scope, globals)
self:analyzeWriteExpr(expr.right, scope, globals)
elseif expr.kind == "SingleResult" then
self:analyzeWriteExpr(expr.expr, scope, globals)
elseif expr.kind == "Instantiate" then
self:analyzeWriteExpr(expr.expr, scope, globals)
elseif expr.kind == "InterpString" then
for _, value in ipairs(expr.expressions) do
self:analyzeWriteExpr(value, scope, globals)
end
elseif expr.kind == "Function" then
self:analyzeWriteFunction(expr, scope, globals)
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
self:analyzeWriteExpr(clause.condition, scope, globals)
self:analyzeWriteExpr(clause.value, scope, globals)
end
self:analyzeWriteExpr(expr.elseValue, scope, globals)
end
end
function PureCompiler:analyzeWriteFunction(expr, parentScope, globals)
expr.functionDepth = (parentScope and parentScope.functionDepth or 0) + 1
local scope = self:newWriteAnalysisScope(parentScope, expr.functionDepth, 0)
expr.paramSymbols = {}
for index, name in ipairs(expr.params) do
local symbol = self:declareWriteSymbol(scope, name)
expr.paramSymbols[index] = symbol
end
self:analyzeWriteBlock(expr.body, scope, globals)
end
function PureCompiler:analyzeWriteBlock(block, scope, globals)
for _, stat in ipairs(block.body) do
if stat.kind == "Local" then
for _, value in ipairs(stat.values) do
self:analyzeWriteExpr(value, scope, globals)
end
stat.localSymbols = {}
for index, name in ipairs(stat.names) do
local symbol = self:declareWriteSymbol(scope, name)
stat.localSymbols[index] = symbol
end
elseif stat.kind == "LocalFunction" then
local symbol = self:declareWriteSymbol(scope, stat.name)
stat.localSymbol = symbol
self:analyzeWriteFunction(stat.value, scope, globals)
elseif stat.kind == "Assign" then
for _, target in ipairs(stat.targets) do
self:analyzeWriteTarget(target, scope, globals)
end
for _, value in ipairs(stat.values) do
self:analyzeWriteExpr(value, scope, globals)
end
elseif stat.kind == "CallStat" then
self:analyzeWriteExpr(stat.expr, scope, globals)
elseif stat.kind == "Return" then
for _, value in ipairs(stat.values) do
self:analyzeWriteExpr(value, scope, globals)
end
elseif stat.kind == "Do" then
self:analyzeWriteBlock(stat.body, self:newWriteAnalysisScope(scope), globals)
elseif stat.kind == "If" then
for _, clause in ipairs(stat.clauses) do
self:analyzeWriteExpr(clause.condition, scope, globals)
self:analyzeWriteBlock(clause.body, self:newWriteAnalysisScope(scope), globals)
end
if stat.elseBody then
self:analyzeWriteBlock(stat.elseBody, self:newWriteAnalysisScope(scope), globals)
end
elseif stat.kind == "While" then
self:analyzeWriteExpr(stat.condition, scope, globals)
self:analyzeWriteBlock(stat.body, self:newWriteAnalysisScope(scope, nil, (scope.loopDepth or 0) + 1), globals)
elseif stat.kind == "Repeat" then
self:analyzeWriteBlock(stat.body, self:newWriteAnalysisScope(scope, nil, (scope.loopDepth or 0) + 1), globals)
self:analyzeWriteExpr(stat.condition, scope, globals)
elseif stat.kind == "ForNumeric" then
self:analyzeWriteExpr(stat.from, scope, globals)
self:analyzeWriteExpr(stat.to, scope, globals)
if stat.step then
self:analyzeWriteExpr(stat.step, scope, globals)
end
local loopScope = self:newWriteAnalysisScope(scope, nil, (scope.loopDepth or 0) + 1)
stat.localSymbol = self:declareWriteSymbol(loopScope, stat.name)
self:analyzeWriteBlock(stat.body, loopScope, globals)
elseif stat.kind == "ForIn" then
for _, value in ipairs(stat.values) do
self:analyzeWriteExpr(value, scope, globals)
end
local loopScope = self:newWriteAnalysisScope(scope, nil, (scope.loopDepth or 0) + 1)
stat.localSymbols = {}
for index, name in ipairs(stat.names) do
stat.localSymbols[index] = self:declareWriteSymbol(loopScope, name)
end
self:analyzeWriteBlock(stat.body, loopScope, globals)
elseif stat.kind == "FunctionStat" then
self:analyzeWriteTarget(stat.target, scope, globals)
self:analyzeWriteFunction(stat.value, scope, globals)
end
end
end
function PureCompiler:annotateWriteSymbols(block)
local globals = {}
self:analyzeWriteBlock(block, self:newWriteAnalysisScope(nil), globals)
return globals
end
function PureCompiler:collectReadNamesFromExpr(expr, reads)
if expr.kind == "Name" then
reads[expr.name] = true
elseif expr.kind == "Field" then
self:collectReadNamesFromExpr(expr.object, reads)
elseif expr.kind == "Index" then
self:collectReadNamesFromExpr(expr.object, reads)
self:collectReadNamesFromExpr(expr.index, reads)
elseif expr.kind == "Call" then
self:collectReadNamesFromExpr(expr.callee, reads)
for _, arg in ipairs(expr.args) do
self:collectReadNamesFromExpr(arg, reads)
end
elseif expr.kind == "MethodCall" then
self:collectReadNamesFromExpr(expr.object, reads)
for _, arg in ipairs(expr.args) do
self:collectReadNamesFromExpr(arg, reads)
end
elseif expr.kind == "Table" then
for _, entry in ipairs(expr.entries) do
if entry.key then
self:collectReadNamesFromExpr(entry.key, reads)
end
if entry.value then
self:collectReadNamesFromExpr(entry.value, reads)
end
end
elseif expr.kind == "Un" then
self:collectReadNamesFromExpr(expr.expr, reads)
elseif expr.kind == "Bin" then
self:collectReadNamesFromExpr(expr.left, reads)
self:collectReadNamesFromExpr(expr.right, reads)
elseif expr.kind == "SingleResult" then
self:collectReadNamesFromExpr(expr.expr, reads)
elseif expr.kind == "Instantiate" then
self:collectReadNamesFromExpr(expr.expr, reads)
elseif expr.kind == "InterpString" then
for _, value in ipairs(expr.expressions) do
self:collectReadNamesFromExpr(value, reads)
end
elseif expr.kind == "Function" then
self:collectReadNames(expr.body, reads)
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
self:collectReadNamesFromExpr(clause.condition, reads)
self:collectReadNamesFromExpr(clause.value, reads)
end
self:collectReadNamesFromExpr(expr.elseValue, reads)
end
end
function PureCompiler:collectReadNames(block, reads)
reads = reads or {}
for _, stat in ipairs(block.body) do
if stat.kind == "Local" then
for _, value in ipairs(stat.values) do
self:collectReadNamesFromExpr(value, reads)
end
elseif stat.kind == "LocalFunction" then
self:collectReadNames(stat.value.body, reads)
elseif stat.kind == "Assign" then
for _, target in ipairs(stat.targets) do
if target.kind ~= "Name" then
self:collectReadNamesFromExpr(target, reads)
end
end
for _, value in ipairs(stat.values) do
self:collectReadNamesFromExpr(value, reads)
end
elseif stat.kind == "CallStat" then
self:collectReadNamesFromExpr(stat.expr, reads)
elseif stat.kind == "Return" then
for _, value in ipairs(stat.values) do
self:collectReadNamesFromExpr(value, reads)
end
elseif stat.kind == "Do" or stat.kind == "While" or stat.kind == "Repeat" then
if stat.condition then
self:collectReadNamesFromExpr(stat.condition, reads)
end
self:collectReadNames(stat.body, reads)
elseif stat.kind == "If" then
for _, clause in ipairs(stat.clauses) do
self:collectReadNamesFromExpr(clause.condition, reads)
self:collectReadNames(clause.body, reads)
end
if stat.elseBody then
self:collectReadNames(stat.elseBody, reads)
end
elseif stat.kind == "ForNumeric" then
self:collectReadNamesFromExpr(stat.from, reads)
self:collectReadNamesFromExpr(stat.to, reads)
if stat.step then
self:collectReadNamesFromExpr(stat.step, reads)
end
self:collectReadNames(stat.body, reads)
elseif stat.kind == "ForIn" then
for _, value in ipairs(stat.values) do
self:collectReadNamesFromExpr(value, reads)
end
self:collectReadNames(stat.body, reads)
elseif stat.kind == "FunctionStat" then
self:collectReadNamesFromExpr(stat.target, reads)
self:collectReadNames(stat.value.body, reads)
end
end
return reads
end
function PureCompiler:collectFenvUsesFromExpr(expr)
if expr.kind == "Name" then
if expr.localSymbol == nil then
if expr.name == "getfenv" then
self.getfenvUsed = true
elseif expr.name == "setfenv" then
self.setfenvUsed = true
end
end
elseif expr.kind == "Field" then
self:collectFenvUsesFromExpr(expr.object)
elseif expr.kind == "Index" then
self:collectFenvUsesFromExpr(expr.object)
self:collectFenvUsesFromExpr(expr.index)
elseif expr.kind == "Call" then
self:collectFenvUsesFromExpr(expr.callee)
for _, arg in ipairs(expr.args) do
self:collectFenvUsesFromExpr(arg)
end
elseif expr.kind == "MethodCall" then
self:collectFenvUsesFromExpr(expr.object)
for _, arg in ipairs(expr.args) do
self:collectFenvUsesFromExpr(arg)
end
elseif expr.kind == "Table" then
for _, entry in ipairs(expr.entries) do
if entry.key then
self:collectFenvUsesFromExpr(entry.key)
end
if entry.value then
self:collectFenvUsesFromExpr(entry.value)
end
end
elseif expr.kind == "Un" then
self:collectFenvUsesFromExpr(expr.expr)
elseif expr.kind == "Bin" then
self:collectFenvUsesFromExpr(expr.left)
self:collectFenvUsesFromExpr(expr.right)
elseif expr.kind == "SingleResult" then
self:collectFenvUsesFromExpr(expr.expr)
elseif expr.kind == "InterpString" then
for _, value in ipairs(expr.expressions) do
self:collectFenvUsesFromExpr(value)
end
elseif expr.kind == "Function" then
self:collectFenvUses(expr.body)
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
self:collectFenvUsesFromExpr(clause.condition)
self:collectFenvUsesFromExpr(clause.value)
end
self:collectFenvUsesFromExpr(expr.elseValue)
end
end
function PureCompiler:collectFenvUses(block)
for _, stat in ipairs(block.body) do
if stat.kind == "Local" then
for _, value in ipairs(stat.values) do
self:collectFenvUsesFromExpr(value)
end
elseif stat.kind == "LocalFunction" then
self:collectFenvUses(stat.value.body)
elseif stat.kind == "Assign" then
for _, target in ipairs(stat.targets) do
if target.kind ~= "Name" then
self:collectFenvUsesFromExpr(target)
end
end
for _, value in ipairs(stat.values) do
self:collectFenvUsesFromExpr(value)
end
elseif stat.kind == "CallStat" then
self:collectFenvUsesFromExpr(stat.expr)
elseif stat.kind == "Return" then
for _, value in ipairs(stat.values) do
self:collectFenvUsesFromExpr(value)
end
elseif stat.kind == "Do" or stat.kind == "While" or stat.kind == "Repeat" then
if stat.condition then
self:collectFenvUsesFromExpr(stat.condition)
end
self:collectFenvUses(stat.body)
elseif stat.kind == "If" then
for _, clause in ipairs(stat.clauses) do
self:collectFenvUsesFromExpr(clause.condition)
self:collectFenvUses(clause.body)
end
if stat.elseBody then
self:collectFenvUses(stat.elseBody)
end
elseif stat.kind == "ForNumeric" then
self:collectFenvUsesFromExpr(stat.from)
self:collectFenvUsesFromExpr(stat.to)
if stat.step then
self:collectFenvUsesFromExpr(stat.step)
end
self:collectFenvUses(stat.body)
elseif stat.kind == "ForIn" then
for _, value in ipairs(stat.values) do
self:collectFenvUsesFromExpr(value)
end
self:collectFenvUses(stat.body)
elseif stat.kind == "FunctionStat" then
if stat.target.kind ~= "Name" then
self:collectFenvUsesFromExpr(stat.target)
end
self:collectFenvUses(stat.value.body)
end
end
end
function PureCompiler:collectTableArrayHints(block, state)
state = state or {
tables = {},
fields = {},
loops = {},
}
local function getShape(tableExpr)
local shape = state.shapes and state.shapes[tableExpr] or nil
if not shape then
shape = {
arraySize = 0,
hashSize = 0,
}
state.shapes = state.shapes or {}
state.shapes[tableExpr] = shape
end
return shape
end
local function setShape(tableExpr, shape)
tableExpr.predictedArrayCount = shape.arraySize
tableExpr.predictedHashCount = shape.hashSize
end
local function getTableHint(expr)
if expr.kind == "Table" then
return expr
end
if expr.kind == "Call" and #expr.args == 2 and expr.callee.kind == "Name" and expr.callee.name == "setmetatable" and expr.callee.localSymbol == nil and expr.args[1].kind == "Table" then
return expr.args[1]
end
return nil
end
local function findTrackedTable(expr)
if expr.kind ~= "Name" or expr.localSymbol == nil then
return nil
end
return state.tables[expr.localSymbol]
end
local function noteField(object, field)
local tableExpr = findTrackedTable(object)
if not tableExpr then
return
end
local fieldKey = tostring(tableExpr) .. "\0" .. field
if not state.fields[fieldKey] then
state.fields[fieldKey] = true
local shape = getShape(tableExpr)
shape.hashSize += 1
setShape(tableExpr, shape)
end
end
local function noteIndex(object, index)
local tableExpr = findTrackedTable(object)
if not tableExpr then
return
end
if index.kind == "Number" then
local shape = getShape(tableExpr)
if index.value == shape.arraySize + 1 then
shape.arraySize += 1
setShape(tableExpr, shape)
end
elseif index.kind == "Name" and index.localSymbol ~= nil then
local bound = state.loops[index.localSymbol]
if bound then
local shape = getShape(tableExpr)
if shape.arraySize == 0 then
shape.arraySize = bound
setShape(tableExpr, shape)
end
end
end
end
local function visitExpr(expr)
if expr.kind == "Field" then
visitExpr(expr.object)
elseif expr.kind == "Index" then
visitExpr(expr.object)
visitExpr(expr.index)
elseif expr.kind == "Call" then
visitExpr(expr.callee)
for _, arg in ipairs(expr.args) do
visitExpr(arg)
end
elseif expr.kind == "MethodCall" then
visitExpr(expr.object)
for _, arg in ipairs(expr.args) do
visitExpr(arg)
end
elseif expr.kind == "Table" then
for _, entry in ipairs(expr.entries) do
if entry.key then
visitExpr(entry.key)
end
if entry.value then
visitExpr(entry.value)
end
end
elseif expr.kind == "Un" then
visitExpr(expr.expr)
elseif expr.kind == "Bin" then
visitExpr(expr.left)
visitExpr(expr.right)
elseif expr.kind == "SingleResult" then
visitExpr(expr.expr)
elseif expr.kind == "InterpString" then
for _, value in ipairs(expr.expressions) do
visitExpr(value)
end
elseif expr.kind == "Function" then
self:collectTableArrayHints(expr.body, {
tables = {},
fields = {},
loops = {},
})
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
visitExpr(clause.condition)
visitExpr(clause.value)
end
visitExpr(expr.elseValue)
end
end
for _, stat in ipairs(block.body) do
if stat.kind == "Local" then
if #stat.names == 1 and #stat.values == 1 and stat.localSymbols and stat.localSymbols[1] then
local tableExpr = getTableHint(stat.values[1])
if tableExpr and #tableExpr.entries == 0 then
state.tables[stat.localSymbols[1]] = tableExpr
getShape(tableExpr)
end
end
for _, value in ipairs(stat.values) do
visitExpr(value)
end
elseif stat.kind == "Assign" then
for _, target in ipairs(stat.targets) do
if target.kind == "Field" then
noteField(target.object, target.field)
elseif target.kind == "Index" then
noteIndex(target.object, target.index)
end
end
for _, value in ipairs(stat.values) do
visitExpr(value)
end
elseif stat.kind == "LocalFunction" then
self:collectTableArrayHints(stat.value.body, {
tables = {},
fields = {},
loops = {},
})
elseif stat.kind == "CallStat" then
visitExpr(stat.expr)
elseif stat.kind == "Return" then
for _, value in ipairs(stat.values) do
visitExpr(value)
end
elseif stat.kind == "Do" or stat.kind == "While" or stat.kind == "Repeat" then
if stat.condition then
visitExpr(stat.condition)
end
self:collectTableArrayHints(stat.body, state)
elseif stat.kind == "If" then
for _, clause in ipairs(stat.clauses) do
visitExpr(clause.condition)
self:collectTableArrayHints(clause.body, state)
end
if stat.elseBody then
self:collectTableArrayHints(stat.elseBody, state)
end
elseif stat.kind == "ForNumeric" then
visitExpr(stat.from)
visitExpr(stat.to)
if stat.step then
visitExpr(stat.step)
end
if stat.localSymbol and stat.from.kind == "Number" and stat.to.kind == "Number" and stat.from.value == 1 and stat.to.value >= 1 and stat.to.value <= 16 and stat.step == nil then
state.loops[stat.localSymbol] = stat.to.value
end
self:collectTableArrayHints(stat.body, state)
elseif stat.kind == "ForIn" then
for _, value in ipairs(stat.values) do
visitExpr(value)
end
self:collectTableArrayHints(stat.body, state)
elseif stat.kind == "FunctionStat" then
if stat.target.kind == "Field" then
noteField(stat.target.object, stat.target.field)
elseif stat.target.kind == "Index" then
noteIndex(stat.target.object, stat.target.index)
end
self:collectTableArrayHints(stat.value.body, {
tables = {},
fields = {},
loops = {},
})
end
end
return {}
end
function PureCompiler:findParentConstant(ctx, name)
local parent = ctx.parent
while parent do
local localInfo = parent.locals[name]
if localInfo then
if localInfo.constKind then
return localInfo
end
return nil
end
parent = parent.parent
end
return nil
end
function PureCompiler:findParentConstPath(ctx, name)
local parent = ctx.parent
while parent do
local localInfo = parent.locals[name]
if localInfo then
return localInfo.constPath
end
parent = parent.parent
end
return nil
end
function PureCompiler:hasParentLocal(ctx, name)
local parent = ctx.parent
while parent do
if parent.locals[name] then
return true
end
parent = parent.parent
end
return false
end
function PureCompiler:getUpvalue(ctx, name)
local existing = ctx.upvalueMap[name]
if existing ~= nil then
return existing
end
if not ctx.parent then
return nil
end
local parentLocal = ctx.parent.locals[name]
local sourceKind = nil
local source = nil
if parentLocal and (not parentLocal.constKind or (self.builder.options.debugLevel or 1) >= 2) then
sourceKind = "local"
source = parentLocal
parentLocal.captured = true
if ctx.writeNames and ctx.writeNames[name] then
parentLocal.written = true
end
else
local parentConstant = (self.builder.options.debugLevel or 1) >= 2 and self:findParentConstant(ctx, name) or nil
if parentConstant then
sourceKind = "constant"
source = parentConstant
else
local parentUpvalue = self:getUpvalue(ctx.parent, name)
if parentUpvalue ~= nil then
sourceKind = "upvalue"
source = parentUpvalue
if ctx.writeNames and ctx.writeNames[name] then
local capturedLocal = self:resolveCapturedLocal(ctx.parent, parentUpvalue)
if capturedLocal then
capturedLocal.written = true
end
end
else
return nil
end
end
end
local index = #ctx.upvalues
ctx.upvalueMap[name] = index
append(ctx.upvalues, {
name = name,
sourceKind = sourceKind,
source = source,
})
ctx.func.numupvalues = #ctx.upvalues
return index
end
function PureCompiler:resolveCapturedLocal(ctx, upvalue)
if upvalue.sourceKind == "local" then
return upvalue.source
end
local sourceCtx = ctx
local current = upvalue
while current and current.sourceKind == "upvalue" do
if not sourceCtx then
return nil
end
current = sourceCtx.upvalues[current.source + 1]
sourceCtx = sourceCtx.parent
end
if current and current.sourceKind == "local" then
return current.source
end
return nil
end
function PureCompiler:getLocalTop(ctx)
local top = 0
for _, localInfo in ipairs(ctx.localList) do
if localInfo.reg ~= nil then
top = math.max(top, localInfo.reg + 1)
end
end
return top
end
function PureCompiler:canUseLocalConstant(localInfo)
return localInfo
and localInfo.constKind ~= nil
and not (localInfo.reg ~= nil and (self.builder.options.debugLevel or 1) >= 2)
end
function PureCompiler:getTempReg(ctx, target)
local temp = math.max(target + 1, ctx.protectedTop or 0, ctx.nextReg)
local localTop = self:getLocalTop(ctx)
if localTop > target + 1 then
temp = math.max(temp, localTop)
end
return temp
end
function PureCompiler:compileExprTempTop(ctx, expr, target)
local oldNextReg = ctx.nextReg
ctx.nextReg = target + 1
self:compileExpr(ctx, expr, target, true)
ctx.nextReg = oldNextReg
end
function PureCompiler:isRegisterlessConstant(expr)
return expr == nil or expr.kind == "Nil" or expr.kind == "Bool" or expr.kind == "Number" or expr.kind == "Integer"
end
function PureCompiler:constantTruth(value)
return not (value.kind == "Nil" or value.kind == "Bool" and value.value == false)
end
function PureCompiler:escapeInterpFormat(text)
return string.gsub(text, "%%", "%%%%")
end
function PureCompiler:lowerInterpString(ctx, expr)
local raw = {}
local format = {}
local args = {}
local allowConstantSubexpr = (self.builder.options.optimizationLevel or 1) >= 1
local function appendFormatText(text)
append(format, self:escapeInterpFormat(text))
end
for index, text in ipairs(expr.strings) do
append(raw, text)
appendFormatText(text)
local value = expr.expressions[index]
if value then
local constant = allowConstantSubexpr and self:getConstant(ctx, value) or nil
if constant and constant.kind == "String" then
self:captureParentConstantsForDebug(ctx, value)
append(raw, constant.value)
appendFormatText(constant.value)
else
append(format, "%*")
append(args, node("SingleResult", {
expr = value,
}, value.line or expr.line))
end
end
end
if #args == 0 then
return {
constant = table.concat(raw),
}
end
return {
format = table.concat(format),
args = args,
}
end
local function constantNumberValue(value)
if value and value.kind == "Number" then
return value.value
end
return nil
end
local function constantStringValue(value)
if value and value.kind == "String" then
return value.value
end
return nil
end
local function constantVectorValue(value)
if value and value.kind == "Vector" then
return value.value
end
return nil
end
local function makeVectorConstant(x, y, z, w)
return {
kind = "Vector",
value = {
x = x,
y = y,
z = z,
w = w or 0,
},
}
end
local function vectorConstantsEqual(left, right)
return left.x == right.x and left.y == right.y and left.z == right.z and left.w == right.w
end
local function vectorResultAllowed(hadW, resultW)
return resultW == 0 or hadW
end
local K_PI = 3.14159265358979323846
local K_RAD_DEG = K_PI / 180.0
local K_NAN = 0 / 0
local K_E = 2.71828182845904523536
local K_PHI = 1.61803398874989484820
local K_SQRT2 = 1.41421356237309504880
local K_TAU = 6.28318530717958647692
local function roundNumber(value)
if value ~= value or value == math.huge or value == -math.huge or value == 0 then
return value
end
if value > 0 then
local base = math.floor(value)
return value - base >= 0.5 and base + 1 or base
else
local base = math.ceil(value)
return value - base <= -0.5 and base - 1 or (base == 0 and -0.0 or base)
end
end
local function truncNumber(value)
if value >= 0 then
return math.floor(value)
else
return math.ceil(value)
end
end
local function foldBit32(value)
return bit32.band(truncNumber(value), 0xffffffff)
end
local function foldBit32Mask(width)
if width >= 32 then
return 0xffffffff
end
return 2 ^ width - 1
end
local function foldO2BuiltinConstant(builtinId, args)
local count = #args
local a = constantNumberValue(args[1])
local b = constantNumberValue(args[2])
local c = constantNumberValue(args[3])
if builtinId == LBF.MATH_ABS and count == 1 and a ~= nil then
return { kind = "Number", value = math.abs(a) }
elseif builtinId == LBF.MATH_ACOS and count == 1 and a ~= nil then
return { kind = "Number", value = math.acos(a) }
elseif builtinId == LBF.MATH_ASIN and count == 1 and a ~= nil then
return { kind = "Number", value = math.asin(a) }
elseif builtinId == LBF.MATH_ATAN2 and count == 2 and a ~= nil and b ~= nil then
local atan2 = (math :: any).atan2
local value
if atan2 then
value = atan2(a, b)
elseif b > 0 then
value = math.atan(a / b)
elseif b < 0 and a >= 0 then
value = math.atan(a / b) + math.pi
elseif b < 0 then
value = math.atan(a / b) - math.pi
elseif a > 0 then
value = math.pi / 2
elseif a < 0 then
value = -math.pi / 2
else
value = 0
end
return { kind = "Number", value = value }
elseif builtinId == LBF.MATH_ATAN and count == 1 and a ~= nil then
return { kind = "Number", value = math.atan(a) }
elseif builtinId == LBF.MATH_CEIL and count == 1 and a ~= nil then
return { kind = "Number", value = math.ceil(a) }
elseif builtinId == LBF.MATH_COSH and count == 1 and a ~= nil then
return { kind = "Number", value = math.cosh(a) }
elseif builtinId == LBF.MATH_COS and count == 1 and a ~= nil then
return { kind = "Number", value = math.cos(a) }
elseif builtinId == LBF.MATH_DEG and count == 1 and a ~= nil then
return { kind = "Number", value = a / K_RAD_DEG }
elseif builtinId == LBF.MATH_EXP and count == 1 and a ~= nil then
return { kind = "Number", value = math.exp(a) }
elseif builtinId == LBF.MATH_FLOOR and count == 1 and a ~= nil then
return { kind = "Number", value = math.floor(a) }
elseif builtinId == LBF.MATH_FMOD and count == 2 and a ~= nil and b ~= nil then
return { kind = "Number", value = math.fmod(a, b) }
elseif builtinId == LBF.MATH_LDEXP and count == 2 and a ~= nil and b ~= nil then
return { kind = "Number", value = a * 2 ^ truncNumber(b) }
elseif builtinId == LBF.MATH_LOG10 and count == 1 and a ~= nil then
return { kind = "Number", value = math.log(a, 10) }
elseif builtinId == LBF.MATH_LOG and a ~= nil then
if count == 1 then
return { kind = "Number", value = math.log(a) }
elseif count == 2 and b ~= nil then
if b == 2 then
return { kind = "Number", value = math.log(a, 2) }
elseif b == 10 then
return { kind = "Number", value = math.log(a, 10) }
else
return { kind = "Number", value = math.log(a) / math.log(b) }
end
end
elseif builtinId == LBF.MATH_MAX and count >= 1 and a ~= nil then
local result = a
for index = 2, count do
local value = constantNumberValue(args[index])
if value == nil then
return nil
end
if value > result then
result = value
end
end
return { kind = "Number", value = result }
elseif builtinId == LBF.MATH_MIN and count >= 1 and a ~= nil then
local result = a
for index = 2, count do
local value = constantNumberValue(args[index])
if value == nil then
return nil
end
if value < result then
result = value
end
end
return { kind = "Number", value = result }
elseif builtinId == LBF.MATH_POW and count == 2 and a ~= nil and b ~= nil then
return { kind = "Number", value = a ^ b }
elseif builtinId == LBF.MATH_RAD and count == 1 and a ~= nil then
return { kind = "Number", value = a * K_RAD_DEG }
elseif builtinId == LBF.MATH_SINH and count == 1 and a ~= nil then
return { kind = "Number", value = math.sinh(a) }
elseif builtinId == LBF.MATH_SIN and count == 1 and a ~= nil then
return { kind = "Number", value = math.sin(a) }
elseif builtinId == LBF.MATH_SQRT and count == 1 and a ~= nil then
return { kind = "Number", value = math.sqrt(a) }
elseif builtinId == LBF.MATH_TANH and count == 1 and a ~= nil then
return { kind = "Number", value = math.tanh(a) }
elseif builtinId == LBF.MATH_TAN and count == 1 and a ~= nil then
return { kind = "Number", value = math.tan(a) }
elseif builtinId == LBF.BIT32_ARSHIFT and count == 2 and a ~= nil and b ~= nil then
local shift = truncNumber(b)
if shift >= 0 and shift < 32 then
return { kind = "Number", value = bit32.band(bit32.arshift(foldBit32(a), shift), 0xffffffff) }
end
elseif builtinId == LBF.BIT32_BAND and count >= 1 and a ~= nil then
local result = foldBit32(a)
for index = 2, count do
local value = constantNumberValue(args[index])
if value == nil then
return nil
end
result = bit32.band(result, foldBit32(value))
end
return { kind = "Number", value = result }
elseif builtinId == LBF.BIT32_BNOT and count == 1 and a ~= nil then
return { kind = "Number", value = bit32.band(bit32.bnot(foldBit32(a)), 0xffffffff) }
elseif builtinId == LBF.BIT32_BOR and count >= 1 and a ~= nil then
local result = foldBit32(a)
for index = 2, count do
local value = constantNumberValue(args[index])
if value == nil then
return nil
end
result = bit32.bor(result, foldBit32(value))
end
return { kind = "Number", value = result }
elseif builtinId == LBF.BIT32_BXOR and count >= 1 and a ~= nil then
local result = foldBit32(a)
for index = 2, count do
local value = constantNumberValue(args[index])
if value == nil then
return nil
end
result = bit32.bxor(result, foldBit32(value))
end
return { kind = "Number", value = result }
elseif builtinId == LBF.BIT32_BTEST and count >= 1 and a ~= nil then
local result = foldBit32(a)
for index = 2, count do
local value = constantNumberValue(args[index])
if value == nil then
return nil
end
result = bit32.band(result, foldBit32(value))
end
return { kind = "Bool", value = result ~= 0 }
elseif builtinId == LBF.BIT32_EXTRACT and count >= 2 and a ~= nil and b ~= nil then
local field = truncNumber(b)
local width = 1
if count >= 3 then
if c == nil then
return nil
end
width = truncNumber(c)
end
if field >= 0 and width > 0 and field + width <= 32 then
local mask = foldBit32Mask(width)
return { kind = "Number", value = bit32.band(bit32.rshift(foldBit32(a), field), mask) }
end
elseif builtinId == LBF.BIT32_LROTATE and count == 2 and a ~= nil and b ~= nil then
return { kind = "Number", value = bit32.band(bit32.lrotate(foldBit32(a), truncNumber(b)), 0xffffffff) }
elseif builtinId == LBF.BIT32_LSHIFT and count == 2 and a ~= nil and b ~= nil then
local shift = truncNumber(b)
if shift >= 0 and shift < 32 then
return { kind = "Number", value = bit32.band(bit32.lshift(foldBit32(a), shift), 0xffffffff) }
end
elseif builtinId == LBF.BIT32_REPLACE and count >= 3 and a ~= nil and b ~= nil and c ~= nil then
local field = truncNumber(c)
local width = 1
if count >= 4 then
local d = constantNumberValue(args[4])
if d == nil then
return nil
end
width = truncNumber(d)
end
if field >= 0 and width > 0 and field + width <= 32 then
local mask = foldBit32Mask(width)
local shiftedMask = bit32.lshift(mask, field)
local result = bit32.bor(bit32.band(foldBit32(a), bit32.bnot(shiftedMask)), bit32.lshift(bit32.band(foldBit32(b), mask), field))
return { kind = "Number", value = bit32.band(result, 0xffffffff) }
end
elseif builtinId == LBF.BIT32_RROTATE and count == 2 and a ~= nil and b ~= nil then
return { kind = "Number", value = bit32.band(bit32.rrotate(foldBit32(a), truncNumber(b)), 0xffffffff) }
elseif builtinId == LBF.BIT32_RSHIFT and count == 2 and a ~= nil and b ~= nil then
local shift = truncNumber(b)
if shift >= 0 and shift < 32 then
return { kind = "Number", value = bit32.rshift(foldBit32(a), shift) }
end
elseif builtinId == LBF.STRING_BYTE then
local str = constantStringValue(args[1])
if count == 1 and str ~= nil then
if #str > 0 then
return { kind = "Number", value = string.byte(str, 1) }
end
elseif count == 2 and str ~= nil and b ~= nil then
local index = truncNumber(b)
if index > 0 and index <= #str then
return { kind = "Number", value = string.byte(str, index) }
end
end
elseif builtinId == LBF.STRING_CHAR then
if count < 128 then
local bytes = {}
for index = 1, count do
local value = constantNumberValue(args[index])
if value == nil then
return nil
end
local byte = truncNumber(value)
if byte < 0 or byte > 255 then
return nil
end
bytes[index] = string.char(byte)
end
return { kind = "String", value = table.concat(bytes) }
end
elseif builtinId == LBF.STRING_LEN then
local str = constantStringValue(args[1])
if count == 1 and str ~= nil then
return { kind = "Number", value = #str }
end
elseif builtinId == LBF.STRING_SUB then
local str = constantStringValue(args[1])
if count >= 2 and str ~= nil and b ~= nil then
local len = #str
local start = truncNumber(b)
local finish = len
if count >= 3 then
if c == nil then
return nil
end
finish = truncNumber(c)
end
if start < 0 then
start += len + 1
end
if finish < 0 then
finish += len + 1
end
if finish < 1 then
return { kind = "String", value = "" }
end
if start < 1 then
start = 1
end
if finish > len then
finish = len
end
if start <= finish then
return { kind = "String", value = string.sub(str, start, finish) }
end
return { kind = "String", value = "" }
end
elseif builtinId == LBF.MATH_CLAMP and count == 3 and a ~= nil and b ~= nil and c ~= nil and b <= c then
local value = a
value = value < b and b or value
value = value > c and c or value
return { kind = "Number", value = value }
elseif builtinId == LBF.MATH_SIGN and count == 1 and a ~= nil then
return { kind = "Number", value = a > 0 and 1 or a < 0 and -1 or 0 }
elseif builtinId == LBF.MATH_ROUND and count == 1 and a ~= nil then
return { kind = "Number", value = roundNumber(a) }
elseif builtinId == LBF.MATH_LERP and count == 3 and a ~= nil and b ~= nil and c ~= nil then
return { kind = "Number", value = c == 1 and b or a + (b - a) * c }
elseif builtinId == LBF.MATH_ISNAN and count == 1 and a ~= nil then
return { kind = "Bool", value = a ~= a }
elseif builtinId == LBF.MATH_ISINF and count == 1 and a ~= nil then
return { kind = "Bool", value = a == math.huge or a == -math.huge }
elseif builtinId == LBF.MATH_ISFINITE and count == 1 and a ~= nil then
return { kind = "Bool", value = a == a and a ~= math.huge and a ~= -math.huge }
elseif builtinId == LBF.VECTOR and count >= 2 and a ~= nil and b ~= nil then
if count == 2 then
return makeVectorConstant(a, b, 0, 0)
elseif count == 3 and c ~= nil then
return makeVectorConstant(a, b, c, 0)
elseif count == 4 and c ~= nil then
local d = constantNumberValue(args[4])
if d ~= nil then
return makeVectorConstant(a, b, c, d)
end
end
end
return nil
end
function PureCompiler:getConstant(ctx, expr)
if expr == nil or expr.kind == "Nil" then
return { kind = "Nil" }
elseif expr.kind == "Bool" then
return { kind = "Bool", value = expr.value }
elseif expr.kind == "Number" then
return { kind = "Number", value = expr.value }
elseif expr.kind == "Integer" then
return { kind = "Integer", value = expr.value }
elseif expr.kind == "String" then
return { kind = "String", value = expr.value }
elseif expr.kind == "SingleResult" then
return self:getConstant(ctx, expr.expr)
elseif expr.kind == "Instantiate" then
return self:getConstant(ctx, expr.expr)
elseif expr.kind == "InterpString" then
local lowered = self:lowerInterpString(ctx, expr)
if lowered.constant ~= nil then
return { kind = "String", value = lowered.constant }
end
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
local condition = self:getConstant(ctx, clause.condition)
if not condition then
return nil
end
if self:constantTruth(condition) then
return self:getConstant(ctx, clause.value)
end
end
return self:getConstant(ctx, expr.elseValue)
elseif expr.kind == "Field" then
if (self.builder.options.optimizationLevel or 1) >= 2 and not self.getfenvUsed and not self.setfenvUsed then
local path = self:getImportPath(ctx, expr)
if path and #path == 2 and path[1] == "math" then
if path[2] == "pi" then
return { kind = "Number", value = K_PI }
elseif path[2] == "huge" then
return { kind = "Number", value = math.huge }
elseif path[2] == "nan" then
return { kind = "Number", value = K_NAN }
elseif path[2] == "e" then
return { kind = "Number", value = K_E }
elseif path[2] == "phi" then
return { kind = "Number", value = K_PHI }
elseif path[2] == "sqrt2" then
return { kind = "Number", value = K_SQRT2 }
elseif path[2] == "tau" then
return { kind = "Number", value = K_TAU }
end
end
end
local object = self:getConstant(ctx, expr.object)
if object and object.kind == "Vector" then
if expr.field == "x" or expr.field == "X" then
return { kind = "Number", value = object.value.x }
elseif expr.field == "y" or expr.field == "Y" then
return { kind = "Number", value = object.value.y }
elseif expr.field == "z" or expr.field == "Z" then
return { kind = "Number", value = object.value.z }
end
end
elseif expr.kind == "Call" then
if (self.builder.options.optimizationLevel or 1) >= 2 and not self.getfenvUsed and not self.setfenvUsed then
local builtinId = self:getBuiltinIdForCallee(ctx, expr.callee)
if #expr.args == 1 and (builtinId == LBF.TYPE or builtinId == LBF.TYPEOF) then
local argument = self:getConstant(ctx, expr.args[1])
local typeName = nil
if argument then
if argument.kind == "Nil" then
typeName = "nil"
elseif argument.kind == "Bool" then
typeName = "boolean"
elseif argument.kind == "Number" then
typeName = "number"
elseif argument.kind == "Integer" then
typeName = "integer"
elseif argument.kind == "String" then
typeName = "string"
elseif argument.kind == "Vector" and builtinId == LBF.TYPE then
typeName = "vector"
end
end
if typeName then
return { kind = "String", value = typeName }
end
end
if builtinId then
local args = {}
local allConstant = true
for index, arg in ipairs(expr.args) do
local constant = self:getConstant(ctx, arg)
if constant == nil then
allConstant = false
break
end
args[index] = constant
end
if allConstant then
return foldO2BuiltinConstant(builtinId, args)
end
end
end
elseif expr.kind == "Name" then
local localInfo = self:findLocal(ctx, expr.name)
if not localInfo then
localInfo = self:findParentConstant(ctx, expr.name)
end
if localInfo and localInfo.constKind then
return {
kind = localInfo.constKind,
value = localInfo.constValue,
}
end
elseif expr.kind == "Un" then
local value = self:getConstant(ctx, expr.expr)
if value then
if expr.op == "not" then
return { kind = "Bool", value = not self:constantTruth(value) }
elseif expr.op == "-" and value.kind == "Number" then
return { kind = "Number", value = -value.value }
elseif expr.op == "-" and value.kind == "Vector" then
return makeVectorConstant(-value.value.x, -value.value.y, -value.value.z, -value.value.w)
elseif expr.op == "#" and value.kind == "String" then
return { kind = "Number", value = #value.value }
end
end
elseif expr.kind == "Bin" then
local left = self:getConstant(ctx, expr.left)
if expr.op == "and" then
if left and not self:constantTruth(left) then
return left
end
local right = self:getConstant(ctx, expr.right)
return left and right or nil
elseif expr.op == "or" then
if left and self:constantTruth(left) then
return left
end
local right = self:getConstant(ctx, expr.right)
return left and right or nil
end
local right = self:getConstant(ctx, expr.right)
if not left or not right then
return nil
end
if expr.op == "==" or expr.op == "~=" then
local equal = false
if left.kind == right.kind then
if left.kind == "Integer" then
equal = left.value.key == right.value.key
elseif left.kind == "Vector" then
equal = vectorConstantsEqual(left.value, right.value)
else
equal = left.value == right.value
end
end
local result = equal
if expr.op == "~=" then
result = not equal
end
return { kind = "Bool", value = result }
elseif expr.op == "<" or expr.op == "<=" or expr.op == ">" or expr.op == ">=" then
if left.kind == right.kind and left.kind == "Number" then
local result
if expr.op == "<" then
result = left.value < right.value
elseif expr.op == "<=" then
result = left.value <= right.value
elseif expr.op == ">" then
result = left.value > right.value
else
result = left.value >= right.value
end
return { kind = "Bool", value = result }
end
elseif expr.op == ".." and left.kind == "String" and right.kind == "String" then
return { kind = "String", value = left.value .. right.value }
elseif left.kind == "Number" and right.kind == "Number" then
if expr.op == "+" then
return { kind = "Number", value = left.value + right.value }
elseif expr.op == "-" then
return { kind = "Number", value = left.value - right.value }
elseif expr.op == "*" then
return { kind = "Number", value = left.value * right.value }
elseif expr.op == "/" then
return { kind = "Number", value = left.value / right.value }
elseif expr.op == "%" then
return { kind = "Number", value = left.value % right.value }
elseif expr.op == "^" then
return { kind = "Number", value = left.value ^ right.value }
end
else
local leftVector = constantVectorValue(left)
local rightVector = constantVectorValue(right)
if leftVector and rightVector then
local hadW = leftVector.w ~= 0 or rightVector.w ~= 0
if expr.op == "+" then
return makeVectorConstant(leftVector.x + rightVector.x, leftVector.y + rightVector.y, leftVector.z + rightVector.z, leftVector.w + rightVector.w)
elseif expr.op == "-" then
return makeVectorConstant(leftVector.x - rightVector.x, leftVector.y - rightVector.y, leftVector.z - rightVector.z, leftVector.w - rightVector.w)
elseif expr.op == "*" then
local w = leftVector.w * rightVector.w
if vectorResultAllowed(hadW, w) then
return makeVectorConstant(leftVector.x * rightVector.x, leftVector.y * rightVector.y, leftVector.z * rightVector.z, w)
end
elseif expr.op == "/" then
local w = leftVector.w / rightVector.w
if vectorResultAllowed(hadW, w) then
return makeVectorConstant(leftVector.x / rightVector.x, leftVector.y / rightVector.y, leftVector.z / rightVector.z, w)
end
elseif expr.op == "//" then
local w = math.floor(leftVector.w / rightVector.w)
if vectorResultAllowed(hadW, w) then
return makeVectorConstant(
math.floor(leftVector.x / rightVector.x),
math.floor(leftVector.y / rightVector.y),
math.floor(leftVector.z / rightVector.z),
w
)
end
end
elseif leftVector and right.kind == "Number" then
local hadW = leftVector.w ~= 0
if expr.op == "*" then
local w = leftVector.w * right.value
if vectorResultAllowed(hadW, w) then
return makeVectorConstant(leftVector.x * right.value, leftVector.y * right.value, leftVector.z * right.value, w)
end
elseif expr.op == "/" then
local w = leftVector.w / right.value
if vectorResultAllowed(hadW, w) then
return makeVectorConstant(leftVector.x / right.value, leftVector.y / right.value, leftVector.z / right.value, w)
end
elseif expr.op == "//" then
local w = math.floor(leftVector.w / right.value)
if vectorResultAllowed(hadW, w) then
return makeVectorConstant(
math.floor(leftVector.x / right.value),
math.floor(leftVector.y / right.value),
math.floor(leftVector.z / right.value),
w
)
end
end
elseif left.kind == "Number" and rightVector then
local hadW = rightVector.w ~= 0
if expr.op == "*" then
local w = left.value * rightVector.w
if vectorResultAllowed(hadW, w) then
return makeVectorConstant(left.value * rightVector.x, left.value * rightVector.y, left.value * rightVector.z, w)
end
elseif expr.op == "/" then
local w = left.value / rightVector.w
if vectorResultAllowed(hadW, w) then
return makeVectorConstant(left.value / rightVector.x, left.value / rightVector.y, left.value / rightVector.z, w)
end
elseif expr.op == "//" then
local w = math.floor(left.value / rightVector.w)
if vectorResultAllowed(hadW, w) then
return makeVectorConstant(
math.floor(left.value / rightVector.x),
math.floor(left.value / rightVector.y),
math.floor(left.value / rightVector.z),
w
)
end
end
end
end
end
return nil
end
local builtinReturnTypes = nil
function PureCompiler:getBytecodeType(ctx, expr)
if expr == nil or expr.kind == "Nil" then
return nil
elseif expr.kind == "Bool" then
return LBC_TYPE_BOOLEAN
elseif expr.kind == "Number" then
return LBC_TYPE_NUMBER
elseif expr.kind == "Integer" then
return LBC_TYPE_INTEGER
elseif expr.kind == "String" then
return LBC_TYPE_STRING
elseif expr.kind == "Table" then
return nil
elseif expr.kind == "Name" then
local localInfo = self:findLocal(ctx, expr.name)
if localInfo and localInfo.typeId ~= nil then
return localInfo.typeId
end
elseif expr.kind == "Call" then
local builtinId = self:getBuiltinIdForCallee(ctx, expr.callee)
if builtinId ~= nil and builtinReturnTypes[builtinId] ~= nil then
return builtinReturnTypes[builtinId]
end
local funcExpr = self:getO2FunctionExprForCallee(ctx, expr.callee)
if funcExpr and funcExpr.returnType ~= nil then
return funcExpr.returnType
end
elseif expr.kind == "Un" then
if expr.op == "#" then
return LBC_TYPE_NUMBER
elseif expr.op == "-" and self:getBytecodeType(ctx, expr.expr) == LBC_TYPE_NUMBER then
return LBC_TYPE_NUMBER
end
elseif expr.kind == "Bin" then
if arithmeticOps[expr.op] and self:getBytecodeType(ctx, expr.left) == LBC_TYPE_NUMBER and self:getBytecodeType(ctx, expr.right) == LBC_TYPE_NUMBER then
return LBC_TYPE_NUMBER
end
elseif expr.kind == "SingleResult" then
return self:getBytecodeType(ctx, expr.expr)
elseif expr.kind == "Instantiate" then
return self:getBytecodeType(ctx, expr.expr)
end
local constant = self:getConstant(ctx, expr)
if constant then
if constant.kind == "Nil" then
return nil
elseif constant.kind == "Bool" then
return LBC_TYPE_BOOLEAN
elseif constant.kind == "Number" then
return LBC_TYPE_NUMBER
elseif constant.kind == "Integer" then
return LBC_TYPE_INTEGER
elseif constant.kind == "String" then
return LBC_TYPE_STRING
elseif constant.kind == "Vector" then
return LBC_TYPE_VECTOR
end
end
return nil
end
function PureCompiler:getInitializerBytecodeType(ctx, values, index, nameCount)
local value = values[index]
if value == nil then
return nil
end
if index == #values and nameCount > #values and (value.kind == "Call" or value.kind == "MethodCall" or value.kind == "Vararg") then
return nil
end
return self:getBytecodeType(ctx, value)
end
function PureCompiler:updateExistingLocalBytecodeType(localInfo, typeId)
if localInfo.typeId ~= nil then
localInfo.typeId = typeId
end
end
function PureCompiler:hintTemporaryExprRegType(ctx, expr, reg, expectedType, instLength)
if (ctx.builder.options.optimizationLevel or 1) < 2 or expr == nil or expr.kind == "Name" then
return
end
local typeId = self:getBytecodeType(ctx, expr)
if typeId ~= nil and typeId ~= expectedType then
local endpc = ctx.builder:label(ctx.func)
ctx.builder:pushLocalTypeInfo(ctx.func, typeId, reg, endpc - instLength, endpc)
end
end
function PureCompiler:captureParentConstantsForDebug(ctx, expr)
if (self.builder.options.debugLevel or 1) < 2 or expr == nil then
return
end
if expr.kind == "Name" then
if not self:findLocal(ctx, expr.name) and self:findParentConstant(ctx, expr.name) then
self:rememberDebugConstantUpvalue(ctx, expr.name)
end
elseif expr.kind == "Field" then
self:captureParentConstantsForDebug(ctx, expr.object)
elseif expr.kind == "Index" then
self:captureParentConstantsForDebug(ctx, expr.object)
self:captureParentConstantsForDebug(ctx, expr.index)
elseif expr.kind == "Un" then
self:captureParentConstantsForDebug(ctx, expr.expr)
elseif expr.kind == "Bin" then
self:captureParentConstantsForDebug(ctx, expr.left)
self:captureParentConstantsForDebug(ctx, expr.right)
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
self:captureParentConstantsForDebug(ctx, clause.condition)
self:captureParentConstantsForDebug(ctx, clause.value)
end
self:captureParentConstantsForDebug(ctx, expr.elseValue)
elseif expr.kind == "SingleResult" then
self:captureParentConstantsForDebug(ctx, expr.expr)
elseif expr.kind == "InterpString" then
for _, value in ipairs(expr.expressions) do
self:captureParentConstantsForDebug(ctx, value)
end
end
end
function PureCompiler:emitConstant(ctx, constant, target, line)
if constant.kind == "Nil" then
ctx.builder:emitABC(ctx.func, LOP.LOADNIL, target, 0, 0, line)
self:useReg(ctx, target)
elseif constant.kind == "Bool" then
ctx.builder:emitABC(ctx.func, LOP.LOADB, target, constant.value and 1 or 0, 0, line)
self:useReg(ctx, target)
elseif constant.kind == "Number" or constant.kind == "String" then
self:emitLoadConstant(ctx, target, constant.value, line)
elseif constant.kind == "Integer" then
local cid = ctx.builder:addConstantInteger(ctx.func, constant.value)
self:emitLoadKIndex(ctx, target, cid, line)
elseif constant.kind == "Vector" then
local cid = ctx.builder:addConstantVector(ctx.func, constant.value)
self:emitLoadKIndex(ctx, target, cid, line)
else
error("invalid constant", 2)
end
end
function PureCompiler:emitLocalConstant(ctx, localInfo, target, line)
if localInfo.constKind == "Nil" then
ctx.builder:emitABC(ctx.func, LOP.LOADNIL, target, 0, 0, line)
elseif localInfo.constKind == "Bool" then
ctx.builder:emitABC(ctx.func, LOP.LOADB, target, localInfo.constValue and 1 or 0, 0, line)
elseif localInfo.constKind == "Number" then
self:emitLoadConstant(ctx, target, localInfo.constValue, line)
return
elseif localInfo.constKind == "String" then
self:emitLoadConstant(ctx, target, localInfo.constValue, line)
return
elseif localInfo.constKind == "Integer" then
local cid = ctx.builder:addConstantInteger(ctx.func, localInfo.constValue)
self:emitLoadKIndex(ctx, target, cid, line)
return
elseif localInfo.constKind == "Vector" then
local cid = ctx.builder:addConstantVector(ctx.func, localInfo.constValue)
self:emitLoadKIndex(ctx, target, cid, line)
return
else
error("invalid constant local", 2)
end
self:useReg(ctx, target)
end
function PureCompiler:makeConstantLocal(ctx, name, expr, written)
if (self.builder.options.optimizationLevel or 1) < 1 or (self.builder.options.debugLevel or 1) > 1 then
return nil
end
if written == true or ctx.writeNames and ctx.writeNames[name] == true then
return nil
end
local constant = self:getConstant(ctx, expr)
if constant then
return {
name = name,
constKind = constant.kind,
constValue = constant.value,
}
end
return nil
end
LBF = {
ASSERT = 1,
MATH_ABS = 2,
MATH_ACOS = 3,
MATH_ASIN = 4,
MATH_ATAN2 = 5,
MATH_ATAN = 6,
MATH_CEIL = 7,
MATH_COSH = 8,
MATH_COS = 9,
MATH_DEG = 10,
MATH_EXP = 11,
MATH_FLOOR = 12,
MATH_FMOD = 13,
MATH_FREXP = 14,
MATH_LDEXP = 15,
MATH_LOG10 = 16,
MATH_LOG = 17,
MATH_MAX = 18,
MATH_MIN = 19,
MATH_MODF = 20,
MATH_POW = 21,
MATH_RAD = 22,
MATH_SINH = 23,
MATH_SIN = 24,
MATH_SQRT = 25,
MATH_TANH = 26,
MATH_TAN = 27,
BIT32_ARSHIFT = 28,
BIT32_BAND = 29,
BIT32_BNOT = 30,
BIT32_BOR = 31,
BIT32_BXOR = 32,
BIT32_BTEST = 33,
BIT32_EXTRACT = 34,
BIT32_LROTATE = 35,
BIT32_LSHIFT = 36,
BIT32_REPLACE = 37,
BIT32_RROTATE = 38,
BIT32_RSHIFT = 39,
TYPE = 40,
STRING_BYTE = 41,
STRING_CHAR = 42,
STRING_LEN = 43,
TYPEOF = 44,
STRING_SUB = 45,
MATH_CLAMP = 46,
MATH_SIGN = 47,
MATH_ROUND = 48,
RAWSET = 49,
RAWGET = 50,
RAWEQUAL = 51,
TABLE_INSERT = 52,
TABLE_UNPACK = 53,
VECTOR = 54,
BIT32_COUNTLZ = 55,
BIT32_COUNTRZ = 56,
SELECT_VARARG = 57,
RAWLEN = 58,
BIT32_EXTRACTK = 59,
GETMETATABLE = 60,
SETMETATABLE = 61,
TONUMBER = 62,
TOSTRING = 63,
BIT32_BYTESWAP = 64,
BUFFER_READI8 = 65,
BUFFER_READU8 = 66,
BUFFER_WRITEU8 = 67,
BUFFER_READI16 = 68,
BUFFER_READU16 = 69,
BUFFER_WRITEU16 = 70,
BUFFER_READI32 = 71,
BUFFER_READU32 = 72,
BUFFER_WRITEU32 = 73,
BUFFER_READF32 = 74,
BUFFER_WRITEF32 = 75,
BUFFER_READF64 = 76,
BUFFER_WRITEF64 = 77,
VECTOR_MAGNITUDE = 78,
VECTOR_NORMALIZE = 79,
VECTOR_CROSS = 80,
VECTOR_DOT = 81,
VECTOR_FLOOR = 82,
VECTOR_CEIL = 83,
VECTOR_ABS = 84,
VECTOR_SIGN = 85,
VECTOR_CLAMP = 86,
VECTOR_MIN = 87,
VECTOR_MAX = 88,
MATH_LERP = 89,
VECTOR_LERP = 90,
MATH_ISNAN = 91,
MATH_ISINF = 92,
MATH_ISFINITE = 93,
INTEGER_CREATE = 94,
INTEGER_TONUMBER = 95,
INTEGER_NEG = 96,
INTEGER_ADD = 97,
INTEGER_SUB = 98,
INTEGER_MUL = 99,
INTEGER_DIV = 100,
INTEGER_MIN = 101,
INTEGER_MAX = 102,
INTEGER_REM = 103,
INTEGER_IDIV = 104,
INTEGER_UDIV = 105,
INTEGER_UREM = 106,
INTEGER_MOD = 107,
INTEGER_CLAMP = 108,
INTEGER_BAND = 109,
INTEGER_BOR = 110,
INTEGER_BNOT = 111,
INTEGER_BXOR = 112,
INTEGER_LT = 113,
INTEGER_LE = 114,
INTEGER_ULT = 115,
INTEGER_ULE = 116,
INTEGER_GT = 117,
INTEGER_GE = 118,
INTEGER_UGT = 119,
INTEGER_UGE = 120,
INTEGER_LSHIFT = 121,
INTEGER_RSHIFT = 122,
INTEGER_ARSHIFT = 123,
INTEGER_LROTATE = 124,
INTEGER_RROTATE = 125,
INTEGER_EXTRACT = 126,
INTEGER_BTEST = 127,
INTEGER_COUNTRZ = 128,
INTEGER_COUNTLZ = 129,
INTEGER_BSWAP = 130,
BUFFER_READINTEGER = 131,
BUFFER_WRITEINTEGER = 132,
}
local globalBuiltinIds = {
assert = LBF.ASSERT,
type = LBF.TYPE,
typeof = LBF.TYPEOF,
rawset = LBF.RAWSET,
rawget = LBF.RAWGET,
rawequal = LBF.RAWEQUAL,
rawlen = LBF.RAWLEN,
unpack = LBF.TABLE_UNPACK,
select = LBF.SELECT_VARARG,
getmetatable = LBF.GETMETATABLE,
setmetatable = LBF.SETMETATABLE,
tonumber = LBF.TONUMBER,
tostring = LBF.TOSTRING,
}
local memberBuiltinIds = {
math = {
abs = LBF.MATH_ABS,
acos = LBF.MATH_ACOS,
asin = LBF.MATH_ASIN,
atan2 = LBF.MATH_ATAN2,
atan = LBF.MATH_ATAN,
ceil = LBF.MATH_CEIL,
cosh = LBF.MATH_COSH,
cos = LBF.MATH_COS,
deg = LBF.MATH_DEG,
exp = LBF.MATH_EXP,
floor = LBF.MATH_FLOOR,
fmod = LBF.MATH_FMOD,
frexp = LBF.MATH_FREXP,
ldexp = LBF.MATH_LDEXP,
log10 = LBF.MATH_LOG10,
log = LBF.MATH_LOG,
max = LBF.MATH_MAX,
min = LBF.MATH_MIN,
modf = LBF.MATH_MODF,
pow = LBF.MATH_POW,
rad = LBF.MATH_RAD,
sinh = LBF.MATH_SINH,
sin = LBF.MATH_SIN,
sqrt = LBF.MATH_SQRT,
tanh = LBF.MATH_TANH,
tan = LBF.MATH_TAN,
clamp = LBF.MATH_CLAMP,
sign = LBF.MATH_SIGN,
round = LBF.MATH_ROUND,
lerp = LBF.MATH_LERP,
isnan = LBF.MATH_ISNAN,
isinf = LBF.MATH_ISINF,
isfinite = LBF.MATH_ISFINITE,
},
bit32 = {
arshift = LBF.BIT32_ARSHIFT,
band = LBF.BIT32_BAND,
bnot = LBF.BIT32_BNOT,
bor = LBF.BIT32_BOR,
bxor = LBF.BIT32_BXOR,
btest = LBF.BIT32_BTEST,
extract = LBF.BIT32_EXTRACT,
lrotate = LBF.BIT32_LROTATE,
lshift = LBF.BIT32_LSHIFT,
replace = LBF.BIT32_REPLACE,
rrotate = LBF.BIT32_RROTATE,
rshift = LBF.BIT32_RSHIFT,
countlz = LBF.BIT32_COUNTLZ,
countrz = LBF.BIT32_COUNTRZ,
byteswap = LBF.BIT32_BYTESWAP,
},
string = {
byte = LBF.STRING_BYTE,
char = LBF.STRING_CHAR,
len = LBF.STRING_LEN,
sub = LBF.STRING_SUB,
},
table = {
insert = LBF.TABLE_INSERT,
unpack = LBF.TABLE_UNPACK,
},
buffer = {
readi8 = LBF.BUFFER_READI8,
readu8 = LBF.BUFFER_READU8,
writei8 = LBF.BUFFER_WRITEU8,
writeu8 = LBF.BUFFER_WRITEU8,
readi16 = LBF.BUFFER_READI16,
readu16 = LBF.BUFFER_READU16,
writei16 = LBF.BUFFER_WRITEU16,
writeu16 = LBF.BUFFER_WRITEU16,
readi32 = LBF.BUFFER_READI32,
readu32 = LBF.BUFFER_READU32,
writei32 = LBF.BUFFER_WRITEU32,
writeu32 = LBF.BUFFER_WRITEU32,
readf32 = LBF.BUFFER_READF32,
writef32 = LBF.BUFFER_WRITEF32,
readf64 = LBF.BUFFER_READF64,
writef64 = LBF.BUFFER_WRITEF64,
readinteger = LBF.BUFFER_READINTEGER,
writeinteger = LBF.BUFFER_WRITEINTEGER,
},
vector = {
create = LBF.VECTOR,
magnitude = LBF.VECTOR_MAGNITUDE,
normalize = LBF.VECTOR_NORMALIZE,
cross = LBF.VECTOR_CROSS,
dot = LBF.VECTOR_DOT,
floor = LBF.VECTOR_FLOOR,
ceil = LBF.VECTOR_CEIL,
abs = LBF.VECTOR_ABS,
sign = LBF.VECTOR_SIGN,
clamp = LBF.VECTOR_CLAMP,
min = LBF.VECTOR_MIN,
max = LBF.VECTOR_MAX,
lerp = LBF.VECTOR_LERP,
},
integer = {
create = LBF.INTEGER_CREATE,
tonumber = LBF.INTEGER_TONUMBER,
neg = LBF.INTEGER_NEG,
add = LBF.INTEGER_ADD,
sub = LBF.INTEGER_SUB,
mul = LBF.INTEGER_MUL,
div = LBF.INTEGER_DIV,
min = LBF.INTEGER_MIN,
max = LBF.INTEGER_MAX,
rem = LBF.INTEGER_REM,
idiv = LBF.INTEGER_IDIV,
udiv = LBF.INTEGER_UDIV,
urem = LBF.INTEGER_UREM,
mod = LBF.INTEGER_MOD,
clamp = LBF.INTEGER_CLAMP,
band = LBF.INTEGER_BAND,
bor = LBF.INTEGER_BOR,
bnot = LBF.INTEGER_BNOT,
bxor = LBF.INTEGER_BXOR,
lt = LBF.INTEGER_LT,
le = LBF.INTEGER_LE,
ult = LBF.INTEGER_ULT,
ule = LBF.INTEGER_ULE,
gt = LBF.INTEGER_GT,
ge = LBF.INTEGER_GE,
ugt = LBF.INTEGER_UGT,
uge = LBF.INTEGER_UGE,
lshift = LBF.INTEGER_LSHIFT,
rshift = LBF.INTEGER_RSHIFT,
arshift = LBF.INTEGER_ARSHIFT,
lrotate = LBF.INTEGER_LROTATE,
rrotate = LBF.INTEGER_RROTATE,
extract = LBF.INTEGER_EXTRACT,
btest = LBF.INTEGER_BTEST,
countrz = LBF.INTEGER_COUNTRZ,
countlz = LBF.INTEGER_COUNTLZ,
bswap = LBF.INTEGER_BSWAP,
},
}
local builtinResultCounts = {
[LBF.MATH_ABS] = 1,
[LBF.MATH_ACOS] = 1,
[LBF.MATH_ASIN] = 1,
[LBF.MATH_ATAN2] = 1,
[LBF.MATH_ATAN] = 1,
[LBF.MATH_CEIL] = 1,
[LBF.MATH_COSH] = 1,
[LBF.MATH_COS] = 1,
[LBF.MATH_DEG] = 1,
[LBF.MATH_EXP] = 1,
[LBF.MATH_FLOOR] = 1,
[LBF.MATH_FMOD] = 1,
[LBF.MATH_FREXP] = 2,
[LBF.MATH_LDEXP] = 1,
[LBF.MATH_LOG10] = 1,
[LBF.MATH_LOG] = 1,
[LBF.MATH_MAX] = 1,
[LBF.MATH_MIN] = 1,
[LBF.MATH_MODF] = 2,
[LBF.MATH_POW] = 1,
[LBF.MATH_RAD] = 1,
[LBF.MATH_SINH] = 1,
[LBF.MATH_SIN] = 1,
[LBF.MATH_SQRT] = 1,
[LBF.MATH_TANH] = 1,
[LBF.MATH_TAN] = 1,
[LBF.BIT32_ARSHIFT] = 1,
[LBF.BIT32_BAND] = 1,
[LBF.BIT32_BNOT] = 1,
[LBF.BIT32_BOR] = 1,
[LBF.BIT32_BXOR] = 1,
[LBF.BIT32_BTEST] = 1,
[LBF.BIT32_EXTRACT] = 1,
[LBF.BIT32_LROTATE] = 1,
[LBF.BIT32_LSHIFT] = 1,
[LBF.BIT32_REPLACE] = 1,
[LBF.BIT32_RROTATE] = 1,
[LBF.BIT32_RSHIFT] = 1,
[LBF.TYPE] = 1,
[LBF.STRING_CHAR] = 1,
[LBF.STRING_LEN] = 1,
[LBF.TYPEOF] = 1,
[LBF.STRING_SUB] = 1,
[LBF.MATH_CLAMP] = 1,
[LBF.MATH_SIGN] = 1,
[LBF.MATH_ROUND] = 1,
[LBF.RAWSET] = 1,
[LBF.RAWGET] = 1,
[LBF.RAWEQUAL] = 1,
[LBF.TABLE_INSERT] = 0,
[LBF.VECTOR] = 1,
[LBF.BIT32_COUNTLZ] = 1,
[LBF.BIT32_COUNTRZ] = 1,
[LBF.RAWLEN] = 1,
[LBF.BIT32_EXTRACTK] = 1,
[LBF.GETMETATABLE] = 1,
[LBF.SETMETATABLE] = 1,
[LBF.TONUMBER] = 1,
[LBF.TOSTRING] = 1,
[LBF.BIT32_BYTESWAP] = 1,
[LBF.BUFFER_READI8] = 1,
[LBF.BUFFER_READU8] = 1,
[LBF.BUFFER_WRITEU8] = 0,
[LBF.BUFFER_READI16] = 1,
[LBF.BUFFER_READU16] = 1,
[LBF.BUFFER_WRITEU16] = 0,
[LBF.BUFFER_READI32] = 1,
[LBF.BUFFER_READU32] = 1,
[LBF.BUFFER_WRITEU32] = 0,
[LBF.BUFFER_READF32] = 1,
[LBF.BUFFER_WRITEF32] = 0,
[LBF.BUFFER_READF64] = 1,
[LBF.BUFFER_WRITEF64] = 0,
[LBF.VECTOR_MAGNITUDE] = 1,
[LBF.VECTOR_NORMALIZE] = 1,
[LBF.VECTOR_CROSS] = 1,
[LBF.VECTOR_DOT] = 1,
[LBF.VECTOR_FLOOR] = 1,
[LBF.VECTOR_CEIL] = 1,
[LBF.VECTOR_ABS] = 1,
[LBF.VECTOR_SIGN] = 1,
[LBF.VECTOR_CLAMP] = 1,
[LBF.VECTOR_MIN] = 1,
[LBF.VECTOR_MAX] = 1,
[LBF.MATH_LERP] = 1,
[LBF.VECTOR_LERP] = 1,
[LBF.MATH_ISNAN] = 1,
[LBF.MATH_ISINF] = 1,
[LBF.MATH_ISFINITE] = 1,
[LBF.INTEGER_CREATE] = 1,
[LBF.INTEGER_TONUMBER] = 1,
[LBF.INTEGER_NEG] = 1,
[LBF.INTEGER_ADD] = 1,
[LBF.INTEGER_SUB] = 1,
[LBF.INTEGER_MUL] = 1,
[LBF.INTEGER_DIV] = 1,
[LBF.INTEGER_MIN] = 1,
[LBF.INTEGER_MAX] = 1,
[LBF.INTEGER_REM] = 1,
[LBF.INTEGER_IDIV] = 1,
[LBF.INTEGER_UDIV] = 1,
[LBF.INTEGER_UREM] = 1,
[LBF.INTEGER_MOD] = 1,
[LBF.INTEGER_CLAMP] = 1,
[LBF.INTEGER_BAND] = 1,
[LBF.INTEGER_BOR] = 1,
[LBF.INTEGER_BNOT] = 1,
[LBF.INTEGER_BXOR] = 1,
[LBF.INTEGER_LT] = 1,
[LBF.INTEGER_LE] = 1,
[LBF.INTEGER_ULT] = 1,
[LBF.INTEGER_ULE] = 1,
[LBF.INTEGER_GT] = 1,
[LBF.INTEGER_GE] = 1,
[LBF.INTEGER_UGT] = 1,
[LBF.INTEGER_UGE] = 1,
[LBF.INTEGER_LSHIFT] = 1,
[LBF.INTEGER_RSHIFT] = 1,
[LBF.INTEGER_ARSHIFT] = 1,
[LBF.INTEGER_LROTATE] = 1,
[LBF.INTEGER_RROTATE] = 1,
[LBF.INTEGER_EXTRACT] = 1,
[LBF.INTEGER_BTEST] = 1,
[LBF.INTEGER_COUNTRZ] = 1,
[LBF.INTEGER_COUNTLZ] = 1,
[LBF.INTEGER_BSWAP] = 1,
[LBF.BUFFER_READINTEGER] = 1,
[LBF.BUFFER_WRITEINTEGER] = 0,
}
builtinReturnTypes = {
[LBF.MATH_ABS] = LBC_TYPE_NUMBER,
[LBF.MATH_ACOS] = LBC_TYPE_NUMBER,
[LBF.MATH_ASIN] = LBC_TYPE_NUMBER,
[LBF.MATH_ATAN2] = LBC_TYPE_NUMBER,
[LBF.MATH_ATAN] = LBC_TYPE_NUMBER,
[LBF.MATH_CEIL] = LBC_TYPE_NUMBER,
[LBF.MATH_COSH] = LBC_TYPE_NUMBER,
[LBF.MATH_COS] = LBC_TYPE_NUMBER,
[LBF.MATH_DEG] = LBC_TYPE_NUMBER,
[LBF.MATH_EXP] = LBC_TYPE_NUMBER,
[LBF.MATH_FLOOR] = LBC_TYPE_NUMBER,
[LBF.MATH_FMOD] = LBC_TYPE_NUMBER,
[LBF.MATH_FREXP] = LBC_TYPE_NUMBER,
[LBF.MATH_LDEXP] = LBC_TYPE_NUMBER,
[LBF.MATH_LOG10] = LBC_TYPE_NUMBER,
[LBF.MATH_LOG] = LBC_TYPE_NUMBER,
[LBF.MATH_MAX] = LBC_TYPE_NUMBER,
[LBF.MATH_MIN] = LBC_TYPE_NUMBER,
[LBF.MATH_MODF] = LBC_TYPE_NUMBER,
[LBF.MATH_POW] = LBC_TYPE_NUMBER,
[LBF.MATH_RAD] = LBC_TYPE_NUMBER,
[LBF.MATH_SINH] = LBC_TYPE_NUMBER,
[LBF.MATH_SIN] = LBC_TYPE_NUMBER,
[LBF.MATH_SQRT] = LBC_TYPE_NUMBER,
[LBF.MATH_TANH] = LBC_TYPE_NUMBER,
[LBF.MATH_TAN] = LBC_TYPE_NUMBER,
[LBF.BIT32_ARSHIFT] = LBC_TYPE_NUMBER,
[LBF.BIT32_BAND] = LBC_TYPE_NUMBER,
[LBF.BIT32_BNOT] = LBC_TYPE_NUMBER,
[LBF.BIT32_BOR] = LBC_TYPE_NUMBER,
[LBF.BIT32_BXOR] = LBC_TYPE_NUMBER,
[LBF.BIT32_BTEST] = LBC_TYPE_NUMBER,
[LBF.BIT32_EXTRACT] = LBC_TYPE_NUMBER,
[LBF.BIT32_LROTATE] = LBC_TYPE_NUMBER,
[LBF.BIT32_LSHIFT] = LBC_TYPE_NUMBER,
[LBF.BIT32_REPLACE] = LBC_TYPE_NUMBER,
[LBF.BIT32_RROTATE] = LBC_TYPE_NUMBER,
[LBF.BIT32_RSHIFT] = LBC_TYPE_NUMBER,
[LBF.TYPE] = LBC_TYPE_STRING,
[LBF.STRING_BYTE] = LBC_TYPE_NUMBER,
[LBF.TYPEOF] = LBC_TYPE_STRING,
[LBF.STRING_CHAR] = LBC_TYPE_STRING,
[LBF.STRING_LEN] = LBC_TYPE_NUMBER,
[LBF.STRING_SUB] = LBC_TYPE_STRING,
[LBF.MATH_CLAMP] = LBC_TYPE_NUMBER,
[LBF.MATH_SIGN] = LBC_TYPE_NUMBER,
[LBF.MATH_ROUND] = LBC_TYPE_NUMBER,
[LBF.RAWEQUAL] = LBC_TYPE_BOOLEAN,
[LBF.RAWLEN] = LBC_TYPE_NUMBER,
[LBF.BIT32_COUNTLZ] = LBC_TYPE_NUMBER,
[LBF.BIT32_COUNTRZ] = LBC_TYPE_NUMBER,
[LBF.BIT32_EXTRACTK] = LBC_TYPE_NUMBER,
[LBF.TONUMBER] = LBC_TYPE_NUMBER,
[LBF.TOSTRING] = LBC_TYPE_STRING,
[LBF.BIT32_BYTESWAP] = LBC_TYPE_NUMBER,
[LBF.BUFFER_READI8] = LBC_TYPE_NUMBER,
[LBF.BUFFER_READU8] = LBC_TYPE_NUMBER,
[LBF.BUFFER_READI16] = LBC_TYPE_NUMBER,
[LBF.BUFFER_READU16] = LBC_TYPE_NUMBER,
[LBF.BUFFER_READI32] = LBC_TYPE_NUMBER,
[LBF.BUFFER_READU32] = LBC_TYPE_NUMBER,
[LBF.BUFFER_READF32] = LBC_TYPE_NUMBER,
[LBF.BUFFER_READF64] = LBC_TYPE_NUMBER,
[LBF.VECTOR] = LBC_TYPE_VECTOR,
[LBF.VECTOR_MAGNITUDE] = LBC_TYPE_NUMBER,
[LBF.VECTOR_DOT] = LBC_TYPE_NUMBER,
[LBF.MATH_LERP] = LBC_TYPE_NUMBER,
[LBF.MATH_ISNAN] = LBC_TYPE_BOOLEAN,
[LBF.MATH_ISINF] = LBC_TYPE_BOOLEAN,
[LBF.MATH_ISFINITE] = LBC_TYPE_BOOLEAN,
[LBF.VECTOR_NORMALIZE] = LBC_TYPE_VECTOR,
[LBF.VECTOR_CROSS] = LBC_TYPE_VECTOR,
[LBF.VECTOR_FLOOR] = LBC_TYPE_VECTOR,
[LBF.VECTOR_CEIL] = LBC_TYPE_VECTOR,
[LBF.VECTOR_ABS] = LBC_TYPE_VECTOR,
[LBF.VECTOR_SIGN] = LBC_TYPE_VECTOR,
[LBF.VECTOR_CLAMP] = LBC_TYPE_VECTOR,
[LBF.VECTOR_MIN] = LBC_TYPE_VECTOR,
[LBF.VECTOR_MAX] = LBC_TYPE_VECTOR,
[LBF.VECTOR_LERP] = LBC_TYPE_VECTOR,
[LBF.INTEGER_CREATE] = LBC_TYPE_INTEGER,
[LBF.INTEGER_TONUMBER] = LBC_TYPE_NUMBER,
[LBF.INTEGER_NEG] = LBC_TYPE_INTEGER,
[LBF.INTEGER_ADD] = LBC_TYPE_INTEGER,
[LBF.INTEGER_SUB] = LBC_TYPE_INTEGER,
[LBF.INTEGER_MUL] = LBC_TYPE_INTEGER,
[LBF.INTEGER_DIV] = LBC_TYPE_INTEGER,
[LBF.INTEGER_MIN] = LBC_TYPE_INTEGER,
[LBF.INTEGER_MAX] = LBC_TYPE_INTEGER,
[LBF.INTEGER_REM] = LBC_TYPE_INTEGER,
[LBF.INTEGER_IDIV] = LBC_TYPE_INTEGER,
[LBF.INTEGER_UDIV] = LBC_TYPE_INTEGER,
[LBF.INTEGER_UREM] = LBC_TYPE_INTEGER,
[LBF.INTEGER_MOD] = LBC_TYPE_INTEGER,
[LBF.INTEGER_CLAMP] = LBC_TYPE_INTEGER,
[LBF.INTEGER_BAND] = LBC_TYPE_INTEGER,
[LBF.INTEGER_BOR] = LBC_TYPE_INTEGER,
[LBF.INTEGER_BNOT] = LBC_TYPE_INTEGER,
[LBF.INTEGER_BXOR] = LBC_TYPE_INTEGER,
[LBF.INTEGER_LT] = LBC_TYPE_BOOLEAN,
[LBF.INTEGER_LE] = LBC_TYPE_BOOLEAN,
[LBF.INTEGER_ULT] = LBC_TYPE_BOOLEAN,
[LBF.INTEGER_ULE] = LBC_TYPE_BOOLEAN,
[LBF.INTEGER_GT] = LBC_TYPE_BOOLEAN,
[LBF.INTEGER_GE] = LBC_TYPE_BOOLEAN,
[LBF.INTEGER_UGT] = LBC_TYPE_BOOLEAN,
[LBF.INTEGER_UGE] = LBC_TYPE_BOOLEAN,
[LBF.INTEGER_LSHIFT] = LBC_TYPE_INTEGER,
[LBF.INTEGER_RSHIFT] = LBC_TYPE_INTEGER,
[LBF.INTEGER_ARSHIFT] = LBC_TYPE_INTEGER,
[LBF.INTEGER_LROTATE] = LBC_TYPE_INTEGER,
[LBF.INTEGER_RROTATE] = LBC_TYPE_INTEGER,
[LBF.INTEGER_EXTRACT] = LBC_TYPE_INTEGER,
[LBF.INTEGER_BTEST] = LBC_TYPE_BOOLEAN,
[LBF.INTEGER_COUNTRZ] = LBC_TYPE_INTEGER,
[LBF.INTEGER_COUNTLZ] = LBC_TYPE_INTEGER,
[LBF.INTEGER_BSWAP] = LBC_TYPE_INTEGER,
[LBF.BUFFER_READINTEGER] = LBC_TYPE_INTEGER,
}
local builtinParamCounts = {
[LBF.MATH_ABS] = 1,
[LBF.MATH_ACOS] = 1,
[LBF.MATH_ASIN] = 1,
[LBF.MATH_ATAN2] = 2,
[LBF.MATH_ATAN] = 1,
[LBF.MATH_CEIL] = 1,
[LBF.MATH_COSH] = 1,
[LBF.MATH_COS] = 1,
[LBF.MATH_DEG] = 1,
[LBF.MATH_EXP] = 1,
[LBF.MATH_FLOOR] = 1,
[LBF.MATH_FMOD] = 2,
[LBF.MATH_FREXP] = 1,
[LBF.MATH_LDEXP] = 2,
[LBF.MATH_LOG10] = 1,
[LBF.MATH_MODF] = 1,
[LBF.MATH_POW] = 2,
[LBF.MATH_RAD] = 1,
[LBF.MATH_SINH] = 1,
[LBF.MATH_SIN] = 1,
[LBF.MATH_SQRT] = 1,
[LBF.MATH_TANH] = 1,
[LBF.MATH_TAN] = 1,
[LBF.BIT32_ARSHIFT] = 2,
[LBF.BIT32_BNOT] = 1,
[LBF.BIT32_LROTATE] = 2,
[LBF.BIT32_LSHIFT] = 2,
[LBF.BIT32_RROTATE] = 2,
[LBF.BIT32_RSHIFT] = 2,
[LBF.TYPE] = 1,
[LBF.STRING_LEN] = 1,
[LBF.TYPEOF] = 1,
[LBF.MATH_CLAMP] = 3,
[LBF.MATH_SIGN] = 1,
[LBF.MATH_ROUND] = 1,
[LBF.RAWSET] = 3,
[LBF.RAWGET] = 2,
[LBF.RAWEQUAL] = 2,
[LBF.BIT32_COUNTLZ] = 1,
[LBF.BIT32_COUNTRZ] = 1,
[LBF.RAWLEN] = 1,
[LBF.BIT32_EXTRACTK] = 3,
[LBF.GETMETATABLE] = 1,
[LBF.SETMETATABLE] = 2,
[LBF.TOSTRING] = 1,
[LBF.BIT32_BYTESWAP] = 1,
[LBF.BUFFER_READI8] = 2,
[LBF.BUFFER_READU8] = 2,
[LBF.BUFFER_WRITEU8] = 3,
[LBF.BUFFER_READI16] = 2,
[LBF.BUFFER_READU16] = 2,
[LBF.BUFFER_WRITEU16] = 3,
[LBF.BUFFER_READI32] = 2,
[LBF.BUFFER_READU32] = 2,
[LBF.BUFFER_WRITEU32] = 3,
[LBF.BUFFER_READF32] = 2,
[LBF.BUFFER_WRITEF32] = 3,
[LBF.BUFFER_READF64] = 2,
[LBF.BUFFER_WRITEF64] = 3,
[LBF.VECTOR_MAGNITUDE] = 1,
[LBF.VECTOR_NORMALIZE] = 1,
[LBF.VECTOR_CROSS] = 2,
[LBF.VECTOR_DOT] = 2,
[LBF.VECTOR_FLOOR] = 1,
[LBF.VECTOR_CEIL] = 1,
[LBF.VECTOR_ABS] = 1,
[LBF.VECTOR_SIGN] = 1,
[LBF.VECTOR_CLAMP] = 3,
[LBF.MATH_LERP] = 3,
[LBF.VECTOR_LERP] = 3,
[LBF.MATH_ISNAN] = 1,
[LBF.MATH_ISINF] = 1,
[LBF.MATH_ISFINITE] = 1,
[LBF.INTEGER_CREATE] = 1,
[LBF.INTEGER_TONUMBER] = 1,
[LBF.INTEGER_NEG] = 1,
[LBF.INTEGER_ADD] = 2,
[LBF.INTEGER_SUB] = 2,
[LBF.INTEGER_MUL] = 2,
[LBF.INTEGER_DIV] = 2,
[LBF.INTEGER_REM] = 2,
[LBF.INTEGER_IDIV] = 2,
[LBF.INTEGER_UDIV] = 2,
[LBF.INTEGER_UREM] = 2,
[LBF.INTEGER_MOD] = 2,
[LBF.INTEGER_CLAMP] = 3,
[LBF.INTEGER_BNOT] = 1,
[LBF.INTEGER_LT] = 2,
[LBF.INTEGER_LE] = 2,
[LBF.INTEGER_ULT] = 2,
[LBF.INTEGER_ULE] = 2,
[LBF.INTEGER_GT] = 2,
[LBF.INTEGER_GE] = 2,
[LBF.INTEGER_UGT] = 2,
[LBF.INTEGER_UGE] = 2,
[LBF.INTEGER_LSHIFT] = 2,
[LBF.INTEGER_RSHIFT] = 2,
[LBF.INTEGER_ARSHIFT] = 2,
[LBF.INTEGER_LROTATE] = 2,
[LBF.INTEGER_RROTATE] = 2,
[LBF.INTEGER_COUNTRZ] = 1,
[LBF.INTEGER_COUNTLZ] = 1,
[LBF.INTEGER_BSWAP] = 1,
[LBF.BUFFER_READINTEGER] = 2,
[LBF.BUFFER_WRITEINTEGER] = 3,
}
local builtinNoneSafe = {
[LBF.MATH_ABS] = true,
[LBF.MATH_ACOS] = true,
[LBF.MATH_ASIN] = true,
[LBF.MATH_ATAN2] = true,
[LBF.MATH_ATAN] = true,
[LBF.MATH_CEIL] = true,
[LBF.MATH_COSH] = true,
[LBF.MATH_COS] = true,
[LBF.MATH_DEG] = true,
[LBF.MATH_EXP] = true,
[LBF.MATH_FLOOR] = true,
[LBF.MATH_FMOD] = true,
[LBF.MATH_FREXP] = true,
[LBF.MATH_LDEXP] = true,
[LBF.MATH_LOG10] = true,
[LBF.MATH_MODF] = true,
[LBF.MATH_POW] = true,
[LBF.MATH_RAD] = true,
[LBF.MATH_SINH] = true,
[LBF.MATH_SIN] = true,
[LBF.MATH_SQRT] = true,
[LBF.MATH_TANH] = true,
[LBF.MATH_TAN] = true,
[LBF.BIT32_ARSHIFT] = true,
[LBF.BIT32_BNOT] = true,
[LBF.BIT32_LROTATE] = true,
[LBF.BIT32_LSHIFT] = true,
[LBF.BIT32_RROTATE] = true,
[LBF.BIT32_RSHIFT] = true,
[LBF.STRING_LEN] = true,
[LBF.MATH_CLAMP] = true,
[LBF.MATH_SIGN] = true,
[LBF.MATH_ROUND] = true,
[LBF.BIT32_COUNTLZ] = true,
[LBF.BIT32_COUNTRZ] = true,
[LBF.RAWLEN] = true,
[LBF.BIT32_EXTRACTK] = true,
[LBF.BIT32_BYTESWAP] = true,
[LBF.BUFFER_READI8] = true,
[LBF.BUFFER_READU8] = true,
[LBF.BUFFER_WRITEU8] = true,
[LBF.BUFFER_READI16] = true,
[LBF.BUFFER_READU16] = true,
[LBF.BUFFER_WRITEU16] = true,
[LBF.BUFFER_READI32] = true,
[LBF.BUFFER_READU32] = true,
[LBF.BUFFER_WRITEU32] = true,
[LBF.BUFFER_READF32] = true,
[LBF.BUFFER_WRITEF32] = true,
[LBF.BUFFER_READF64] = true,
[LBF.BUFFER_WRITEF64] = true,
[LBF.BUFFER_READINTEGER] = true,
[LBF.BUFFER_WRITEINTEGER] = true,
[LBF.VECTOR_MAGNITUDE] = true,
[LBF.VECTOR_NORMALIZE] = true,
[LBF.VECTOR_CROSS] = true,
[LBF.VECTOR_DOT] = true,
[LBF.VECTOR_FLOOR] = true,
[LBF.VECTOR_CEIL] = true,
[LBF.VECTOR_ABS] = true,
[LBF.VECTOR_SIGN] = true,
[LBF.VECTOR_CLAMP] = true,
[LBF.VECTOR_LERP] = true,
[LBF.MATH_LERP] = true,
[LBF.MATH_ISNAN] = true,
[LBF.MATH_ISINF] = true,
[LBF.MATH_ISFINITE] = true,
[LBF.INTEGER_CREATE] = true,
[LBF.INTEGER_TONUMBER] = true,
[LBF.INTEGER_NEG] = true,
[LBF.INTEGER_ADD] = true,
[LBF.INTEGER_SUB] = true,
[LBF.INTEGER_MUL] = true,
[LBF.INTEGER_DIV] = true,
[LBF.INTEGER_REM] = true,
[LBF.INTEGER_IDIV] = true,
[LBF.INTEGER_UDIV] = true,
[LBF.INTEGER_UREM] = true,
[LBF.INTEGER_MOD] = true,
[LBF.INTEGER_CLAMP] = true,
[LBF.INTEGER_BNOT] = true,
[LBF.INTEGER_LT] = true,
[LBF.INTEGER_LE] = true,
[LBF.INTEGER_ULT] = true,
[LBF.INTEGER_ULE] = true,
[LBF.INTEGER_GT] = true,
[LBF.INTEGER_GE] = true,
[LBF.INTEGER_UGT] = true,
[LBF.INTEGER_UGE] = true,
[LBF.INTEGER_LSHIFT] = true,
[LBF.INTEGER_RSHIFT] = true,
[LBF.INTEGER_ARSHIFT] = true,
[LBF.INTEGER_LROTATE] = true,
[LBF.INTEGER_RROTATE] = true,
[LBF.INTEGER_COUNTRZ] = true,
[LBF.INTEGER_COUNTLZ] = true,
[LBF.INTEGER_BSWAP] = true,
}
local function joinPath(path)
return table.concat(path, ".")
end
function PureCompiler:isBuiltinDisabled(path, bfid)
local disabled = self.builder.options.disabledBuiltins
if disabled == nil then
return false
end
local pathText = joinPath(path)
for _, name in pairs(disabled) do
if name == pathText or (#path == 1 and name == path[1]) then
return true
end
local disabledId = nil
local dot = string.find(name, ".", 1, true)
if dot then
local object = string.sub(name, 1, dot - 1)
local method = string.sub(name, dot + 1)
local members = memberBuiltinIds[object]
disabledId = members and members[method] or nil
else
disabledId = globalBuiltinIds[name]
end
if disabledId == bfid then
return true
end
end
return false
end
function PureCompiler:getBuiltinIdFromPath(path)
if (self.builder.options.optimizationLevel or 1) < 1 then
return nil
end
local id = nil
if #path == 1 then
id = globalBuiltinIds[path[1]]
elseif #path == 2 then
local members = memberBuiltinIds[path[1]]
id = members and members[path[2]] or nil
local options = self.builder.options
if not id and options.vectorCtor and options.vectorLib and path[1] == options.vectorLib and path[2] == options.vectorCtor then
id = LBF.VECTOR
end
end
if not id and #path == 1 and self.builder.options.vectorCtor and not self.builder.options.vectorLib and path[1] == self.builder.options.vectorCtor then
id = LBF.VECTOR
end
if id and self:isBuiltinDisabled(path, id) then
return nil
end
return id
end
function PureCompiler:enterScope(ctx)
return {
localCount = #ctx.localList,
nextReg = ctx.nextReg,
}
end
function PureCompiler:declareLocal(ctx, localInfo)
localInfo.previous = ctx.locals[localInfo.name]
localInfo.functionDepth = localInfo.functionDepth or ctx.functionDepth
localInfo.loopDepth = localInfo.loopDepth or ctx.loopDepth
localInfo.debugDepth = localInfo.debugDepth or ctx.debugScopeDepth
if localInfo.debugOrder == nil then
localInfo.debugOrder = ctx.debugLocalOrder
ctx.debugLocalOrder += 1
end
if localInfo.reg ~= nil and localInfo.debugStart == nil then
localInfo.debugStart = ctx.builder:label(ctx.func)
end
ctx.locals[localInfo.name] = localInfo
append(ctx.localList, localInfo)
return localInfo
end
function PureCompiler:popLocals(ctx, localCount)
local endpc = ctx.builder:label(ctx.func)
for index = localCount + 1, #ctx.localList do
local localInfo = ctx.localList[index]
if localInfo.reg ~= nil and localInfo.debugStart ~= nil then
ctx.builder:pushDebugLocal(ctx.func, localInfo.name, localInfo.reg, localInfo.debugStart, endpc, localInfo.debugDepth, localInfo.debugOrder)
end
end
for index = #ctx.localList, localCount + 1, -1 do
local localInfo = ctx.localList[index]
ctx.locals[localInfo.name] = localInfo.previous
ctx.localList[index] = nil
end
end
function PureCompiler:leaveScope(ctx, mark)
self:popLocals(ctx, mark.localCount)
ctx.nextReg = mark.nextReg
end
function PureCompiler:emitLoadConstant(ctx, target, value, line)
local builder = ctx.builder
local func = ctx.func
local negativeZero = type(value) == "number" and value == 0 and 1 / value == -math.huge
if (builder.options.optimizationLevel or 1) >= 1 and type(value) == "number" and value % 1 == 0 and value >= -32768 and value <= 32767 and not negativeZero then
builder:emitAD(func, LOP.LOADN, target, value, line)
else
local constant = type(value) == "string" and builder:addConstantString(func, value) or builder:addConstantNumber(func, value)
if constant <= 32767 then
builder:emitAD(func, LOP.LOADK, target, constant, line)
else
builder:emitABC(func, LOP.LOADKX, target, 0, 0, line)
builder:emitAux(func, constant, line)
end
end
self:useReg(ctx, target)
end
function PureCompiler:emitGetGlobal(ctx, target, name, line)
if (self.builder.options.optimizationLevel or 1) >= 1 and not (ctx.globalWrites and ctx.globalWrites[name]) then
local importConstant = ctx.builder:addConstantImport(ctx.func, { name })
ctx.builder:emitAD(ctx.func, LOP.GETIMPORT, target, importConstant, line)
ctx.builder:emitAux(ctx.func, ctx.func.constants[importConstant + 1].value, line)
self:useReg(ctx, target)
return
end
local constant = ctx.builder:addConstantString(ctx.func, name)
local slot = stringHash(name) % 256
ctx.builder:emitABC(ctx.func, LOP.GETGLOBAL, target, 0, slot, line)
ctx.builder:emitAux(ctx.func, constant, line)
self:useReg(ctx, target)
end
function PureCompiler:emitSetGlobal(ctx, source, name, line)
local constant = ctx.builder:addConstantString(ctx.func, name)
local slot = stringHash(name) % 256
ctx.builder:emitABC(ctx.func, LOP.SETGLOBAL, source, 0, slot, line)
ctx.builder:emitAux(ctx.func, constant, line)
end
function PureCompiler:emitCloseUpvals(ctx, line, startIndex)
local minReg = nil
startIndex = startIndex or 1
for index = startIndex, #ctx.localList do
local localInfo = ctx.localList[index]
if localInfo.captured and localInfo.written and localInfo.reg ~= nil then
minReg = minReg and math.min(minReg, localInfo.reg) or localInfo.reg
end
end
if minReg ~= nil then
ctx.builder:emitABC(ctx.func, LOP.CLOSEUPVALS, minReg, 0, 0, line)
return true
end
return false
end
function PureCompiler:hasCloseUpvals(ctx, startIndex)
startIndex = startIndex or 1
for index = startIndex, #ctx.localList do
local localInfo = ctx.localList[index]
if localInfo.captured and localInfo.written and localInfo.reg ~= nil then
return true
end
end
return false
end
function PureCompiler:hasCapturedLocals(ctx, startIndex)
startIndex = startIndex or 1
for index = startIndex, #ctx.localList do
local localInfo = ctx.localList[index]
if localInfo.captured and localInfo.reg ~= nil then
return true
end
end
return false
end
function PureCompiler:rememberDebugConstantUpvalue(ctx, name)
if (self.builder.options.debugLevel or 1) < 2 or ctx.debugConstantUpvalueMap[name] then
return
end
if not self:findParentConstant(ctx, name) then
return
end
ctx.debugConstantUpvalueMap[name] = true
append(ctx.debugConstantUpvalues, name)
end
function PureCompiler:appendDebugConstantUpvalues(ctx)
if (self.builder.options.debugLevel or 1) < 2 then
return
end
for _, name in ipairs(ctx.debugConstantUpvalues) do
if ctx.upvalueMap[name] == nil then
local parentConstant = self:findParentConstant(ctx, name)
if parentConstant then
local index = #ctx.upvalues
ctx.upvalueMap[name] = index
append(ctx.upvalues, {
name = name,
sourceKind = "constant",
source = parentConstant,
})
ctx.func.numupvalues = #ctx.upvalues
end
end
end
end
function PureCompiler:emitCoverage(ctx, line)
ctx.builder:emitABC(ctx.func, LOP.COVERAGE, 0, 0, 0, line)
end
function PureCompiler:emitSourceCoverageIfNeeded(ctx, expr, line)
if (ctx.builder.options.coverageLevel or 0) < 2 or expr.kind ~= "Name" then
return
end
local localInfo = self:findLocal(ctx, expr.name)
local parentConstant = localInfo == nil and self:findParentConstant(ctx, expr.name) or nil
local localConstant = self:canUseLocalConstant(localInfo)
if localConstant or parentConstant then
self:emitCoverage(ctx, expr.line or line)
end
end
function PureCompiler:buildFunctionTypeInfo(expr)
if (self.builder.options.optimizationLevel or 1) < 2 or not expr.paramTypes then
return ""
end
local haveNonAnyParam = false
for index = 1, #expr.params do
local typeId = expr.paramTypes[index]
if typeId ~= nil and typeId ~= LBC_TYPE_ANY then
haveNonAnyParam = true
break
end
end
if not haveNonAnyParam then
return ""
end
local out = {
string.char(_LBC_TYPE_FUNCTION),
string.char(#expr.params),
}
for index = 1, #expr.params do
append(out, string.char(expr.paramTypes[index] or LBC_TYPE_ANY))
end
return table.concat(out)
end
function PureCompiler:ensureFunctionExprCompiled(ctx, expr, debugName)
if expr.compiledFunctionId ~= nil then
return expr.compiledFunctionId, expr.compiledUpvalueNames
end
append(self.compilingFunctionStack, expr)
local func = self.builder:createFunction(#expr.params, expr.isvararg)
func.typeInfo = self:buildFunctionTypeInfo(expr)
func.debuglinedefined = expr.line or 1
if expr.native then
func.flags = (func.flags or 0) + LPF_NATIVE_FUNCTION
self.hasNativeFunction = true
end
local childCtx = self:newContext(func, ctx)
childCtx.writeNames = {}
childCtx.readNames = self:collectReadNames(expr.body)
childCtx.globalWrites = ctx.globalWrites
childCtx.tableArrayHints = (self.builder.options.optimizationLevel or 1) >= 1 and self:collectTableArrayHints(expr.body) or {}
childCtx.nextReg = #expr.params
childCtx.maxReg = #expr.params
func.maxstacksize = #expr.params
if expr.isvararg then
self.builder:emitABC(func, LOP.PREPVARARGS, #expr.params, 0, 0, expr.line)
end
for index, name in ipairs(expr.params) do
self:declareLocal(childCtx, {
name = name,
reg = index - 1,
written = expr.paramSymbols and expr.paramSymbols[index] and expr.paramSymbols[index].written == true or false,
typeId = expr.paramTypes and expr.paramTypes[index] or nil,
})
end
local terminated = self:compileBlock(childCtx, expr.body, false, (self.builder.options.optimizationLevel or 1) >= 1)
if not terminated then
local returnLine = expr.endLine or expr.line
self:emitCloseUpvals(childCtx, returnLine)
self.builder:emitABC(func, LOP.RETURN, 0, 1, 0, returnLine)
end
if debugName and (self.builder.options.debugLevel or 1) >= 1 then
func.debugname = self.builder:addString(debugName)
end
self:appendDebugConstantUpvalues(childCtx)
local upvalueNames = {}
for _, upvalue in ipairs(childCtx.upvalues) do
append(upvalueNames, upvalue.name)
if (self.builder.options.debugLevel or 1) >= 2 then
self.builder:pushDebugUpval(func, upvalue.name)
end
end
self:popLocals(childCtx, 0)
local childId = self.builder:addFunction(func)
expr.compiledFunctionId = childId
expr.compiledUpvalueNames = upvalueNames
self.compilingFunctionStack[#self.compilingFunctionStack] = nil
return childId, upvalueNames
end
function PureCompiler:compileFunctionExpr(ctx, expr, target, debugName, selfLocal)
local childId = expr.compiledFunctionId
local upvalueNames = expr.compiledUpvalueNames
if childId == nil then
append(self.compilingFunctionStack, expr)
local func = self.builder:createFunction(#expr.params, expr.isvararg)
func.typeInfo = self:buildFunctionTypeInfo(expr)
func.debuglinedefined = expr.line or 1
if expr.native then
func.flags = (func.flags or 0) + LPF_NATIVE_FUNCTION
self.hasNativeFunction = true
end
local childCtx = self:newContext(func, ctx)
childCtx.writeNames = {}
childCtx.readNames = self:collectReadNames(expr.body)
childCtx.globalWrites = ctx.globalWrites
childCtx.tableArrayHints = (self.builder.options.optimizationLevel or 1) >= 1 and self:collectTableArrayHints(expr.body) or {}
childCtx.nextReg = #expr.params
childCtx.maxReg = #expr.params
func.maxstacksize = #expr.params
if expr.isvararg then
self.builder:emitABC(func, LOP.PREPVARARGS, #expr.params, 0, 0, expr.line)
end
for index, name in ipairs(expr.params) do
self:declareLocal(childCtx, {
name = name,
reg = index - 1,
written = expr.paramSymbols and expr.paramSymbols[index] and expr.paramSymbols[index].written == true or false,
typeId = expr.paramTypes and expr.paramTypes[index] or nil,
})
end
local terminated = self:compileBlock(childCtx, expr.body, false, (self.builder.options.optimizationLevel or 1) >= 1)
if not terminated then
local returnLine = expr.endLine or expr.line
self:emitCloseUpvals(childCtx, returnLine)
self.builder:emitABC(func, LOP.RETURN, 0, 1, 0, returnLine)
end
if debugName and (self.builder.options.debugLevel or 1) >= 1 then
func.debugname = self.builder:addString(debugName)
end
self:appendDebugConstantUpvalues(childCtx)
upvalueNames = {}
for _, upvalue in ipairs(childCtx.upvalues) do
append(upvalueNames, upvalue.name)
if (self.builder.options.debugLevel or 1) >= 2 then
self.builder:pushDebugUpval(func, upvalue.name)
end
end
self:popLocals(childCtx, 0)
childId = self.builder:addFunction(func)
expr.compiledFunctionId = childId
expr.compiledUpvalueNames = upvalueNames
self.compilingFunctionStack[#self.compilingFunctionStack] = nil
end
local protoIndex = self.builder:addChildFunction(ctx.func, childId)
local captures = {}
local shareable = true
local nextCaptureTemp = nil
for _, name in ipairs(upvalueNames or {}) do
local capturedLocal = self:findLocal(ctx, name)
local captureType = nil
local captureData = nil
if self:canUseLocalConstant(capturedLocal) then
local temp = nextCaptureTemp or self:getTempReg(ctx, target)
nextCaptureTemp = temp + 1
self:emitLocalConstant(ctx, capturedLocal, temp, expr.line or 1)
captureType = 0
captureData = temp
elseif capturedLocal and capturedLocal.reg ~= nil then
captureType = capturedLocal.written and 1 or 0
captureData = capturedLocal.reg
else
local parentConstant = self:findParentConstant(ctx, name)
if parentConstant then
capturedLocal = parentConstant
local temp = nextCaptureTemp or self:getTempReg(ctx, target)
nextCaptureTemp = temp + 1
self:emitLocalConstant(ctx, parentConstant, temp, expr.line or 1)
captureType = 0
captureData = temp
else
local upvalueIndex = self:getUpvalue(ctx, name)
if upvalueIndex ~= nil then
local upvalue = ctx.upvalues[upvalueIndex + 1]
capturedLocal = self:resolveCapturedLocal(ctx, upvalue)
captureType = 2
captureData = upvalueIndex
end
end
end
if captureType == nil then
shareable = false
end
append(captures, {
kind = captureType,
data = captureData,
})
local isSelfCapture = selfLocal ~= nil and capturedLocal == selfLocal
local initFunction = capturedLocal and (capturedLocal.initFunction or capturedLocal.inlineFunction) or nil
local depthBlocked = not isSelfCapture and ((capturedLocal and capturedLocal.functionDepth or 0) ~= 0 or (capturedLocal and capturedLocal.loopDepth or 0) ~= 0)
if not capturedLocal
or capturedLocal.written
or (depthBlocked and (not initFunction or (initFunction ~= expr and not capturedLocal.shareableClosure)))
then
shareable = false
end
end
if selfLocal then
selfLocal.initFunction = expr
selfLocal.shareableClosure = shareable
end
if (self.builder.options.optimizationLevel or 1) >= 1 and shareable and not self.setfenvUsed then
local closureConstant = self.builder:addConstantClosure(ctx.func, childId)
self.builder:emitAD(ctx.func, LOP.DUPCLOSURE, target, closureConstant, expr.line)
else
self.builder:emitAD(ctx.func, LOP.NEWCLOSURE, target, protoIndex, expr.line)
end
for _, capture in ipairs(captures) do
if capture.kind ~= nil then
self.builder:emitABC(ctx.func, LOP.CAPTURE, capture.kind, capture.data, 0, expr.line)
end
end
self:useReg(ctx, target)
end
function PureCompiler:isMutableGlobalName(name)
if name == "_G" then
return true
end
local mutableGlobals = self.builder.options.mutableGlobals
if type(mutableGlobals) == "table" then
if mutableGlobals[name] then
return true
end
for _, value in pairs(mutableGlobals) do
if value == name then
return true
end
end
end
return false
end
function PureCompiler:getImportPath(ctx, expr, seen)
if (self.builder.options.optimizationLevel or 1) < 1 then
return nil
end
seen = seen or {}
if seen[expr] then
return nil
end
seen[expr] = true
if expr.kind == "Name" then
if self:findLocal(ctx, expr.name) or self:findParentConstant(ctx, expr.name) or self:hasParentLocal(ctx, expr.name) then
return nil
end
if ctx.globalWrites and ctx.globalWrites[expr.name] then
return nil
end
return { expr.name }
elseif expr.kind == "Field" then
local path = self:getImportPath(ctx, expr.object, seen)
if path and #path < 3 and not self:isMutableGlobalName(path[1]) then
append(path, expr.field)
return path
end
end
return nil
end
function PureCompiler:compileImportPath(ctx, path, target, line)
local constant = ctx.builder:addConstantImport(ctx.func, path)
ctx.builder:emitAD(ctx.func, LOP.GETIMPORT, target, constant, line)
ctx.builder:emitAux(ctx.func, ctx.func.constants[constant + 1].value, line)
self:useReg(ctx, target)
end
function PureCompiler:compileExprAsSource(ctx, expr, target)
if expr.kind == "Name" then
local localInfo = self:findLocal(ctx, expr.name)
if localInfo then
if localInfo.reg ~= nil and (self.builder.options.debugLevel or 1) >= 2 then
return localInfo.reg
end
if self:canUseLocalConstant(localInfo) then
self:emitLocalConstant(ctx, localInfo, target, expr.line or 1)
return target
end
return localInfo.reg
end
local parentConstant = self:findParentConstant(ctx, expr.name)
if parentConstant then
if (self.builder.options.debugLevel or 1) >= 2 then
self:rememberDebugConstantUpvalue(ctx, expr.name)
end
self:emitLocalConstant(ctx, parentConstant, target, expr.line or 1)
return target
end
local upvalue = self:getUpvalue(ctx, expr.name)
if upvalue ~= nil then
ctx.builder:emitABC(ctx.func, LOP.GETUPVAL, target, upvalue, 0, expr.line or 1)
self:useReg(ctx, target)
return target
end
end
self:compileExpr(ctx, expr, target)
return target
end
function PureCompiler:compileExprAsSourceNoMove(ctx, expr, target)
if expr.kind == "Call" or expr.kind == "MethodCall" then
if (ctx.builder.options.optimizationLevel or 1) >= 1 then
local constant = self:getConstant(ctx, expr)
if constant then
self:captureParentConstantsForDebug(ctx, expr)
self:emitConstant(ctx, constant, target, expr.line or 1)
return target
end
end
target = math.max(target, ctx.nextReg)
return self:compileCall(ctx, expr, target, 1, nil, true)
end
return self:compileExprAsSource(ctx, expr, target)
end
function PureCompiler:addConstantValue(ctx, constant)
if constant.kind == "Nil" then
return 0
elseif constant.kind == "Bool" then
return constant.value and 1 or 0
elseif constant.kind == "Number" then
return ctx.builder:addConstantNumber(ctx.func, constant.value)
elseif constant.kind == "Integer" then
return ctx.builder:addConstantInteger(ctx.func, constant.value)
elseif constant.kind == "Vector" then
return ctx.builder:addConstantVector(ctx.func, constant.value)
elseif constant.kind == "String" then
return ctx.builder:addConstantString(ctx.func, constant.value)
end
error("invalid comparison constant", 2)
end
function PureCompiler:addConstantIndex(ctx, constant)
if constant.kind == "Nil" then
return ctx.builder:addConstantNil(ctx.func)
elseif constant.kind == "Bool" then
return ctx.builder:addConstantBoolean(ctx.func, constant.value)
elseif constant.kind == "Number" then
return ctx.builder:addConstantNumber(ctx.func, constant.value)
elseif constant.kind == "Integer" then
return ctx.builder:addConstantInteger(ctx.func, constant.value)
elseif constant.kind == "Vector" then
return ctx.builder:addConstantVector(ctx.func, constant.value)
elseif constant.kind == "String" then
return ctx.builder:addConstantString(ctx.func, constant.value)
end
return nil
end
function PureCompiler:getConstantIndex(ctx, expr)
local constant = self:getConstant(ctx, expr)
if not constant then
return nil
end
return self:addConstantIndex(ctx, constant)
end
function PureCompiler:emitLoadKIndex(ctx, target, cid, line)
if cid <= 32767 then
ctx.builder:emitAD(ctx.func, LOP.LOADK, target, cid, line)
else
ctx.builder:emitABC(ctx.func, LOP.LOADKX, target, 0, 0, line)
ctx.builder:emitAux(ctx.func, cid, line)
end
self:useReg(ctx, target)
end
function PureCompiler:getLoadKInstructionCount(cid)
return cid <= 32767 and 1 or 2
end
function PureCompiler:getExprLocalReg(ctx, expr)
if expr.kind ~= "Name" then
return nil
end
local localInfo = self:findLocal(ctx, expr.name)
if localInfo and localInfo.reg ~= nil then
return localInfo.reg
end
return nil
end
function PureCompiler:exprUsesLocalReg(ctx, expr, reg)
if expr.kind == "Name" then
return self:getExprLocalReg(ctx, expr) == reg
elseif expr.kind == "Field" then
return self:exprUsesLocalReg(ctx, expr.object, reg)
elseif expr.kind == "Index" then
return self:exprUsesLocalReg(ctx, expr.object, reg) or self:exprUsesLocalReg(ctx, expr.index, reg)
elseif expr.kind == "Call" then
if self:exprUsesLocalReg(ctx, expr.callee, reg) then
return true
end
for _, arg in ipairs(expr.args) do
if self:exprUsesLocalReg(ctx, arg, reg) then
return true
end
end
elseif expr.kind == "MethodCall" then
if self:exprUsesLocalReg(ctx, expr.object, reg) then
return true
end
for _, arg in ipairs(expr.args) do
if self:exprUsesLocalReg(ctx, arg, reg) then
return true
end
end
elseif expr.kind == "Table" then
for _, entry in ipairs(expr.entries) do
if entry.key and self:exprUsesLocalReg(ctx, entry.key, reg) then
return true
end
if entry.value and self:exprUsesLocalReg(ctx, entry.value, reg) then
return true
end
end
elseif expr.kind == "Un" then
return self:exprUsesLocalReg(ctx, expr.expr, reg)
elseif expr.kind == "Bin" then
return self:exprUsesLocalReg(ctx, expr.left, reg) or self:exprUsesLocalReg(ctx, expr.right, reg)
elseif expr.kind == "SingleResult" then
return self:exprUsesLocalReg(ctx, expr.expr, reg)
elseif expr.kind == "InterpString" then
for _, value in ipairs(expr.expressions) do
if self:exprUsesLocalReg(ctx, value, reg) then
return true
end
end
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
if self:exprUsesLocalReg(ctx, clause.condition, reg) or self:exprUsesLocalReg(ctx, clause.value, reg) then
return true
end
end
return self:exprUsesLocalReg(ctx, expr.elseValue, reg)
end
return false
end
function PureCompiler:exprAsSourceEmits(ctx, expr)
if expr.kind == "Name" then
local localInfo = self:findLocal(ctx, expr.name)
if localInfo and localInfo.reg ~= nil then
return false
end
end
return true
end
function PureCompiler:tryCompileDirectLocalAssign(ctx, stat, line)
if #stat.targets == 0 or #stat.values == 0 then
return false
end
local assignments = {}
local assignedRegs = {}
local nextTemp = ctx.nextReg
for index, target in ipairs(stat.targets) do
if target.kind ~= "Name" then
return false
end
local localInfo = self:findLocal(ctx, target.name)
if not localInfo then
return false
end
localInfo.constPath = nil
localInfo.constKind = nil
if localInfo.reg == nil then
localInfo.reg = self:reserve(ctx, 1)
end
assignments[index] = {
localInfo = localInfo,
targetReg = localInfo.reg,
valueReg = localInfo.reg,
typeId = self:getInitializerBytecodeType(ctx, stat.values, index, #stat.targets),
}
end
for index, assignment in ipairs(assignments) do
local value = stat.values[index]
if value then
for assignedReg in pairs(assignedRegs) do
if self:exprUsesLocalReg(ctx, value, assignedReg) then
for _, prior in ipairs(assignments) do
if prior.targetReg == assignedReg then
if prior.conflictReg == nil then
prior.conflictReg = nextTemp
nextTemp += 1
self:useReg(ctx, prior.conflictReg)
end
prior.valueReg = prior.conflictReg
break
end
end
end
end
end
assignedRegs[assignment.targetReg] = true
end
for index = #stat.targets + 1, #stat.values do
local value = stat.values[index]
for assignedReg in pairs(assignedRegs) do
if self:exprUsesLocalReg(ctx, value, assignedReg) then
for _, prior in ipairs(assignments) do
if prior.targetReg == assignedReg then
if prior.conflictReg == nil then
prior.conflictReg = nextTemp
nextTemp += 1
self:useReg(ctx, prior.conflictReg)
end
prior.valueReg = prior.conflictReg
break
end
end
end
end
end
for index, assignment in ipairs(assignments) do
local value = stat.values[index]
if value == nil then
break
end
if index == #stat.values and #stat.targets > #stat.values then
local rest = #stat.targets - #stat.values + 1
local temp = ctx.nextReg
self:compileExprList(ctx, { value }, temp, rest)
for targetIndex = index, #stat.targets do
assignments[targetIndex].valueReg = temp + targetIndex - index
end
break
end
self:compileExpr(ctx, value, assignment.valueReg, false)
end
if #stat.values > #stat.targets then
local sideTarget = nextTemp
local oldProtectedTop = ctx.protectedTop or 0
ctx.protectedTop = math.max(oldProtectedTop, sideTarget)
for index = #stat.targets + 1, #stat.values do
self:compileExprSide(ctx, stat.values[index], sideTarget)
end
ctx.protectedTop = oldProtectedTop
end
local moveLine = ctx.func.lines[#ctx.func.lines] or line
for _, assignment in ipairs(assignments) do
if assignment.valueReg ~= assignment.targetReg then
ctx.builder:emitABC(ctx.func, LOP.MOVE, assignment.targetReg, assignment.valueReg, 0, moveLine)
end
self:updateExistingLocalBytecodeType(assignment.localInfo, assignment.typeId)
end
return true
end
function PureCompiler:tryCompileMixedNameAssign(ctx, stat, line, baseReg)
if #stat.targets == 0 or #stat.values ~= #stat.targets then
return false
end
local assignments = {}
local hasLocal = false
local hasNonLocal = false
for index, target in ipairs(stat.targets) do
if target.kind ~= "Name" then
return false
end
local localInfo = self:findLocal(ctx, target.name)
if localInfo then
localInfo.constPath = nil
localInfo.constKind = nil
if localInfo.reg == nil then
localInfo.reg = self:reserve(ctx, 1)
end
assignments[index] = {
target = target,
localInfo = localInfo,
targetReg = localInfo.reg,
valueReg = localInfo.reg,
isLocal = true,
typeId = self:getInitializerBytecodeType(ctx, stat.values, index, #stat.targets),
}
hasLocal = true
else
assignments[index] = {
target = target,
isLocal = false,
}
hasNonLocal = true
end
end
if not hasLocal or not hasNonLocal then
return false
end
local assignedRegs = {}
local conflictRegs = {}
for index, assignment in ipairs(assignments) do
if assignment.isLocal then
local value = stat.values[index]
for assignedReg in pairs(assignedRegs) do
if self:exprUsesLocalReg(ctx, value, assignedReg) then
conflictRegs[assignedReg] = true
end
end
assignedRegs[assignment.targetReg] = true
end
end
for index, assignment in ipairs(assignments) do
if not assignment.isLocal then
local value = stat.values[index]
for assignedReg in pairs(assignedRegs) do
if self:exprUsesLocalReg(ctx, value, assignedReg) then
conflictRegs[assignedReg] = true
end
end
end
end
local tempReg = math.max(baseReg, ctx.nextReg)
for _, assignment in ipairs(assignments) do
if assignment.isLocal and conflictRegs[assignment.targetReg] then
assignment.valueReg = tempReg
tempReg += 1
end
end
for index, assignment in ipairs(assignments) do
if not assignment.isLocal then
local valueReg = self:getExprLocalReg(ctx, stat.values[index])
if valueReg ~= nil then
assignment.valueReg = valueReg
else
assignment.valueReg = tempReg
tempReg += 1
end
end
end
local beforeExpr = ctx.builder:label(ctx.func)
local liveTop = math.max(baseReg, ctx.nextReg)
for index, assignment in ipairs(assignments) do
local oldProtectedTop = ctx.protectedTop or 0
ctx.protectedTop = math.max(oldProtectedTop, liveTop)
if assignment.isLocal then
self:compileExpr(ctx, stat.values[index], assignment.valueReg, false)
elseif self:getExprLocalReg(ctx, stat.values[index]) == nil then
self:compileExpr(ctx, stat.values[index], assignment.valueReg)
end
ctx.protectedTop = oldProtectedTop
if assignment.valueReg and assignment.valueReg >= liveTop then
liveTop = assignment.valueReg + 1
end
end
local assignLine = ctx.builder:label(ctx.func) == beforeExpr and line or ctx.func.lines[#ctx.func.lines] or line
for _, assignment in ipairs(assignments) do
if not assignment.isLocal then
self:compileAssignTarget(ctx, assignment.target, assignment.valueReg, assignment.target.line or assignLine)
end
end
for _, assignment in ipairs(assignments) do
if assignment.isLocal and assignment.valueReg ~= assignment.targetReg then
ctx.builder:emitABC(ctx.func, LOP.MOVE, assignment.targetReg, assignment.valueReg, 0, assignLine)
end
if assignment.isLocal then
self:updateExistingLocalBytecodeType(assignment.localInfo, assignment.typeId)
end
end
return true
end
function PureCompiler:tryCompileNonLocalNameAssign(ctx, stat, line, baseReg)
if #stat.targets <= 1 or #stat.values == 0 then
return false
end
local assignments = {}
local nextTemp = baseReg
for index, target in ipairs(stat.targets) do
if target.kind ~= "Name" or self:findLocal(ctx, target.name) ~= nil then
return false
end
local lvalue
lvalue, nextTemp = self:compileLValue(ctx, target, nextTemp)
if lvalue.kind == "Local" then
return false
end
assignments[index] = {
target = target,
lvalue = lvalue,
typeId = self:getInitializerBytecodeType(ctx, stat.values, index, #stat.targets),
}
end
local oldNextReg = ctx.nextReg
ctx.nextReg = math.max(ctx.nextReg, nextTemp)
local beforeExpr = ctx.builder:label(ctx.func)
local limit = math.min(#stat.targets, #stat.values)
for index = 1, limit do
local value = stat.values[index]
local oldProtectedTop = ctx.protectedTop or 0
ctx.protectedTop = math.max(oldProtectedTop, ctx.nextReg)
if index == #stat.values and #stat.targets > #stat.values then
local rest = #stat.targets - #stat.values + 1
local temp = ctx.nextReg
ctx.nextReg = temp + rest
self:useReg(ctx, temp + rest - 1)
self:compileExprList(ctx, { value }, temp, rest)
for fill = index, #stat.targets do
assignments[fill].valueReg = temp + fill - index
end
ctx.protectedTop = oldProtectedTop
break
else
local valueReg = self:compileExprAsSource(ctx, value, ctx.nextReg)
assignments[index].valueReg = valueReg
ctx.nextReg = self:bumpTempReg(ctx.nextReg, valueReg)
end
ctx.protectedTop = oldProtectedTop
end
for index = #stat.targets + 1, #stat.values do
local oldProtectedTop = ctx.protectedTop or 0
ctx.protectedTop = math.max(oldProtectedTop, ctx.nextReg)
self:compileExprSide(ctx, stat.values[index], ctx.nextReg)
ctx.protectedTop = oldProtectedTop
end
local assignLine = ctx.builder:label(ctx.func) == beforeExpr and line or ctx.func.lines[#ctx.func.lines] or line
for _, assignment in ipairs(assignments) do
self:compileAssignLValue(ctx, assignment.lvalue, assignment.valueReg, assignment.target.line or assignLine)
end
ctx.nextReg = oldNextReg
return true
end
function PureCompiler:bumpTempReg(nextTemp, reg)
if reg >= nextTemp then
return reg + 1
end
return nextTemp
end
function PureCompiler:compileLValue(ctx, target, nextTemp)
local builder = ctx.builder
if target.kind == "Name" then
local localInfo = self:findLocal(ctx, target.name)
if localInfo then
localInfo.constPath = nil
localInfo.constKind = nil
if localInfo.reg == nil then
localInfo.reg = self:reserve(ctx, 1)
end
return {
kind = "Local",
reg = localInfo.reg,
localInfo = localInfo,
}, nextTemp
end
local upvalue = self:getUpvalue(ctx, target.name)
if upvalue ~= nil then
return {
kind = "Upvalue",
upvalue = upvalue,
}, nextTemp
end
return {
kind = "Global",
name = target.name,
}, nextTemp
elseif target.kind == "Field" then
local objectReg = self:compileExprAsSource(ctx, target.object, nextTemp)
nextTemp = self:bumpTempReg(nextTemp, objectReg)
return {
kind = "Field",
objectReg = objectReg,
objectExpr = target.object,
field = target.field,
}, nextTemp
elseif target.kind == "Index" then
local objectReg = self:compileExprAsSource(ctx, target.object, nextTemp)
nextTemp = self:bumpTempReg(nextTemp, objectReg)
local indexConstant = (builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, target.index) or nil
if indexConstant and indexConstant.kind == "String" then
return {
kind = "IndexString",
objectReg = objectReg,
objectExpr = target.object,
key = indexConstant.value,
}, nextTemp
elseif indexConstant and indexConstant.kind == "Number" and indexConstant.value % 1 == 0 and indexConstant.value >= 1 and indexConstant.value <= 256 then
return {
kind = "IndexNumber",
objectReg = objectReg,
objectExpr = target.object,
index = indexConstant.value,
}, nextTemp
else
local indexReg = self:compileExprAsSource(ctx, target.index, nextTemp)
nextTemp = self:bumpTempReg(nextTemp, indexReg)
return {
kind = "Index",
objectReg = objectReg,
objectExpr = target.object,
indexReg = indexReg,
indexExpr = target.index,
}, nextTemp
end
end
error("invalid assignment target", 2)
end
function PureCompiler:compileAssignLValue(ctx, lvalue, source, line)
local builder = ctx.builder
local func = ctx.func
if lvalue.kind == "Local" then
if lvalue.reg ~= source then
builder:emitABC(func, LOP.MOVE, lvalue.reg, source, 0, line)
end
elseif lvalue.kind == "Upvalue" then
builder:emitABC(func, LOP.SETUPVAL, source, lvalue.upvalue, 0, line)
elseif lvalue.kind == "Global" then
self:emitSetGlobal(ctx, source, lvalue.name, line)
elseif lvalue.kind == "Field" then
local constant = builder:addConstantString(func, lvalue.field)
local slot = stringHash(lvalue.field) % 256
builder:emitABC(func, LOP.SETTABLEKS, source, lvalue.objectReg, slot, line)
builder:emitAux(func, constant, line)
self:hintTemporaryExprRegType(ctx, lvalue.objectExpr, lvalue.objectReg, LBC_TYPE_TABLE, 2)
elseif lvalue.kind == "IndexString" then
local constant = builder:addConstantString(func, lvalue.key)
local slot = stringHash(lvalue.key) % 256
builder:emitABC(func, LOP.SETTABLEKS, source, lvalue.objectReg, slot, line)
builder:emitAux(func, constant, line)
self:hintTemporaryExprRegType(ctx, lvalue.objectExpr, lvalue.objectReg, LBC_TYPE_TABLE, 2)
elseif lvalue.kind == "IndexNumber" then
builder:emitABC(func, LOP.SETTABLEN, source, lvalue.objectReg, lvalue.index - 1, line)
self:hintTemporaryExprRegType(ctx, lvalue.objectExpr, lvalue.objectReg, LBC_TYPE_TABLE, 1)
elseif lvalue.kind == "Index" then
builder:emitABC(func, LOP.SETTABLE, source, lvalue.objectReg, lvalue.indexReg, line)
self:hintTemporaryExprRegType(ctx, lvalue.objectExpr, lvalue.objectReg, LBC_TYPE_TABLE, 1)
self:hintTemporaryExprRegType(ctx, lvalue.indexExpr, lvalue.indexReg, LBC_TYPE_NUMBER, 1)
else
error("invalid assignment lvalue", 2)
end
end
function PureCompiler:compileLValueUse(ctx, lvalue, target, line)
local builder = ctx.builder
local func = ctx.func
if lvalue.kind == "Local" then
if lvalue.reg ~= target then
builder:emitABC(func, LOP.MOVE, target, lvalue.reg, 0, line)
self:useReg(ctx, target)
else
self:useReg(ctx, target)
end
elseif lvalue.kind == "Upvalue" then
builder:emitABC(func, LOP.GETUPVAL, target, lvalue.upvalue, 0, line)
self:useReg(ctx, target)
elseif lvalue.kind == "Global" then
self:emitGetGlobal(ctx, target, lvalue.name, line)
elseif lvalue.kind == "Field" then
local constant = builder:addConstantString(func, lvalue.field)
local slot = stringHash(lvalue.field) % 256
builder:emitABC(func, LOP.GETTABLEKS, target, lvalue.objectReg, slot, line)
builder:emitAux(func, constant, line)
self:useReg(ctx, target)
self:hintTemporaryExprRegType(ctx, lvalue.objectExpr, lvalue.objectReg, LBC_TYPE_TABLE, 2)
elseif lvalue.kind == "IndexString" then
local constant = builder:addConstantString(func, lvalue.key)
local slot = stringHash(lvalue.key) % 256
builder:emitABC(func, LOP.GETTABLEKS, target, lvalue.objectReg, slot, line)
builder:emitAux(func, constant, line)
self:useReg(ctx, target)
self:hintTemporaryExprRegType(ctx, lvalue.objectExpr, lvalue.objectReg, LBC_TYPE_TABLE, 2)
elseif lvalue.kind == "IndexNumber" then
builder:emitABC(func, LOP.GETTABLEN, target, lvalue.objectReg, lvalue.index - 1, line)
self:useReg(ctx, target)
self:hintTemporaryExprRegType(ctx, lvalue.objectExpr, lvalue.objectReg, LBC_TYPE_TABLE, 1)
elseif lvalue.kind == "Index" then
builder:emitABC(func, LOP.GETTABLE, target, lvalue.objectReg, lvalue.indexReg, line)
self:useReg(ctx, target)
self:hintTemporaryExprRegType(ctx, lvalue.objectExpr, lvalue.objectReg, LBC_TYPE_TABLE, 1)
self:hintTemporaryExprRegType(ctx, lvalue.indexExpr, lvalue.indexReg, LBC_TYPE_NUMBER, 1)
else
error("invalid assignment lvalue", 2)
end
end
function PureCompiler:lvalueUsesLocalReg(lvalue, reg)
if lvalue.kind == "Field" or lvalue.kind == "IndexString" or lvalue.kind == "IndexNumber" then
return lvalue.objectReg == reg
elseif lvalue.kind == "Index" then
return lvalue.objectReg == reg or lvalue.indexReg == reg
end
return false
end
function PureCompiler:tryCompileGeneralAssign(ctx, stat, line, baseReg)
if #stat.targets <= 1 then
return false
end
local hasComplexTarget = false
for _, target in ipairs(stat.targets) do
if target.kind ~= "Name" then
hasComplexTarget = true
break
end
end
if not hasComplexTarget then
return false
end
local assignments = {}
local nextTemp = baseReg
for index, target in ipairs(stat.targets) do
local lvalue
lvalue, nextTemp = self:compileLValue(ctx, target, nextTemp)
assignments[index] = {
target = target,
lvalue = lvalue,
typeId = self:getInitializerBytecodeType(ctx, stat.values, index, #stat.targets),
}
end
local assignedRegs = {}
local conflictRegs = {}
for index, assignment in ipairs(assignments) do
local lvalue = assignment.lvalue
if lvalue.kind == "Local" then
local value = stat.values[index]
if value then
for assignedReg in pairs(assignedRegs) do
if self:exprUsesLocalReg(ctx, value, assignedReg) then
conflictRegs[assignedReg] = true
end
end
end
assignedRegs[lvalue.reg] = true
end
end
for index, assignment in ipairs(assignments) do
if assignment.lvalue.kind ~= "Local" then
local value = stat.values[index]
if value then
for assignedReg in pairs(assignedRegs) do
if self:exprUsesLocalReg(ctx, value, assignedReg) then
conflictRegs[assignedReg] = true
end
end
end
end
end
for index = #stat.targets + 1, #stat.values do
local value = stat.values[index]
for assignedReg in pairs(assignedRegs) do
if self:exprUsesLocalReg(ctx, value, assignedReg) then
conflictRegs[assignedReg] = true
end
end
end
for _, assignment in ipairs(assignments) do
for assignedReg in pairs(assignedRegs) do
if self:lvalueUsesLocalReg(assignment.lvalue, assignedReg) then
conflictRegs[assignedReg] = true
end
end
end
for _, assignment in ipairs(assignments) do
local lvalue = assignment.lvalue
if lvalue.kind == "Local" and conflictRegs[lvalue.reg] then
assignment.conflictReg = nextTemp
nextTemp += 1
self:useReg(ctx, assignment.conflictReg)
end
end
local limit = math.min(#stat.targets, #stat.values)
for index = 1, limit do
local assignment = assignments[index]
local value = stat.values[index]
local oldProtectedTop = ctx.protectedTop or 0
ctx.protectedTop = math.max(oldProtectedTop, nextTemp)
if index == #stat.values and #stat.targets > #stat.values then
local rest = #stat.targets - #stat.values + 1
local temp = nextTemp
nextTemp += rest
self:useReg(ctx, nextTemp - 1)
self:compileExprList(ctx, { value }, temp, rest)
for fill = index, #stat.targets do
assignments[fill].valueReg = temp + fill - index
end
ctx.protectedTop = oldProtectedTop
break
elseif assignment.lvalue.kind == "Local" then
assignment.valueReg = assignment.conflictReg or assignment.lvalue.reg
self:compileExpr(ctx, value, assignment.valueReg, false)
else
local valueReg = self:compileExprAsSource(ctx, value, nextTemp)
assignment.valueReg = valueReg
nextTemp = self:bumpTempReg(nextTemp, valueReg)
end
ctx.protectedTop = oldProtectedTop
end
for index = #stat.targets + 1, #stat.values do
local oldProtectedTop = ctx.protectedTop or 0
ctx.protectedTop = math.max(oldProtectedTop, nextTemp)
self:compileExprSide(ctx, stat.values[index], nextTemp)
ctx.protectedTop = oldProtectedTop
end
local assignLine = ctx.func.lines[#ctx.func.lines] or line
for _, assignment in ipairs(assignments) do
if assignment.lvalue.kind ~= "Local" then
self:compileAssignLValue(ctx, assignment.lvalue, assignment.valueReg, assignment.target.line or assignLine)
end
end
for _, assignment in ipairs(assignments) do
if assignment.lvalue.kind == "Local" and assignment.valueReg ~= assignment.lvalue.reg then
self:compileAssignLValue(ctx, assignment.lvalue, assignment.valueReg, assignLine)
end
if assignment.lvalue.kind == "Local" then
self:updateExistingLocalBytecodeType(assignment.lvalue.localInfo, assignment.typeId)
end
end
return true
end
function PureCompiler:isConditionFast(ctx, expr)
if (self.builder.options.optimizationLevel or 1) >= 1 then
local constant = self:getConstant(ctx, expr)
if constant then
return true
end
end
if expr.kind == "Bin" then
return expr.op == "and" or expr.op == "or" or compareOps[expr.op] == true
end
return false
end
function PureCompiler:getJumpOpCompare(op, inverted)
if op == "~=" then
return inverted and LOP.JUMPIFEQ or LOP.JUMPIFNOTEQ
elseif op == "==" then
return inverted and LOP.JUMPIFNOTEQ or LOP.JUMPIFEQ
elseif op == "<" or op == ">" then
return inverted and LOP.JUMPIFNOTLT or LOP.JUMPIFLT
elseif op == "<=" or op == ">=" then
return inverted and LOP.JUMPIFNOTLE or LOP.JUMPIFLE
end
error("invalid comparison operator", 2)
end
function PureCompiler:compileCompareJump(ctx, expr, jumpIfTrue, target, lineOverride)
local builder = ctx.builder
local func = ctx.func
local line = lineOverride or expr.line or 1
local isEq = expr.op == "==" or expr.op == "~="
local left = expr.left
local right = expr.right
local rightConstant = self:getConstant(ctx, right)
if isEq and not rightConstant and (builder.options.optimizationLevel or 1) >= 1 then
local leftConstant = self:getConstant(ctx, left)
if leftConstant then
left, right = right, left
rightConstant = leftConstant
end
end
local beforeLeft = builder:label(func)
local leftReg = self:compileExprAsSource(ctx, left, target)
local leftEmitted = builder:label(func) ~= beforeLeft
local jumpLine = leftEmitted and (func.lines[#func.lines] or line) or (lineOverride or line)
local jump
if
isEq
and rightConstant
and (rightConstant.kind == "Nil" or rightConstant.kind == "Bool" or rightConstant.kind == "Number" or rightConstant.kind == "String")
and (builder.options.optimizationLevel or 1) >= 1
then
local op
local cid = self:addConstantValue(ctx, rightConstant)
local inverted = not jumpIfTrue
local flip = (expr.op == "==") == inverted
if rightConstant.kind == "Nil" then
op = LOP.JUMPXEQKNIL
elseif rightConstant.kind == "Bool" then
op = LOP.JUMPXEQKB
elseif rightConstant.kind == "Number" then
op = LOP.JUMPXEQKN
elseif rightConstant.kind == "String" then
op = LOP.JUMPXEQKS
else
error("invalid comparison constant", 2)
end
jump = builder:label(func)
builder:emitAD(func, op, leftReg, 0, jumpLine)
builder:emitAux(func, cid + (flip and 0x80000000 or 0), jumpLine)
else
local rightTarget = leftReg == target and target + 1 or target
local beforeRight = builder:label(func)
local rightReg = self:compileExprAsSource(ctx, right, rightTarget)
local inverted = not jumpIfTrue
local op = self:getJumpOpCompare(expr.op, inverted)
if builder:label(func) ~= beforeRight then
jumpLine = func.lines[#func.lines] or jumpLine
end
jump = builder:label(func)
if expr.op == ">" or expr.op == ">=" then
builder:emitAD(func, op, rightReg, 0, jumpLine)
builder:emitAux(func, leftReg, jumpLine)
else
builder:emitAD(func, op, leftReg, 0, jumpLine)
builder:emitAux(func, rightReg, jumpLine)
end
end
return jump
end
function PureCompiler:getBuiltinIdForCallee(ctx, expr)
local path = self:getImportPath(ctx, expr)
if path then
return self:getBuiltinIdFromPath(path)
end
if expr.kind == "Name" then
local localInfo = self:findLocal(ctx, expr.name)
if localInfo and localInfo.constPath then
return self:getBuiltinIdFromPath(localInfo.constPath)
end
local parentPath = self:findParentConstPath(ctx, expr.name)
if parentPath then
return self:getBuiltinIdFromPath(parentPath)
end
end
return nil
end
function PureCompiler:callHasOneResultO2(ctx, expr)
if (self.builder.options.optimizationLevel or 1) < 2 or expr.kind ~= "Call" then
return false
end
local builtinId = self:getBuiltinIdForCallee(ctx, expr.callee)
if builtinId ~= nil and builtinResultCounts[builtinId] == 1 then
return true
end
local funcExpr = self:getO2FunctionExprForCallee(ctx, expr.callee)
if funcExpr and self:functionReturnsOneO2(ctx, funcExpr) then
return true
end
return false
end
function PureCompiler:getDiscardableO2BuiltinCallResult(ctx, expr)
if (self.builder.options.optimizationLevel or 1) < 2 or expr.kind ~= "Call" or #expr.args ~= 1 then
return nil
end
if self:getBuiltinIdForCallee(ctx, expr.callee) ~= LBF.MATH_FLOOR then
return nil
end
local argument = self:getConstant(ctx, expr.args[1])
if argument and argument.kind == "Number" then
return { kind = "Number", value = math.floor(argument.value) }
end
return nil
end
function PureCompiler:isSideEffectFreeO2Expr(ctx, expr)
if (self.builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, expr) then
return true
end
if expr.kind == "Nil" or expr.kind == "Bool" or expr.kind == "Number" or expr.kind == "Integer" or expr.kind == "String" then
return true
elseif expr.kind == "Name" then
return self:findLocal(ctx, expr.name) ~= nil or self:findParentConstant(ctx, expr.name) ~= nil
elseif expr.kind == "Un" then
return self:isSideEffectFreeO2Expr(ctx, expr.expr)
elseif expr.kind == "Bin" then
return self:isSideEffectFreeO2Expr(ctx, expr.left) and self:isSideEffectFreeO2Expr(ctx, expr.right)
end
return false
end
function PureCompiler:exprContainsCallO2(expr)
if expr.kind == "Call" or expr.kind == "MethodCall" then
return true
elseif expr.kind == "Function" then
return false
elseif expr.kind == "Field" then
return self:exprContainsCallO2(expr.object)
elseif expr.kind == "Index" then
return self:exprContainsCallO2(expr.object) or self:exprContainsCallO2(expr.index)
elseif expr.kind == "Table" then
for _, entry in ipairs(expr.entries) do
if entry.key and self:exprContainsCallO2(entry.key) then
return true
end
if entry.value and self:exprContainsCallO2(entry.value) then
return true
end
end
elseif expr.kind == "Un" or expr.kind == "SingleResult" or expr.kind == "Instantiate" then
return self:exprContainsCallO2(expr.expr)
elseif expr.kind == "Bin" then
return self:exprContainsCallO2(expr.left) or self:exprContainsCallO2(expr.right)
elseif expr.kind == "InterpString" then
for _, value in ipairs(expr.expressions) do
if self:exprContainsCallO2(value) then
return true
end
end
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
if self:exprContainsCallO2(clause.condition) or self:exprContainsCallO2(clause.value) then
return true
end
end
return self:exprContainsCallO2(expr.elseValue)
end
return false
end
function PureCompiler:blockContainsCallO2(block)
for _, stat in ipairs(block.body) do
if stat.kind == "Nop" then
-- skip
elseif stat.kind == "CallStat" then
return true
elseif stat.kind == "Local" then
for _, value in ipairs(stat.values) do
if self:exprContainsCallO2(value) then
return true
end
end
elseif stat.kind == "Assign" then
for _, target in ipairs(stat.targets) do
if self:exprContainsCallO2(target) then
return true
end
end
for _, value in ipairs(stat.values) do
if self:exprContainsCallO2(value) then
return true
end
end
elseif stat.kind == "Return" then
for _, value in ipairs(stat.values) do
if self:exprContainsCallO2(value) then
return true
end
end
elseif stat.kind == "Do" then
if self:blockContainsCallO2(stat.body) then
return true
end
elseif stat.kind == "If" then
for _, clause in ipairs(stat.clauses) do
if self:exprContainsCallO2(clause.condition) or self:blockContainsCallO2(clause.body) then
return true
end
end
if stat.elseBody and self:blockContainsCallO2(stat.elseBody) then
return true
end
elseif stat.kind == "While" or stat.kind == "Repeat" then
if self:exprContainsCallO2(stat.condition) or self:blockContainsCallO2(stat.body) then
return true
end
elseif stat.kind == "ForNumeric" then
if self:exprContainsCallO2(stat.from) or self:exprContainsCallO2(stat.to) or stat.step and self:exprContainsCallO2(stat.step) or self:blockContainsCallO2(stat.body) then
return true
end
elseif stat.kind == "ForIn" then
for _, value in ipairs(stat.values) do
if self:exprContainsCallO2(value) then
return true
end
end
if self:blockContainsCallO2(stat.body) then
return true
end
end
end
return false
end
function PureCompiler:canInlineDirectO2Function(funcExpr)
return true
end
function PureCompiler:getInlineO2Function(ctx, call)
if (self.builder.options.optimizationLevel or 1) < 2 or call.kind ~= "Call" then
return nil
end
if self.getfenvUsed or self.setfenvUsed then
return nil
end
local funcExpr = nil
if call.callee.kind == "Name" then
local localInfo = self:findLocal(ctx, call.callee.name)
if not localInfo then
local parent = ctx.parent
while parent do
localInfo = parent.locals[call.callee.name]
if localInfo then
break
end
parent = parent.parent
end
end
funcExpr = localInfo and not localInfo.written and localInfo.inlineFunction or nil
elseif call.callee.kind == "Function" then
if ctx.parent ~= nil then
return nil
end
funcExpr = call.callee
if not self:canInlineDirectO2Function(funcExpr) then
return nil
end
self:ensureFunctionExprCompiled(ctx, funcExpr, nil)
end
if not funcExpr or funcExpr.isvararg then
return nil
end
for _, compiling in ipairs(self.compilingFunctionStack) do
if compiling == funcExpr then
return nil
end
end
return funcExpr
end
function PureCompiler:getO2FunctionExprForCallee(ctx, expr)
if expr.kind == "Name" then
local localInfo = self:findLocal(ctx, expr.name)
if not localInfo then
local parent = ctx.parent
while parent do
localInfo = parent.locals[expr.name]
if localInfo then
break
end
parent = parent.parent
end
end
if not localInfo or localInfo.written then
return nil
end
return localInfo.initFunction or localInfo.inlineFunction
elseif expr.kind == "SingleResult" or expr.kind == "Instantiate" then
return self:getO2FunctionExprForCallee(ctx, expr.expr)
elseif expr.kind == "Function" then
return expr
end
return nil
end
function PureCompiler:statAlwaysTerminatesO2(ctx, stat)
if stat.kind == "Return" or stat.kind == "Break" or stat.kind == "Continue" then
return true
elseif stat.kind == "Do" then
return self:blockAlwaysTerminatesO2(ctx, stat.body)
elseif stat.kind == "If" then
for _, clause in ipairs(stat.clauses) do
local condition = self:getConstant(ctx, clause.condition)
if condition and not self:constantTruth(condition) then
-- Keep walking the false path.
else
if not self:blockAlwaysTerminatesO2(ctx, clause.body) then
return false
end
if condition and self:constantTruth(condition) then
return true
end
end
end
return stat.elseBody ~= nil and self:blockAlwaysTerminatesO2(ctx, stat.elseBody)
end
return false
end
function PureCompiler:blockAlwaysTerminatesO2(ctx, block)
for _, stat in ipairs(block.body) do
if self:statAlwaysTerminatesO2(ctx, stat) then
return true
end
end
return false
end
function PureCompiler:exprMayReturnMultipleO2(ctx, expr, seen)
if expr.kind == "Vararg" then
return true
elseif expr.kind == "SingleResult" or expr.kind == "Instantiate" then
return false
elseif expr.kind == "Call" then
if (self.builder.options.optimizationLevel or 1) >= 2 and self:getConstant(ctx, expr) then
return false
end
local builtinId = self:getBuiltinIdForCallee(ctx, expr.callee)
if builtinId ~= nil then
return builtinResultCounts[builtinId] ~= 1
end
local funcExpr = self:getO2FunctionExprForCallee(ctx, expr.callee)
if funcExpr and self:functionReturnsOneO2(ctx, funcExpr, seen) then
return false
end
return true
elseif expr.kind == "MethodCall" then
local builtinId = self:getBuiltinIdForCallee(ctx, expr)
if builtinId ~= nil then
return builtinResultCounts[builtinId] ~= 1
end
return true
end
return false
end
function PureCompiler:returnsOneVisitorO2(ctx, block, seen)
for _, stat in ipairs(block.body) do
if stat.kind == "Return" then
if #stat.values ~= 1 or self:exprMayReturnMultipleO2(ctx, stat.values[1], seen) then
return false
end
elseif stat.kind == "Do" then
if not self:returnsOneVisitorO2(ctx, stat.body, seen) then
return false
end
elseif stat.kind == "If" then
for _, clause in ipairs(stat.clauses) do
if not self:returnsOneVisitorO2(ctx, clause.body, seen) then
return false
end
end
if stat.elseBody and not self:returnsOneVisitorO2(ctx, stat.elseBody, seen) then
return false
end
elseif stat.kind == "While" or stat.kind == "Repeat" or stat.kind == "ForNumeric" or stat.kind == "ForIn" then
if not self:returnsOneVisitorO2(ctx, stat.body, seen) then
return false
end
end
end
return true
end
function PureCompiler:functionReturnsOneO2(ctx, funcExpr, seen)
if (self.builder.options.optimizationLevel or 1) < 2 or funcExpr.isvararg or self.getfenvUsed or self.setfenvUsed then
return false
end
if funcExpr.returnsOneO2 ~= nil then
return funcExpr.returnsOneO2
end
seen = seen or {}
if seen[funcExpr] then
return false
end
seen[funcExpr] = true
local returnsOne = self:blockAlwaysTerminatesO2(ctx, funcExpr.body) and self:returnsOneVisitorO2(ctx, funcExpr.body, seen)
funcExpr.returnsOneO2 = returnsOne
seen[funcExpr] = nil
return returnsOne
end
function PureCompiler:blockHasInlineO2Return(block)
for _, stat in ipairs(block.body) do
if stat.kind == "Return" then
return true
elseif stat.kind == "Do" then
if self:blockHasInlineO2Return(stat.body) then
return true
end
elseif stat.kind == "If" then
for _, clause in ipairs(stat.clauses) do
if self:blockHasInlineO2Return(clause.body) then
return true
end
end
if stat.elseBody and self:blockHasInlineO2Return(stat.elseBody) then
return true
end
elseif stat.kind == "While" or stat.kind == "Repeat" or stat.kind == "ForNumeric" or stat.kind == "ForIn" then
if self:blockHasInlineO2Return(stat.body) then
return true
end
end
end
return false
end
function PureCompiler:summarizeInlineO2Block(ctx, block)
local summary = {
effects = {},
returns = {},
terminated = false,
}
for _, stat in ipairs(block.body) do
if stat.kind == "Nop" then
-- skip
elseif stat.kind == "Assign" and #stat.targets == 1 and #stat.values == 1 and stat.targets[1].kind == "Name" then
append(summary.effects, stat)
elseif stat.kind == "Local" then
append(summary.effects, stat)
elseif stat.kind == "LocalFunction" then
append(summary.effects, stat)
elseif stat.kind == "CallStat" then
append(summary.effects, stat)
elseif stat.kind == "ForNumeric" then
append(summary.effects, stat)
elseif stat.kind == "ForIn" then
append(summary.effects, stat)
elseif stat.kind == "Return" then
summary.returns = stat.values
summary.terminated = true
return summary
elseif stat.kind == "If" then
local selectedBody = nil
local dynamicCondition = false
for _, clause in ipairs(stat.clauses) do
local condition = ctx ~= nil and self:getConstant(ctx, clause.condition) or nil
if not condition then
dynamicCondition = true
break
end
if self:constantTruth(condition) then
selectedBody = clause.body
break
end
end
if dynamicCondition then
if self:blockHasInlineO2Return(stat.clauses[1].body) then
return nil
end
local canInline = true
for _, clause in ipairs(stat.clauses) do
if self:blockHasInlineO2Return(clause.body) then
canInline = false
break
end
end
if canInline and stat.elseBody and self:blockHasInlineO2Return(stat.elseBody) then
canInline = false
end
if not canInline then
return nil
end
append(summary.effects, stat)
else
if selectedBody == nil then
selectedBody = stat.elseBody
end
if selectedBody then
local branchSummary = self:summarizeInlineO2Block(ctx, selectedBody)
if not branchSummary then
return nil
end
for _, effect in ipairs(branchSummary.effects) do
append(summary.effects, effect)
end
if branchSummary.terminated then
summary.returns = branchSummary.returns
summary.terminated = true
return summary
end
end
end
else
return nil
end
end
return summary
end
function PureCompiler:summarizeInlineO2Function(funcExpr, ctx)
return self:summarizeInlineO2Block(ctx, funcExpr.body)
end
function PureCompiler:getInlineO2ReturnCount(ctx, call)
local funcExpr = self:getInlineO2Function(ctx, call)
if not funcExpr then
return nil
end
local summary = self:summarizeInlineO2Function(funcExpr)
if not summary then
return nil
end
for index, value in ipairs(summary.returns) do
if index == #summary.returns and (value.kind == "Call" or value.kind == "MethodCall" or value.kind == "Vararg") then
return nil
end
end
return #summary.returns
end
function PureCompiler:canBindInlineO2ArgumentsWithoutCode(ctx, funcExpr, call)
if #call.args > 0 and #funcExpr.params > #call.args and self:isInlineO2MultRetArg(ctx, call.args[#call.args]) then
return false
end
for index, name in ipairs(funcExpr.params) do
local arg = call.args[index]
local written = funcExpr.paramSymbols and funcExpr.paramSymbols[index] and funcExpr.paramSymbols[index].written == true or false
if arg == nil then
if written then
return false
end
elseif not written and self:getConstant(ctx, arg) then
-- constants are folded into the inlined body
elseif not written and arg.kind == "Name" then
local sourceLocal = self:findLocal(ctx, arg.name)
if not sourceLocal or sourceLocal.reg == nil or sourceLocal.written then
return false
end
else
return false
end
if name == nil then
return false
end
end
for index = #funcExpr.params + 1, #call.args do
local arg = call.args[index]
if not (arg.kind == "Name" or (self.builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, arg)) then
return false
end
end
return true
end
function PureCompiler:isInlineO2MultRetArg(ctx, expr)
return expr ~= nil and self:exprMayReturnMultipleO2(ctx, expr)
end
function PureCompiler:compileInlineO2Arguments(ctx, funcExpr, call, target, wantedResults, depth)
depth = depth or 0
if #call.args > 0 and #funcExpr.params > #call.args and self:isInlineO2MultRetArg(ctx, call.args[#call.args]) then
return nil
end
local localCount = #ctx.localList
local oldNextReg = ctx.nextReg
local oldProtectedTop = ctx.protectedTop or 0
local targetTop = wantedResults > 0 and target + wantedResults or target
ctx.nextReg = math.max(ctx.nextReg, oldProtectedTop, targetTop)
local nextTemp = math.max(ctx.nextReg, self:getLocalTop(ctx), targetTop)
local inlineDebugDepth = (ctx.debugScopeDepth or 0) + depth + 1
local paramLocals = {}
ctx.protectedTop = math.max(oldProtectedTop, nextTemp)
for index, name in ipairs(funcExpr.params) do
local arg = call.args[index]
local paramSymbol = funcExpr.paramSymbols and funcExpr.paramSymbols[index] or nil
local written = paramSymbol and paramSymbol.written == true or false
local symbolFunctionDepth = paramSymbol and paramSymbol.functionDepth or nil
local symbolLoopDepth = paramSymbol and paramSymbol.loopDepth or nil
local paramType = funcExpr.paramTypes and funcExpr.paramTypes[index] or nil
local localInfo = nil
if arg == nil then
if written then
local reg = nextTemp
ctx.builder:emitABC(ctx.func, LOP.LOADNIL, reg, 0, 0, call.line or 1)
self:useReg(ctx, reg)
nextTemp += 1
localInfo = {
name = name,
reg = reg,
written = true,
typeId = paramType,
functionDepth = symbolFunctionDepth,
loopDepth = symbolLoopDepth,
}
else
localInfo = {
name = name,
constKind = "Nil",
constValue = nil,
written = false,
typeId = paramType,
functionDepth = symbolFunctionDepth,
loopDepth = symbolLoopDepth,
}
end
elseif not written then
local constant = self:getConstant(ctx, arg)
if constant then
self:captureParentConstantsForDebug(ctx, arg)
localInfo = {
name = name,
constKind = constant.kind,
constValue = constant.value,
written = false,
typeId = paramType,
functionDepth = symbolFunctionDepth,
loopDepth = symbolLoopDepth,
}
elseif arg.kind == "Name" then
local sourceLocal = self:findLocal(ctx, arg.name)
if sourceLocal and sourceLocal.reg ~= nil and not sourceLocal.written then
localInfo = {
name = name,
reg = sourceLocal.reg,
written = false,
typeId = paramType or sourceLocal.typeId,
functionDepth = symbolFunctionDepth,
loopDepth = symbolLoopDepth,
}
if sourceLocal.constPath then
localInfo.constPath = shallowCopy(sourceLocal.constPath)
end
local initFunction = sourceLocal.initFunction or sourceLocal.inlineFunction
if initFunction then
localInfo.initFunction = initFunction
localInfo.inlineFunction = initFunction
end
end
end
end
if not localInfo then
local reg = nextTemp
local importPath = not written and arg and self:getImportPath(ctx, arg) or nil
self:compileExpr(ctx, arg, reg)
nextTemp += 1
localInfo = {
name = name,
reg = reg,
written = written,
typeId = paramType,
functionDepth = symbolFunctionDepth,
loopDepth = symbolLoopDepth,
}
if importPath then
localInfo.constPath = importPath
end
if not written and arg and arg.kind == "Function" then
localInfo.initFunction = arg
localInfo.inlineFunction = arg
end
end
localInfo.debugDepth = localInfo.debugDepth or inlineDebugDepth
append(paramLocals, localInfo)
end
for index = #funcExpr.params + 1, #call.args do
self:compileExprSide(ctx, call.args[index], nextTemp)
end
local debugStart = ctx.builder:label(ctx.func)
for _, localInfo in ipairs(paramLocals) do
localInfo.debugStart = debugStart
self:declareLocal(ctx, localInfo)
end
ctx.nextReg = math.max(nextTemp, self:getLocalTop(ctx), targetTop)
return {
localCount = localCount,
nextReg = oldNextReg,
protectedTop = oldProtectedTop,
}
end
function PureCompiler:leaveInlineO2Arguments(ctx, mark)
self:popLocals(ctx, mark.localCount)
ctx.nextReg = mark.nextReg
ctx.protectedTop = mark.protectedTop
end
function PureCompiler:leaveInlineO2Frame(ctx, mark)
self:popLocals(ctx, mark.localCount)
ctx.nextReg = mark.nextReg
ctx.protectedTop = mark.protectedTop
end
function PureCompiler:compileInlineO2Effects(ctx, summary, depth)
depth = depth or 0
for _, stat in ipairs(summary.effects) do
if stat.kind == "LocalFunction" then
local localInfo = self:addLocal(ctx, stat.name, stat.localSymbol and stat.localSymbol.written == true or false)
localInfo.inlineFunction = stat.value
localInfo.initFunction = stat.value
self:compileFunctionExpr(ctx, stat.value, localInfo.reg, stat.name, localInfo)
localInfo.debugStart = ctx.builder:label(ctx.func)
elseif stat.kind == "CallStat" then
local baseReg = ctx.nextReg
if stat.expr.kind ~= "Call" or not self:compileInlineO2Call(ctx, stat.expr, baseReg, 0, depth + 1) then
self:compileCall(ctx, stat.expr, baseReg, 0)
end
elseif stat.kind == "Local" then
if #stat.values > #stat.names then
return false
end
local lastValue = stat.values[#stat.values]
if #stat.values < #stat.names and lastValue and (lastValue.kind == "Call" or lastValue.kind == "MethodCall" or lastValue.kind == "Vararg") then
local oldDebugScopeDepth = ctx.debugScopeDepth
ctx.debugScopeDepth = (oldDebugScopeDepth or 0) + depth + 1
self:compileStatement(ctx, stat, false)
ctx.debugScopeDepth = oldDebugScopeDepth
continue
end
local constantLocals = {}
local canFoldAllLocals = true
for index, name in ipairs(stat.names) do
local value = stat.values[index]
local written = stat.localSymbols and stat.localSymbols[index] and stat.localSymbols[index].written == true or false
local localInfo = self:makeConstantLocal(ctx, name, value, written)
if not localInfo then
canFoldAllLocals = false
break
end
localInfo.debugDepth = (ctx.debugScopeDepth or 0) + depth + 1
append(constantLocals, localInfo)
end
if canFoldAllLocals then
local debugStart = ctx.builder:label(ctx.func)
for _, localInfo in ipairs(constantLocals) do
localInfo.debugStart = debugStart
self:declareLocal(ctx, localInfo)
end
continue
end
local locals = {}
for index, name in ipairs(stat.names) do
local value = stat.values[index]
local written = stat.localSymbols and stat.localSymbols[index] and stat.localSymbols[index].written == true or false
local localInfo = {
name = name,
reg = self:reserve(ctx, 1),
written = written,
debugDepth = (ctx.debugScopeDepth or 0) + depth + 1,
typeId = self:getInitializerBytecodeType(ctx, stat.values, index, #stat.names),
}
if value and value.kind == "Function" then
localInfo.initFunction = value
localInfo.inlineFunction = value
end
append(locals, localInfo)
end
if #locals > 0 then
if #stat.values == 0 then
for _, localInfo in ipairs(locals) do
ctx.builder:emitABC(ctx.func, LOP.LOADNIL, localInfo.reg, 0, 0, stat.line or 1)
self:useReg(ctx, localInfo.reg)
end
else
self:compileExprList(ctx, stat.values, locals[1].reg, #locals)
end
for index, localInfo in ipairs(locals) do
local value = stat.values[index]
local missingValueIsNil = value == nil
and (#stat.values == 0
or index > #stat.values
and #stat.values > 0
and stat.values[#stat.values].kind ~= "Call"
and stat.values[#stat.values].kind ~= "MethodCall"
and stat.values[#stat.values].kind ~= "Vararg")
if (ctx.builder.options.optimizationLevel or 1) >= 1
and not localInfo.written
and not (ctx.writeNames and ctx.writeNames[localInfo.name] == true)
and (value ~= nil or missingValueIsNil)
then
local constant = self:getConstant(ctx, value)
if constant then
localInfo.constKind = constant.kind
localInfo.constValue = constant.value
end
end
end
end
local debugStart = ctx.builder:label(ctx.func)
for _, localInfo in ipairs(locals) do
localInfo.debugStart = debugStart
self:declareLocal(ctx, localInfo)
end
elseif stat.kind == "ForNumeric" then
local unroll = self:getNumericForLoopUnroll(ctx, stat)
if unroll then
for iteration = 0, unroll.iterations - 1 do
local mark = self:enterScope(ctx)
self:declareLocal(ctx, {
name = stat.name,
constKind = "Number",
constValue = unroll.from + iteration * unroll.step,
written = false,
debugStart = ctx.builder:label(ctx.func),
})
self:compileBlock(ctx, stat.body, false)
self:leaveScope(ctx, mark)
end
else
self:compileStatement(ctx, stat, false)
end
elseif stat.kind == "ForIn" then
self:compileStatement(ctx, stat, false)
elseif stat.kind == "If" then
self:compileStatement(ctx, stat, false)
else
local target = stat.targets[1]
local localInfo = self:findLocal(ctx, target.name)
if not localInfo or localInfo.reg == nil then
return false
end
localInfo.constPath = nil
localInfo.constKind = nil
if (self.builder.options.coverageLevel or 0) >= 1 then
self:emitCoverage(ctx, stat.line or 1)
end
if stat.op == "=" then
self:compileExpr(ctx, stat.values[1], localInfo.reg, false)
elseif stat.op ~= ".." then
local op = arithmeticOps[stat.op]
if not op then
return false
end
local line = stat.line or 1
local value = stat.values[1]
local rightConstant = (self.builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, value) or nil
if rightConstant and rightConstant.kind == "Number" and arithmeticKOps[stat.op] then
local constant = ctx.builder:addConstantNumber(ctx.func, rightConstant.value)
if constant <= 255 then
ctx.builder:emitABC(ctx.func, arithmeticKOps[stat.op], localInfo.reg, localInfo.reg, constant, line)
else
local baseReg = ctx.nextReg
self:compileExpr(ctx, value, baseReg)
ctx.builder:emitABC(ctx.func, op, localInfo.reg, localInfo.reg, baseReg, line)
end
else
local baseReg = ctx.nextReg
local rightReg = self:compileExprAsSource(ctx, value or node("Nil", nil, line), baseReg)
ctx.builder:emitABC(ctx.func, op, localInfo.reg, localInfo.reg, rightReg, line)
end
else
return false
end
end
end
return true
end
local O2_INLINE_THRESHOLD = 25
local O2_INLINE_THRESHOLD_MAX_BOOST = 300
local O2_INLINE_DEPTH_LIMIT = 5
local O2_INLINE_COST_LIMIT = 127
local function addInlineCost(a, b)
local value = a + b
return value > O2_INLINE_COST_LIMIT and O2_INLINE_COST_LIMIT or value
end
local function mulInlineCost(a, b)
local value = a * b
return value > O2_INLINE_COST_LIMIT and O2_INLINE_COST_LIMIT or value
end
function PureCompiler:getInlineO2CostConstant(ctx, expr, locals, constants)
if expr == nil then
return { kind = "Nil" }
elseif expr.kind == "Name" then
if constants and constants[expr.name] then
return constants[expr.name]
end
if locals and locals[expr.name] then
return nil
end
return self:getConstant(ctx, expr)
elseif expr.kind == "SingleResult" then
return self:getInlineO2CostConstant(ctx, expr.expr, locals, constants)
elseif expr.kind == "Nil" then
return { kind = "Nil" }
elseif expr.kind == "Bool" then
return { kind = "Bool", value = expr.value }
elseif expr.kind == "Number" then
return { kind = "Number", value = expr.value }
elseif expr.kind == "Integer" then
return { kind = "Integer", value = expr.value }
elseif expr.kind == "String" then
return { kind = "String", value = expr.value }
elseif expr.kind == "Un" then
local value = self:getInlineO2CostConstant(ctx, expr.expr, locals, constants)
if value then
if expr.op == "not" then
return { kind = "Bool", value = not self:constantTruth(value) }
elseif expr.op == "-" and value.kind == "Number" then
return { kind = "Number", value = -value.value }
elseif expr.op == "#" and value.kind == "String" then
return { kind = "Number", value = #value.value }
end
end
elseif expr.kind == "Bin" then
local left = self:getInlineO2CostConstant(ctx, expr.left, locals, constants)
if expr.op == "and" then
if left and not self:constantTruth(left) then
return left
end
local right = self:getInlineO2CostConstant(ctx, expr.right, locals, constants)
return left and right or nil
elseif expr.op == "or" then
if left and self:constantTruth(left) then
return left
end
local right = self:getInlineO2CostConstant(ctx, expr.right, locals, constants)
return left and right or nil
end
local right = self:getInlineO2CostConstant(ctx, expr.right, locals, constants)
if not left or not right then
return nil
end
if expr.op == "==" or expr.op == "~=" then
local equal = left.kind == right.kind and (left.kind == "Integer" and left.value.key == right.value.key or left.value == right.value)
return { kind = "Bool", value = expr.op == "~=" and not equal or equal }
elseif expr.op == "<" or expr.op == "<=" or expr.op == ">" or expr.op == ">=" then
if left.kind == "Number" and right.kind == "Number" or left.kind == "String" and right.kind == "String" then
local result
if expr.op == "<" then
result = left.value < right.value
elseif expr.op == "<=" then
result = left.value <= right.value
elseif expr.op == ">" then
result = left.value > right.value
else
result = left.value >= right.value
end
return { kind = "Bool", value = result }
end
elseif expr.op == ".." and left.kind == "String" and right.kind == "String" then
return { kind = "String", value = left.value .. right.value }
elseif left.kind == "Number" and right.kind == "Number" then
if expr.op == "+" then
return { kind = "Number", value = left.value + right.value }
elseif expr.op == "-" then
return { kind = "Number", value = left.value - right.value }
elseif expr.op == "*" then
return { kind = "Number", value = left.value * right.value }
elseif expr.op == "/" then
return { kind = "Number", value = left.value / right.value }
elseif expr.op == "%" then
return { kind = "Number", value = left.value % right.value }
elseif expr.op == "^" then
return { kind = "Number", value = left.value ^ right.value }
end
end
elseif expr.kind == "Call" then
local builtinId = self:getBuiltinIdForCallee(ctx, expr.callee)
if builtinId then
local args = {}
local allConstant = true
for index, arg in ipairs(expr.args) do
local constant = self:getInlineO2CostConstant(ctx, arg, locals, constants)
if not constant then
allConstant = false
break
end
args[index] = constant
end
if allConstant then
if #args == 1 and (builtinId == LBF.TYPE or builtinId == LBF.TYPEOF) then
local argument = args[1]
local typeName = nil
if argument.kind == "Nil" then
typeName = "nil"
elseif argument.kind == "Bool" then
typeName = "boolean"
elseif argument.kind == "Number" then
typeName = "number"
elseif argument.kind == "Integer" then
typeName = "integer"
elseif argument.kind == "String" then
typeName = "string"
end
if typeName then
return { kind = "String", value = typeName }
end
end
return foldO2BuiltinConstant(builtinId, args)
end
end
else
return self:getConstant(ctx, expr)
end
return nil
end
function PureCompiler:costInlineO2Expr(ctx, expr, locals, constants)
if expr == nil then
return 0
end
if self:getInlineO2CostConstant(ctx, expr, locals, constants) then
return 0
end
if expr.kind == "Nil" or expr.kind == "Bool" or expr.kind == "Number" or expr.kind == "Integer" or expr.kind == "String" then
return 0
elseif expr.kind == "SingleResult" or expr.kind == "Instantiate" then
return self:costInlineO2Expr(ctx, expr.expr, locals, constants)
elseif expr.kind == "Name" then
if locals[expr.name] or self:findLocal(ctx, expr.name) or self:findParentConstant(ctx, expr.name) then
return 0
end
return 1
elseif expr.kind == "Vararg" then
return 3
elseif expr.kind == "Call" then
local builtinId = self:getBuiltinIdForCallee(ctx, expr.callee)
local builtin = builtinId ~= nil
local builtinShort = builtin and #expr.args <= 2
local cost = builtin and 2 or addInlineCost(3, self:costInlineO2Expr(ctx, expr.callee, locals, constants))
for _, arg in ipairs(expr.args) do
local argCost = self:costInlineO2Expr(ctx, arg, locals, constants)
if argCost == 0 and not builtinShort then
cost = addInlineCost(cost, 1)
else
cost = addInlineCost(cost, argCost)
end
end
return cost
elseif expr.kind == "MethodCall" then
local cost = addInlineCost(4, self:costInlineO2Expr(ctx, expr.object, locals, constants))
for _, arg in ipairs(expr.args) do
cost = addInlineCost(cost, self:costInlineO2Expr(ctx, arg, locals, constants))
end
return cost
elseif expr.kind == "Field" then
return addInlineCost(self:costInlineO2Expr(ctx, expr.object, locals, constants), 1)
elseif expr.kind == "Index" then
return addInlineCost(addInlineCost(self:costInlineO2Expr(ctx, expr.object, locals, constants), self:costInlineO2Expr(ctx, expr.index, locals, constants)), 1)
elseif expr.kind == "Function" then
return 10
elseif expr.kind == "Table" then
local cost = 10
for _, entry in ipairs(expr.entries) do
if entry.key then
cost = addInlineCost(cost, self:costInlineO2Expr(ctx, entry.key, locals, constants))
end
cost = addInlineCost(cost, self:costInlineO2Expr(ctx, entry.value, locals, constants))
cost = addInlineCost(cost, 1)
end
return cost
elseif expr.kind == "Un" then
return self:costInlineO2Expr(ctx, expr.expr, locals, constants)
elseif expr.kind == "Bin" then
return addInlineCost(addInlineCost(self:costInlineO2Expr(ctx, expr.left, locals, constants), self:costInlineO2Expr(ctx, expr.right, locals, constants)), 1)
elseif expr.kind == "IfExpr" then
return addInlineCost(
addInlineCost(self:costInlineO2Expr(ctx, expr.clauses[1] and expr.clauses[1].condition, locals, constants), self:costInlineO2Expr(ctx, expr.elseValue, locals, constants)),
2
)
elseif expr.kind == "InterpString" then
local cost = 3
for _, value in ipairs(expr.expressions) do
cost = addInlineCost(cost, self:costInlineO2Expr(ctx, value, locals, constants))
end
return cost
end
return 1
end
function PureCompiler:costInlineO2Block(ctx, block, locals, constants)
local cost = 0
for _, stat in ipairs(block.body) do
if stat.kind == "Nop" then
-- skip
elseif stat.kind == "Local" then
for _, value in ipairs(stat.values) do
cost = addInlineCost(cost, self:costInlineO2Expr(ctx, value, locals, constants))
end
for _, name in ipairs(stat.names) do
locals[name] = true
end
elseif stat.kind == "LocalFunction" then
cost = addInlineCost(cost, 10)
locals[stat.name] = true
elseif stat.kind == "Assign" then
local assignCost = 0
for _, target in ipairs(stat.targets) do
assignCost = addInlineCost(assignCost, self:costInlineO2Expr(ctx, target, locals, constants))
end
for _, value in ipairs(stat.values) do
assignCost = addInlineCost(assignCost, self:costInlineO2Expr(ctx, value, locals, constants))
end
cost = addInlineCost(cost, assignCost == 0 and 1 or assignCost)
elseif stat.kind == "CallStat" then
cost = addInlineCost(cost, self:costInlineO2Expr(ctx, stat.expr, locals, constants))
elseif stat.kind == "Return" then
for _, value in ipairs(stat.values) do
cost = addInlineCost(cost, self:costInlineO2Expr(ctx, value, locals, constants))
end
break
elseif stat.kind == "Do" then
cost = addInlineCost(cost, self:costInlineO2Block(ctx, stat.body, shallowCopy(locals), constants))
elseif stat.kind == "If" then
local selectedBody = nil
local knownCondition = false
for _, clause in ipairs(stat.clauses) do
local condition = self:getInlineO2CostConstant(ctx, clause.condition, locals, constants)
if condition then
knownCondition = true
if self:constantTruth(condition) then
selectedBody = clause.body
break
end
else
cost = addInlineCost(cost, self:costInlineO2Expr(ctx, clause.condition, locals, constants))
cost = addInlineCost(cost, self:costInlineO2Block(ctx, clause.body, shallowCopy(locals), constants))
end
end
if knownCondition then
if selectedBody then
cost = addInlineCost(cost, self:costInlineO2Block(ctx, selectedBody, shallowCopy(locals), constants))
elseif stat.elseBody then
cost = addInlineCost(cost, self:costInlineO2Block(ctx, stat.elseBody, shallowCopy(locals), constants))
end
else
cost = addInlineCost(cost, 1 + (stat.elseBody and 1 or 0))
if stat.elseBody then
cost = addInlineCost(cost, self:costInlineO2Block(ctx, stat.elseBody, shallowCopy(locals), constants))
end
end
elseif stat.kind == "ForNumeric" then
local fromCost = self:costInlineO2Expr(ctx, stat.from, locals, constants)
local toCost = self:costInlineO2Expr(ctx, stat.to, locals, constants)
local stepCost = stat.step and self:costInlineO2Expr(ctx, stat.step, locals, constants) or 0
local loopCost = addInlineCost(addInlineCost(fromCost, toCost), stepCost)
local factor = 3
local fromConstant = self:getInlineO2CostConstant(ctx, stat.from, locals, constants)
local toConstant = self:getInlineO2CostConstant(ctx, stat.to, locals, constants)
local stepConstant = stat.step and self:getInlineO2CostConstant(ctx, stat.step, locals, constants) or { kind = "Number", value = 1 }
if fromConstant and toConstant and stepConstant and fromConstant.kind == "Number" and toConstant.kind == "Number" and stepConstant.kind == "Number" then
factor = self:getNumericForTripCount(fromConstant.value, toConstant.value, stepConstant.value) or factor
end
local nested = shallowCopy(locals)
nested[stat.name] = true
cost = addInlineCost(cost, addInlineCost(loopCost, mulInlineCost(addInlineCost(self:costInlineO2Block(ctx, stat.body, nested, constants), 1), factor)))
elseif stat.kind == "ForIn" then
local valuesCost = 0
for _, value in ipairs(stat.values) do
valuesCost = addInlineCost(valuesCost, self:costInlineO2Expr(ctx, value, locals, constants))
end
local nested = shallowCopy(locals)
for _, name in ipairs(stat.names) do
nested[name] = true
end
cost = addInlineCost(cost, addInlineCost(valuesCost, mulInlineCost(addInlineCost(self:costInlineO2Block(ctx, stat.body, nested, constants), 1), 3)))
elseif stat.kind == "While" or stat.kind == "Repeat" then
local bodyCost = self:costInlineO2Block(ctx, stat.body, shallowCopy(locals), constants)
local conditionCost = self:costInlineO2Expr(ctx, stat.condition, locals, constants)
cost = addInlineCost(cost, mulInlineCost(addInlineCost(bodyCost, conditionCost), 3))
else
cost = addInlineCost(cost, 1)
end
end
return cost
end
function PureCompiler:shouldInlineO2Call(ctx, call, funcExpr, depth)
if depth >= O2_INLINE_DEPTH_LIMIT then
return false
end
local locals = {}
local constants = {}
for index, name in ipairs(funcExpr.params) do
locals[name] = true
local arg = call.args[index]
if arg == nil then
constants[name] = { kind = "Nil" }
else
local constant = self:getConstant(ctx, arg)
if constant then
constants[name] = constant
end
end
end
local inlinedCost = self:costInlineO2Block(ctx, funcExpr.body, locals, constants)
local baselineCost = addInlineCost(inlinedCost, 3)
local inlineProfit = inlinedCost == 0 and O2_INLINE_THRESHOLD_MAX_BOOST or math.min(O2_INLINE_THRESHOLD_MAX_BOOST, math.floor(100 * baselineCost / inlinedCost))
local threshold = math.floor(O2_INLINE_THRESHOLD * inlineProfit / 100)
return inlinedCost <= threshold
end
function PureCompiler:compileInlineO2Return(ctx, stat, frame, baseReg, line)
local builder = ctx.builder
local func = ctx.func
local wantedResults = frame.wantedResults
if wantedResults > 0 then
for index = 1, wantedResults do
local value = stat.values[index]
local target = frame.target + index - 1
if value then
self:compileExpr(ctx, value, target)
else
builder:emitABC(func, LOP.LOADNIL, target, 0, 0, line)
self:useReg(ctx, target)
end
end
end
for index = math.max(wantedResults, 0) + 1, #stat.values do
self:compileExprSide(ctx, stat.values[index], baseReg)
end
local jumpLine = func.lines[#func.lines] or line
self:emitCloseUpvals(ctx, jumpLine, (frame.localOffset or 0) + 1)
local jump = builder:label(func)
builder:emitAD(func, LOP.JUMP, 0, 0, jumpLine)
append(frame.returnJumps, jump)
end
function PureCompiler:compileInlineO2CallGeneric(ctx, call, funcExpr, target, wantedResults, depth)
if wantedResults < 0 then
return false
end
local argMark = self:compileInlineO2Arguments(ctx, funcExpr, call, target, wantedResults, depth)
if not argMark then
return false
end
local effectLocalCount = #ctx.localList
local oldInlineReturnFrame = ctx.inlineReturnFrame
local frame = {
target = target,
wantedResults = wantedResults,
localOffset = effectLocalCount,
returnJumps = {},
}
ctx.inlineReturnFrame = frame
append(self.inlineO2Stack, funcExpr)
local terminated = self:compileBlock(ctx, funcExpr.body, false, false)
self.inlineO2Stack[#self.inlineO2Stack] = nil
ctx.inlineReturnFrame = oldInlineReturnFrame
if terminated then
local lastJump = frame.returnJumps[#frame.returnJumps]
if lastJump ~= nil and lastJump == ctx.builder:label(ctx.func) - 1 then
local lastInsn = ctx.func.insns[#ctx.func.insns]
if lastInsn and lastInsn % 256 == LOP.JUMP then
table.remove(ctx.func.insns)
table.remove(ctx.func.lines)
frame.returnJumps[#frame.returnJumps] = nil
end
end
elseif wantedResults > 0 then
for index = 1, wantedResults do
local reg = target + index - 1
ctx.builder:emitABC(ctx.func, LOP.LOADNIL, reg, 0, 0, call.line or 1)
self:useReg(ctx, reg)
end
end
local returnLabel = ctx.builder:label(ctx.func)
for _, jump in ipairs(frame.returnJumps) do
ctx.builder:patchJump(ctx.func, jump, returnLabel)
end
self:leaveInlineO2Frame(ctx, argMark)
return true
end
function PureCompiler:compileInlineO2Call(ctx, call, target, wantedResults, depth)
depth = depth or 0
if depth >= O2_INLINE_DEPTH_LIMIT then
return false
end
local funcExpr = self:getInlineO2Function(ctx, call)
if not funcExpr then
return false
end
self.inlineO2Stack = self.inlineO2Stack or {}
for _, active in ipairs(self.inlineO2Stack) do
if active == funcExpr then
return false
end
end
if not self:shouldInlineO2Call(ctx, call, funcExpr, depth) then
return false
end
local initialSummary = self:summarizeInlineO2Function(funcExpr, nil)
if not initialSummary then
return self:compileInlineO2CallGeneric(ctx, call, funcExpr, target, wantedResults, depth)
end
local argMark = self:compileInlineO2Arguments(ctx, funcExpr, call, target, wantedResults, depth)
if not argMark then
return false
end
local summary = self:summarizeInlineO2Function(funcExpr, ctx)
if not summary then
self:leaveInlineO2Arguments(ctx, argMark)
return false
end
local effectLocalCount = #ctx.localList
append(self.inlineO2Stack, funcExpr)
if not self:compileInlineO2Effects(ctx, summary, depth) then
self.inlineO2Stack[#self.inlineO2Stack] = nil
self:popLocals(ctx, effectLocalCount)
self:leaveInlineO2Arguments(ctx, argMark)
return false
end
if wantedResults == 0 then
for _, value in ipairs(summary.returns) do
self:compileExprSide(ctx, value, target)
end
self.inlineO2Stack[#self.inlineO2Stack] = nil
self:leaveInlineO2Frame(ctx, argMark)
return true
end
if #summary.returns == 1 and summary.returns[1].kind == "Call" then
if self:compileInlineO2Call(ctx, summary.returns[1], target, wantedResults, depth + 1) then
self.inlineO2Stack[#self.inlineO2Stack] = nil
self:leaveInlineO2Frame(ctx, argMark)
return true
end
end
for index = 1, wantedResults do
local value = summary.returns[index]
if value then
self:compileExpr(ctx, value, target + index - 1)
else
ctx.builder:emitABC(ctx.func, LOP.LOADNIL, target + index - 1, 0, 0, call.line or 1)
end
end
self.inlineO2Stack[#self.inlineO2Stack] = nil
self:leaveInlineO2Frame(ctx, argMark)
return true
end
function PureCompiler:getCalleeInstructionCount(ctx, expr)
local importPath = self:getImportPath(ctx, expr)
local coverage = (self.builder.options.coverageLevel or 0) >= 2 and 1 or 0
if importPath then
return 2 + coverage
end
return 1 + coverage
end
function PureCompiler:flattenConcat(expr, out)
if expr.kind == "Bin" and expr.op == ".." then
self:flattenConcat(expr.left, out)
self:flattenConcat(expr.right, out)
else
append(out, expr)
end
end
function PureCompiler:exprReferencesParentRegister(ctx, expr, locals)
if expr.kind == "Name" then
if locals[expr.name] then
return false
end
local localInfo = ctx.locals[expr.name]
return localInfo ~= nil and localInfo.reg ~= nil
elseif expr.kind == "Field" then
return self:exprReferencesParentRegister(ctx, expr.object, locals)
elseif expr.kind == "Index" then
return self:exprReferencesParentRegister(ctx, expr.object, locals) or self:exprReferencesParentRegister(ctx, expr.index, locals)
elseif expr.kind == "Call" then
if self:exprReferencesParentRegister(ctx, expr.callee, locals) then
return true
end
for _, arg in ipairs(expr.args) do
if self:exprReferencesParentRegister(ctx, arg, locals) then
return true
end
end
elseif expr.kind == "MethodCall" then
if self:exprReferencesParentRegister(ctx, expr.object, locals) then
return true
end
for _, arg in ipairs(expr.args) do
if self:exprReferencesParentRegister(ctx, arg, locals) then
return true
end
end
elseif expr.kind == "Function" then
return false
elseif expr.kind == "Table" then
for _, entry in ipairs(expr.entries) do
if entry.key and self:exprReferencesParentRegister(ctx, entry.key, locals) then
return true
end
if entry.value and self:exprReferencesParentRegister(ctx, entry.value, locals) then
return true
end
end
elseif expr.kind == "Un" then
return self:exprReferencesParentRegister(ctx, expr.expr, locals)
elseif expr.kind == "Bin" then
return self:exprReferencesParentRegister(ctx, expr.left, locals) or self:exprReferencesParentRegister(ctx, expr.right, locals)
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
if self:exprReferencesParentRegister(ctx, clause.condition, locals) or self:exprReferencesParentRegister(ctx, clause.value, locals) then
return true
end
end
return self:exprReferencesParentRegister(ctx, expr.elseValue, locals)
end
return false
end
function PureCompiler:blockReferencesParentRegister(ctx, block, locals)
for _, stat in ipairs(block.body) do
if stat.kind == "Local" then
for _, value in ipairs(stat.values) do
if self:exprReferencesParentRegister(ctx, value, locals) then
return true
end
end
for _, name in ipairs(stat.names) do
locals[name] = true
end
elseif stat.kind == "LocalFunction" then
locals[stat.name] = true
elseif stat.kind == "Assign" then
for _, target in ipairs(stat.targets) do
if self:exprReferencesParentRegister(ctx, target, locals) then
return true
end
end
for _, value in ipairs(stat.values) do
if self:exprReferencesParentRegister(ctx, value, locals) then
return true
end
end
elseif stat.kind == "CallStat" then
if self:exprReferencesParentRegister(ctx, stat.expr, locals) then
return true
end
elseif stat.kind == "Return" then
for _, value in ipairs(stat.values) do
if self:exprReferencesParentRegister(ctx, value, locals) then
return true
end
end
elseif stat.kind == "Do" then
local nested = shallowCopy(locals)
if self:blockReferencesParentRegister(ctx, stat.body, nested) then
return true
end
elseif stat.kind == "If" then
for _, clause in ipairs(stat.clauses) do
if self:exprReferencesParentRegister(ctx, clause.condition, locals) then
return true
end
local nested = shallowCopy(locals)
if self:blockReferencesParentRegister(ctx, clause.body, nested) then
return true
end
end
if stat.elseBody then
local nested = shallowCopy(locals)
if self:blockReferencesParentRegister(ctx, stat.elseBody, nested) then
return true
end
end
elseif stat.kind == "While" then
if self:exprReferencesParentRegister(ctx, stat.condition, locals) then
return true
end
local nested = shallowCopy(locals)
if self:blockReferencesParentRegister(ctx, stat.body, nested) then
return true
end
elseif stat.kind == "Repeat" then
local nested = shallowCopy(locals)
if self:blockReferencesParentRegister(ctx, stat.body, nested) or self:exprReferencesParentRegister(ctx, stat.condition, nested) then
return true
end
elseif stat.kind == "ForNumeric" then
if self:exprReferencesParentRegister(ctx, stat.from, locals) or self:exprReferencesParentRegister(ctx, stat.to, locals) or stat.step and self:exprReferencesParentRegister(ctx, stat.step, locals) then
return true
end
local nested = shallowCopy(locals)
nested[stat.name] = true
if self:blockReferencesParentRegister(ctx, stat.body, nested) then
return true
end
elseif stat.kind == "ForIn" then
for _, value in ipairs(stat.values) do
if self:exprReferencesParentRegister(ctx, value, locals) then
return true
end
end
local nested = shallowCopy(locals)
for _, name in ipairs(stat.names) do
nested[name] = true
end
if self:blockReferencesParentRegister(ctx, stat.body, nested) then
return true
end
elseif stat.kind == "FunctionStat" then
if self:exprReferencesParentRegister(ctx, stat.target, locals) then
return true
end
end
end
return false
end
function PureCompiler:functionReferencesParentRegister(ctx, expr)
local locals = {}
for _, name in ipairs(expr.params) do
locals[name] = true
end
return self:blockReferencesParentRegister(ctx, expr.body, locals)
end
function PureCompiler:compileExpr(ctx, expr, target, targetTemp)
local builder = ctx.builder
local func = ctx.func
local line = expr.line or 1
targetTemp = targetTemp ~= false
if (builder.options.coverageLevel or 0) >= 2 then
self:emitCoverage(ctx, line)
end
if (builder.options.optimizationLevel or 1) >= 1 then
local constant = self:getConstant(ctx, expr)
if constant then
self:captureParentConstantsForDebug(ctx, expr)
self:emitConstant(ctx, constant, target, line)
return
end
end
local importPath = self:getImportPath(ctx, expr)
if importPath and #importPath > 1 then
self:compileImportPath(ctx, importPath, target, line)
return
end
if expr.kind == "Nil" then
builder:emitABC(func, LOP.LOADNIL, target, 0, 0, line)
self:useReg(ctx, target)
elseif expr.kind == "Bool" then
builder:emitABC(func, LOP.LOADB, target, expr.value and 1 or 0, 0, line)
self:useReg(ctx, target)
elseif expr.kind == "Number" then
self:emitLoadConstant(ctx, target, expr.value, line)
elseif expr.kind == "Integer" then
local cid = builder:addConstantInteger(func, expr.value)
self:emitLoadKIndex(ctx, target, cid, line)
elseif expr.kind == "String" then
self:emitLoadConstant(ctx, target, expr.value, line)
elseif expr.kind == "Instantiate" then
self:compileExpr(ctx, expr.expr, target, targetTemp)
elseif expr.kind == "InterpString" then
local lowered = self:lowerInterpString(ctx, expr)
if lowered.constant ~= nil then
self:emitLoadConstant(ctx, target, lowered.constant, line)
else
local resultTarget = math.max(target + 1, ctx.protectedTop or 0)
local callTarget = resultTarget
local stackTop = math.max(ctx.nextReg, ctx.protectedTop or 0, self:getLocalTop(ctx))
if stackTop > callTarget and callTarget ~= stackTop - 1 then
callTarget = stackTop
end
self:emitLoadConstant(ctx, callTarget, lowered.format, line)
self:compileCallArguments(ctx, { args = lowered.args }, callTarget + 2, false, line)
local callLine = func.lines[#func.lines] or line
local constant = builder:addConstantString(func, "format")
local slot = stringHash("format") % 256
builder:emitABC(func, LOP.NAMECALL, callTarget, callTarget, slot, callLine)
builder:emitAux(func, constant, callLine)
builder:emitABC(func, LOP.CALL, callTarget, #lowered.args + 2, 2, callLine)
self:useReg(ctx, callTarget + #lowered.args + 1)
if resultTarget ~= callTarget then
builder:emitABC(func, LOP.MOVE, resultTarget, callTarget, 0, callLine)
self:useReg(ctx, resultTarget)
end
if resultTarget ~= target then
builder:emitABC(func, LOP.MOVE, target, resultTarget, 0, callLine)
self:useReg(ctx, target)
end
end
elseif expr.kind == "SingleResult" then
if expr.expr.kind == "Call" or expr.expr.kind == "MethodCall" then
self:compileCall(ctx, expr.expr, target, 1)
elseif expr.expr.kind == "Vararg" then
builder:emitABC(func, LOP.GETVARARGS, target, 2, 0, line)
self:useReg(ctx, target)
else
self:compileExpr(ctx, expr.expr, target)
end
elseif expr.kind == "Vararg" then
builder:emitABC(func, LOP.GETVARARGS, target, 2, 0, line)
self:useReg(ctx, target)
elseif expr.kind == "Name" then
local localInfo = self:findLocal(ctx, expr.name)
if localInfo then
if self:canUseLocalConstant(localInfo) then
self:emitLocalConstant(ctx, localInfo, target, line)
elseif (builder.options.optimizationLevel or 1) == 0 or localInfo.reg ~= target then
builder:emitABC(func, LOP.MOVE, target, localInfo.reg, 0, line)
self:useReg(ctx, target)
else
self:useReg(ctx, target)
end
else
local parentConstant = self:findParentConstant(ctx, expr.name)
if parentConstant then
if (self.builder.options.debugLevel or 1) >= 2 then
self:rememberDebugConstantUpvalue(ctx, expr.name)
end
self:emitLocalConstant(ctx, parentConstant, target, line)
else
local upvalue = self:getUpvalue(ctx, expr.name)
if upvalue ~= nil then
builder:emitABC(func, LOP.GETUPVAL, target, upvalue, 0, line)
self:useReg(ctx, target)
else
self:emitGetGlobal(ctx, target, expr.name, line)
end
end
end
elseif expr.kind == "Field" then
local objectTarget = targetTemp and target or math.max(target + 1, ctx.protectedTop or 0)
local objectReg = self:compileExprAsSource(ctx, expr.object, objectTarget)
local constant = builder:addConstantString(func, expr.field)
local slot = stringHash(expr.field) % 256
builder:emitABC(func, LOP.GETTABLEKS, target, objectReg, slot, line)
builder:emitAux(func, constant, line)
self:useReg(ctx, target)
self:hintTemporaryExprRegType(ctx, expr.object, objectReg, LBC_TYPE_TABLE, 2)
elseif expr.kind == "Index" then
local objectTarget = math.max(target + 1, ctx.protectedTop or 0)
local objectReg = self:compileExprAsSource(ctx, expr.object, objectTarget)
local indexConstant = (builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, expr.index) or nil
if indexConstant and indexConstant.kind == "String" then
local constant = builder:addConstantString(func, indexConstant.value)
local slot = stringHash(indexConstant.value) % 256
builder:emitABC(func, LOP.GETTABLEKS, target, objectReg, slot, line)
builder:emitAux(func, constant, line)
self:useReg(ctx, target)
self:hintTemporaryExprRegType(ctx, expr.object, objectReg, LBC_TYPE_TABLE, 2)
elseif indexConstant and indexConstant.kind == "Number" and indexConstant.value % 1 == 0 and indexConstant.value >= 1 and indexConstant.value <= 256 then
builder:emitABC(func, LOP.GETTABLEN, target, objectReg, indexConstant.value - 1, line)
self:useReg(ctx, target)
self:hintTemporaryExprRegType(ctx, expr.object, objectReg, LBC_TYPE_TABLE, 1)
else
local indexTarget = objectReg == objectTarget and objectTarget + 1 or objectTarget
local indexReg = self:compileExprAsSource(ctx, expr.index, indexTarget)
builder:emitABC(func, LOP.GETTABLE, target, objectReg, indexReg, line)
self:useReg(ctx, target)
self:hintTemporaryExprRegType(ctx, expr.object, objectReg, LBC_TYPE_TABLE, 1)
self:hintTemporaryExprRegType(ctx, expr.index, indexReg, LBC_TYPE_NUMBER, 1)
end
elseif expr.kind == "Call" or expr.kind == "MethodCall" then
if expr.kind == "Call" and self:compileInlineO2Call(ctx, expr, target, 1) then
return
end
self:compileCall(ctx, expr, target, 1, targetTemp)
elseif expr.kind == "Function" then
local selfLocal = ctx.initializingLocalsByReg and ctx.initializingLocalsByReg[target] or nil
self:compileFunctionExpr(ctx, expr, target, nil, selfLocal)
elseif expr.kind == "IfExpr" then
local optimizationLevel = builder.options.optimizationLevel or 1
local firstClauseIndex = 1
while firstClauseIndex <= #expr.clauses do
local clause = expr.clauses[firstClauseIndex]
local condition = optimizationLevel >= 1 and self:getConstant(ctx, clause.condition) or nil
if not condition then
break
end
if self:constantTruth(condition) then
self:compileExpr(ctx, clause.value, target, targetTemp)
return
end
firstClauseIndex += 1
end
if firstClauseIndex > #expr.clauses then
self:compileExpr(ctx, expr.elseValue, target, targetTemp)
return
end
if firstClauseIndex == 1 and #expr.clauses == 1 then
local clause = expr.clauses[firstClauseIndex]
local conditionReg = self:getExprLocalReg(ctx, clause.condition)
if conditionReg ~= nil then
local trueReg = self:getExprLocalReg(ctx, clause.value)
local falseReg = self:getExprLocalReg(ctx, expr.elseValue)
if conditionReg == trueReg and falseReg ~= nil then
builder:emitABC(func, LOP.OR, target, conditionReg, falseReg, line)
self:useReg(ctx, target)
return
elseif conditionReg == trueReg then
local constant = self:getConstant(ctx, expr.elseValue)
local constantIndex = constant and self:addConstantIndex(ctx, constant) or nil
if constantIndex and constantIndex <= 255 then
builder:emitABC(func, LOP.ORK, target, conditionReg, constantIndex, line)
self:useReg(ctx, target)
return
elseif constant then
local otherReg = self:compileExprAsSource(ctx, expr.elseValue, self:getTempReg(ctx, target))
builder:emitABC(func, LOP.OR, target, conditionReg, otherReg, line)
self:useReg(ctx, target)
return
end
elseif conditionReg == falseReg and trueReg ~= nil then
builder:emitABC(func, LOP.AND, target, conditionReg, trueReg, line)
self:useReg(ctx, target)
return
elseif conditionReg == falseReg then
local constant = self:getConstant(ctx, clause.value)
local constantIndex = constant and self:addConstantIndex(ctx, constant) or nil
if constantIndex and constantIndex <= 255 then
builder:emitABC(func, LOP.ANDK, target, conditionReg, constantIndex, line)
self:useReg(ctx, target)
return
elseif constant then
local otherReg = self:compileExprAsSource(ctx, clause.value, self:getTempReg(ctx, target))
builder:emitABC(func, LOP.AND, target, conditionReg, otherReg, line)
self:useReg(ctx, target)
return
end
end
end
end
local endJumps = {}
for clauseIndex = firstClauseIndex, #expr.clauses do
local clause = expr.clauses[clauseIndex]
local falseJumps = self:compileCondJumps(ctx, clause.condition, false, target + 1)
self:compileExpr(ctx, clause.value, target)
local jump = builder:label(func)
builder:emitAD(func, LOP.JUMP, 0, 0, func.lines[#func.lines] or clause.value.line or line)
append(endJumps, jump)
local nextLabel = builder:label(func)
for _, falseJump in ipairs(falseJumps) do
builder:patchJump(func, falseJump, nextLabel)
end
end
self:compileExpr(ctx, expr.elseValue, target)
local exit = builder:label(func)
for _, jump in ipairs(endJumps) do
builder:patchJump(func, jump, exit)
end
elseif expr.kind == "Table" then
local resultTarget = target
if not targetTemp and #expr.entries > 0 then
target = self:getTempReg(ctx, target)
end
local arrayCount = 0
local hashCount = 0
local recordCount = 0
local indexCount = 0
for _, entry in ipairs(expr.entries) do
if entry.kind == "array" then
arrayCount += 1
else
hashCount += 1
if entry.kind == "field" then
recordCount += 1
end
end
end
if (builder.options.optimizationLevel or 1) >= 1 and arrayCount == 0 and hashCount > 0 then
for _, entry in ipairs(expr.entries) do
local keyConstant = nil
if entry.kind == "index" then
keyConstant = self:getConstant(ctx, entry.key)
end
if keyConstant and keyConstant.kind == "Number" and keyConstant.value == indexCount + 1 then
indexCount += 1
end
end
if hashCount == recordCount + indexCount then
hashCount = recordCount
else
indexCount = 0
end
end
local lastEntry = expr.entries[#expr.entries]
local trailingVarargs = lastEntry and lastEntry.kind == "array" and lastEntry.value.kind == "Vararg"
local arrayAllocation = arrayCount - (trailingVarargs and 1 or 0) + indexCount
local hashAllocation = #expr.entries == 0 and math.max(hashCount, (builder.options.optimizationLevel or 1) >= 1 and (expr.predictedHashCount or 0) or 0) or hashCount
local tableTemplateKeys = nil
local tableTemplateValues = nil
local tableTemplateValueByKey = nil
local tableTemplateHasConstants = false
if arrayCount == 0 and indexCount == 0 and hashCount == recordCount and recordCount >= 1 and recordCount <= 32 then
tableTemplateKeys = {}
tableTemplateValues = {}
tableTemplateValueByKey = {}
local tableTemplateIndexByKey = {}
for _, entry in ipairs(expr.entries) do
if entry.kind ~= "field" then
tableTemplateKeys = nil
tableTemplateValues = nil
tableTemplateValueByKey = nil
break
end
local keyConstant = builder:addConstantString(func, entry.key)
local valueConstant = -1
if (builder.options.optimizationLevel or 1) >= 1 then
valueConstant = self:getConstantIndex(ctx, entry.value) or -1
end
local keyIndex = tableTemplateIndexByKey[keyConstant]
if keyIndex then
if tableTemplateValues[keyIndex] ~= -1 then
tableTemplateValues[keyIndex] = valueConstant
end
else
append(tableTemplateKeys, keyConstant)
append(tableTemplateValues, valueConstant)
tableTemplateIndexByKey[keyConstant] = #tableTemplateKeys
end
end
if tableTemplateValues then
for index, valueConstant in ipairs(tableTemplateValues) do
if valueConstant >= 0 then
tableTemplateHasConstants = true
end
tableTemplateValueByKey[tableTemplateKeys[index]] = valueConstant
end
end
end
if tableTemplateKeys then
local tableConstant = builder:addConstantTable(func, tableTemplateKeys, tableTemplateValues, tableTemplateHasConstants)
builder:emitAD(func, LOP.DUPTABLE, target, tableConstant, line)
else
builder:emitABC(func, LOP.NEWTABLE, target, encodeTableSize(hashAllocation), 0, line)
builder:emitAux(func, math.max(arrayAllocation, expr.predictedArrayCount or 0), line)
end
self:useReg(ctx, target)
local arrayIndex = 1
local arrayBatch = {}
local batchStart = 1
local arrayChunkSize = math.min(16, arrayCount)
local arrayChunkReg = math.max(target + 1, ctx.protectedTop or 0, ctx.nextReg)
local tempBase = arrayChunkReg + arrayChunkSize
local function flushArrayBatch(precoverLine)
if #arrayBatch == 0 then
return false
end
local multRet = false
for index, item in ipairs(arrayBatch) do
if (builder.options.coverageLevel or 0) >= 2 and not item.precovered then
self:emitCoverage(ctx, item.value.line or line)
end
if item.multRet and index == #arrayBatch then
if item.value.kind == "Vararg" then
builder:emitABC(func, LOP.GETVARARGS, arrayChunkReg + index - 1, 0, 0, item.value.line or line)
self:useReg(ctx, arrayChunkReg + index - 1)
else
self:compileCall(ctx, item.value, arrayChunkReg + index - 1, -1)
end
multRet = true
else
self:compileExpr(ctx, item.value, arrayChunkReg + index - 1)
end
end
if precoverLine ~= nil and (builder.options.coverageLevel or 0) >= 2 then
self:emitCoverage(ctx, precoverLine)
end
local batchLine = precoverLine or arrayBatch[#arrayBatch].value.line or line
builder:emitABC(func, LOP.SETLIST, target, arrayChunkReg, multRet and 0 or #arrayBatch + 1, batchLine)
builder:emitAux(func, batchStart, batchLine)
arrayBatch = {}
return precoverLine ~= nil
end
for entryIndex, entry in ipairs(expr.entries) do
if entry.kind == "array" then
local precovered = false
if #arrayBatch == 16 then
flushArrayBatch(entry.value.line or line)
precovered = true
end
if #arrayBatch == 0 then
batchStart = arrayIndex
end
local multRet = entryIndex == #expr.entries and (entry.value.kind == "Vararg" or entry.value.kind == "Call" or entry.value.kind == "MethodCall")
if multRet and entry.value.kind == "Call" then
if self:callHasOneResultO2(ctx, entry.value) or self:getInlineO2ReturnCount(ctx, entry.value) == 1 then
multRet = false
end
end
append(arrayBatch, {
value = entry.value,
precovered = precovered,
multRet = multRet,
})
arrayIndex += 1
elseif entry.kind == "field" then
local precovered = flushArrayBatch(entry.value.line or line)
if (builder.options.coverageLevel or 0) >= 2 and not precovered then
self:emitCoverage(ctx, entry.value.line or line)
end
local skipPackedConstant = false
if tableTemplateValueByKey then
local keyConstant = builder:addConstantString(func, entry.key)
local valueConstant = tableTemplateValueByKey[keyConstant]
skipPackedConstant = valueConstant ~= nil and valueConstant >= 0
end
if skipPackedConstant then
-- Constant record fields are pre-filled by DUPTABLE when optimization is enabled.
elseif (builder.options.optimizationLevel or 1) >= 1 then
local sourceReg
if entry.value.kind == "Function" then
self:compileFunctionExpr(ctx, entry.value, tempBase, entry.key)
sourceReg = tempBase
else
sourceReg = self:compileExprAsSource(ctx, entry.value, tempBase)
end
local constant = builder:addConstantString(func, entry.key)
local slot = stringHash(entry.key) % 256
builder:emitABC(func, LOP.SETTABLEKS, sourceReg, target, slot, entry.value.line)
builder:emitAux(func, constant, entry.value.line)
else
self:emitLoadConstant(ctx, tempBase, entry.key, entry.value.line or line)
local valueReg = tempBase + 1
local sourceReg
if entry.value.kind == "Function" then
self:compileFunctionExpr(ctx, entry.value, valueReg, entry.key)
sourceReg = valueReg
else
sourceReg = self:compileExprAsSource(ctx, entry.value, valueReg)
end
builder:emitABC(func, LOP.SETTABLE, sourceReg, target, tempBase, entry.value.line)
end
else
local precovered = flushArrayBatch(entry.value.line or line)
if (builder.options.coverageLevel or 0) >= 2 and not precovered then
self:emitCoverage(ctx, entry.value.line or line)
end
local keyConstant = (builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, entry.key) or nil
if keyConstant and keyConstant.kind == "String" then
local sourceReg
if entry.value.kind == "Function" then
self:compileFunctionExpr(ctx, entry.value, tempBase, keyConstant.value)
sourceReg = tempBase
else
sourceReg = self:compileExprAsSource(ctx, entry.value, tempBase)
end
local constant = builder:addConstantString(func, keyConstant.value)
local slot = stringHash(keyConstant.value) % 256
builder:emitABC(func, LOP.SETTABLEKS, sourceReg, target, slot, entry.value.line)
builder:emitAux(func, constant, entry.value.line)
elseif keyConstant and keyConstant.kind == "Number" and keyConstant.value % 1 == 0 and keyConstant.value >= 1 and keyConstant.value <= 256 then
local sourceReg = self:compileExprAsSource(ctx, entry.value, tempBase)
builder:emitABC(func, LOP.SETTABLEN, sourceReg, target, keyConstant.value - 1, entry.value.line)
else
local keyReg = self:compileExprAsSource(ctx, entry.key, tempBase)
local valueTarget = keyReg == tempBase and tempBase + 1 or tempBase
local sourceReg = self:compileExprAsSource(ctx, entry.value, valueTarget)
builder:emitABC(func, LOP.SETTABLE, sourceReg, target, keyReg, entry.value.line)
end
end
end
flushArrayBatch()
if target ~= resultTarget then
builder:emitABC(func, LOP.MOVE, resultTarget, target, 0, line)
self:useReg(ctx, resultTarget)
end
elseif expr.kind == "Un" then
if expr.op == "-" and expr.expr.kind == "Integer" and not expr.expr.parenthesized then
local cid = builder:addConstantInteger(func, negateIntegerLiteralValue(expr.expr.value))
self:emitLoadKIndex(ctx, target, cid, line)
return
end
local sourceReg = self:compileExprAsSource(ctx, expr.expr, self:getTempReg(ctx, target))
local op = expr.op == "not" and LOP.NOT or expr.op == "-" and LOP.MINUS or LOP.LENGTH
builder:emitABC(func, op, target, sourceReg, 0, line)
self:useReg(ctx, target)
elseif expr.kind == "Bin" and compareOps[expr.op] then
local trueJump = self:compileCompareJump(ctx, expr, true, self:getTempReg(ctx, target))
local boolLine = func.lines[#func.lines] or line
builder:emitABC(func, LOP.LOADB, target, 0, 1, boolLine)
local trueLabel = builder:label(func)
builder:patchJump(func, trueJump, trueLabel)
builder:emitABC(func, LOP.LOADB, target, 1, 0, boolLine)
self:useReg(ctx, target)
elseif expr.kind == "Bin" and (expr.op == "and" or expr.op == "or") then
local andOp = expr.op == "and"
if (builder.options.optimizationLevel or 1) >= 1 then
local leftConstant = self:getConstant(ctx, expr.left)
if leftConstant then
self:compileExpr(ctx, (andOp == self:constantTruth(leftConstant)) and expr.right or expr.left, target, targetTemp)
return
end
end
if not self:isConditionFast(ctx, expr.left) then
local rightReg = self:getExprLocalReg(ctx, expr.right)
if rightReg ~= nil then
local leftReg = self:compileExprAsSource(ctx, expr.left, target + 1)
builder:emitABC(func, andOp and LOP.AND or LOP.OR, target, leftReg, rightReg, line)
self:useReg(ctx, target)
return
end
local rightConstant = (builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, expr.right) or nil
local constant = rightConstant and self:addConstantIndex(ctx, rightConstant) or nil
if constant and constant <= 255 then
local leftReg = self:compileExprAsSource(ctx, expr.left, target + 1)
builder:emitABC(func, andOp and LOP.ANDK or LOP.ORK, target, leftReg, constant, line)
self:useReg(ctx, target)
return
end
end
local resultReg = targetTemp and target or self:getTempReg(ctx, target)
self:useReg(ctx, resultReg)
local skipJumps = {}
self:compileConditionValue(ctx, expr.left, resultReg, skipJumps, not andOp)
self:compileExpr(ctx, expr.right, resultReg, true)
local moveLabel = builder:label(func)
for _, jump in ipairs(skipJumps) do
builder:patchJump(func, jump, moveLabel)
end
if target ~= resultReg then
builder:emitABC(func, LOP.MOVE, target, resultReg, 0, line)
end
self:useReg(ctx, target)
elseif expr.kind == "Bin" then
if expr.op == ".." then
local parts = {}
self:flattenConcat(expr, parts)
local oldProtectedTop = ctx.protectedTop or 0
local firstReg = self:getTempReg(ctx, target)
local concatTop = firstReg + #parts - 1
for index, part in ipairs(parts) do
if index < #parts then
ctx.protectedTop = math.max(oldProtectedTop, concatTop + 1)
else
ctx.protectedTop = oldProtectedTop
end
self:compileExpr(ctx, part, firstReg + index - 1)
end
ctx.protectedTop = oldProtectedTop
local concatLine = func.lines[#func.lines] or line
builder:emitABC(func, LOP.CONCAT, target, firstReg, concatTop, concatLine)
self:useReg(ctx, target)
else
local op = arithmeticOps[expr.op]
if not op then
error("unsupported binary operator " .. tostring(expr.op), 2)
end
local useKOps = (builder.options.optimizationLevel or 1) >= 1
local rightConstant = useKOps and self:getConstant(ctx, expr.right) or nil
local leftConstant = useKOps and self:getConstant(ctx, expr.left) or nil
local lineStart = #func.lines
if useKOps and rightConstant and rightConstant.kind == "Number" and arithmeticKOps[expr.op] then
local constant = builder:addConstantNumber(func, rightConstant.value)
local leftTemp = self:getTempReg(ctx, target)
local leftReg = self:compileExprAsSourceNoMove(ctx, expr.left, leftTemp)
if constant <= 255 then
local emitLine = #func.lines > lineStart and func.lines[#func.lines] or line
builder:emitABC(func, arithmeticKOps[expr.op], target, leftReg, constant, emitLine)
self:hintTemporaryExprRegType(ctx, expr.left, leftReg, LBC_TYPE_NUMBER, 1)
else
local rightTarget = leftReg == leftTemp and leftTemp + 1 or leftTemp
local rightReg = rightTarget
self:compileExpr(ctx, expr.right, rightReg)
local emitLine = #func.lines > lineStart and func.lines[#func.lines] or line
builder:emitABC(func, op, target, leftReg, rightReg, emitLine)
self:hintTemporaryExprRegType(ctx, expr.left, leftReg, LBC_TYPE_NUMBER, 1)
self:hintTemporaryExprRegType(ctx, expr.right, rightReg, LBC_TYPE_NUMBER, 1)
end
elseif (builder.options.optimizationLevel or 1) >= 2 and leftConstant and leftConstant.kind == "Number" and (expr.op == "+" or expr.op == "*") and self:getBytecodeType(ctx, expr) == LBC_TYPE_NUMBER then
local constant = builder:addConstantNumber(func, leftConstant.value)
if constant <= 255 then
local rightTemp = self:getTempReg(ctx, target)
local rightReg = self:compileExprAsSourceNoMove(ctx, expr.right, rightTemp)
local emitLine = #func.lines > lineStart and func.lines[#func.lines] or line
builder:emitABC(func, arithmeticKOps[expr.op], target, rightReg, constant, emitLine)
self:hintTemporaryExprRegType(ctx, expr.right, rightReg, LBC_TYPE_NUMBER, 1)
else
local leftTarget = self:getTempReg(ctx, target)
local leftReg = self:compileExprAsSourceNoMove(ctx, expr.left, leftTarget)
local rightTarget = leftReg == leftTarget and leftTarget + 1 or leftTarget
local rightReg = self:compileExprAsSourceNoMove(ctx, expr.right, rightTarget)
local emitLine = #func.lines > lineStart and func.lines[#func.lines] or line
builder:emitABC(func, op, target, leftReg, rightReg, emitLine)
self:hintTemporaryExprRegType(ctx, expr.left, leftReg, LBC_TYPE_NUMBER, 1)
self:hintTemporaryExprRegType(ctx, expr.right, rightReg, LBC_TYPE_NUMBER, 1)
end
elseif useKOps and leftConstant and leftConstant.kind == "Number" and (expr.op == "-" or expr.op == "/") then
local constant = builder:addConstantNumber(func, leftConstant.value)
local rightTemp = self:getTempReg(ctx, target)
local rightReg = self:compileExprAsSourceNoMove(ctx, expr.right, rightTemp)
if constant <= 255 then
local emitLine = #func.lines > lineStart and func.lines[#func.lines] or line
builder:emitABC(func, expr.op == "-" and LOP.SUBRK or LOP.DIVRK, target, constant, rightReg, emitLine)
self:hintTemporaryExprRegType(ctx, expr.right, rightReg, LBC_TYPE_NUMBER, 1)
else
local leftTarget = rightReg == rightTemp and rightTemp + 1 or rightTemp
local leftReg = leftTarget
self:compileExpr(ctx, expr.left, leftReg)
local emitLine = #func.lines > lineStart and func.lines[#func.lines] or line
builder:emitABC(func, op, target, leftReg, rightReg, emitLine)
self:hintTemporaryExprRegType(ctx, expr.left, leftReg, LBC_TYPE_NUMBER, 1)
self:hintTemporaryExprRegType(ctx, expr.right, rightReg, LBC_TYPE_NUMBER, 1)
end
else
local leftTarget = self:getTempReg(ctx, target)
local leftReg = self:compileExprAsSourceNoMove(ctx, expr.left, leftTarget)
local rightTarget = leftReg == leftTarget and leftTarget + 1 or leftTarget
local rightReg = self:compileExprAsSourceNoMove(ctx, expr.right, rightTarget)
local emitLine = #func.lines > lineStart and func.lines[#func.lines] or line
builder:emitABC(func, op, target, leftReg, rightReg, emitLine)
self:hintTemporaryExprRegType(ctx, expr.left, leftReg, LBC_TYPE_NUMBER, 1)
self:hintTemporaryExprRegType(ctx, expr.right, rightReg, LBC_TYPE_NUMBER, 1)
end
self:useReg(ctx, target)
end
else
error("unsupported expression " .. tostring(expr.kind), 2)
end
end
local function truncateToInteger(value)
if value >= 0 then
return math.floor(value)
else
return math.ceil(value)
end
end
function PureCompiler:compileFastcall(ctx, expr, target, wantedResults, argBase, argCount, bfid, multArgs, line, bfK)
if argCount < 1 or multArgs then
return false
end
local maxFastcallArgs = 2
if argCount == 3 then
for _, arg in ipairs(expr.args) do
if self:getExprLocalReg(ctx, arg) ~= nil then
maxFastcallArgs = 3
break
end
end
end
if bfK ~= nil then
maxFastcallArgs = 3
end
if argCount > maxFastcallArgs then
return false
end
local op = LOP.FASTCALL1
local useK = false
local secondConstant = nil
if argCount == 1 then
op = LOP.FASTCALL1
elseif bfK ~= nil then
op = LOP.FASTCALL2K
useK = true
elseif argCount == 2 then
secondConstant = self:getConstant(ctx, expr.args[2])
if secondConstant ~= nil then
op = LOP.FASTCALL2K
useK = true
else
op = LOP.FASTCALL2
end
else
op = LOP.FASTCALL3
end
local args = {}
local precompiled = {}
for index, arg in ipairs(expr.args) do
if useK and index > 1 then
local constant = index == 2 and secondConstant or self:getConstant(ctx, arg)
if constant == nil then
return false
end
args[index] = self:addConstantIndex(ctx, constant)
else
local reg = self:getExprLocalReg(ctx, arg)
if reg ~= nil then
args[index] = reg
else
reg = argBase + index - 1
self:compileExpr(ctx, arg, reg)
args[index] = reg
precompiled[index] = true
end
end
end
local skipCount = op == LOP.FASTCALL1 and 0 or 1
for index = 1, argCount do
local desiredReg = argBase + index - 1
if useK and index > 1 then
skipCount += self:getLoadKInstructionCount(args[index])
elseif args[index] ~= desiredReg then
skipCount += 1
end
end
skipCount += self:getCalleeInstructionCount(ctx, expr.callee)
if op == LOP.FASTCALL1 and (ctx.builder.options.coverageLevel or 0) >= 2 and wantedResults == 1 and not precompiled[1] then
self:emitCoverage(ctx, expr.args[1].line or line)
end
local fastcallLine = line
for index = 1, argCount do
if precompiled[index] then
fastcallLine = ctx.func.lines[#ctx.func.lines] or line
end
end
if op == LOP.FASTCALL1 then
ctx.builder:emitABC(ctx.func, op, bfid, args[1], skipCount, fastcallLine)
else
ctx.builder:emitABC(ctx.func, op, bfid, args[1], skipCount, fastcallLine)
if op == LOP.FASTCALL3 then
ctx.builder:emitAux(ctx.func, args[2] + args[3] * 256, fastcallLine)
else
ctx.builder:emitAux(ctx.func, bfK or args[2], fastcallLine)
end
end
for index = 1, argCount do
local desiredReg = argBase + index - 1
if useK and index > 1 then
self:emitLoadKIndex(ctx, desiredReg, args[index], fastcallLine)
elseif args[index] ~= desiredReg then
ctx.builder:emitABC(ctx.func, LOP.MOVE, desiredReg, args[index], 0, fastcallLine)
self:useReg(ctx, desiredReg)
end
end
self:compileExpr(ctx, expr.callee, target)
return true
end
function PureCompiler:compileSelectVarargFastcall(ctx, expr, target, wantedResults, argBase, line)
if wantedResults ~= 1 or #expr.args ~= 2 or expr.args[2].kind ~= "Vararg" then
return false
end
local arg = expr.args[1]
local argReg = self:getExprLocalReg(ctx, arg)
if argReg == nil then
argReg = argBase
self:compileExpr(ctx, arg, argReg)
end
local moveCount = argReg ~= argBase and 1 or 0
local skipCount = self:getCalleeInstructionCount(ctx, expr.callee) + moveCount + 1
ctx.builder:emitABC(ctx.func, LOP.FASTCALL1, LBF.SELECT_VARARG, argReg, skipCount, line)
self:compileExpr(ctx, expr.callee, target)
if moveCount == 1 then
ctx.builder:emitABC(ctx.func, LOP.MOVE, argBase, argReg, 0, arg.line or line)
self:useReg(ctx, argBase)
end
ctx.builder:emitABC(ctx.func, LOP.GETVARARGS, argBase + 1, 0, 0, expr.args[2].line or line)
self:useReg(ctx, argBase + 1)
return true
end
function PureCompiler:compileCallArguments(ctx, expr, argBase, multArgs, line)
for index, arg in ipairs(expr.args) do
if index == #expr.args then
if multArgs then
if arg.kind == "Vararg" then
ctx.builder:emitABC(ctx.func, LOP.GETVARARGS, argBase + index - 1, 0, 0, arg.line or line)
self:useReg(ctx, argBase + index - 1)
else
local oldNextReg = ctx.nextReg
local oldProtectedTop = ctx.protectedTop or 0
ctx.nextReg = argBase + index - 1
ctx.protectedTop = math.min(oldProtectedTop, ctx.nextReg)
self:compileCall(ctx, arg, argBase + index - 1, -1)
ctx.protectedTop = oldProtectedTop
ctx.nextReg = oldNextReg
end
else
local oldProtectedTop = ctx.protectedTop or 0
ctx.protectedTop = math.max(oldProtectedTop, ctx.nextReg)
self:compileExpr(ctx, arg, argBase + index - 1)
ctx.protectedTop = oldProtectedTop
end
else
self:compileExprTempTop(ctx, arg, argBase + index - 1)
end
end
end
function PureCompiler:compileCall(ctx, expr, target, wantedResults, targetTemp, suppressResultMove)
local builder = ctx.builder
local func = ctx.func
local line = expr.line or 1
local resultTarget = target
if wantedResults == 1 and targetTemp == false then
target = self:getTempReg(ctx, target)
elseif wantedResults == 1 then
local protectedTop = ctx.protectedTop or 0
local localTop = self:getLocalTop(ctx)
local stackTop = math.max(ctx.nextReg, protectedTop, localTop)
if stackTop > target and target ~= stackTop - 1 then
target = stackTop
end
elseif wantedResults == 0 then
local protectedTop = ctx.protectedTop or 0
local localTop = self:getLocalTop(ctx)
local stackTop = math.max(ctx.nextReg, protectedTop, localTop)
if stackTop > target then
target = stackTop
end
end
local argBase
local argCount = #expr.args
local fastcall = nil
local fastcallK = nil
local multArgs = false
local fastcallMultArgs = false
local callMultArgs = false
local methodObjectReg = nil
if expr.kind == "MethodCall" then
methodObjectReg = self:compileExprAsSource(ctx, expr.object, target)
argBase = target + 2
argCount += 1
else
argBase = target + 1
fastcall = self:getBuiltinIdForCallee(ctx, expr.callee)
end
local oldProtectedTop = ctx.protectedTop or 0
if wantedResults > 1 and target + wantedResults ~= ctx.nextReg then
ctx.protectedTop = math.max(oldProtectedTop, target + wantedResults)
end
if #expr.args > 0 then
local lastArg = expr.args[#expr.args]
multArgs = lastArg.kind == "Vararg" or lastArg.kind == "Call" or lastArg.kind == "MethodCall"
if lastArg.kind == "Call" and self:callHasOneResultO2(ctx, lastArg) then
multArgs = false
elseif lastArg.kind == "Call" and self:getInlineO2ReturnCount(ctx, lastArg) == 1 then
multArgs = false
end
fastcallMultArgs = multArgs
if (builder.options.optimizationLevel or 1) >= 2 and fastcall ~= nil then
local fixedParams = builtinParamCounts[fastcall]
if fixedParams == #expr.args and builtinNoneSafe[fastcall] then
fastcallMultArgs = false
end
end
end
local selectVarargFastcall = fastcall == LBF.SELECT_VARARG and #expr.args == 2 and expr.args[2].kind == "Vararg"
if fastcall == LBF.SELECT_VARARG and not selectVarargFastcall then
fastcall = nil
elseif fastcall == LBF.SELECT_VARARG and wantedResults ~= 1 then
fastcall = nil
end
if selectVarargFastcall and self:compileSelectVarargFastcall(ctx, expr, target, wantedResults, argBase, line) then
callMultArgs = multArgs
-- fallback setup was emitted by compileSelectVarargFastcall
else
if fastcall == LBF.BIT32_EXTRACT and #expr.args == 3 then
local fieldConstant = self:getConstant(ctx, expr.args[2])
local widthConstant = self:getConstant(ctx, expr.args[3])
if fieldConstant and fieldConstant.kind == "Number" and widthConstant and widthConstant.kind == "Number" then
local field = truncateToInteger(fieldConstant.value)
local width = truncateToInteger(widthConstant.value)
if field >= 0 and width > 0 and field + width <= 32 then
fastcall = LBF.BIT32_EXTRACTK
fastcallK = builder:addConstantNumber(func, field + (width - 1) * 32)
end
end
end
if fastcall and self:compileFastcall(ctx, expr, target, wantedResults, argBase, #expr.args, fastcall, fastcallMultArgs, line, fastcallK) then
callMultArgs = fastcallMultArgs
-- fallback setup was emitted by compileFastcall
elseif fastcall and expr.kind ~= "MethodCall" then
callMultArgs = multArgs
self:compileCallArguments(ctx, expr, argBase, multArgs, line)
builder:emitABC(func, LOP.FASTCALL, fastcall, 0, self:getCalleeInstructionCount(ctx, expr.callee), line)
self:compileExpr(ctx, expr.callee, target)
else
callMultArgs = multArgs
if expr.kind ~= "MethodCall" then
if (builder.options.coverageLevel or 0) >= 2 and (builder.options.optimizationLevel or 1) < 1 and wantedResults == 1 and expr.callee.kind == "Name" and self:findLocal(ctx, expr.callee.name) and #expr.args == 1 then
self:emitCoverage(ctx, line)
end
self:compileExprTempTop(ctx, expr.callee, target)
end
self:compileCallArguments(ctx, expr, argBase, multArgs, line)
if expr.kind == "MethodCall" then
local constant = builder:addConstantString(func, expr.method)
local slot = stringHash(expr.method) % 256
builder:emitABC(func, LOP.NAMECALL, target, methodObjectReg, slot, line)
builder:emitAux(func, constant, line)
self:hintTemporaryExprRegType(ctx, expr.object, methodObjectReg, LBC_TYPE_TABLE, 2)
end
end
end
ctx.protectedTop = oldProtectedTop
local b = callMultArgs and 0 or argCount + 1
local c = wantedResults == -1 and 0 or wantedResults + 1
builder:emitABC(func, LOP.CALL, target, b, c, line)
local resultTop = wantedResults == -1 and 0 or math.max(wantedResults - 1, 0)
self:useReg(ctx, target + math.max(argCount, resultTop))
if resultTarget ~= target and wantedResults == 1 and not suppressResultMove then
builder:emitABC(func, LOP.MOVE, resultTarget, target, 0, line)
self:useReg(ctx, resultTarget)
end
return target
end
function PureCompiler:compileExprSide(ctx, expr, target)
if expr.kind == "Name" or expr.kind == "Vararg" or expr.kind == "Function" then
return
end
if expr.kind == "Instantiate" then
self:compileExprSide(ctx, expr.expr, target)
return
end
if (ctx.builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, expr) then
return
end
if expr.kind == "Call" or expr.kind == "MethodCall" then
local sideTarget = math.max(target, ctx.nextReg, self:getLocalTop(ctx), ctx.protectedTop or 0)
if expr.kind == "Call" and self:compileInlineO2Call(ctx, expr, sideTarget, 1) then
return
end
self:compileCall(ctx, expr, sideTarget, 1, true, true)
return
end
self:compileExpr(ctx, expr, target, true)
end
function PureCompiler:compileExprList(ctx, values, target, count, compileExtras)
if compileExtras == nil then
compileExtras = true
end
if #values == 0 then
for index = 0, count - 1 do
ctx.builder:emitABC(ctx.func, LOP.LOADNIL, target + index, 0, 0, 1)
end
self:useReg(ctx, target + count - 1)
return
end
local oldNextReg = ctx.nextReg
if #values <= count and target + count > ctx.nextReg then
ctx.nextReg = target + count
end
for index = 1, count do
local expr = values[index]
if expr then
if index == #values and #values < count and expr.kind == "Vararg" then
local wantedResults = count - index + 1
ctx.builder:emitABC(ctx.func, LOP.GETVARARGS, target + index - 1, wantedResults + 1, 0, expr.line or 1)
self:useReg(ctx, target + count - 1)
break
elseif index == #values and #values < count and (expr.kind == "Call" or expr.kind == "MethodCall") then
local wantedResults = count - index + 1
local localSingleArgumentCall = expr.kind == "Call" and #expr.args == 1 and expr.callee.kind == "Name" and self:findLocal(ctx, expr.callee.name) ~= nil
local callEmitsOwnCoverage = localSingleArgumentCall and ((ctx.builder.options.optimizationLevel or 1) < 1 or self:getBuiltinIdForCallee(ctx, expr.callee) ~= nil)
if (ctx.builder.options.coverageLevel or 0) >= 2 and wantedResults == 1 and not callEmitsOwnCoverage then
self:emitCoverage(ctx, expr.line or 1)
end
if expr.kind == "Call" and self:compileInlineO2Call(ctx, expr, target + index - 1, wantedResults) then
break
end
self:compileCall(ctx, expr, target + index - 1, wantedResults)
break
else
local protectTop = index == #values and #values < count
local oldProtectedTop = ctx.protectedTop or 0
if protectTop then
ctx.protectedTop = math.max(oldProtectedTop, target + count)
end
self:compileExpr(ctx, expr, target + index - 1)
if protectTop then
ctx.protectedTop = oldProtectedTop
end
end
else
ctx.builder:emitABC(ctx.func, LOP.LOADNIL, target + index - 1, 0, 0, values[#values].line or 1)
self:useReg(ctx, target + index - 1)
end
end
ctx.nextReg = oldNextReg
if compileExtras and #values > count then
local sideTarget = target + count
local oldProtectedTop = ctx.protectedTop or 0
ctx.protectedTop = math.max(oldProtectedTop, sideTarget)
for index = count + 1, #values do
self:compileExprSide(ctx, values[index], sideTarget)
end
ctx.protectedTop = oldProtectedTop
end
end
function PureCompiler:compileConditionValue(ctx, expr, target, skipJumps, onlyTruth, loadLine, tempTarget)
local builder = ctx.builder
local func = ctx.func
local line = expr.line or 1
loadLine = loadLine or line
if (builder.options.optimizationLevel or 1) >= 1 then
local constant = self:getConstant(ctx, expr)
if constant then
if self:constantTruth(constant) == onlyTruth then
if target ~= nil then
self:compileExpr(ctx, expr, target, true)
end
local jump = builder:label(func)
builder:emitAD(func, LOP.JUMP, 0, 0, loadLine)
append(skipJumps, jump)
end
return
end
end
if expr.kind == "Bin" and (expr.op == "and" or expr.op == "or") then
local andOp = expr.op == "and"
local opLine = expr.opLine or loadLine
if onlyTruth == andOp then
local elseJumps = {}
local beforeLeft = builder:label(func)
self:compileConditionValue(ctx, expr.left, nil, elseJumps, not onlyTruth, loadLine, target ~= nil and self:getTempReg(ctx, target) or tempTarget)
local rightLoadLine = builder:label(func) ~= beforeLeft and (func.lines[#func.lines] or loadLine) or loadLine
self:compileConditionValue(ctx, expr.right, target, skipJumps, onlyTruth, rightLoadLine or opLine, tempTarget)
local elseLabel = builder:label(func)
for _, jump in ipairs(elseJumps) do
builder:patchJump(func, jump, elseLabel)
end
else
local beforeLeft = builder:label(func)
self:compileConditionValue(ctx, expr.left, target, skipJumps, onlyTruth, loadLine, tempTarget)
local rightLoadLine = builder:label(func) ~= beforeLeft and (func.lines[#func.lines] or loadLine) or loadLine
self:compileConditionValue(ctx, expr.right, target, skipJumps, onlyTruth, rightLoadLine, tempTarget)
end
return
end
if expr.kind == "Bin" and compareOps[expr.op] then
if target ~= nil then
builder:emitABC(func, LOP.LOADB, target, onlyTruth and 1 or 0, 0, loadLine)
self:useReg(ctx, target)
end
append(skipJumps, self:compileCompareJump(ctx, expr, onlyTruth, target ~= nil and self:getTempReg(ctx, target) or tempTarget or math.max(ctx.nextReg, ctx.protectedTop or 0), loadLine))
return
end
if target == nil and expr.kind == "Un" and expr.op == "not" then
self:compileConditionValue(ctx, expr.expr, nil, skipJumps, not onlyTruth, loadLine)
return
end
local reg
local beforeExpr = builder:label(func)
if target ~= nil then
reg = target
self:compileExpr(ctx, expr, reg, true)
else
reg = self:compileExprAsSource(ctx, expr, tempTarget or math.max(ctx.nextReg, ctx.protectedTop or 0))
end
local jumpLine = builder:label(func) ~= beforeExpr and (func.lines[#func.lines] or loadLine) or loadLine
local jump = builder:label(func)
builder:emitAD(func, onlyTruth and LOP.JUMPIF or LOP.JUMPIFNOT, reg, 0, jumpLine)
append(skipJumps, jump)
end
function PureCompiler:compileCondJumps(ctx, expr, jumpWhen, target, lineOverride)
local builder = ctx.builder
local func = ctx.func
local line = lineOverride or expr.line or 1
target = target or math.max(ctx.nextReg, ctx.protectedTop or 0)
if (builder.options.optimizationLevel or 1) >= 1 then
local constant = self:getConstant(ctx, expr)
if constant then
if self:constantTruth(constant) == jumpWhen then
local jump = builder:label(func)
builder:emitAD(func, LOP.JUMP, 0, 0, line)
return { jump }
end
return {}
end
end
if expr.kind == "Un" and expr.op == "not" then
return self:compileCondJumps(ctx, expr.expr, not jumpWhen, target, line)
end
if expr.kind == "Bin" and expr.op == "and" and jumpWhen == false then
local jumps = self:compileCondJumps(ctx, expr.left, false, target, line)
local more = self:compileCondJumps(ctx, expr.right, false, target, line)
for _, jump in ipairs(more) do
append(jumps, jump)
end
return jumps
end
if expr.kind == "Bin" and expr.op == "and" and jumpWhen == true then
local falseJumps = self:compileCondJumps(ctx, expr.left, false, target, line)
local trueJumps = self:compileCondJumps(ctx, expr.right, true, target, line)
local endLabel = builder:label(func)
for _, jump in ipairs(falseJumps) do
builder:patchJump(func, jump, endLabel)
end
return trueJumps
end
if expr.kind == "Bin" and expr.op == "or" and jumpWhen == true then
local jumps = self:compileCondJumps(ctx, expr.left, true, target, line)
local more = self:compileCondJumps(ctx, expr.right, true, target, line)
for _, jump in ipairs(more) do
append(jumps, jump)
end
return jumps
end
if expr.kind == "Bin" and expr.op == "or" and jumpWhen == false then
local trueJumps = self:compileCondJumps(ctx, expr.left, true, target, line)
local falseJumps = self:compileCondJumps(ctx, expr.right, false, target, line)
local bodyLabel = builder:label(func)
for _, jump in ipairs(trueJumps) do
builder:patchJump(func, jump, bodyLabel)
end
return falseJumps
end
if expr.kind == "Bin" and compareOps[expr.op] then
return { self:compileCompareJump(ctx, expr, jumpWhen, target, line) }
end
if expr.kind == "Name" then
local localInfo = self:findLocal(ctx, expr.name)
if localInfo and localInfo.reg ~= nil then
local jump = builder:label(func)
builder:emitAD(func, jumpWhen and LOP.JUMPIF or LOP.JUMPIFNOT, localInfo.reg, 0, line)
return { jump }
end
end
local beforeExpr = builder:label(func)
self:compileExpr(ctx, expr, target)
local jumpLine = builder:label(func) ~= beforeExpr and (func.lines[#func.lines] or line) or line
local jump = builder:label(func)
builder:emitAD(func, jumpWhen and LOP.JUMPIF or LOP.JUMPIFNOT, target, 0, jumpLine)
return { jump }
end
function PureCompiler:compileAssignTarget(ctx, target, source, line, tempBase)
local builder = ctx.builder
local func = ctx.func
if target.kind == "Name" then
local localInfo = self:findLocal(ctx, target.name)
if localInfo then
localInfo.constPath = nil
localInfo.constKind = nil
if localInfo.reg == nil then
localInfo.reg = self:reserve(ctx, 1)
end
if localInfo.reg ~= source then
builder:emitABC(func, LOP.MOVE, localInfo.reg, source, 0, line)
end
else
local upvalue = self:getUpvalue(ctx, target.name)
if upvalue ~= nil then
builder:emitABC(func, LOP.SETUPVAL, source, upvalue, 0, line)
else
self:emitSetGlobal(ctx, source, target.name, line)
end
end
elseif target.kind == "Field" then
local objectReg = self:compileExprAsSource(ctx, target.object, tempBase or source + 1)
local constant = builder:addConstantString(func, target.field)
local slot = stringHash(target.field) % 256
builder:emitABC(func, LOP.SETTABLEKS, source, objectReg, slot, line)
builder:emitAux(func, constant, line)
self:hintTemporaryExprRegType(ctx, target.object, objectReg, LBC_TYPE_TABLE, 2)
elseif target.kind == "Index" then
local keyBase = tempBase or source + 1
local objectReg = self:compileExprAsSource(ctx, target.object, keyBase)
local indexConstant = (builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, target.index) or nil
if indexConstant and indexConstant.kind == "String" then
local constant = builder:addConstantString(func, indexConstant.value)
local slot = stringHash(indexConstant.value) % 256
builder:emitABC(func, LOP.SETTABLEKS, source, objectReg, slot, line)
builder:emitAux(func, constant, line)
self:hintTemporaryExprRegType(ctx, target.object, objectReg, LBC_TYPE_TABLE, 2)
elseif indexConstant and indexConstant.kind == "Number" and indexConstant.value % 1 == 0 and indexConstant.value >= 1 and indexConstant.value <= 256 then
builder:emitABC(func, LOP.SETTABLEN, source, objectReg, indexConstant.value - 1, line)
self:hintTemporaryExprRegType(ctx, target.object, objectReg, LBC_TYPE_TABLE, 1)
else
local indexTarget = objectReg == keyBase and keyBase + 1 or keyBase
local indexReg = self:compileExprAsSource(ctx, target.index, indexTarget)
builder:emitABC(func, LOP.SETTABLE, source, objectReg, indexReg, line)
self:hintTemporaryExprRegType(ctx, target.object, objectReg, LBC_TYPE_TABLE, 1)
self:hintTemporaryExprRegType(ctx, target.index, indexReg, LBC_TYPE_NUMBER, 1)
end
else
error("invalid assignment target", 2)
end
end
function PureCompiler:compileReturnIfExpr(ctx, expr, line, target)
local builder = ctx.builder
local func = ctx.func
local optimizationLevel = tonumber(builder.options.optimizationLevel) or 1
if optimizationLevel < 1 then
local endJumps = {}
for _, clause in ipairs(expr.clauses) do
local falseJumps = self:compileCondJumps(ctx, clause.condition, false, target + 1)
self:compileExpr(ctx, clause.value, target)
local jump = builder:label(func)
builder:emitAD(func, LOP.JUMP, 0, 0, clause.value.line or line)
append(endJumps, jump)
local nextLabel = builder:label(func)
for _, falseJump in ipairs(falseJumps) do
builder:patchJump(func, falseJump, nextLabel)
end
end
self:compileExpr(ctx, expr.elseValue, target)
local exit = builder:label(func)
for _, jump in ipairs(endJumps) do
builder:patchJump(func, jump, exit)
end
local returnLine = func.lines[#func.lines] or line
self:emitCloseUpvals(ctx, returnLine)
builder:emitABC(func, LOP.RETURN, target, 2, 0, returnLine)
return
end
for _, clause in ipairs(expr.clauses) do
local falseJumps = self:compileCondJumps(ctx, clause.condition, false, target + 1)
self:compileExpr(ctx, clause.value, target)
local returnLine = clause.value.line or line
self:emitCloseUpvals(ctx, returnLine)
builder:emitABC(func, LOP.RETURN, target, 2, 0, returnLine)
local nextLabel = builder:label(func)
for _, falseJump in ipairs(falseJumps) do
builder:patchJump(func, falseJump, nextLabel)
end
end
self:compileExpr(ctx, expr.elseValue, target)
local returnLine = expr.elseValue.line or line
self:emitCloseUpvals(ctx, returnLine)
builder:emitABC(func, LOP.RETURN, target, 2, 0, returnLine)
end
function PureCompiler:hasLeadingContinueJump(func, label, continues)
if not continues or continues[1] ~= label then
return false
end
local word = func.insns[label + 1]
return word ~= nil and word % 256 == LOP.JUMP
end
function PureCompiler:hasLeadingBreakJump(func, label, breaks)
if not breaks or breaks[1] ~= label then
return false
end
local word = func.insns[label + 1]
return word ~= nil and word % 256 == LOP.JUMP
end
function PureCompiler:containsJumpLabel(jumps, label)
if not jumps then
return false
end
for _, jump in ipairs(jumps) do
if jump == label then
return true
end
end
return false
end
function PureCompiler:getLoopBackTarget(func, loopStart, breakJumps, continueJumps, continueTarget)
if (self.builder.options.optimizationLevel or 1) < 1 then
return loopStart
end
local word = func.insns[loopStart + 1]
if word and word % 256 == LOP.JUMP and math.floor(word / 65536) % 65536 == 0 and not self:containsJumpLabel(breakJumps, loopStart) then
return loopStart + 1
end
if self:hasLeadingContinueJump(func, loopStart, continueJumps) then
return continueTarget
end
return loopStart
end
function PureCompiler:exprHasUnrollBarrier(expr)
if expr.kind == "Function" then
return true
elseif expr.kind == "Field" then
return self:exprHasUnrollBarrier(expr.object)
elseif expr.kind == "Index" then
return self:exprHasUnrollBarrier(expr.object) or self:exprHasUnrollBarrier(expr.index)
elseif expr.kind == "Call" then
if self:exprHasUnrollBarrier(expr.callee) then
return true
end
for _, arg in ipairs(expr.args) do
if self:exprHasUnrollBarrier(arg) then
return true
end
end
elseif expr.kind == "MethodCall" then
if self:exprHasUnrollBarrier(expr.object) then
return true
end
for _, arg in ipairs(expr.args) do
if self:exprHasUnrollBarrier(arg) then
return true
end
end
elseif expr.kind == "Table" then
for _, entry in ipairs(expr.entries) do
if entry.key and self:exprHasUnrollBarrier(entry.key) then
return true
end
if entry.value and self:exprHasUnrollBarrier(entry.value) then
return true
end
end
elseif expr.kind == "Un" or expr.kind == "SingleResult" then
return self:exprHasUnrollBarrier(expr.expr)
elseif expr.kind == "Bin" then
return self:exprHasUnrollBarrier(expr.left) or self:exprHasUnrollBarrier(expr.right)
elseif expr.kind == "InterpString" then
for _, value in ipairs(expr.expressions) do
if self:exprHasUnrollBarrier(value) then
return true
end
end
elseif expr.kind == "IfExpr" then
for _, clause in ipairs(expr.clauses) do
if self:exprHasUnrollBarrier(clause.condition) or self:exprHasUnrollBarrier(clause.value) then
return true
end
end
return self:exprHasUnrollBarrier(expr.elseValue)
end
return false
end
function PureCompiler:blockHasUnrollBarrier(block)
for _, stat in ipairs(block.body) do
if stat.kind == "Nop" then
-- skip
elseif stat.kind == "Break" or stat.kind == "Continue" or stat.kind == "Return" or stat.kind == "FunctionStat" or stat.kind == "LocalFunction" then
return true
elseif stat.kind == "Local" then
for _, value in ipairs(stat.values) do
if self:exprHasUnrollBarrier(value) then
return true
end
end
elseif stat.kind == "Assign" then
for _, target in ipairs(stat.targets) do
if self:exprHasUnrollBarrier(target) then
return true
end
end
for _, value in ipairs(stat.values) do
if self:exprHasUnrollBarrier(value) then
return true
end
end
elseif stat.kind == "CallStat" then
if self:exprHasUnrollBarrier(stat.expr) then
return true
end
elseif stat.kind == "Do" then
if self:blockHasUnrollBarrier(stat.body) then
return true
end
elseif stat.kind == "If" then
for _, clause in ipairs(stat.clauses) do
if self:exprHasUnrollBarrier(clause.condition) or self:blockHasUnrollBarrier(clause.body) then
return true
end
end
if stat.elseBody and self:blockHasUnrollBarrier(stat.elseBody) then
return true
end
elseif stat.kind == "While" or stat.kind == "Repeat" or stat.kind == "ForNumeric" or stat.kind == "ForIn" then
return true
elseif stat.kind == "IfExprReturn" then
return true
end
end
return false
end
function PureCompiler:getLoopUnrollTripCount(fromValue, toValue, stepValue)
local function toTripInteger(value)
if value < -32767 or value > 32767 or value % 1 ~= 0 then
return nil
end
return value
end
local fromInt = toTripInteger(fromValue)
local toInt = toTripInteger(toValue)
local stepInt = toTripInteger(stepValue)
if fromInt == nil or toInt == nil or stepInt == nil or stepInt == 0 then
return nil
end
if stepInt < 0 and toInt > fromInt or stepInt > 0 and toInt < fromInt then
return 0
end
return math.floor((toInt - fromInt) / stepInt) + 1
end
function PureCompiler:findRuntimeLocalForUnroll(ctx, name)
local localInfo = self:findLocal(ctx, name)
if localInfo then
return localInfo.reg ~= nil
end
local parent = ctx.parent
while parent do
localInfo = parent.locals[name]
if localInfo then
return localInfo.reg ~= nil
end
parent = parent.parent
end
return false
end
function PureCompiler:exprHasRuntimeLocalForUnroll(ctx, expr)
if expr == nil then
return false
elseif expr.kind == "Name" then
return self:findRuntimeLocalForUnroll(ctx, expr.name)
elseif expr.kind == "Un" or expr.kind == "SingleResult" or expr.kind == "Instantiate" then
return self:exprHasRuntimeLocalForUnroll(ctx, expr.expr)
elseif expr.kind == "Bin" then
return self:exprHasRuntimeLocalForUnroll(ctx, expr.left) or self:exprHasRuntimeLocalForUnroll(ctx, expr.right)
end
return false
end
function PureCompiler:getNumericForLoopUnroll(ctx, stat)
if (self.builder.options.optimizationLevel or 1) < 2 then
return nil
end
if self:exprHasRuntimeLocalForUnroll(ctx, stat.from) or self:exprHasRuntimeLocalForUnroll(ctx, stat.to) or self:exprHasRuntimeLocalForUnroll(ctx, stat.step) then
return nil
end
local fromConstant = self:getConstant(ctx, stat.from)
local toConstant = self:getConstant(ctx, stat.to)
local stepConstant = stat.step and self:getConstant(ctx, stat.step) or { kind = "Number", value = 1 }
if not fromConstant or not toConstant or not stepConstant or fromConstant.kind ~= "Number" or toConstant.kind ~= "Number" or stepConstant.kind ~= "Number" then
return nil
end
local tripCount = self:getLoopUnrollTripCount(fromConstant.value, toConstant.value, stepConstant.value)
if tripCount == nil or tripCount < 0 or tripCount > O2_INLINE_THRESHOLD then
return nil
end
local loopVarWritten = stat.localSymbol and stat.localSymbol.written == true or self:blockWritesName(stat.body, stat.name)
if loopVarWritten or self:blockHasUnrollBarrier(stat.body) then
return nil
end
if tripCount == 0 then
return {
iterations = 0,
from = fromConstant.value,
step = stepConstant.value,
}
end
local locals = {
[stat.name] = true,
}
local baselineCost = mulInlineCost(addInlineCost(self:costInlineO2Block(ctx, stat.body, shallowCopy(locals), {}), 1), tripCount)
local constants = {
[stat.name] = { kind = "Number", value = fromConstant.value },
}
local unrolledCost = mulInlineCost(self:costInlineO2Block(ctx, stat.body, shallowCopy(locals), constants), tripCount)
local unrollProfit = unrolledCost == 0 and O2_INLINE_THRESHOLD_MAX_BOOST or math.min(O2_INLINE_THRESHOLD_MAX_BOOST, math.floor(100 * baselineCost / unrolledCost))
local threshold = math.floor(O2_INLINE_THRESHOLD * unrollProfit / 100)
if unrolledCost > threshold then
return nil
end
return {
iterations = tripCount,
from = fromConstant.value,
step = stepConstant.value,
}
end
function PureCompiler:targetWritesName(target, name)
return target.kind == "Name" and target.name == name
end
function PureCompiler:blockWritesName(block, name)
for _, stat in ipairs(block.body) do
if stat.kind == "Assign" then
for _, target in ipairs(stat.targets) do
if self:targetWritesName(target, name) then
return true
end
end
elseif stat.kind == "Do" then
if self:blockWritesName(stat.body, name) then
return true
end
elseif stat.kind == "If" then
for _, clause in ipairs(stat.clauses) do
if self:blockWritesName(clause.body, name) then
return true
end
end
if stat.elseBody and self:blockWritesName(stat.elseBody, name) then
return true
end
end
end
return false
end
function PureCompiler:getNumericForTripCount(fromValue, toValue, stepValue)
local fromNan = fromValue ~= fromValue
local toNan = toValue ~= toValue
local stepNan = stepValue ~= stepValue
if fromNan or toNan or stepNan then
return nil
end
if stepValue == 0 then
if fromValue <= toValue then
return 0
end
return nil
end
if stepValue > 0 then
if fromValue > toValue then
return 0
end
elseif fromValue < toValue then
return 0
end
local span = (toValue - fromValue) / stepValue
if span ~= span or span == math.huge or span == -math.huge then
return nil
end
local count = math.floor(span) + 1
if count < 0 then
return 0
end
if count == math.huge or count > 2147483647 then
return nil
end
return count
end
function PureCompiler:isStaticNumericConstantExpr(expr)
if expr.kind == "Number" then
return true
elseif expr.kind == "Un" and expr.op == "-" then
return self:isStaticNumericConstantExpr(expr.expr)
elseif expr.kind == "Bin" then
return self:isStaticNumericConstantExpr(expr.left) and self:isStaticNumericConstantExpr(expr.right)
end
return false
end
function PureCompiler:compileStatement(ctx, stat, tailReturn)
local builder = ctx.builder
local func = ctx.func
local line = stat.line or 1
local baseReg = ctx.nextReg
tailReturn = tailReturn == true and (builder.options.optimizationLevel or 1) >= 1
if (builder.options.coverageLevel or 0) >= 1 and stat.kind ~= "Nop" then
self:emitCoverage(ctx, line)
end
if stat.kind == "Nop" then
return false
elseif stat.kind == "Local" then
if #stat.names == 1 and #stat.values == 1 and not (ctx.readNames and ctx.readNames[stat.names[1]]) then
local foldedCall = self:getDiscardableO2BuiltinCallResult(ctx, stat.values[1])
if foldedCall then
if (builder.options.debugLevel or 1) >= 2 then
local localInfo = {
name = stat.names[1],
reg = self:reserve(ctx, 1),
constKind = foldedCall.kind,
constValue = foldedCall.value,
written = false,
}
if (builder.options.coverageLevel or 0) >= 2 then
self:emitCoverage(ctx, stat.values[1].line or line)
end
self:emitConstant(ctx, foldedCall, localInfo.reg, line)
localInfo.debugStart = builder:label(func)
self:declareLocal(ctx, localInfo)
end
return false
end
end
if (builder.options.optimizationLevel or 1) >= 1 and #stat.names == 1 and #stat.values == 1 and stat.values[1].kind == "Name" then
local sourceLocal = self:findLocal(ctx, stat.values[1].name)
local written = stat.localSymbols and stat.localSymbols[1] and stat.localSymbols[1].written == true or false
if sourceLocal and sourceLocal.reg ~= nil and not sourceLocal.written and not written and not (ctx.writeNames and ctx.writeNames[stat.names[1]]) then
self:declareLocal(ctx, {
name = stat.names[1],
reg = sourceLocal.reg,
written = false,
typeId = sourceLocal.typeId,
debugStart = builder:label(func),
})
return false
end
end
local locals = {}
local initializerPaths = {}
local canUseRegisterlessConstants = #stat.values == 0 or #stat.values == #stat.names
for index, value in ipairs(stat.values) do
initializerPaths[index] = self:getImportPath(ctx, value)
local name = stat.names[index]
local shape = name and ctx.tableArrayHints[name] or nil
if (builder.options.optimizationLevel or 1) >= 1 and value.kind == "Table" and #value.entries == 0 and shape then
if shape == true then
value.predictedArrayCount = 1
else
value.predictedArrayCount = shape.arraySize or 0
value.predictedHashCount = shape.hashSize or 0
end
end
end
if #stat.values <= #stat.names then
local constantLocals = {}
local canFoldAllLocals = true
for index, name in ipairs(stat.names) do
local value = stat.values[index]
local written = stat.localSymbols and stat.localSymbols[index] and stat.localSymbols[index].written == true or false
local missingValueIsNil = value == nil
and (#stat.values == 0
or index > #stat.values
and #stat.values > 0
and stat.values[#stat.values].kind ~= "Call"
and stat.values[#stat.values].kind ~= "MethodCall"
and stat.values[#stat.values].kind ~= "Vararg")
local localInfo = (value ~= nil or missingValueIsNil) and self:makeConstantLocal(ctx, name, value, written) or nil
if not localInfo then
canFoldAllLocals = false
break
end
append(constantLocals, localInfo)
end
if canFoldAllLocals then
local debugStart = builder:label(func)
for index, localInfo in ipairs(constantLocals) do
localInfo.constPath = initializerPaths[index]
localInfo.typeId = self:getInitializerBytecodeType(ctx, stat.values, index, #stat.names)
localInfo.debugStart = debugStart
self:declareLocal(ctx, localInfo)
end
return false
end
end
for index, name in ipairs(stat.names) do
local value = stat.values[index]
local written = stat.localSymbols and stat.localSymbols[index] and stat.localSymbols[index].written == true or false
local localInfo = nil
if not localInfo then
local reg = self:reserve(ctx, 1)
localInfo = {
name = name,
reg = reg,
written = written,
typeId = self:getInitializerBytecodeType(ctx, stat.values, index, #stat.names),
}
if value and value.kind == "Function" then
localInfo.initFunction = value
localInfo.inlineFunction = value
end
local missingValueIsNil = value == nil
and (#stat.values == 0
or index > #stat.values
and #stat.values > 0
and stat.values[#stat.values].kind ~= "Call"
and stat.values[#stat.values].kind ~= "MethodCall"
and stat.values[#stat.values].kind ~= "Vararg")
if (builder.options.optimizationLevel or 1) >= 1 and not localInfo.written and (value ~= nil or missingValueIsNil) then
local constant = self:getConstant(ctx, value)
if constant then
localInfo.constKind = constant.kind
localInfo.constValue = constant.value
end
end
end
append(locals, localInfo)
end
if #locals > 0 then
local firstRegister = nil
for _, localInfo in ipairs(locals) do
if localInfo.reg ~= nil then
firstRegister = localInfo.reg
break
end
end
if firstRegister ~= nil then
local oldInitializingLocalsByReg = ctx.initializingLocalsByReg
local initializingLocalsByReg = nil
if #stat.values == 0 then
for index, localInfo in ipairs(locals) do
if localInfo.reg ~= nil then
builder:emitABC(func, LOP.LOADNIL, localInfo.reg, 0, 0, line)
end
end
elseif canUseRegisterlessConstants then
local values = {}
local registerCount = 0
for index, localInfo in ipairs(locals) do
if localInfo.reg ~= nil then
registerCount += 1
local value = stat.values[index] or node("Nil", nil, line)
values[registerCount] = value
if value.kind == "Function" then
initializingLocalsByReg = initializingLocalsByReg or {}
initializingLocalsByReg[localInfo.reg] = localInfo
end
end
end
ctx.initializingLocalsByReg = initializingLocalsByReg
self:compileExprList(ctx, values, firstRegister, registerCount)
else
for index, localInfo in ipairs(locals) do
local value = stat.values[index]
if localInfo.reg ~= nil and value and value.kind == "Function" then
initializingLocalsByReg = initializingLocalsByReg or {}
initializingLocalsByReg[localInfo.reg] = localInfo
end
end
ctx.initializingLocalsByReg = initializingLocalsByReg
self:compileExprList(ctx, stat.values, firstRegister, #locals)
end
ctx.initializingLocalsByReg = oldInitializingLocalsByReg
end
local debugStart = builder:label(func)
for index, localInfo in ipairs(locals) do
localInfo.constPath = initializerPaths[index]
localInfo.debugStart = debugStart
self:declareLocal(ctx, localInfo)
end
end
elseif stat.kind == "LocalFunction" then
local localInfo = self:addLocal(ctx, stat.name, stat.localSymbol and stat.localSymbol.written == true or false)
localInfo.inlineFunction = stat.value
localInfo.initFunction = stat.value
self:compileFunctionExpr(ctx, stat.value, localInfo.reg, stat.name, localInfo)
localInfo.debugStart = builder:label(func)
elseif stat.kind == "Assign" then
if stat.op == "=" then
if #stat.targets == 1 and #stat.values == 1 and stat.targets[1].kind == "Name" and not (stat.values[1].kind == "Table" and #stat.values[1].entries > 0) then
local localInfo = self:findLocal(ctx, stat.targets[1].name)
if localInfo then
localInfo.constPath = nil
localInfo.constKind = nil
if localInfo.reg == nil then
localInfo.reg = self:reserve(ctx, 1)
end
self:compileExpr(ctx, stat.values[1], localInfo.reg, false)
self:updateExistingLocalBytecodeType(localInfo, self:getBytecodeType(ctx, stat.values[1]))
return false
end
end
if self:tryCompileDirectLocalAssign(ctx, stat, line) then
return false
end
if #stat.targets > 1 and self:tryCompileMixedNameAssign(ctx, stat, line, baseReg) then
return false
end
if #stat.targets > 1 and self:tryCompileNonLocalNameAssign(ctx, stat, line, baseReg) then
return false
end
if #stat.targets > 1 and self:tryCompileGeneralAssign(ctx, stat, line, baseReg) then
return false
end
if #stat.targets == 1 and #stat.values == 1 and stat.targets[1].kind ~= "Name" then
local target = stat.targets[1]
local value = stat.values[1]
if target.kind == "Field" then
local objectReg = self:compileExprAsSource(ctx, target.object, baseReg)
local sourceTarget = math.max(objectReg == baseReg and baseReg + 1 or baseReg, ctx.protectedTop or 0, ctx.nextReg)
if sourceTarget == objectReg then
sourceTarget += 1
end
self:emitSourceCoverageIfNeeded(ctx, value, line)
local sourceReg = self:compileExprAsSource(ctx, value, sourceTarget)
local constant = builder:addConstantString(func, target.field)
local slot = stringHash(target.field) % 256
builder:emitABC(func, LOP.SETTABLEKS, sourceReg, objectReg, slot, line)
builder:emitAux(func, constant, line)
self:hintTemporaryExprRegType(ctx, target.object, objectReg, LBC_TYPE_TABLE, 2)
return false
elseif target.kind == "Index" then
local objectReg = self:compileExprAsSource(ctx, target.object, baseReg)
local indexConstant = (builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, target.index) or nil
if indexConstant and indexConstant.kind == "String" then
local sourceTarget = math.max(objectReg == baseReg and baseReg + 1 or baseReg, ctx.protectedTop or 0, ctx.nextReg)
if sourceTarget == objectReg then
sourceTarget += 1
end
self:emitSourceCoverageIfNeeded(ctx, value, line)
local sourceReg = self:compileExprAsSource(ctx, value, sourceTarget)
local constant = builder:addConstantString(func, indexConstant.value)
local slot = stringHash(indexConstant.value) % 256
builder:emitABC(func, LOP.SETTABLEKS, sourceReg, objectReg, slot, line)
builder:emitAux(func, constant, line)
self:hintTemporaryExprRegType(ctx, target.object, objectReg, LBC_TYPE_TABLE, 2)
return false
elseif indexConstant and indexConstant.kind == "Number" and indexConstant.value % 1 == 0 and indexConstant.value >= 1 and indexConstant.value <= 256 then
local sourceTarget = math.max(objectReg == baseReg and baseReg + 1 or baseReg, ctx.protectedTop or 0, ctx.nextReg)
if sourceTarget == objectReg then
sourceTarget += 1
end
self:emitSourceCoverageIfNeeded(ctx, value, line)
local sourceReg = self:compileExprAsSource(ctx, value, sourceTarget)
builder:emitABC(func, LOP.SETTABLEN, sourceReg, objectReg, indexConstant.value - 1, line)
self:hintTemporaryExprRegType(ctx, target.object, objectReg, LBC_TYPE_TABLE, 1)
return false
else
local keyTarget = objectReg == baseReg and baseReg + 1 or baseReg
local keyReg = self:compileExprAsSource(ctx, target.index, keyTarget)
local sourceTarget = math.max(baseReg, ctx.protectedTop or 0, ctx.nextReg)
while sourceTarget == objectReg or sourceTarget == keyReg do
sourceTarget += 1
end
self:emitSourceCoverageIfNeeded(ctx, value, line)
local sourceReg = self:compileExprAsSource(ctx, value, sourceTarget)
builder:emitABC(func, LOP.SETTABLE, sourceReg, objectReg, keyReg, line)
self:hintTemporaryExprRegType(ctx, target.object, objectReg, LBC_TYPE_TABLE, 1)
self:hintTemporaryExprRegType(ctx, target.index, keyReg, LBC_TYPE_NUMBER, 1)
return false
end
end
end
local singleNameLValue = nil
if #stat.targets == 1 and stat.targets[1].kind == "Name" and self:findLocal(ctx, stat.targets[1].name) == nil then
singleNameLValue = self:compileLValue(ctx, stat.targets[1], baseReg)
end
local assignLine = line
local assignTempBase = nil
if #stat.targets == 1 and #stat.values == 1 and stat.values[1].kind ~= "Call" and stat.values[1].kind ~= "MethodCall" and stat.values[1].kind ~= "Vararg" then
assignTempBase = baseReg
local beforeExpr = builder:label(func)
local value = stat.values[1]
if (builder.options.coverageLevel or 0) >= 2 and value.kind == "Name" then
local localInfo = self:findLocal(ctx, value.name)
local parentConstant = localInfo == nil and self:findParentConstant(ctx, value.name) or nil
local localConstant = self:canUseLocalConstant(localInfo)
if localConstant or parentConstant then
self:emitCoverage(ctx, value.line or line)
end
end
baseReg = self:compileExprAsSource(ctx, stat.values[1], baseReg)
assignLine = builder:label(func) == beforeExpr and line or func.lines[#func.lines] or line
if singleNameLValue ~= nil then
assignLine = stat.targets[1].line or line
end
if baseReg == assignTempBase then
assignTempBase = nil
end
else
self:compileExprList(ctx, stat.values, baseReg, #stat.targets, false)
assignLine = func.lines[#func.lines] or line
if #stat.values > #stat.targets then
local sideTarget = baseReg + #stat.targets
local oldProtectedTop = ctx.protectedTop or 0
ctx.protectedTop = math.max(oldProtectedTop, sideTarget)
for index = #stat.targets + 1, #stat.values do
self:compileExprSide(ctx, stat.values[index], sideTarget)
end
ctx.protectedTop = oldProtectedTop
assignLine = func.lines[#func.lines] or assignLine
end
end
for index, target in ipairs(stat.targets) do
if target.kind ~= "Name" or self:findLocal(ctx, target.name) == nil then
if singleNameLValue and index == 1 then
self:compileAssignLValue(ctx, singleNameLValue, baseReg + index - 1, target.line or assignLine)
else
self:compileAssignTarget(ctx, target, baseReg + index - 1, target.line or assignLine, assignTempBase)
end
end
end
for index, target in ipairs(stat.targets) do
if target.kind == "Name" and self:findLocal(ctx, target.name) ~= nil then
self:compileAssignTarget(ctx, target, baseReg + index - 1, assignLine, assignTempBase)
end
end
else
if #stat.targets ~= 1 then
error("compound assignment only supports one target", 2)
end
local target = stat.targets[1]
local localInfo = target.kind == "Name" and self:findLocal(ctx, target.name) or nil
local resultReg = localInfo and localInfo.reg or baseReg
if localInfo then
localInfo.constPath = nil
localInfo.constKind = nil
if localInfo.reg == nil then
localInfo.reg = self:reserve(ctx, 1)
end
resultReg = localInfo.reg
end
if localInfo and stat.op ~= ".." then
local op = arithmeticOps[stat.op]
if not op then
error("unsupported compound assignment operator " .. tostring(stat.op), 2)
end
local value = stat.values[1]
local rightConstant = (builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, value) or nil
if rightConstant and rightConstant.kind == "Number" and arithmeticKOps[stat.op] then
local constant = builder:addConstantNumber(func, rightConstant.value)
if constant <= 255 then
builder:emitABC(func, arithmeticKOps[stat.op], resultReg, resultReg, constant, line)
else
self:compileExpr(ctx, value, baseReg)
builder:emitABC(func, op, resultReg, resultReg, baseReg, line)
end
else
local rightReg = self:compileExprAsSource(ctx, value or node("Nil", nil, line), baseReg)
builder:emitABC(func, op, resultReg, resultReg, rightReg, line)
end
else
local lvalue = nil
local valueBaseReg = baseReg
if localInfo then
builder:emitABC(func, LOP.MOVE, baseReg, resultReg, 0, line)
else
local nextTemp
lvalue, nextTemp = self:compileLValue(ctx, target, baseReg)
resultReg = nextTemp
valueBaseReg = resultReg
self:compileLValueUse(ctx, lvalue, valueBaseReg, line)
end
if stat.op == ".." then
local parts = {}
self:flattenConcat(stat.values[1], parts)
for index, part in ipairs(parts) do
self:compileExpr(ctx, part, valueBaseReg + index)
end
builder:emitABC(func, LOP.CONCAT, resultReg, valueBaseReg, valueBaseReg + #parts, line)
else
local op = arithmeticOps[stat.op]
if not op then
error("unsupported compound assignment operator " .. tostring(stat.op), 2)
end
local value = stat.values[1]
local rightConstant = (builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, value) or nil
if rightConstant and rightConstant.kind == "Number" and arithmeticKOps[stat.op] then
local constant = builder:addConstantNumber(func, rightConstant.value)
if constant <= 255 then
builder:emitABC(func, arithmeticKOps[stat.op], resultReg, valueBaseReg, constant, line)
else
local rightReg = self:compileExprAsSource(ctx, value, valueBaseReg + 1)
builder:emitABC(func, op, resultReg, valueBaseReg, rightReg, line)
end
else
local rightReg = self:compileExprAsSource(ctx, value, valueBaseReg + 1)
builder:emitABC(func, op, resultReg, valueBaseReg, rightReg, line)
end
end
if not localInfo then
self:compileAssignLValue(ctx, lvalue, resultReg, line)
end
end
end
elseif stat.kind == "CallStat" then
if stat.expr.kind ~= "Call" or not self:compileInlineO2Call(ctx, stat.expr, baseReg, 0) then
self:compileCall(ctx, stat.expr, baseReg, 0)
end
elseif stat.kind == "Return" then
if ctx.inlineReturnFrame then
self:compileInlineO2Return(ctx, stat, ctx.inlineReturnFrame, baseReg, line)
return true
end
if #stat.values == 0 then
self:emitCloseUpvals(ctx, line)
builder:emitABC(func, LOP.RETURN, 0, 1, 0, line)
return true
end
if #stat.values == 1 and stat.values[1].kind == "IfExpr" then
self:compileReturnIfExpr(ctx, stat.values[1], line, baseReg)
return true
end
if #stat.values > 0 then
local firstReg = nil
local contiguous = true
for index, expr in ipairs(stat.values) do
local localInfo = expr.kind == "Name" and self:findLocal(ctx, expr.name) or nil
if not localInfo or localInfo.reg == nil then
contiguous = false
break
end
if index == 1 then
firstReg = localInfo.reg
elseif localInfo.reg ~= firstReg + index - 1 then
contiguous = false
break
end
end
if contiguous and firstReg ~= nil then
self:emitCloseUpvals(ctx, line)
builder:emitABC(func, LOP.RETURN, firstReg, #stat.values + 1, 0, line)
return true
end
end
if #stat.values > 0 then
for index, expr in ipairs(stat.values) do
if index == #stat.values and (expr.kind == "Call" or expr.kind == "MethodCall") then
local constant = (builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, expr) or nil
if constant then
self:captureParentConstantsForDebug(ctx, expr)
self:emitConstant(ctx, constant, baseReg + index - 1, line)
continue
end
local inlineReturnCount = nil
if expr.kind == "Call" then
local candidateReturnCount = self:getInlineO2ReturnCount(ctx, expr)
if candidateReturnCount == 1 then
inlineReturnCount = candidateReturnCount
end
end
if inlineReturnCount ~= nil and self:compileInlineO2Call(ctx, expr, baseReg + index - 1, inlineReturnCount) then
local returnLine = func.lines[#func.lines] or line
self:emitCloseUpvals(ctx, returnLine)
builder:emitABC(func, LOP.RETURN, baseReg, index + inlineReturnCount, 0, returnLine)
return true
elseif self:callHasOneResultO2(ctx, expr) then
self:compileCall(ctx, expr, baseReg + index - 1, 1)
else
self:compileCall(ctx, expr, baseReg + index - 1, -1)
local returnLine = func.lines[#func.lines] or line
self:emitCloseUpvals(ctx, returnLine)
builder:emitABC(func, LOP.RETURN, baseReg, 0, 0, returnLine)
return true
end
elseif index == #stat.values and expr.kind == "Vararg" then
builder:emitABC(func, LOP.GETVARARGS, baseReg + index - 1, 0, 0, expr.line or line)
self:useReg(ctx, baseReg + index - 1)
local returnLine = func.lines[#func.lines] or line
self:emitCloseUpvals(ctx, returnLine)
builder:emitABC(func, LOP.RETURN, baseReg, 0, 0, returnLine)
return true
else
self:compileExpr(ctx, expr, baseReg + index - 1)
end
end
end
local returnLine = #stat.values > 0 and (func.lines[#func.lines] or line) or line
self:emitCloseUpvals(ctx, returnLine)
builder:emitABC(func, LOP.RETURN, baseReg, #stat.values + 1, 0, returnLine)
return true
elseif stat.kind == "Do" then
local mark = self:enterScope(ctx)
local terminated = self:compileBlock(ctx, stat.body, false, tailReturn)
if not terminated then
self:emitCloseUpvals(ctx, func.lines[#func.lines] or line, mark.localCount + 1)
end
self:leaveScope(ctx, mark)
return terminated
elseif stat.kind == "If" then
if #stat.clauses == 1 and stat.elseBody == nil then
local onlyClause = stat.clauses[1]
local onlyStat = onlyClause and onlyClause.body and #onlyClause.body.body == 1 and onlyClause.body.body[1] or nil
local loopCloseStart = ctx.loopCloseStarts[#ctx.loopCloseStarts]
local needsLoopClose = loopCloseStart ~= nil and self:hasCloseUpvals(ctx, loopCloseStart)
if onlyStat and onlyStat.kind == "Break" and ctx.loopBreaks[#ctx.loopBreaks] and not needsLoopClose then
local breakJumps = {}
self:compileConditionValue(ctx, onlyClause.condition, nil, breakJumps, true)
for _, jump in ipairs(breakJumps) do
append(ctx.loopBreaks[#ctx.loopBreaks], jump)
end
return false
elseif onlyStat and onlyStat.kind == "Continue" and ctx.loopContinues[#ctx.loopContinues] and not needsLoopClose then
local continueJumps = {}
self:compileConditionValue(ctx, onlyClause.condition, nil, continueJumps, true)
for _, jump in ipairs(continueJumps) do
append(ctx.loopContinues[#ctx.loopContinues], jump)
end
return false
end
end
local endJumps = {}
local allBranchesTerminate = true
local reachedUnconditionalBranch = false
local optimizationLevel = builder.options.optimizationLevel or 1
for clauseIndex, clause in ipairs(stat.clauses) do
local condition = optimizationLevel >= 1 and self:getConstant(ctx, clause.condition) or nil
local conditionTruth = nil
if condition then
conditionTruth = self:constantTruth(condition)
end
local forcedFalseSide = nil
if not condition and optimizationLevel >= 1 and clause.condition.kind == "Bin" and clause.condition.op == "and" then
local rightConstant = self:getConstant(ctx, clause.condition.right)
if rightConstant and not self:constantTruth(rightConstant) then
forcedFalseSide = clause.condition.left
end
end
if conditionTruth == false then
-- Constant-false conditions compile as the else path only.
elseif forcedFalseSide then
self:compileExprSide(ctx, forcedFalseSide, baseReg)
else
local conditionAlwaysTrue = conditionTruth == true
local falseJumps = conditionAlwaysTrue and {} or self:compileCondJumps(ctx, clause.condition, false, baseReg)
local mark = self:enterScope(ctx)
local branchTerminates = self:compileBlock(ctx, clause.body, false, tailReturn)
local hasFollowingBranch = not conditionAlwaysTrue and (clauseIndex < #stat.clauses or stat.elseBody ~= nil)
local implicitTailReturn = false
local scopeLeft = false
if not branchTerminates then
self:emitCloseUpvals(ctx, func.lines[#func.lines] or line, mark.localCount + 1)
end
if not branchTerminates and hasFollowingBranch and tailReturn then
local returnLine = func.lines[#func.lines] or line
self:leaveScope(ctx, mark)
scopeLeft = true
local endJump = builder:label(func)
builder:emitAD(func, LOP.JUMP, 0, 0, returnLine)
append(endJumps, endJump)
branchTerminates = true
implicitTailReturn = true
end
if not scopeLeft then
self:leaveScope(ctx, mark)
end
if implicitTailReturn then
allBranchesTerminate = false
elseif not branchTerminates then
allBranchesTerminate = false
if hasFollowingBranch then
local endJump = builder:label(func)
builder:emitAD(func, LOP.JUMP, 0, 0, func.lines[#func.lines] or line)
append(endJumps, endJump)
end
end
local nextLabel = builder:label(func)
for _, jump in ipairs(falseJumps) do
builder:patchJump(func, jump, nextLabel)
end
if conditionAlwaysTrue then
reachedUnconditionalBranch = true
break
end
end
end
if reachedUnconditionalBranch then
-- Remaining elseif/else branches are unreachable.
elseif stat.elseBody then
local mark = self:enterScope(ctx)
local branchTerminates = self:compileBlock(ctx, stat.elseBody, false, tailReturn)
if not branchTerminates then
self:emitCloseUpvals(ctx, func.lines[#func.lines] or line, mark.localCount + 1)
end
self:leaveScope(ctx, mark)
if not branchTerminates then
allBranchesTerminate = false
end
else
allBranchesTerminate = false
end
local endLabel = builder:label(func)
for _, jump in ipairs(endJumps) do
builder:patchJump(func, jump, endLabel)
end
if allBranchesTerminate then
return true
end
elseif stat.kind == "While" then
if (builder.options.optimizationLevel or 1) >= 1 then
local condition = self:getConstant(ctx, stat.condition)
if condition and not self:constantTruth(condition) then
return false
end
end
ctx.hasLoops = true
local start = builder:label(func)
local falseJumps = self:compileCondJumps(ctx, stat.condition, false, baseReg)
append(ctx.loopBreaks, {})
append(ctx.loopContinues, {})
append(ctx.loopCloseStarts, #ctx.localList + 1)
append(ctx.loopKinds, "while")
local oldLoopDepth = ctx.loopDepth
ctx.loopDepth = oldLoopDepth + 1
local mark = self:enterScope(ctx)
local bodyStart = builder:label(func)
local bodyTerminated = self:compileBlock(ctx, stat.body, false)
local needsLoopClose = self:hasCloseUpvals(ctx, mark.localCount + 1)
local hasLoopCaptures = self:hasCapturedLocals(ctx, mark.localCount + 1)
self:emitCloseUpvals(ctx, func.lines[#func.lines] or line, mark.localCount + 1)
self:leaveScope(ctx, mark)
ctx.loopDepth = oldLoopDepth
local continueTarget = builder:label(func)
local continueJumps = ctx.loopContinues[#ctx.loopContinues]
local breakJumps = ctx.loopBreaks[#ctx.loopBreaks]
local backTarget = self:getLoopBackTarget(func, start, breakJumps, continueJumps, continueTarget)
local backOffset = backTarget - builder:label(func) - 1
if bodyTerminated and #breakJumps > 0 and #continueJumps == 0 and not needsLoopClose and not hasLoopCaptures and self:hasLeadingBreakJump(func, bodyStart, breakJumps) then
backOffset = 0
end
builder:emitAD(func, LOP.JUMPBACK, 0, backOffset, line)
local exit = builder:label(func)
for _, jump in ipairs(falseJumps) do
builder:patchJump(func, jump, exit)
end
local breaks = table.remove(ctx.loopBreaks) or {}
local continues = table.remove(ctx.loopContinues) or {}
table.remove(ctx.loopCloseStarts)
table.remove(ctx.loopKinds)
for _, jump in ipairs(breaks) do
builder:patchJump(func, jump, exit)
end
for _, jump in ipairs(continues) do
builder:patchJump(func, jump, continueTarget)
end
elseif stat.kind == "ForNumeric" then
local unroll = self:getNumericForLoopUnroll(ctx, stat)
if unroll then
if unroll.iterations == 0 then
return false
end
ctx.hasLoops = true
append(ctx.loopBreaks, {})
append(ctx.loopContinues, {})
append(ctx.loopCloseStarts, #ctx.localList + 1)
append(ctx.loopKinds, "for")
local oldLoopDepth = ctx.loopDepth
ctx.loopDepth = oldLoopDepth + 1
local breaks = ctx.loopBreaks[#ctx.loopBreaks]
local continues = ctx.loopContinues[#ctx.loopContinues]
for iteration = 0, unroll.iterations - 1 do
local continueStart = #continues
local mark = self:enterScope(ctx)
self:declareLocal(ctx, {
name = stat.name,
constKind = "Number",
constValue = unroll.from + iteration * unroll.step,
written = false,
})
self:compileBlock(ctx, stat.body, false)
self:emitCloseUpvals(ctx, func.lines[#func.lines] or line, mark.localCount + 1)
self:leaveScope(ctx, mark)
local continueTarget = builder:label(func)
for index = continueStart + 1, #continues do
builder:patchJump(func, continues[index], continueTarget)
end
end
local exit = builder:label(func)
for _, jump in ipairs(breaks) do
builder:patchJump(func, jump, exit)
end
table.remove(ctx.loopBreaks)
table.remove(ctx.loopContinues)
table.remove(ctx.loopCloseStarts)
table.remove(ctx.loopKinds)
ctx.loopDepth = oldLoopDepth
return false
end
ctx.hasLoops = true
local mark = self:enterScope(ctx)
local loopBase = self:reserve(ctx, 3)
local loopVarWritten = stat.localSymbol and stat.localSymbol.written == true or false
local loopVarReg = loopBase + 2
if loopVarWritten then
loopVarReg = self:reserve(ctx, 1)
end
local oldProtectedTop = ctx.protectedTop or 0
ctx.protectedTop = math.max(oldProtectedTop, ctx.nextReg)
self:compileExpr(ctx, stat.from, loopBase + 2)
self:compileExpr(ctx, stat.to, loopBase)
if stat.step then
self:compileExpr(ctx, stat.step, loopBase + 1)
else
builder:emitAD(func, LOP.LOADN, loopBase + 1, 1, line)
end
ctx.protectedTop = oldProtectedTop
local prep = builder:label(func)
builder:emitAD(func, LOP.FORNPREP, loopBase, 0, line)
local loopStart = builder:label(func)
local oldLoopDepth = ctx.loopDepth
ctx.loopDepth = oldLoopDepth + 1
if loopVarReg ~= loopBase + 2 then
builder:emitABC(func, LOP.MOVE, loopVarReg, loopBase + 2, 0, line)
end
local loopLocal = self:addLocalAt(ctx, stat.name, loopVarReg, loopVarWritten)
loopLocal.typeId = LBC_TYPE_NUMBER
append(ctx.loopBreaks, {})
append(ctx.loopContinues, {})
append(ctx.loopCloseStarts, mark.localCount + 1)
append(ctx.loopKinds, "for")
self:compileBlock(ctx, stat.body, false)
self:emitCloseUpvals(ctx, func.lines[#func.lines] or line, mark.localCount + 1)
self:leaveScope(ctx, mark)
ctx.loopDepth = oldLoopDepth
local continueTarget = builder:label(func)
local continueJumps = ctx.loopContinues[#ctx.loopContinues]
local breakJumps = ctx.loopBreaks[#ctx.loopBreaks]
local backTarget = self:getLoopBackTarget(func, loopStart, breakJumps, continueJumps, continueTarget)
local backOffset = backTarget - builder:label(func) - 1
builder:emitAD(func, LOP.FORNLOOP, loopBase, backOffset, line)
local exit = builder:label(func)
builder:patchJump(func, prep, exit)
local breaks = table.remove(ctx.loopBreaks) or {}
local continues = table.remove(ctx.loopContinues) or {}
table.remove(ctx.loopCloseStarts)
table.remove(ctx.loopKinds)
for _, jump in ipairs(breaks) do
builder:patchJump(func, jump, exit)
end
for _, jump in ipairs(continues) do
builder:patchJump(func, jump, continueTarget)
end
elseif stat.kind == "ForIn" then
ctx.hasLoops = true
local mark = self:enterScope(ctx)
local loopBase = self:reserve(ctx, 3)
self:compileExprList(ctx, stat.values, loopBase, 3)
local varsBase = self:reserve(ctx, math.max(#stat.names, 2))
local skipOp = LOP.FORGPREP
if (self.builder.options.optimizationLevel or 1) >= 1 and #stat.names <= 2 then
if #stat.values == 1 and stat.values[1].kind == "Call" then
local path = self:getImportPath(ctx, stat.values[1].callee)
if not path and stat.values[1].callee.kind == "Name" then
local localInfo = self:findLocal(ctx, stat.values[1].callee.name)
path = localInfo and localInfo.constPath or nil
end
if path and #path == 1 and path[1] == "ipairs" then
skipOp = LOP.FORGPREP_INEXT
elseif path and #path == 1 and path[1] == "pairs" then
skipOp = LOP.FORGPREP_NEXT
end
elseif #stat.values == 2 and not self.getfenvUsed and not self.setfenvUsed then
local path = self:getImportPath(ctx, stat.values[1])
if path and #path == 1 and path[1] == "next" then
skipOp = LOP.FORGPREP_NEXT
end
end
end
local prep = builder:label(func)
builder:emitAD(func, skipOp, loopBase, 0, line)
local loopStart = builder:label(func)
local oldLoopDepth = ctx.loopDepth
ctx.loopDepth = oldLoopDepth + 1
for index, name in ipairs(stat.names) do
self:addLocalAt(ctx, name, varsBase + index - 1, stat.localSymbols and stat.localSymbols[index] and stat.localSymbols[index].written == true or false)
end
append(ctx.loopBreaks, {})
append(ctx.loopContinues, {})
append(ctx.loopCloseStarts, mark.localCount + 1)
append(ctx.loopKinds, "for")
local oldBodyProtectedTop = ctx.protectedTop or 0
ctx.protectedTop = math.max(oldBodyProtectedTop, varsBase + math.max(#stat.names, 2))
local bodyStart = builder:label(func)
local bodyTerminated = self:compileBlock(ctx, stat.body, false)
local needsLoopClose = self:hasCloseUpvals(ctx, mark.localCount + 1)
local hasLoopCaptures = self:hasCapturedLocals(ctx, mark.localCount + 1)
self:emitCloseUpvals(ctx, func.lines[#func.lines] or line, mark.localCount + 1)
ctx.protectedTop = oldBodyProtectedTop
self:leaveScope(ctx, mark)
ctx.loopDepth = oldLoopDepth
local back = builder:label(func)
local continueJumps = ctx.loopContinues[#ctx.loopContinues]
local breakJumps = ctx.loopBreaks[#ctx.loopBreaks]
local backTarget = self:getLoopBackTarget(func, loopStart, breakJumps, continueJumps, back)
local breakOnlyTerminated = bodyTerminated and #breakJumps > 0 and #continueJumps == 0 and not needsLoopClose and not hasLoopCaptures and self:hasLeadingBreakJump(func, bodyStart, breakJumps)
builder:emitAD(func, LOP.FORGLOOP, loopBase, 0, line)
builder:emitAux(func, (skipOp == LOP.FORGPREP_INEXT and 0x80000000 or 0) + #stat.names, line)
local exit = builder:label(func)
builder:patchJump(func, prep, back)
builder:patchJump(func, back, breakOnlyTerminated and exit or backTarget)
local breaks = table.remove(ctx.loopBreaks) or {}
local continues = table.remove(ctx.loopContinues) or {}
table.remove(ctx.loopCloseStarts)
table.remove(ctx.loopKinds)
for _, jump in ipairs(breaks) do
builder:patchJump(func, jump, exit)
end
for _, jump in ipairs(continues) do
builder:patchJump(func, jump, back)
end
elseif stat.kind == "Repeat" then
ctx.hasLoops = true
local start = builder:label(func)
append(ctx.loopBreaks, {})
append(ctx.loopContinues, {})
append(ctx.loopCloseStarts, #ctx.localList + 1)
append(ctx.loopKinds, "repeat")
local oldLoopDepth = ctx.loopDepth
ctx.loopDepth = oldLoopDepth + 1
local mark = self:enterScope(ctx)
local oldDebugScopeDepth = ctx.debugScopeDepth
ctx.debugScopeDepth = oldDebugScopeDepth + 1
local continuesInLoop = ctx.loopContinues[#ctx.loopContinues]
local continueCount = continuesInLoop and #continuesInLoop or 0
local continueValidated = false
local conditionLocalCount = nil
for _, bodyStat in ipairs(stat.body.body) do
self:compileStatement(ctx, bodyStat, false)
ctx.nextReg = math.max(self:getLocalTop(ctx), ctx.protectedTop or 0)
if not continueValidated and continuesInLoop and #continuesInLoop > continueCount then
continueValidated = true
conditionLocalCount = #ctx.localList
end
end
ctx.debugScopeDepth = oldDebugScopeDepth
if continueValidated and conditionLocalCount ~= nil then
self:emitCloseUpvals(ctx, func.lines[#func.lines] or line, conditionLocalCount + 1)
self:popLocals(ctx, conditionLocalCount)
end
local continueTarget = builder:label(func)
local condition = (builder.options.optimizationLevel or 1) >= 1 and self:getConstant(ctx, stat.condition) or nil
local conditionLine = stat.condition.line or line
local trueJumps = {}
if not (condition and self:constantTruth(condition)) then
local beforeCondition = builder:label(func)
trueJumps = self:compileCondJumps(ctx, stat.condition, true, ctx.nextReg)
conditionLine = builder:label(func) ~= beforeCondition and (func.lines[#func.lines] or conditionLine) or conditionLine
local continues = ctx.loopContinues[#ctx.loopContinues]
local breaks = ctx.loopBreaks[#ctx.loopBreaks]
local backTarget = self:getLoopBackTarget(func, start, breaks, continues, continueTarget)
self:emitCloseUpvals(ctx, conditionLine, mark.localCount + 1)
builder:emitAD(func, LOP.JUMPBACK, 0, backTarget - builder:label(func) - 1, conditionLine)
else
self:emitCloseUpvals(ctx, conditionLine, mark.localCount + 1)
end
local exitCloseLabel = nil
if not (condition and self:constantTruth(condition)) then
exitCloseLabel = builder:label(func)
self:emitCloseUpvals(ctx, conditionLine, mark.localCount + 1)
end
self:leaveScope(ctx, mark)
ctx.loopDepth = oldLoopDepth
local exit = builder:label(func)
for _, jump in ipairs(trueJumps) do
builder:patchJump(func, jump, exitCloseLabel or exit)
end
local breaks = table.remove(ctx.loopBreaks) or {}
local continues = table.remove(ctx.loopContinues) or {}
table.remove(ctx.loopCloseStarts)
table.remove(ctx.loopKinds)
for _, jump in ipairs(breaks) do
builder:patchJump(func, jump, exit)
end
for _, jump in ipairs(continues) do
builder:patchJump(func, jump, continueTarget)
end
elseif stat.kind == "Break" then
local breaks = ctx.loopBreaks[#ctx.loopBreaks]
if not breaks then
error("break outside loop", 2)
end
self:emitCloseUpvals(ctx, line, ctx.loopCloseStarts[#ctx.loopCloseStarts])
local jump = builder:label(func)
builder:emitAD(func, LOP.JUMP, 0, 0, line)
append(breaks, jump)
return true
elseif stat.kind == "Continue" then
local continues = ctx.loopContinues[#ctx.loopContinues]
if not continues then
error("continue outside loop", 2)
end
local closeStart = ctx.loopCloseStarts[#ctx.loopCloseStarts]
if ctx.loopKinds[#ctx.loopKinds] == "repeat" then
closeStart = #ctx.localList + 1
end
self:emitCloseUpvals(ctx, line, closeStart)
local jump = builder:label(func)
builder:emitAD(func, LOP.JUMP, 0, 0, line)
append(continues, jump)
return true
elseif stat.kind == "FunctionStat" then
local debugName = nil
if stat.target.kind == "Name" then
debugName = stat.target.name
elseif stat.target.kind == "Field" then
debugName = stat.target.field
end
if stat.target.kind == "Name" then
local localInfo = self:findLocal(ctx, stat.target.name)
if localInfo and localInfo.reg ~= nil then
localInfo.constPath = nil
localInfo.constKind = nil
localInfo.constValue = nil
self:compileFunctionExpr(ctx, stat.value, localInfo.reg, debugName)
return false
end
end
self:compileFunctionExpr(ctx, stat.value, baseReg, debugName)
self:compileAssignTarget(ctx, stat.target, baseReg, line)
else
error("unsupported statement " .. tostring(stat.kind), 2)
end
ctx.nextReg = self:getLocalTop(ctx)
return false
end
function PureCompiler:compileBlock(ctx, block, scoped, tailReturn)
local mark = scoped and self:enterScope(ctx) or nil
local oldDebugScopeDepth = ctx.debugScopeDepth
ctx.debugScopeDepth = oldDebugScopeDepth + 1
local terminated = false
local tailIndex = nil
if tailReturn and (ctx.builder.options.optimizationLevel or 1) >= 1 then
for index = #block.body, 1, -1 do
if block.body[index].kind ~= "Nop" then
tailIndex = index
break
end
end
end
for index, stat in ipairs(block.body) do
if self:compileStatement(ctx, stat, index == tailIndex) then
terminated = true
break
end
ctx.nextReg = math.max(self:getLocalTop(ctx), ctx.protectedTop or 0)
end
if mark then
self:leaveScope(ctx, mark)
end
ctx.debugScopeDepth = oldDebugScopeDepth
return terminated
end
function PureCompiler:compileProto(block, params, isvararg)
local func = self.builder:createFunction(#params, isvararg)
local ctx = self:newContext(func, nil)
local globalWrites = self:annotateWriteSymbols(block)
self.getfenvUsed = false
self.setfenvUsed = false
if (self.builder.options.optimizationLevel or 1) >= 1 then
self:collectFenvUses(block)
end
ctx.writeNames = {}
ctx.readNames = self:collectReadNames(block)
ctx.globalWrites = globalWrites
ctx.tableArrayHints = (self.builder.options.optimizationLevel or 1) >= 1 and self:collectTableArrayHints(block) or {}
ctx.nextReg = #params
ctx.maxReg = #params
if isvararg then
self.builder:emitABC(func, LOP.PREPVARARGS, #params, 0, 0, 1)
end
for index, name in ipairs(params) do
self:declareLocal(ctx, {
name = name,
reg = index - 1,
written = ctx.writeNames and ctx.writeNames[name] == true or false,
})
end
local terminated = self:compileBlock(ctx, block, false, (self.builder.options.optimizationLevel or 1) >= 1)
if not ctx.hasLoops then
func.flags = (func.flags or 0) + LPF_NATIVE_COLD
end
if self.hasNativeFunction then
func.flags = (func.flags or 0) + LPF_NATIVE_FUNCTION
end
if not terminated then
local returnLine = block.endLine or block.line or 1
self:emitCloseUpvals(ctx, returnLine)
self.builder:emitABC(func, LOP.RETURN, 0, 1, 0, returnLine)
end
self:popLocals(ctx, 0)
return self.builder:addFunction(func)
end
function PureCompiler.compile(source, options)
local parser = PureParser.new(source)
local block = parser:parse()
local builder = PureBytecodeBuilder.new(options)
local compiler = PureCompiler.new(builder)
local main = compiler:compileProto(block, {}, true)
builder.mainFunction = main
return builder:finalize()
end
local Backend = {}
function Backend.compile(source, options)
local ok, result = pcall(function()
return PureCompiler.compile(source, options)
end)
if ok then
return result
end
local message = trimStackTrace(tostring(result))
if not string.match(message, "^:%d+:") then
message = ":1: " .. message
end
return BytecodeBuilder.getError(message)
end
Luau.Backend = Backend
local Compiler = {}
Compiler.__index = Compiler
function Compiler.new(options)
return setmetatable({
options = CompileOptions.new(options),
}, Compiler)
end
function Compiler:compileInto(parseResult, names, bytecodeBuilder, source)
local builder = bytecodeBuilder or BytecodeBuilder.new()
local bytecodeBlob = Backend.compile(source or "", self.options)
builder:setBytecode(bytecodeBlob)
builder:setMainFunction(0)
builder:finalize()
return builder:getBytecode()
end
function Compiler.compile(source, options, parseOptions)
local allocator = Allocator.new()
local names = AstNameTable.new()
local parseResult = Parser.parse(source, #source, names, allocator, ParseOptions.new(parseOptions))
local bytecodeBuilder = BytecodeBuilder.new()
local compiler = Compiler.new(options)
compiler:compileInto(parseResult, names, bytecodeBuilder, source)
return bytecodeBuilder:getBytecode()
end
function Compiler.compileOrThrow(source, options, parseOptions)
local bytecodeBlob = Compiler.compile(source, options, parseOptions)
if BytecodeBuilder.isError(bytecodeBlob) then
error(string.sub(bytecodeBlob, 2), 2)
end
return bytecodeBlob
end
Luau.Compiler = Compiler
function Luau.compile(source, options, parseOptions)
assert(type(source) == "string", "source must be a string")
return Compiler.compile(source, options, parseOptions)
end
function Luau.compileOrThrow(source, options, parseOptions)
assert(type(source) == "string", "source must be a string")
return Compiler.compileOrThrow(source, options, parseOptions)
end
function Luau.parse(source, parseOptions)
local allocator = Allocator.new()
local names = AstNameTable.new()
return Parser.parse(source, #source, names, allocator, ParseOptions.new(parseOptions))
end
function Luau.luau_compile(source, sizeOrOptions, optionsOrOutsize, maybeOutsize)
assert(type(source) == "string", "source must be a string")
local options = nil
local outsize = nil
if type(sizeOrOptions) == "number" then
source = string.sub(source, 1, sizeOrOptions)
options = optionsOrOutsize
outsize = maybeOutsize
else
options = sizeOrOptions
outsize = optionsOrOutsize
end
local bytecodeBlob = Luau.compile(source, options)
if type(outsize) == "table" then
outsize[1] = #bytecodeBlob
outsize.size = #bytecodeBlob
end
return bytecodeBlob
end
Luau.luau_set_compile_constant_nil = function(constant)
constant.type = "nil"
constant.value = nil
end
Luau.luau_set_compile_constant_boolean = function(constant, value)
constant.type = "boolean"
constant.value = value ~= false and value ~= 0
end
Luau.luau_set_compile_constant_number = function(constant, value)
constant.type = "number"
constant.value = value
end
Luau.luau_set_compile_constant_integer64 = function(constant, value)
constant.type = "integer64"
constant.value = value
end
Luau.luau_set_compile_constant_vector = function(constant, x, y, z, w)
constant.type = "vector"
constant.value = { x, y, z, w }
end
Luau.luau_set_compile_constant_string = function(constant, value, length)
constant.type = "string"
constant.value = string.sub(value, 1, length or #value)
end
local function parseCliOptions(args)
local options = CompileOptions.new()
local files = {}
local output = nil
local selfTest = false
local summary = false
local index = 1
if args[1] == "--" then
index = 2
end
while index <= #args do
local arg = args[index]
if arg == "--compile" then
-- Explicit marker for CLI use; modules ignore process args unless this or another CLI flag is present.
elseif arg == "--self-test" then
selfTest = true
elseif arg == "--summary" then
summary = true
elseif arg == "-o" or arg == "--output" then
index += 1
output = args[index]
elseif startsWith(arg, "-O") and #arg >= 3 then
options.optimizationLevel = tonumber(string.sub(arg, 3)) or options.optimizationLevel
elseif startsWith(arg, "-g") and #arg >= 3 then
options.debugLevel = tonumber(string.sub(arg, 3)) or options.debugLevel
elseif startsWith(arg, "--optimization-level=") then
options.optimizationLevel = tonumber(string.sub(arg, #"--optimization-level=" + 1)) or options.optimizationLevel
elseif startsWith(arg, "--debug-level=") then
options.debugLevel = tonumber(string.sub(arg, #"--debug-level=" + 1)) or options.debugLevel
elseif startsWith(arg, "--coverage-level=") then
options.coverageLevel = tonumber(string.sub(arg, #"--coverage-level=" + 1)) or options.coverageLevel
elseif startsWith(arg, "--vector-lib=") then
options.vectorLib = string.sub(arg, #"--vector-lib=" + 1)
elseif startsWith(arg, "--vector-ctor=") then
options.vectorCtor = string.sub(arg, #"--vector-ctor=" + 1)
elseif startsWith(arg, "--vector-type=") then
options.vectorType = string.sub(arg, #"--vector-type=" + 1)
elseif arg == "-h" or arg == "--help" then
return options, files, output, selfTest, summary, true
else
append(files, arg)
end
index += 1
end
return options, files, output, selfTest, summary, false
end
function Luau.runSelfTest()
local snippets = {
"print('Hello, World!')",
"local function f(x) return x + 1 end return f(41)",
"local t = {a = 1, [2] = 3}; return t.a + t[2]",
"local i = 1 while i <= 3 do i += 1 end return i",
"local s = 0 for i = 1, 3 do s += i end return s",
}
local hostLuau = tryRequire("@lune/luau")
for index, source in ipairs(snippets) do
local compiled = Luau.luau_compile(source, {
optimizationLevel = 1,
debugLevel = 1,
coverageLevel = 0,
})
assert(not BytecodeBuilder.isError(compiled), string.format("self-test snippet %d failed: %s", index, string.sub(compiled, 2)))
if hostLuau then
assert(hostLuau.load(compiled), string.format("self-test snippet %d produced unloadable bytecode", index))
end
end
return true
end
function Luau.main(args)
local fs = tryRequire("@lune/fs")
local stdio = tryRequire("@lune/stdio")
if not fs or not stdio then
error("lcompile CLI requires Lune's @lune/fs and @lune/stdio modules", 2)
end
local options, files, output, selfTest, summary, help = parseCliOptions(args or {})
if help then
stdio.write("usage: lune run lcompile.luau -- [options] <file>\n")
stdio.write("options: -O<n> -g<n> --coverage-level=<n> --summary -o <file> --self-test\n")
return 0
end
if selfTest then
Luau.runSelfTest()
stdio.write("lcompile self-test passed\n")
return 0
end
if #files == 0 then
return 0
end
local chunks = {}
for _, file in ipairs(files) do
append(chunks, fs.readFile(file))
end
local source = table.concat(chunks, "\n")
local bytecodeBlob = Luau.luau_compile(source, options)
if summary then
local info = BytecodeReader.summary(bytecodeBlob)
if info.error then
stdio.write(string.format("error %s\n", info.error))
else
stdio.write(string.format("bytecode version=%d typeVersion=%s size=%d\n", info.version, tostring(info.typeVersion), info.size))
end
return info.error and 1 or 0
end
if output then
fs.writeFile(output, bytecodeBlob)
else
stdio.write(bytecodeBlob)
end
return BytecodeBuilder.isError(bytecodeBlob) and 1 or 0
end
if type(_G) == "table" then
pcall(function()
_G.luau_compile = Luau.luau_compile
_G.LuauCompile = Luau
end)
end
do
local process = tryRequire("@lune/process")
if process and process.args and #process.args > 0 and rawget(_G, "__LCOMPILE_DISABLE_CLI") ~= true then
Luau.main(process.args)
end
end
return Luau
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment