Lua vararg testing

pull/1219/head
Zef Hemel 2025-01-08 11:04:33 +01:00
parent c70f8fef1a
commit ec6db8c766
3 changed files with 262 additions and 159 deletions

View File

@ -114,6 +114,31 @@ export function evalExpression(
return evalPrefixExpression(e, env, sf); return evalPrefixExpression(e, env, sf);
case "TableConstructor": { case "TableConstructor": {
const table = new LuaTable(); const table = new LuaTable();
if (
e.fields.length === 1 &&
e.fields[0].type === "ExpressionField" &&
e.fields[0].value.type === "Variable" &&
e.fields[0].value.name === "..."
) {
const varargs = env.get("...");
if (varargs instanceof Promise) {
return varargs.then((resolvedVarargs) => {
if (resolvedVarargs instanceof LuaTable) {
const newTable = new LuaTable();
for (let i = 1; i <= resolvedVarargs.length; i++) {
newTable.set(i, resolvedVarargs.get(i), sf);
}
return newTable;
}
return table;
});
} else if (varargs instanceof LuaTable) {
for (let i = 1; i <= varargs.length; i++) {
table.set(i, varargs.get(i), sf);
}
}
return table;
}
const promises: Promise<void>[] = []; const promises: Promise<void>[] = [];
for (const field of e.fields) { for (const field of e.fields) {
switch (field.type) { switch (field.type) {
@ -162,21 +187,49 @@ export function evalExpression(
case "ExpressionField": { case "ExpressionField": {
const value = evalExpression(field.value, env, sf); const value = evalExpression(field.value, env, sf);
if (value instanceof Promise) { if (value instanceof Promise) {
promises.push(value.then((value) => { promises.push(value.then(async (value) => {
// +1 because Lua tables are 1-indexed if (
table.set( field.value.type === "Variable" &&
table.length + 1, field.value.name === "..."
singleResult(value), ) {
sf, // Special handling for {...}
); const varargs = await Promise.resolve(env.get("..."));
if (varargs instanceof LuaTable) {
// Copy all values from varargs table
for (let i = 1; i <= varargs.length; i++) {
const val = await Promise.resolve(varargs.get(i));
table.set(i, val, sf);
}
}
} else {
// Normal case
table.set(table.length + 1, singleResult(value), sf);
}
})); }));
} else { } else {
// +1 because Lua tables are 1-indexed if (
table.set( field.value.type === "Variable" && field.value.name === "..."
table.length + 1, ) {
singleResult(value), // Special handling for {...}
sf, const varargs = env.get("...");
); if (varargs instanceof LuaTable) {
for (let i = 1; i <= varargs.length; i++) {
const val = varargs.get(i);
if (val instanceof Promise) {
promises.push(
Promise.resolve(val).then((val) => {
table.set(i, val, sf);
}),
);
} else {
table.set(i, val, sf);
}
}
}
} else {
// Normal case
table.set(table.length + 1, singleResult(value), sf);
}
} }
break; break;
} }
@ -258,73 +311,65 @@ function evalPrefixExpression(
sf.withCtx(e.prefix.ctx), sf.withCtx(e.prefix.ctx),
); );
} }
if (prefixValue instanceof Promise) {
return prefixValue.then((prefixValue) => { // Special handling for f(...) - propagate varargs
if (!prefixValue) { if (
throw new LuaRuntimeError( e.args.length === 1 && e.args[0].type === "Variable" &&
`Attempting to call a nil value`, e.args[0].name === "..."
sf.withCtx(e.prefix.ctx), ) {
); const varargs = env.get("...");
const resolveVarargs = async () => {
const resolvedVarargs = await Promise.resolve(varargs);
if (resolvedVarargs instanceof LuaTable) {
const args = [];
for (let i = 1; i <= resolvedVarargs.length; i++) {
const val = await Promise.resolve(resolvedVarargs.get(i));
args.push(val);
}
return args;
} }
let selfArgs: LuaValue[] = []; return [];
// Handling a:b() syntax (b is kept in .name) };
if (e.name && !prefixValue.get) {
throw new LuaRuntimeError( if (prefixValue instanceof Promise) {
`Attempting to index a non-table: ${prefixValue}`, return prefixValue.then(async (resolvedPrefix) => {
sf.withCtx(e.prefix.ctx), const args = await resolveVarargs();
); return luaCall(resolvedPrefix, args, e.ctx, sf);
} else if (e.name) { });
// Two things need to happen: the actual function be called needs to be looked up in the table, and the table itself needs to be passed as the first argument
selfArgs = [prefixValue];
prefixValue = prefixValue.get(e.name);
}
if (!prefixValue.call) {
throw new LuaRuntimeError(
`Attempting to call ${prefixValue} as a function`,
sf.withCtx(e.prefix.ctx),
);
}
const args = evalPromiseValues(
e.args.map((arg) => evalExpression(arg, env, sf)),
);
if (args instanceof Promise) {
return args.then((args) =>
luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf)
);
} else {
return luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf);
}
});
} else {
let selfArgs: LuaValue[] = [];
// Handling a:b() syntax (b is kept in .name)
if (e.name && !prefixValue.get) {
throw new LuaRuntimeError(
`Attempting to index a non-table: ${prefixValue}`,
sf.withCtx(e.prefix.ctx),
);
} else if (e.name) {
// Two things need to happen: the actual function be called needs to be looked up in the table, and the table itself needs to be passed as the first argument
selfArgs = [prefixValue];
prefixValue = prefixValue.get(e.name);
}
if (!prefixValue.call) {
throw new LuaRuntimeError(
`Attempting to call ${prefixValue} as a function`,
sf.withCtx(e.prefix.ctx),
);
}
const args = evalPromiseValues(
e.args.map((arg) => evalExpression(arg, env, sf)),
);
if (args instanceof Promise) {
return args.then((args) =>
luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf)
);
} else { } else {
return luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf); return resolveVarargs().then((args) =>
luaCall(prefixValue, args, e.ctx, sf)
);
} }
} }
// Normal argument handling
let selfArgs: LuaValue[] = [];
if (e.name && !prefixValue.get) {
throw new LuaRuntimeError(
`Attempting to index a non-table: ${prefixValue}`,
sf.withCtx(e.prefix.ctx),
);
} else if (e.name) {
selfArgs = [prefixValue];
prefixValue = prefixValue.get(e.name);
}
if (!prefixValue.call) {
throw new LuaRuntimeError(
`Attempting to call ${prefixValue} as a function`,
sf.withCtx(e.prefix.ctx),
);
}
const args = evalPromiseValues(
e.args.map((arg) => evalExpression(arg, env, sf)),
);
if (args instanceof Promise) {
return args.then((args) =>
luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf)
);
} else {
return luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf);
}
} }
default: default:
throw new Error(`Unknown prefix expression type ${e.type}`); throw new Error(`Unknown prefix expression type ${e.type}`);

View File

@ -438,6 +438,68 @@ local function test_advanced_closures()
assert(c2.get() == 5, "Second counter should be independent") assert(c2.get() == 5, "Second counter should be independent")
end end
-- Test varargs handling
local function test_varargs()
-- Basic varargs sum function
local function sum(...)
local total = 0
for i, v in ipairs({ ... }) do
total = total + v
end
return total
end
assert(sum() == 0, "Sum should handle no arguments")
assert(sum(42) == 42, "Sum should handle single argument")
assert(sum(1, 2) == 3, "Sum should handle two arguments")
-- Test varargs propagation
local function pass_varargs(...)
return sum(...)
end
assert(pass_varargs() == 0, "Should propagate empty varargs")
assert(pass_varargs(1, 2, 3) == 6, "Should propagate varargs")
end
test_varargs()
print("All varargs tests passed!")
-- Test closure behavior
local function test_closures()
-- Counter that can count by custom steps
local function make_counter_with_step()
local count = 0
return {
increment = function(step)
count = count + (step or 1)
return count
end,
decrement = function(step)
count = count - (step or 1)
return count
end,
get = function()
return count
end
}
end
local counter = make_counter_with_step()
assert(counter.increment(5) == 5, "Counter should increment by 5")
assert(counter.decrement(2) == 3, "Counter should decrement by 2")
assert(counter.get() == 3, "Counter should maintain state")
assert(counter.increment() == 4, "Counter should default to 1")
-- Test multiple independent counters
local c1 = make_counter_with_step()
local c2 = make_counter_with_step()
c1.increment(10)
c2.increment(5)
assert(c1.get() == 10, "First counter should be independent")
assert(c2.get() == 5, "Second counter should be independent")
end
-- Test closures with shared upvalues -- Test closures with shared upvalues
local function test_shared_closures() local function test_shared_closures()
local function make_shared_counter() local function make_shared_counter()
@ -463,78 +525,9 @@ local function test_shared_closures()
assert(get() == 1, "Get should return current value") assert(get() == 1, "Get should return current value")
end end
-- Test varargs handling -- Run all tests
local function test_varargs()
-- Basic varargs sum function
local function sum(...)
local args = { ... }
local total = 0
for _, v in ipairs(args) do
total = total + v
end
return total
end
assert(sum(1, 2, 3, 4, 5) == 15, "Sum should handle multiple arguments")
assert(sum() == 0, "Sum should handle no arguments")
assert(sum(42) == 42, "Sum should handle single argument")
-- Test varargs propagation
local function pass_varargs(...)
return sum(...)
end
assert(pass_varargs(1, 2, 3) == 6, "Should propagate varargs")
assert(pass_varargs() == 0, "Should propagate empty varargs")
-- Test mixing regular args with varargs
local function first_plus_sum(first, ...)
local args = { ... }
local total = first or 0
for _, v in ipairs(args) do
total = total + v
end
return total
end
assert(first_plus_sum(10, 1, 2, 3) == 16, "Should handle mixed arguments")
assert(first_plus_sum(5) == 5, "Should handle only first argument")
end
-- Test closure edge cases
local function test_closure_edge_cases()
-- Test closure over loop variables
local closures = {}
for i = 1, 3 do
closures[i] = function() return i end
end
assert(closures[1]() == 1, "Should capture loop variable")
assert(closures[2]() == 2, "Should capture loop variable")
assert(closures[3]() == 3, "Should capture loop variable")
-- Test nested closure scopes
local function make_nested_counter(start)
local count = start
return function()
local function increment()
count = count + 1
return count
end
return increment()
end
end
local counter1 = make_nested_counter(5)
local counter2 = make_nested_counter(10)
assert(counter1() == 6, "First nested counter")
assert(counter1() == 7, "First nested counter increment")
assert(counter2() == 11, "Second nested counter independent")
end
-- Run the new tests
test_advanced_closures()
test_shared_closures()
test_varargs() test_varargs()
test_closure_edge_cases() test_closures()
print("All closure and varargs tests passed!") test_shared_closures()
test_advanced_closures()
print("All tests passed!")

View File

@ -1,6 +1,6 @@
import type { ASTCtx, LuaFunctionBody } from "./ast.ts"; import type { ASTCtx, LuaFunctionBody } from "./ast.ts";
import { evalStatement } from "$common/space_lua/eval.ts"; import { evalStatement } from "$common/space_lua/eval.ts";
import { asyncQuickSort } from "$common/space_lua/util.ts"; import { asyncQuickSort, evalPromiseValues } from "$common/space_lua/util.ts";
export type LuaType = export type LuaType =
| "nil" | "nil"
@ -141,26 +141,95 @@ export class LuaFunction implements ILuaFunction {
private capturedEnv: LuaEnv; private capturedEnv: LuaEnv;
constructor(readonly body: LuaFunctionBody, closure: LuaEnv) { constructor(readonly body: LuaFunctionBody, closure: LuaEnv) {
// Don't create a new environment, just store the reference to the closure environment
this.capturedEnv = closure; this.capturedEnv = closure;
} }
call(sf: LuaStackFrame, ...args: LuaValue[]): Promise<LuaValue> | LuaValue { call(sf: LuaStackFrame, ...args: LuaValue[]): Promise<LuaValue> {
// Create a new environment that chains to the captured environment // Create a new environment that chains to the captured environment
const env = new LuaEnv(this.capturedEnv); const env = new LuaEnv(this.capturedEnv);
if (!sf) { if (!sf) {
console.trace(sf); console.trace(sf);
} }
env.setLocal("_CTX", sf.threadLocal); env.setLocal("_CTX", sf.threadLocal);
// Assign the passed arguments to the parameters // Assign the passed arguments to the parameters
for (let i = 0; i < this.body.parameters.length; i++) { for (let i = 0; i < this.body.parameters.length; i++) {
const paramName = this.body.parameters[i];
if (paramName === "...") {
// Handle varargs by creating a table with all remaining arguments
const varargs = new LuaTable();
// Include all remaining arguments (might be none)
for (let j = i; j < args.length; j++) {
varargs.set(j - i + 1, args[j], sf);
}
env.setLocal("...", varargs);
break;
}
let arg = args[i]; let arg = args[i];
if (arg === undefined) { if (arg === undefined) {
arg = null; arg = null;
} }
env.setLocal(this.body.parameters[i], arg); env.setLocal(this.body.parameters[i], arg);
} }
return evalStatement(this.body.block, env, sf).catch((e: any) => {
// If the function has varargs parameter but it wasn't set above, set an empty varargs table
if (this.body.parameters.includes("...") && !env.has("...")) {
env.setLocal("...", new LuaTable());
}
const resolvedArgs = evalPromiseValues(args);
if (resolvedArgs instanceof Promise) {
return resolvedArgs.then((args) => this.callWithArgs(args, env, sf));
}
return this.callWithArgs(resolvedArgs, env, sf);
}
toString(): string {
return `<lua function(${this.body.parameters.join(", ")})>`;
}
private callWithArgs(
args: LuaValue[],
env: LuaEnv,
sf: LuaStackFrame,
): Promise<LuaValue> {
// Set up parameters and varargs
for (let i = 0; i < this.body.parameters.length; i++) {
const paramName = this.body.parameters[i];
if (paramName === "...") {
const varargs = new LuaTable();
for (let j = i; j < args.length; j++) {
if (args[j] instanceof Promise) {
return Promise.all(args.slice(i)).then((resolvedArgs) => {
const varargs = new LuaTable();
resolvedArgs.forEach((val, idx) => varargs.set(idx + 1, val, sf));
env.setLocal("...", varargs);
return this.evalBody(env, sf);
});
}
varargs.set(j - i + 1, args[j], sf);
}
env.setLocal("...", varargs);
break;
}
env.setLocal(paramName, args[i] ?? null);
}
// Ensure empty varargs table exists if needed
if (this.body.parameters.includes("...") && !env.has("...")) {
env.setLocal("...", new LuaTable());
}
return this.evalBody(env, sf);
}
private async evalBody(
env: LuaEnv,
sf: LuaStackFrame,
): Promise<LuaValue> {
try {
await evalStatement(this.body.block, env, sf);
} catch (e: any) {
if (e instanceof LuaReturn) { if (e instanceof LuaReturn) {
if (e.values.length === 0) { if (e.values.length === 0) {
return; return;
@ -172,11 +241,7 @@ export class LuaFunction implements ILuaFunction {
} else { } else {
throw e; throw e;
} }
}); }
}
toString(): string {
return `<lua function(${this.body.parameters.join(", ")})>`;
} }
} }