From 076c296dfd78c448c74b7181fcdddbf4abdb6934 Mon Sep 17 00:00:00 2001 From: Zef Hemel Date: Wed, 8 Jan 2025 11:12:26 +0100 Subject: [PATCH] Lua iterator fixes --- common/space_lua/language_test.lua | 114 +++++++++++++++++++++++++++++ common/space_lua/stdlib/table.ts | 10 ++- 2 files changed, 121 insertions(+), 3 deletions(-) diff --git a/common/space_lua/language_test.lua b/common/space_lua/language_test.lua index c4c6bf1f..21554ee2 100644 --- a/common/space_lua/language_test.lua +++ b/common/space_lua/language_test.lua @@ -531,3 +531,117 @@ test_closures() test_shared_closures() test_advanced_closures() print("All tests passed!") + +-- Test custom iterators +local function test_custom_iterators() + -- Basic iterator that counts down from n to 1 + local function countdown(n) + local count = n + return function() + if count > 0 then + local current = count + count = count - 1 + return current + end + end + end + + -- Test basic iterator usage + local sum = 0 + for num in countdown(3) do + sum = sum + num + end + assert(sum == 6, "Countdown iterator should sum to 6 (3+2+1)") + + -- Iterator that returns even numbers from an array + local function even_values(arr) + local index = 0 + return function() + repeat + index = index + 1 + if index > #arr then return nil end + if arr[index] % 2 == 0 then + return index, arr[index] + end + until false + end + end + + -- Test array iterator + local arr = { 1, 2, 3, 4, 6, 7, 8 } + local count = 0 + local sum = 0 + for i, v in even_values(arr) do + count = count + 1 + sum = sum + v + end + assert(count == 4, "Should find 4 even numbers") + assert(sum == 20, "Sum of even numbers should be 20 (2+4+6+8)") + + -- Range iterator with step + local function range(from, to, step) + step = step or 1 + local current = from + return function() + if current > to then + return nil + end + local value = current + current = current + step + return value + end + end + + -- Test range iterator with different steps + local function collect_range(from, to, step) + local values = {} + for v in range(from, to, step) do + table.insert(values, v) + end + return values + end + + local values1 = collect_range(1, 5, 2) + assert(#values1 == 3, "Range with step 2 should return 3 values") + assert(values1[1] == 1 and values1[2] == 3 and values1[3] == 5, "Range values with step 2 should be correct") + + local values2 = collect_range(10, 15) + assert(#values2 == 6, "Range with default step should return 6 values") + assert(values2[1] == 10 and values2[6] == 15, "Range values with default step should be correct") + + local values3 = collect_range(1, 10, 3) + assert(#values3 == 4, "Range with step 3 should return 4 values") + assert(values3[1] == 1 and values3[2] == 4 and values3[3] == 7 and values3[4] == 10, + "Range values with step 3 should be correct") + + -- Test nested iterators + local function grid(rows, cols) + local row = 0 + return function() + row = row + 1 + if row <= rows then + local col = 0 + return function() + col = col + 1 + if col <= cols then + return row, col + end + end + end + end + end + + local points = {} + for row_iter in grid(2, 3) do + for r, c in row_iter do + table.insert(points, { r, c }) + end + end + + assert(#points == 6, "Grid should generate 6 points") + assert(points[1][1] == 1 and points[1][2] == 1, "First point should be (1,1)") + assert(points[6][1] == 2 and points[6][2] == 3, "Last point should be (2,3)") +end + +test_custom_iterators() +print("All iterator tests passed!") diff --git a/common/space_lua/stdlib/table.ts b/common/space_lua/stdlib/table.ts index 72c03c49..4b1e6002 100644 --- a/common/space_lua/stdlib/table.ts +++ b/common/space_lua/stdlib/table.ts @@ -20,10 +20,14 @@ export const tableApi = new LuaTable({ insert: new LuaBuiltinFunction( (_sf, tbl: LuaTable, posOrValue: number | any, value?: any) => { if (value === undefined) { - value = posOrValue; - posOrValue = tbl.length + 1; + let pos = 1; + while (tbl.get(pos) !== null) { + pos++; + } + tbl.set(pos, posOrValue); + } else { + tbl.insert(posOrValue, value); } - tbl.insert(posOrValue, value); }, ), remove: new LuaBuiltinFunction((_sf, tbl: LuaTable, pos?: number) => {