Skip to content

Commit e159959

Browse files
committed
More fixes
1 parent 9ca8c0b commit e159959

File tree

8 files changed

+304
-63
lines changed

8 files changed

+304
-63
lines changed

src/canonicalize/CIR.zig

Lines changed: 190 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,28 +1801,62 @@ pub fn canonicalizeExpr(self: *CIR, allocator: Allocator, node_idx: AST.Node.Idx
18011801

18021802
// Lambda expressions from |x| body syntax
18031803
.lambda => {
1804+
// Collect free variables from outer scope before creating new scope
1805+
// These are potential captures - we'll filter them later
1806+
var outer_scope_vars = std.ArrayList(Ident.Idx).init(allocator);
1807+
defer outer_scope_vars.deinit();
1808+
1809+
// Collect all identifiers available in the current scope
1810+
// These could become captures if referenced in the lambda body
1811+
for (self.scope_state.scopes.items) |scope| {
1812+
var iter = scope.idents.iterator();
1813+
while (iter.next()) |entry| {
1814+
try outer_scope_vars.append(entry.key_ptr.*);
1815+
}
1816+
}
1817+
18041818
// Push a new scope for the lambda (this is a function boundary)
18051819
try self.scope_state.pushScope(allocator, true); // true = function boundary
18061820
defer self.popScopeAndCheckUnused(allocator) catch {};
18071821

18081822
// Parser creates lambda nodes with body_then_args payload
18091823
// Format: [body, param1, param2, ...]
18101824
const nodes_idx = node.payload.body_then_args;
1825+
var param_idents = std.ArrayList(Ident.Idx).init(allocator);
1826+
defer param_idents.deinit();
1827+
18111828
if (!nodes_idx.isNil()) {
18121829
var iter = self.ast.*.node_slices.nodes(&nodes_idx);
18131830

18141831
// First node is the body - save it for later
18151832
const body_node = iter.next();
18161833

18171834
// Process parameters first and add them to scope
1835+
// Also track parameter identifiers so we can exclude them from captures
18181836
while (iter.next()) |param_node| {
1819-
_ = try self.canonicalizePatt(allocator, param_node);
1820-
// Pattern canonicalization registers names in scope
1837+
const patt_idx = try self.canonicalizePatt(allocator, param_node);
1838+
// Collect parameter identifiers
1839+
try self.collectPatternIdents(allocator, patt_idx, &param_idents);
18211840
}
18221841

18231842
// Now process the body with parameters in scope
18241843
if (body_node) |body| {
18251844
_ = try self.canonicalizeExpr(allocator, body, raw_src, idents);
1845+
1846+
// After canonicalizing the body, analyze captures
1847+
// Captures are free variables: referenced in body, not parameters, from outer scope
1848+
var captures = std.ArrayList(Ident.Idx).init(allocator);
1849+
defer captures.deinit();
1850+
1851+
try self.collectFreeVariables(allocator, body, &captures, &param_idents, &outer_scope_vars);
1852+
1853+
// Store captures with the lambda
1854+
// For now, just track that we analyzed captures
1855+
// In a full implementation, we'd store these in the CIR
1856+
if (captures.items.len > 0) {
1857+
// Lambda has captures - this will become a closure at runtime
1858+
// The interpreter will need to capture these values when creating the closure
1859+
}
18261860
}
18271861
}
18281862

@@ -3358,6 +3392,160 @@ pub const Scope = struct {
33583392
}
33593393
};
33603394

3395+
/// Helper function to collect all identifiers from a pattern
3396+
fn collectPatternIdents(self: *CIR, allocator: Allocator, patt_idx: Patt.Idx, idents: *std.ArrayList(Ident.Idx)) !void {
3397+
// Get the AST node for this pattern
3398+
const node_idx = @as(AST.Node.Idx, @enumFromInt(@intFromEnum(patt_idx)));
3399+
const node = self.getNode(node_idx);
3400+
3401+
switch (node.tag) {
3402+
.ident, .var_ident => {
3403+
// Simple identifier pattern - add it to the list
3404+
if (node.payload == .ident) {
3405+
try idents.append(node.payload.ident);
3406+
}
3407+
},
3408+
.underscore => {
3409+
// Underscore pattern - no identifier to collect
3410+
},
3411+
.list, .tuple => {
3412+
// Recursively collect from nested patterns
3413+
const nodes_idx = node.payload.nodes;
3414+
if (!nodes_idx.isNil()) {
3415+
var iter = self.ast.*.node_slices.nodes(&nodes_idx);
3416+
while (iter.next()) |child_node| {
3417+
const child_patt = asPattIdx(child_node);
3418+
try self.collectPatternIdents(allocator, child_patt, idents);
3419+
}
3420+
}
3421+
},
3422+
.record => {
3423+
// Collect identifiers from record field patterns
3424+
const nodes_idx = node.payload.nodes;
3425+
if (!nodes_idx.isNil()) {
3426+
var iter = self.ast.*.node_slices.nodes(&nodes_idx);
3427+
while (iter.next()) |field_node| {
3428+
const field = self.getNode(field_node);
3429+
if (field.tag == .binop_colon and field.payload == .binop) {
3430+
// Field pattern: fieldName : pattern
3431+
const binop = self.ast.*.node_slices.binOp(field.payload.binop);
3432+
// Collect from the pattern part (right side)
3433+
const rhs_patt = asPattIdx(binop.rhs);
3434+
try self.collectPatternIdents(allocator, rhs_patt, idents);
3435+
}
3436+
}
3437+
}
3438+
},
3439+
else => {
3440+
// Other pattern types - may need to handle more cases
3441+
},
3442+
}
3443+
}
3444+
3445+
/// Helper function to collect free variables referenced in an expression
3446+
fn collectFreeVariables(
3447+
self: *CIR,
3448+
allocator: Allocator,
3449+
expr_node: AST.Node.Idx,
3450+
captures: *std.ArrayList(Ident.Idx),
3451+
param_idents: *const std.ArrayList(Ident.Idx),
3452+
outer_scope_vars: *const std.ArrayList(Ident.Idx),
3453+
) !void {
3454+
const node = self.getNode(expr_node);
3455+
3456+
switch (node.tag) {
3457+
.expr_lookup => {
3458+
// Variable reference - check if it's a capture
3459+
if (node.payload == .ident) {
3460+
const ident = node.payload.ident;
3461+
3462+
// Check if this is a parameter (not a capture)
3463+
var is_param = false;
3464+
for (param_idents.items) |param| {
3465+
if (@intFromEnum(param) == @intFromEnum(ident)) {
3466+
is_param = true;
3467+
break;
3468+
}
3469+
}
3470+
3471+
if (!is_param) {
3472+
// Check if it's from outer scope (potential capture)
3473+
for (outer_scope_vars.items) |outer_var| {
3474+
if (@intFromEnum(outer_var) == @intFromEnum(ident)) {
3475+
// This is a capture - add it if not already present
3476+
var already_captured = false;
3477+
for (captures.items) |cap| {
3478+
if (@intFromEnum(cap) == @intFromEnum(ident)) {
3479+
already_captured = true;
3480+
break;
3481+
}
3482+
}
3483+
if (!already_captured) {
3484+
try captures.append(ident);
3485+
}
3486+
break;
3487+
}
3488+
}
3489+
}
3490+
}
3491+
},
3492+
.expr_bin_op, .binop_plus, .binop_minus, .binop_star, .binop_slash, .binop_double_equals, .binop_not_equals, .binop_gt, .binop_gte, .binop_lt, .binop_lte, .binop_and, .binop_or => {
3493+
// Binary operations - check both sides
3494+
if (node.payload == .binop) {
3495+
const binop = self.ast.*.node_slices.binOp(node.payload.binop);
3496+
try self.collectFreeVariables(allocator, binop.lhs, captures, param_idents, outer_scope_vars);
3497+
try self.collectFreeVariables(allocator, binop.rhs, captures, param_idents, outer_scope_vars);
3498+
}
3499+
},
3500+
.expr_call, .expr_apply => {
3501+
// Function calls - check function and arguments
3502+
const nodes_idx = node.payload.nodes;
3503+
if (!nodes_idx.isNil()) {
3504+
var iter = self.ast.*.node_slices.nodes(&nodes_idx);
3505+
while (iter.next()) |child| {
3506+
try self.collectFreeVariables(allocator, child, captures, param_idents, outer_scope_vars);
3507+
}
3508+
}
3509+
},
3510+
.expr_if => {
3511+
// If expressions - check condition, then, and else branches
3512+
const nodes_idx = node.payload.nodes;
3513+
if (!nodes_idx.isNil()) {
3514+
var iter = self.ast.*.node_slices.nodes(&nodes_idx);
3515+
while (iter.next()) |child| {
3516+
try self.collectFreeVariables(allocator, child, captures, param_idents, outer_scope_vars);
3517+
}
3518+
}
3519+
},
3520+
.expr_lambda, .lambda => {
3521+
// Nested lambda - don't traverse into it
3522+
// It has its own capture analysis
3523+
},
3524+
.expr_list_literal, .expr_tuple_literal, .expr_record_literal => {
3525+
// Collection literals - check all elements
3526+
const nodes_idx = node.payload.nodes;
3527+
if (!nodes_idx.isNil()) {
3528+
var iter = self.ast.*.node_slices.nodes(&nodes_idx);
3529+
while (iter.next()) |child| {
3530+
try self.collectFreeVariables(allocator, child, captures, param_idents, outer_scope_vars);
3531+
}
3532+
}
3533+
},
3534+
else => {
3535+
// For other expression types, recursively check child nodes if any
3536+
if (node.payload == .nodes) {
3537+
const nodes_idx = node.payload.nodes;
3538+
if (!nodes_idx.isNil()) {
3539+
var iter = self.ast.*.node_slices.nodes(&nodes_idx);
3540+
while (iter.next()) |child| {
3541+
try self.collectFreeVariables(allocator, child, captures, param_idents, outer_scope_vars);
3542+
}
3543+
}
3544+
}
3545+
},
3546+
}
3547+
}
3548+
33613549
test "CIR2 canonicalize mutable variable declaration" {
33623550
const testing = std.testing;
33633551
const allocator = testing.allocator;

src/canonicalize/mod.zig

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ pub const CIR = @import("CIR.zig");
88
pub const Can = CIR;
99
/// The Module Environment after canonicalization (used also for type checking and serialization)
1010
pub const ModuleEnv = @import("ModuleEnv.zig");
11+
/// The Scope type for managing identifier scopes
12+
pub const Scope = @import("Scope.zig");
1113

1214
test "compile tests" {
1315
std.testing.refAllDecls(@This());

src/check/test/let_polymorphism_integration_test.zig

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ const std = @import("std");
55
const base = @import("base");
66
const parse = @import("parse");
77
const types = @import("types");
8-
const canonicalize = @import("canonicalize");
8+
const canonicalize = @import("can");
99
const check = @import("../mod.zig");
1010

1111
const testing = std.testing;
@@ -27,7 +27,7 @@ fn parseAndCanonicalizeSource(env: *ModuleEnv, source: []const u8) !CIR {
2727
defer ast.deinit(test_allocator);
2828

2929
// Create CIR and canonicalize
30-
var cir = CIR.init(&env.byte_slices, &env.types, env.getIdents());
30+
const cir = CIR.init(&ast, &env.types);
3131
// Note: In real usage, canonicalization would happen here
3232
// For these tests, we're focusing on the type checking aspects
3333

@@ -43,7 +43,7 @@ test "let polymorphism - identity function" {
4343
}
4444

4545
// Create identity function: |x| x
46-
var cir = CIR.init(&env.byte_slices, &env.types, env.getIdents());
46+
const cir = CIR.init(&env.byte_slices, &env.types, env.getIdents());
4747
defer cir.deinit(test_allocator);
4848

4949
// Create type variables for the polymorphic identity function

src/compile/test/module_env_test.zig

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
const std = @import("std");
55
const base = @import("base");
66
const types = @import("types");
7-
const canonicalize = @import("canonicalize");
7+
const canonicalize = @import("can");
88
const parse = @import("parse");
99

1010
const testing = std.testing;
1111
const test_allocator = testing.allocator;
1212
const ModuleEnv = canonicalize.ModuleEnv;
1313
const CIR = canonicalize.CIR;
14+
const AST = parse.AST;
1415

1516
test "module env - create and destroy" {
1617
const env = try test_allocator.create(ModuleEnv);
@@ -28,7 +29,9 @@ test "module env - add expression with type" {
2829
defer env.deinit();
2930

3031
// Create a simple CIR with an integer literal
31-
var cir = CIR.init(&env.byte_slices, &env.types, env.getIdents());
32+
// Need to create an AST first
33+
var ast = AST{};
34+
var cir = CIR.init(&ast, &env.types);
3235
defer cir.deinit(test_allocator);
3336

3437
// Create an integer literal expression
@@ -99,7 +102,9 @@ test "module env - type variable association" {
99102
defer env.deinit();
100103

101104
// Create a CIR
102-
var cir = CIR.init(&env.byte_slices, &env.types, env.getIdents());
105+
// Need to create an AST first
106+
var ast = AST{};
107+
var cir = CIR.init(&ast, &env.types);
103108
defer cir.deinit(test_allocator);
104109

105110
// Create various expressions and verify they get type variables
@@ -125,7 +130,9 @@ test "module env - pattern management" {
125130
defer env.deinit();
126131

127132
// Create a CIR
128-
var cir = CIR.init(&env.byte_slices, &env.types, env.getIdents());
133+
// Need to create an AST first
134+
var ast = AST{};
135+
var cir = CIR.init(&ast, &env.types);
129136
defer cir.deinit(test_allocator);
130137

131138
// Create a pattern (identifier pattern)
@@ -147,11 +154,14 @@ test "module env - scope management" {
147154
defer env.deinit();
148155

149156
// Create a CIR with scope tracking
150-
var cir = CIR.init(&env.byte_slices, &env.types, env.getIdents());
157+
// Need to create an AST first
158+
var ast = AST{};
159+
var cir = CIR.init(&ast, &env.types);
151160
defer cir.deinit(test_allocator);
152161

153162
// Create a scope
154-
const Scope = @import("canonicalize/Scope.zig");
163+
const can = @import("can");
164+
const Scope = can.Scope;
155165
var scope = Scope.init(false);
156166
defer scope.deinit(test_allocator);
157167

@@ -178,7 +188,9 @@ test "module env - binary operation creation" {
178188
defer env.deinit();
179189

180190
// Create a CIR
181-
var cir = CIR.init(&env.byte_slices, &env.types, env.getIdents());
191+
// Need to create an AST first
192+
var ast = AST{};
193+
var cir = CIR.init(&ast, &env.types);
182194
defer cir.deinit(test_allocator);
183195

184196
// Create operands

0 commit comments

Comments
 (0)