From ec6db8c766139b49c280fb3a31eda4add39a266d Mon Sep 17 00:00:00 2001 From: Zef Hemel Date: Wed, 8 Jan 2025 11:04:33 +0100 Subject: [PATCH] Lua vararg testing --- common/space_lua/eval.ts | 197 ++++++++++++++++++----------- common/space_lua/language_test.lua | 141 ++++++++++----------- common/space_lua/runtime.ts | 83 ++++++++++-- 3 files changed, 262 insertions(+), 159 deletions(-) diff --git a/common/space_lua/eval.ts b/common/space_lua/eval.ts index ca3d703b..13f63b27 100644 --- a/common/space_lua/eval.ts +++ b/common/space_lua/eval.ts @@ -114,6 +114,31 @@ export function evalExpression( return evalPrefixExpression(e, env, sf); case "TableConstructor": { const table = new LuaTable(); + if ( + e.fields.length === 1 && + e.fields[0].type === "ExpressionField" && + e.fields[0].value.type === "Variable" && + e.fields[0].value.name === "..." + ) { + const varargs = env.get("..."); + if (varargs instanceof Promise) { + return varargs.then((resolvedVarargs) => { + if (resolvedVarargs instanceof LuaTable) { + const newTable = new LuaTable(); + for (let i = 1; i <= resolvedVarargs.length; i++) { + newTable.set(i, resolvedVarargs.get(i), sf); + } + return newTable; + } + return table; + }); + } else if (varargs instanceof LuaTable) { + for (let i = 1; i <= varargs.length; i++) { + table.set(i, varargs.get(i), sf); + } + } + return table; + } const promises: Promise[] = []; for (const field of e.fields) { switch (field.type) { @@ -162,21 +187,49 @@ export function evalExpression( case "ExpressionField": { const value = evalExpression(field.value, env, sf); if (value instanceof Promise) { - promises.push(value.then((value) => { - // +1 because Lua tables are 1-indexed - table.set( - table.length + 1, - singleResult(value), - sf, - ); + promises.push(value.then(async (value) => { + if ( + field.value.type === "Variable" && + field.value.name === "..." + ) { + // Special handling for {...} + const varargs = await Promise.resolve(env.get("...")); + if (varargs instanceof LuaTable) { + // Copy all values from varargs table + for (let i = 1; i <= varargs.length; i++) { + const val = await Promise.resolve(varargs.get(i)); + table.set(i, val, sf); + } + } + } else { + // Normal case + table.set(table.length + 1, singleResult(value), sf); + } })); } else { - // +1 because Lua tables are 1-indexed - table.set( - table.length + 1, - singleResult(value), - sf, - ); + if ( + field.value.type === "Variable" && field.value.name === "..." + ) { + // Special handling for {...} + const varargs = env.get("..."); + if (varargs instanceof LuaTable) { + for (let i = 1; i <= varargs.length; i++) { + const val = varargs.get(i); + if (val instanceof Promise) { + promises.push( + Promise.resolve(val).then((val) => { + table.set(i, val, sf); + }), + ); + } else { + table.set(i, val, sf); + } + } + } + } else { + // Normal case + table.set(table.length + 1, singleResult(value), sf); + } } break; } @@ -258,73 +311,65 @@ function evalPrefixExpression( sf.withCtx(e.prefix.ctx), ); } - if (prefixValue instanceof Promise) { - return prefixValue.then((prefixValue) => { - if (!prefixValue) { - throw new LuaRuntimeError( - `Attempting to call a nil value`, - sf.withCtx(e.prefix.ctx), - ); + + // Special handling for f(...) - propagate varargs + if ( + e.args.length === 1 && e.args[0].type === "Variable" && + e.args[0].name === "..." + ) { + const varargs = env.get("..."); + const resolveVarargs = async () => { + const resolvedVarargs = await Promise.resolve(varargs); + if (resolvedVarargs instanceof LuaTable) { + const args = []; + for (let i = 1; i <= resolvedVarargs.length; i++) { + const val = await Promise.resolve(resolvedVarargs.get(i)); + args.push(val); + } + return args; } - let selfArgs: LuaValue[] = []; - // Handling a:b() syntax (b is kept in .name) - if (e.name && !prefixValue.get) { - throw new LuaRuntimeError( - `Attempting to index a non-table: ${prefixValue}`, - sf.withCtx(e.prefix.ctx), - ); - } else if (e.name) { - // Two things need to happen: the actual function be called needs to be looked up in the table, and the table itself needs to be passed as the first argument - selfArgs = [prefixValue]; - prefixValue = prefixValue.get(e.name); - } - if (!prefixValue.call) { - throw new LuaRuntimeError( - `Attempting to call ${prefixValue} as a function`, - sf.withCtx(e.prefix.ctx), - ); - } - const args = evalPromiseValues( - e.args.map((arg) => evalExpression(arg, env, sf)), - ); - if (args instanceof Promise) { - return args.then((args) => - luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf) - ); - } else { - return luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf); - } - }); - } else { - let selfArgs: LuaValue[] = []; - // Handling a:b() syntax (b is kept in .name) - if (e.name && !prefixValue.get) { - throw new LuaRuntimeError( - `Attempting to index a non-table: ${prefixValue}`, - sf.withCtx(e.prefix.ctx), - ); - } else if (e.name) { - // Two things need to happen: the actual function be called needs to be looked up in the table, and the table itself needs to be passed as the first argument - selfArgs = [prefixValue]; - prefixValue = prefixValue.get(e.name); - } - if (!prefixValue.call) { - throw new LuaRuntimeError( - `Attempting to call ${prefixValue} as a function`, - sf.withCtx(e.prefix.ctx), - ); - } - const args = evalPromiseValues( - e.args.map((arg) => evalExpression(arg, env, sf)), - ); - if (args instanceof Promise) { - return args.then((args) => - luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf) - ); + return []; + }; + + if (prefixValue instanceof Promise) { + return prefixValue.then(async (resolvedPrefix) => { + const args = await resolveVarargs(); + return luaCall(resolvedPrefix, args, e.ctx, sf); + }); } else { - return luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf); + return resolveVarargs().then((args) => + luaCall(prefixValue, args, e.ctx, sf) + ); } } + + // Normal argument handling + let selfArgs: LuaValue[] = []; + if (e.name && !prefixValue.get) { + throw new LuaRuntimeError( + `Attempting to index a non-table: ${prefixValue}`, + sf.withCtx(e.prefix.ctx), + ); + } else if (e.name) { + selfArgs = [prefixValue]; + prefixValue = prefixValue.get(e.name); + } + if (!prefixValue.call) { + throw new LuaRuntimeError( + `Attempting to call ${prefixValue} as a function`, + sf.withCtx(e.prefix.ctx), + ); + } + const args = evalPromiseValues( + e.args.map((arg) => evalExpression(arg, env, sf)), + ); + if (args instanceof Promise) { + return args.then((args) => + luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf) + ); + } else { + return luaCall(prefixValue, [...selfArgs, ...args], e.ctx, sf); + } } default: throw new Error(`Unknown prefix expression type ${e.type}`); diff --git a/common/space_lua/language_test.lua b/common/space_lua/language_test.lua index 836ff012..c4c6bf1f 100644 --- a/common/space_lua/language_test.lua +++ b/common/space_lua/language_test.lua @@ -438,6 +438,68 @@ local function test_advanced_closures() assert(c2.get() == 5, "Second counter should be independent") end +-- Test varargs handling +local function test_varargs() + -- Basic varargs sum function + local function sum(...) + local total = 0 + for i, v in ipairs({ ... }) do + total = total + v + end + return total + end + + assert(sum() == 0, "Sum should handle no arguments") + assert(sum(42) == 42, "Sum should handle single argument") + assert(sum(1, 2) == 3, "Sum should handle two arguments") + + -- Test varargs propagation + local function pass_varargs(...) + return sum(...) + end + + assert(pass_varargs() == 0, "Should propagate empty varargs") + assert(pass_varargs(1, 2, 3) == 6, "Should propagate varargs") +end + +test_varargs() +print("All varargs tests passed!") + +-- Test closure behavior +local function test_closures() + -- Counter that can count by custom steps + local function make_counter_with_step() + local count = 0 + return { + increment = function(step) + count = count + (step or 1) + return count + end, + decrement = function(step) + count = count - (step or 1) + return count + end, + get = function() + return count + end + } + end + + local counter = make_counter_with_step() + assert(counter.increment(5) == 5, "Counter should increment by 5") + assert(counter.decrement(2) == 3, "Counter should decrement by 2") + assert(counter.get() == 3, "Counter should maintain state") + assert(counter.increment() == 4, "Counter should default to 1") + + -- Test multiple independent counters + local c1 = make_counter_with_step() + local c2 = make_counter_with_step() + c1.increment(10) + c2.increment(5) + assert(c1.get() == 10, "First counter should be independent") + assert(c2.get() == 5, "Second counter should be independent") +end + -- Test closures with shared upvalues local function test_shared_closures() local function make_shared_counter() @@ -463,78 +525,9 @@ local function test_shared_closures() assert(get() == 1, "Get should return current value") end --- Test varargs handling -local function test_varargs() - -- Basic varargs sum function - local function sum(...) - local args = { ... } - local total = 0 - for _, v in ipairs(args) do - total = total + v - end - return total - end - - assert(sum(1, 2, 3, 4, 5) == 15, "Sum should handle multiple arguments") - assert(sum() == 0, "Sum should handle no arguments") - assert(sum(42) == 42, "Sum should handle single argument") - - -- Test varargs propagation - local function pass_varargs(...) - return sum(...) - end - - assert(pass_varargs(1, 2, 3) == 6, "Should propagate varargs") - assert(pass_varargs() == 0, "Should propagate empty varargs") - - -- Test mixing regular args with varargs - local function first_plus_sum(first, ...) - local args = { ... } - local total = first or 0 - for _, v in ipairs(args) do - total = total + v - end - return total - end - - assert(first_plus_sum(10, 1, 2, 3) == 16, "Should handle mixed arguments") - assert(first_plus_sum(5) == 5, "Should handle only first argument") -end - --- Test closure edge cases -local function test_closure_edge_cases() - -- Test closure over loop variables - local closures = {} - for i = 1, 3 do - closures[i] = function() return i end - end - - assert(closures[1]() == 1, "Should capture loop variable") - assert(closures[2]() == 2, "Should capture loop variable") - assert(closures[3]() == 3, "Should capture loop variable") - - -- Test nested closure scopes - local function make_nested_counter(start) - local count = start - return function() - local function increment() - count = count + 1 - return count - end - return increment() - end - end - - local counter1 = make_nested_counter(5) - local counter2 = make_nested_counter(10) - assert(counter1() == 6, "First nested counter") - assert(counter1() == 7, "First nested counter increment") - assert(counter2() == 11, "Second nested counter independent") -end - --- Run the new tests -test_advanced_closures() -test_shared_closures() +-- Run all tests test_varargs() -test_closure_edge_cases() -print("All closure and varargs tests passed!") +test_closures() +test_shared_closures() +test_advanced_closures() +print("All tests passed!") diff --git a/common/space_lua/runtime.ts b/common/space_lua/runtime.ts index 67113587..fabc0b92 100644 --- a/common/space_lua/runtime.ts +++ b/common/space_lua/runtime.ts @@ -1,6 +1,6 @@ import type { ASTCtx, LuaFunctionBody } from "./ast.ts"; import { evalStatement } from "$common/space_lua/eval.ts"; -import { asyncQuickSort } from "$common/space_lua/util.ts"; +import { asyncQuickSort, evalPromiseValues } from "$common/space_lua/util.ts"; export type LuaType = | "nil" @@ -141,26 +141,95 @@ export class LuaFunction implements ILuaFunction { private capturedEnv: LuaEnv; constructor(readonly body: LuaFunctionBody, closure: LuaEnv) { - // Don't create a new environment, just store the reference to the closure environment this.capturedEnv = closure; } - call(sf: LuaStackFrame, ...args: LuaValue[]): Promise | LuaValue { + call(sf: LuaStackFrame, ...args: LuaValue[]): Promise { // Create a new environment that chains to the captured environment const env = new LuaEnv(this.capturedEnv); if (!sf) { console.trace(sf); } env.setLocal("_CTX", sf.threadLocal); + // Assign the passed arguments to the parameters for (let i = 0; i < this.body.parameters.length; i++) { + const paramName = this.body.parameters[i]; + if (paramName === "...") { + // Handle varargs by creating a table with all remaining arguments + const varargs = new LuaTable(); + // Include all remaining arguments (might be none) + for (let j = i; j < args.length; j++) { + varargs.set(j - i + 1, args[j], sf); + } + env.setLocal("...", varargs); + break; + } let arg = args[i]; if (arg === undefined) { arg = null; } env.setLocal(this.body.parameters[i], arg); } - return evalStatement(this.body.block, env, sf).catch((e: any) => { + + // If the function has varargs parameter but it wasn't set above, set an empty varargs table + if (this.body.parameters.includes("...") && !env.has("...")) { + env.setLocal("...", new LuaTable()); + } + + const resolvedArgs = evalPromiseValues(args); + if (resolvedArgs instanceof Promise) { + return resolvedArgs.then((args) => this.callWithArgs(args, env, sf)); + } + return this.callWithArgs(resolvedArgs, env, sf); + } + + toString(): string { + return ``; + } + + private callWithArgs( + args: LuaValue[], + env: LuaEnv, + sf: LuaStackFrame, + ): Promise { + // Set up parameters and varargs + for (let i = 0; i < this.body.parameters.length; i++) { + const paramName = this.body.parameters[i]; + if (paramName === "...") { + const varargs = new LuaTable(); + for (let j = i; j < args.length; j++) { + if (args[j] instanceof Promise) { + return Promise.all(args.slice(i)).then((resolvedArgs) => { + const varargs = new LuaTable(); + resolvedArgs.forEach((val, idx) => varargs.set(idx + 1, val, sf)); + env.setLocal("...", varargs); + return this.evalBody(env, sf); + }); + } + varargs.set(j - i + 1, args[j], sf); + } + env.setLocal("...", varargs); + break; + } + env.setLocal(paramName, args[i] ?? null); + } + + // Ensure empty varargs table exists if needed + if (this.body.parameters.includes("...") && !env.has("...")) { + env.setLocal("...", new LuaTable()); + } + + return this.evalBody(env, sf); + } + + private async evalBody( + env: LuaEnv, + sf: LuaStackFrame, + ): Promise { + try { + await evalStatement(this.body.block, env, sf); + } catch (e: any) { if (e instanceof LuaReturn) { if (e.values.length === 0) { return; @@ -172,11 +241,7 @@ export class LuaFunction implements ILuaFunction { } else { throw e; } - }); - } - - toString(): string { - return ``; + } } }