diff --git a/rivetc/src/codegen/__init__.py b/rivetc/src/codegen/__init__.py index f4cf68dae..f651250a3 100644 --- a/rivetc/src/codegen/__init__.py +++ b/rivetc/src/codegen/__init__.py @@ -1877,17 +1877,25 @@ def gen_expr(self, expr, custom_tmp = None): var_t2 = var_t if isinstance( var_t, ir.Pointer ) or expr.var.is_mut else var_t.ptr() - val = ir.Inst( - ir.InstKind.Cast, [ - ir.Selector(ir.VOID_PTR_T, left, ir.Name("obj")), - var_t2 - ] - ) - if not ( - (isinstance(var_t2, ir.Pointer) and var_t2.is_managed) - or expr.var.is_mut - ): - val = ir.Inst(ir.InstKind.LoadPtr, [val], var_t2) + if left_sym.kind == TypeKind.Enum: + union_name = f"{cg_utils.mangle_symbol(left_sym)}5Union" + union_type = ir.Type(union_name) + obj_val = ir.Selector(union_type, left, ir.Name("obj")) + val = ir.Selector( + ir.Type(self.ir_type(expr.typ)), obj_val, ir.Name(f"v{expr.var.typ.symbol().id}") + ) + else: + val = ir.Inst( + ir.InstKind.Cast, [ + ir.Selector(ir.VOID_PTR_T, left, ir.Name("obj")), + var_t2 + ] + ) + if not ( + (isinstance(var_t2, ir.Pointer) and var_t2.is_managed) + or expr.var.is_mut + ): + val = ir.Inst(ir.InstKind.LoadPtr, [val], var_t2) unique_name = self.cur_func.unique_name(expr.var.name) expr.scope.update_ir_name(expr.var.name, unique_name) self.cur_func.inline_alloca(var_t, unique_name, val) @@ -2227,14 +2235,14 @@ def gen_expr(self, expr, custom_tmp = None): var_t = self.ir_type(b.var_typ) var_t2 = var_t.ptr( ) if not isinstance(var_t, ir.Pointer) else var_t - if expr.expr.typ.symbol().kind == TypeKind.Enum: - val = ir.Inst( - ir.InstKind.Cast, [ - ir.Selector( - ir.VOID_PTR_T, match_expr, - ir.Name("obj") - ), var_t2 - ] + e_expr_typ_sym = expr.expr.typ.symbol() + if e_expr_typ_sym.kind == TypeKind.Enum: + obj_f = ir.Selector( + e_expr_typ_sym.name + "5Union", match_expr, ir.Name("obj") + ) + val = ir.Selector( + self.ir_type(p.variant_info.typ), + obj_f, ir.Name(f"v{p.variant_info.typ.symbol().id}") ) else: val = ir.Inst( @@ -2245,13 +2253,13 @@ def gen_expr(self, expr, custom_tmp = None): ), var_t ] ) - if not ( - b.var_is_mut or ( - isinstance(var_t, ir.Pointer) - and var_t.is_managed - ) - ): - val = ir.Inst(ir.InstKind.LoadPtr, [val]) + if not ( + b.var_is_mut or ( + isinstance(var_t, ir.Pointer) + and var_t.is_managed + ) + ): + val = ir.Inst(ir.InstKind.LoadPtr, [val]) if b.var_is_mut and not isinstance( var_t, ir.Pointer ): @@ -2809,21 +2817,14 @@ def tagged_enum_value( if variant_info.has_typ and value and not isinstance( value, ast.EmptyExpr ): + variant_typ_sym = variant_info.typ.symbol() arg0 = self.gen_expr_with_cast(variant_info.typ, value) size, _ = self.comp.type_size(variant_info.typ) - if isinstance(arg0.typ, ir.Pointer): - value = arg0 - else: - value = ir.Inst( - ir.InstKind.Call, [ - ir.Name("_R4core7mem_dupF"), - ir.Inst(ir.InstKind.GetRef, [arg0]), - ir.IntLit(uint_t, str(size)) - ] - ) - else: - value = ir.NoneLit(ir.VOID_PTR_T) - self.cur_func.store(ir.Selector(uint_t, tmp, ir.Name("obj")), value) + obj_f = ir.Selector(ir.Type(f"{cg_utils.mangle_symbol(enum_sym)}5Union"), tmp, ir.Name("obj")) + self.cur_func.store( + ir.Selector(self.ir_type(variant_info.typ), obj_f, ir.Name(f"v{variant_typ_sym.id}")), + arg0 + ) return tmp def tagged_enum_variant_with_fields_value( @@ -2836,11 +2837,15 @@ def tagged_enum_variant_with_fields_value( cg_utils.mangle_symbol(enum_sym), enum_sym.id ) variant_info = enum_sym.info.get_variant(variant_name) + variant_typ_sym = variant_info.typ.symbol() self.cur_func.store( ir.Selector(ir.UINT_T, tmp, ir.Name("_idx_")), ir.IntLit(ir.UINT_T, variant_info.value) ) - self.cur_func.store(ir.Selector(ir.UINT_T, tmp, ir.Name("obj")), value) + obj_f = ir.Selector(ir.Type(f"{cg_utils.mangle_symbol(enum_sym)}5Union"), tmp, ir.Name("obj")) + self.cur_func.store( + ir.Selector(self.ir_type(variant_info.typ), obj_f, ir.Name(f"v{variant_typ_sym.id}")), value + ) return tmp def gen_return_trace_add(self, pos):