Lua vararg testing
parent
c70f8fef1a
commit
ec6db8c766
|
@ -114,6 +114,31 @@ export function evalExpression(
|
|||
return evalPrefixExpression(e, env, sf);
|
||||
case "TableConstructor": {
|
||||
const table = new LuaTable();
|
||||
if (
|
||||
e.fields.length === 1 &&
|
||||
e.fields[0].type === "ExpressionField" &&
|
||||
e.fields[0].value.type === "Variable" &&
|
||||
e.fields[0].value.name === "..."
|
||||
) {
|
||||
const varargs = env.get("...");
|
||||
if (varargs instanceof Promise) {
|
||||
return varargs.then((resolvedVarargs) => {
|
||||
if (resolvedVarargs instanceof LuaTable) {
|
||||
const newTable = new LuaTable();
|
||||
for (let i = 1; i <= resolvedVarargs.length; i++) {
|
||||
newTable.set(i, resolvedVarargs.get(i), sf);
|
||||
}
|
||||
return newTable;
|
||||
}
|
||||
return table;
|
||||
});
|
||||
} else if (varargs instanceof LuaTable) {
|
||||
for (let i = 1; i <= varargs.length; i++) {
|
||||
table.set(i, varargs.get(i), sf);
|
||||
}
|
||||
}
|
||||
return table;
|
||||
}
|
||||
const promises: Promise<void>[] = [];
|
||||
for (const field of e.fields) {
|
||||
switch (field.type) {
|
||||
|
@ -162,21 +187,49 @@ export function evalExpression(
|
|||
case "ExpressionField": {
|
||||
const value = evalExpression(field.value, env, sf);
|
||||
if (value instanceof Promise) {
|
||||
promises.push(value.then((value) => {
|
||||
// +1 because Lua tables are 1-indexed
|
||||
table.set(
|
||||
table.length + 1,
|
||||
singleResult(value),
|
||||
sf,
|
||||
);
|
||||
promises.push(value.then(async (value) => {
|
||||
if (
|
||||
field.value.type === "Variable" &&
|
||||
field.value.name === "..."
|
||||
) {
|
||||
// Special handling for {...}
|
||||
const varargs = await Promise.resolve(env.get("..."));
|
||||
if (varargs instanceof LuaTable) {
|
||||
// Copy all values from varargs table
|
||||
for (let i = 1; i <= varargs.length; i++) {
|
||||
const val = await Promise.resolve(varargs.get(i));
|
||||
table.set(i, val, sf);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Normal case
|
||||
table.set(table.length + 1, singleResult(value), sf);
|
||||
}
|
||||
}));
|
||||
} else {
|
||||
// +1 because Lua tables are 1-indexed
|
||||
table.set(
|
||||
table.length + 1,
|
||||
singleResult(value),
|
||||
sf,
|
||||
if (
|
||||
field.value.type === "Variable" && field.value.name === "..."
|
||||
) {
|
||||
// Special handling for {...}
|
||||
const varargs = env.get("...");
|
||||
if (varargs instanceof LuaTable) {
|
||||
for (let i = 1; i <= varargs.length; i++) {
|
||||
const val = varargs.get(i);
|
||||
if (val instanceof Promise) {
|
||||
promises.push(
|
||||
Promise.resolve(val).then((val) => {
|
||||
table.set(i, val, sf);
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
table.set(i, val, sf);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Normal case
|
||||
table.set(table.length + 1, singleResult(value), sf);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -258,53 +311,46 @@ function evalPrefixExpression(
|
|||
sf.withCtx(e.prefix.ctx),
|
||||
);
|
||||
}
|
||||
|
||||
// Special handling for f(...) - propagate varargs
|
||||
if (
|
||||
e.args.length === 1 && e.args[0].type === "Variable" &&
|
||||
e.args[0].name === "..."
|
||||
) {
|
||||
const varargs = env.get("...");
|
||||
const resolveVarargs = async () => {
|
||||
const resolvedVarargs = await Promise.resolve(varargs);
|
||||
if (resolvedVarargs instanceof LuaTable) {
|
||||
const args = [];
|
||||
for (let i = 1; i <= resolvedVarargs.length; i++) {
|
||||
const val = await Promise.resolve(resolvedVarargs.get(i));
|
||||
args.push(val);
|
||||
}
|
||||
return args;
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
if (prefixValue instanceof Promise) {
|
||||
return prefixValue.then((prefixValue) => {
|
||||
if (!prefixValue) {
|
||||
throw new LuaRuntimeError(
|
||||
`Attempting to call a nil value`,
|
||||
sf.withCtx(e.prefix.ctx),
|
||||
);
|
||||
}
|
||||
let selfArgs: LuaValue[] = [];
|
||||
// Handling a:b() syntax (b is kept in .name)
|
||||
if (e.name && !prefixValue.get) {
|
||||
throw new LuaRuntimeError(
|
||||
`Attempting to index a non-table: ${prefixValue}`,
|
||||
sf.withCtx(e.prefix.ctx),
|
||||
);
|
||||
} else if (e.name) {
|
||||
// Two things need to happen: the actual function be called needs to be looked up in the table, and the table itself needs to be passed as the first argument
|
||||
selfArgs = [prefixValue];
|
||||
prefixValue = prefixValue.get(e.name);
|
||||
}
|
||||
if (!prefixValue.call) {
|
||||
throw new LuaRuntimeError(
|
||||
`Attempting to call ${prefixValue} as a function`,
|
||||
sf.withCtx(e.prefix.ctx),
|
||||
);
|
||||
}
|
||||
const args = evalPromiseValues(
|
||||
e.args.map((arg) => evalExpression(arg, env, sf)),
|
||||
);
|
||||
if (args instanceof Promise) {
|
||||
return args.then((args) =>
|
||||
luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf)
|
||||
);
|
||||
} else {
|
||||
return luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf);
|
||||
}
|
||||
return prefixValue.then(async (resolvedPrefix) => {
|
||||
const args = await resolveVarargs();
|
||||
return luaCall(resolvedPrefix, args, e.ctx, sf);
|
||||
});
|
||||
} else {
|
||||
return resolveVarargs().then((args) =>
|
||||
luaCall(prefixValue, args, e.ctx, sf)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Normal argument handling
|
||||
let selfArgs: LuaValue[] = [];
|
||||
// Handling a:b() syntax (b is kept in .name)
|
||||
if (e.name && !prefixValue.get) {
|
||||
throw new LuaRuntimeError(
|
||||
`Attempting to index a non-table: ${prefixValue}`,
|
||||
sf.withCtx(e.prefix.ctx),
|
||||
);
|
||||
} else if (e.name) {
|
||||
// Two things need to happen: the actual function be called needs to be looked up in the table, and the table itself needs to be passed as the first argument
|
||||
selfArgs = [prefixValue];
|
||||
prefixValue = prefixValue.get(e.name);
|
||||
}
|
||||
|
@ -325,7 +371,6 @@ function evalPrefixExpression(
|
|||
return luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf);
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
throw new Error(`Unknown prefix expression type ${e.type}`);
|
||||
}
|
||||
|
|
|
@ -438,6 +438,68 @@ local function test_advanced_closures()
|
|||
assert(c2.get() == 5, "Second counter should be independent")
|
||||
end
|
||||
|
||||
-- Test varargs handling
|
||||
local function test_varargs()
|
||||
-- Basic varargs sum function
|
||||
local function sum(...)
|
||||
local total = 0
|
||||
for i, v in ipairs({ ... }) do
|
||||
total = total + v
|
||||
end
|
||||
return total
|
||||
end
|
||||
|
||||
assert(sum() == 0, "Sum should handle no arguments")
|
||||
assert(sum(42) == 42, "Sum should handle single argument")
|
||||
assert(sum(1, 2) == 3, "Sum should handle two arguments")
|
||||
|
||||
-- Test varargs propagation
|
||||
local function pass_varargs(...)
|
||||
return sum(...)
|
||||
end
|
||||
|
||||
assert(pass_varargs() == 0, "Should propagate empty varargs")
|
||||
assert(pass_varargs(1, 2, 3) == 6, "Should propagate varargs")
|
||||
end
|
||||
|
||||
test_varargs()
|
||||
print("All varargs tests passed!")
|
||||
|
||||
-- Test closure behavior
|
||||
local function test_closures()
|
||||
-- Counter that can count by custom steps
|
||||
local function make_counter_with_step()
|
||||
local count = 0
|
||||
return {
|
||||
increment = function(step)
|
||||
count = count + (step or 1)
|
||||
return count
|
||||
end,
|
||||
decrement = function(step)
|
||||
count = count - (step or 1)
|
||||
return count
|
||||
end,
|
||||
get = function()
|
||||
return count
|
||||
end
|
||||
}
|
||||
end
|
||||
|
||||
local counter = make_counter_with_step()
|
||||
assert(counter.increment(5) == 5, "Counter should increment by 5")
|
||||
assert(counter.decrement(2) == 3, "Counter should decrement by 2")
|
||||
assert(counter.get() == 3, "Counter should maintain state")
|
||||
assert(counter.increment() == 4, "Counter should default to 1")
|
||||
|
||||
-- Test multiple independent counters
|
||||
local c1 = make_counter_with_step()
|
||||
local c2 = make_counter_with_step()
|
||||
c1.increment(10)
|
||||
c2.increment(5)
|
||||
assert(c1.get() == 10, "First counter should be independent")
|
||||
assert(c2.get() == 5, "Second counter should be independent")
|
||||
end
|
||||
|
||||
-- Test closures with shared upvalues
|
||||
local function test_shared_closures()
|
||||
local function make_shared_counter()
|
||||
|
@ -463,78 +525,9 @@ local function test_shared_closures()
|
|||
assert(get() == 1, "Get should return current value")
|
||||
end
|
||||
|
||||
-- Test varargs handling
|
||||
local function test_varargs()
|
||||
-- Basic varargs sum function
|
||||
local function sum(...)
|
||||
local args = { ... }
|
||||
local total = 0
|
||||
for _, v in ipairs(args) do
|
||||
total = total + v
|
||||
end
|
||||
return total
|
||||
end
|
||||
|
||||
assert(sum(1, 2, 3, 4, 5) == 15, "Sum should handle multiple arguments")
|
||||
assert(sum() == 0, "Sum should handle no arguments")
|
||||
assert(sum(42) == 42, "Sum should handle single argument")
|
||||
|
||||
-- Test varargs propagation
|
||||
local function pass_varargs(...)
|
||||
return sum(...)
|
||||
end
|
||||
|
||||
assert(pass_varargs(1, 2, 3) == 6, "Should propagate varargs")
|
||||
assert(pass_varargs() == 0, "Should propagate empty varargs")
|
||||
|
||||
-- Test mixing regular args with varargs
|
||||
local function first_plus_sum(first, ...)
|
||||
local args = { ... }
|
||||
local total = first or 0
|
||||
for _, v in ipairs(args) do
|
||||
total = total + v
|
||||
end
|
||||
return total
|
||||
end
|
||||
|
||||
assert(first_plus_sum(10, 1, 2, 3) == 16, "Should handle mixed arguments")
|
||||
assert(first_plus_sum(5) == 5, "Should handle only first argument")
|
||||
end
|
||||
|
||||
-- Test closure edge cases
|
||||
local function test_closure_edge_cases()
|
||||
-- Test closure over loop variables
|
||||
local closures = {}
|
||||
for i = 1, 3 do
|
||||
closures[i] = function() return i end
|
||||
end
|
||||
|
||||
assert(closures[1]() == 1, "Should capture loop variable")
|
||||
assert(closures[2]() == 2, "Should capture loop variable")
|
||||
assert(closures[3]() == 3, "Should capture loop variable")
|
||||
|
||||
-- Test nested closure scopes
|
||||
local function make_nested_counter(start)
|
||||
local count = start
|
||||
return function()
|
||||
local function increment()
|
||||
count = count + 1
|
||||
return count
|
||||
end
|
||||
return increment()
|
||||
end
|
||||
end
|
||||
|
||||
local counter1 = make_nested_counter(5)
|
||||
local counter2 = make_nested_counter(10)
|
||||
assert(counter1() == 6, "First nested counter")
|
||||
assert(counter1() == 7, "First nested counter increment")
|
||||
assert(counter2() == 11, "Second nested counter independent")
|
||||
end
|
||||
|
||||
-- Run the new tests
|
||||
test_advanced_closures()
|
||||
test_shared_closures()
|
||||
-- Run all tests
|
||||
test_varargs()
|
||||
test_closure_edge_cases()
|
||||
print("All closure and varargs tests passed!")
|
||||
test_closures()
|
||||
test_shared_closures()
|
||||
test_advanced_closures()
|
||||
print("All tests passed!")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import type { ASTCtx, LuaFunctionBody } from "./ast.ts";
|
||||
import { evalStatement } from "$common/space_lua/eval.ts";
|
||||
import { asyncQuickSort } from "$common/space_lua/util.ts";
|
||||
import { asyncQuickSort, evalPromiseValues } from "$common/space_lua/util.ts";
|
||||
|
||||
export type LuaType =
|
||||
| "nil"
|
||||
|
@ -141,26 +141,95 @@ export class LuaFunction implements ILuaFunction {
|
|||
private capturedEnv: LuaEnv;
|
||||
|
||||
constructor(readonly body: LuaFunctionBody, closure: LuaEnv) {
|
||||
// Don't create a new environment, just store the reference to the closure environment
|
||||
this.capturedEnv = closure;
|
||||
}
|
||||
|
||||
call(sf: LuaStackFrame, ...args: LuaValue[]): Promise<LuaValue> | LuaValue {
|
||||
call(sf: LuaStackFrame, ...args: LuaValue[]): Promise<LuaValue> {
|
||||
// Create a new environment that chains to the captured environment
|
||||
const env = new LuaEnv(this.capturedEnv);
|
||||
if (!sf) {
|
||||
console.trace(sf);
|
||||
}
|
||||
env.setLocal("_CTX", sf.threadLocal);
|
||||
|
||||
// Assign the passed arguments to the parameters
|
||||
for (let i = 0; i < this.body.parameters.length; i++) {
|
||||
const paramName = this.body.parameters[i];
|
||||
if (paramName === "...") {
|
||||
// Handle varargs by creating a table with all remaining arguments
|
||||
const varargs = new LuaTable();
|
||||
// Include all remaining arguments (might be none)
|
||||
for (let j = i; j < args.length; j++) {
|
||||
varargs.set(j - i + 1, args[j], sf);
|
||||
}
|
||||
env.setLocal("...", varargs);
|
||||
break;
|
||||
}
|
||||
let arg = args[i];
|
||||
if (arg === undefined) {
|
||||
arg = null;
|
||||
}
|
||||
env.setLocal(this.body.parameters[i], arg);
|
||||
}
|
||||
return evalStatement(this.body.block, env, sf).catch((e: any) => {
|
||||
|
||||
// If the function has varargs parameter but it wasn't set above, set an empty varargs table
|
||||
if (this.body.parameters.includes("...") && !env.has("...")) {
|
||||
env.setLocal("...", new LuaTable());
|
||||
}
|
||||
|
||||
const resolvedArgs = evalPromiseValues(args);
|
||||
if (resolvedArgs instanceof Promise) {
|
||||
return resolvedArgs.then((args) => this.callWithArgs(args, env, sf));
|
||||
}
|
||||
return this.callWithArgs(resolvedArgs, env, sf);
|
||||
}
|
||||
|
||||
toString(): string {
|
||||
return `<lua function(${this.body.parameters.join(", ")})>`;
|
||||
}
|
||||
|
||||
private callWithArgs(
|
||||
args: LuaValue[],
|
||||
env: LuaEnv,
|
||||
sf: LuaStackFrame,
|
||||
): Promise<LuaValue> {
|
||||
// Set up parameters and varargs
|
||||
for (let i = 0; i < this.body.parameters.length; i++) {
|
||||
const paramName = this.body.parameters[i];
|
||||
if (paramName === "...") {
|
||||
const varargs = new LuaTable();
|
||||
for (let j = i; j < args.length; j++) {
|
||||
if (args[j] instanceof Promise) {
|
||||
return Promise.all(args.slice(i)).then((resolvedArgs) => {
|
||||
const varargs = new LuaTable();
|
||||
resolvedArgs.forEach((val, idx) => varargs.set(idx + 1, val, sf));
|
||||
env.setLocal("...", varargs);
|
||||
return this.evalBody(env, sf);
|
||||
});
|
||||
}
|
||||
varargs.set(j - i + 1, args[j], sf);
|
||||
}
|
||||
env.setLocal("...", varargs);
|
||||
break;
|
||||
}
|
||||
env.setLocal(paramName, args[i] ?? null);
|
||||
}
|
||||
|
||||
// Ensure empty varargs table exists if needed
|
||||
if (this.body.parameters.includes("...") && !env.has("...")) {
|
||||
env.setLocal("...", new LuaTable());
|
||||
}
|
||||
|
||||
return this.evalBody(env, sf);
|
||||
}
|
||||
|
||||
private async evalBody(
|
||||
env: LuaEnv,
|
||||
sf: LuaStackFrame,
|
||||
): Promise<LuaValue> {
|
||||
try {
|
||||
await evalStatement(this.body.block, env, sf);
|
||||
} catch (e: any) {
|
||||
if (e instanceof LuaReturn) {
|
||||
if (e.values.length === 0) {
|
||||
return;
|
||||
|
@ -172,11 +241,7 @@ export class LuaFunction implements ILuaFunction {
|
|||
} else {
|
||||
throw e;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
toString(): string {
|
||||
return `<lua function(${this.body.parameters.join(", ")})>`;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue