Skip to content

Commit 54b4ba1

Browse files
committed
support union generation
1 parent 4f69155 commit 54b4ba1

File tree

3 files changed

+53
-8
lines changed

3 files changed

+53
-8
lines changed

rivetc/src/codegen/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3024,12 +3024,20 @@ def gen_types(self):
30243024
# TODO: in the self-hosted compiler calculate the enum value here
30253025
# not in register nor resolver.
30263026
if ts.info.is_tagged:
3027+
mangled_name = cg_utils.mangle_symbol(ts)
3028+
fields = []
3029+
for v in ts.info.variants:
3030+
if v.has_typ:
3031+
typ_sym = v.typ.symbol()
3032+
fields.append(ir.Field(f"v{typ_sym.id}", self.ir_type(v.typ)))
3033+
union_name = mangled_name + "5Union"
3034+
self.out_rir.unions.append(ir.Union(union_name, fields))
30273035
self.out_rir.structs.append(
30283036
ir.Struct(
3029-
False, cg_utils.mangle_symbol(ts), [
3037+
False, mangled_name, [
30303038
ir.Field("_rc_", ir.UINT_T),
30313039
ir.Field("_idx_", ir.UINT_T),
3032-
ir.Field("obj", ir.VOID_PTR_T)
3040+
ir.Field("obj", ir.Type(union_name))
30333041
]
30343042
)
30353043
)

rivetc/src/codegen/c.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,15 @@ class CGen:
2929
def __init__(self, comp):
3030
self.comp = comp
3131
self.typedefs = utils.Builder()
32+
self.unions = utils.Builder()
3233
self.structs = utils.Builder()
3334
self.protos = utils.Builder()
3435
self.globals = utils.Builder()
3536
self.out = utils.Builder()
3637

3738
def gen(self, out_rir):
39+
self.comp.vlog("cgen: generating unions...")
40+
self.gen_unions(out_rir.unions)
3841
self.comp.vlog("cgen: generating structs...")
3942
self.gen_structs(out_rir.structs)
4043
self.comp.vlog("cgen: generating externs...")
@@ -51,6 +54,7 @@ def gen(self, out_rir):
5154
if self.comp.prefs.build_mode != prefs.BuildMode.Release:
5255
out.write(c_headers.RIVET_BREAKPOINT)
5356
out.write(str(self.typedefs).strip() + "\n\n")
57+
out.write(str(self.unions).strip() + "\n\n")
5458
out.write(str(self.structs).strip() + "\n\n")
5559
out.write(str(self.protos).strip() + "\n\n")
5660
out.write(str(self.globals).strip() + "\n\n")
@@ -96,6 +100,19 @@ def write(self, txt):
96100
def writeln(self, txt = ""):
97101
self.out.writeln(txt)
98102

103+
def gen_unions(self, unions):
104+
for u in unions:
105+
self.typedefs.writeln(f"typedef union {u.name} {u.name};")
106+
self.unions.writeln(f"union {u.name} {{")
107+
for i, f in enumerate(u.fields):
108+
f_name = c_escape(f.name)
109+
self.unions.write(" ")
110+
self.unions.write(self.gen_type(f.typ, f_name))
111+
if not isinstance(f.typ, (ir.Array, ir.Function)):
112+
self.unions.write(f" {f_name}")
113+
self.unions.writeln(";")
114+
self.unions.writeln("};\n")
115+
99116
def gen_structs(self, structs):
100117
for s in structs:
101118
self.typedefs.writeln(f"typedef struct {s.name} {s.name};")
@@ -107,10 +124,7 @@ def gen_structs(self, structs):
107124
self.structs.write(self.gen_type(f.typ, f_name))
108125
if not isinstance(f.typ, (ir.Array, ir.Function)):
109126
self.structs.write(f" {f_name}")
110-
if i < len(s.fields) - 1:
111-
self.structs.writeln(";")
112-
else:
113-
self.structs.writeln(";")
127+
self.structs.writeln(";")
114128
self.structs.writeln("};")
115129
self.structs.writeln()
116130

rivetc/src/codegen/ir.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __eq__(self, other):
128128
class RIRFile:
129129
def __init__(self, mod_name):
130130
self.mod_name = mod_name
131+
self.unions = []
131132
self.structs = []
132133
self.externs = []
133134
self.globals = []
@@ -146,6 +147,11 @@ def __repr__(self):
146147
"// and is subject to change without notice. Knock yourself out."
147148
)
148149
sb.writeln()
150+
for i, u in enumerate(self.unions):
151+
sb.writeln(str(u))
152+
if i < len(self.unions) - 1:
153+
sb.writeln()
154+
sb.writeln()
149155
for i, s in enumerate(self.structs):
150156
sb.writeln(str(s))
151157
if i < len(self.structs) - 1:
@@ -189,6 +195,23 @@ def __str__(self):
189195
sb.write("}")
190196
return str(sb)
191197

198+
class Union:
199+
def __init__(self, name, fields):
200+
self.name = name
201+
self.fields = fields
202+
203+
def __str__(self):
204+
sb = utils.Builder()
205+
sb.writeln(f'union {self.name} {{')
206+
for i, f in enumerate(self.fields):
207+
sb.write(f' {f.name}: {f.typ}')
208+
if i < len(self.fields) - 1:
209+
sb.writeln(",")
210+
else:
211+
sb.writeln()
212+
sb.write("}")
213+
return str(sb)
214+
192215
class Struct:
193216
def __init__(self, is_opaque, name, fields):
194217
self.is_opaque = is_opaque
@@ -198,9 +221,9 @@ def __init__(self, is_opaque, name, fields):
198221
def __str__(self):
199222
sb = utils.Builder()
200223
if self.is_opaque:
201-
sb.write(f'type {self.name} opaque')
224+
sb.write(f'struct {self.name} opaque')
202225
else:
203-
sb.writeln(f'type {self.name} {{')
226+
sb.writeln(f'struct {self.name} {{')
204227
for i, f in enumerate(self.fields):
205228
sb.write(f' {f.name}: {f.typ}')
206229
if i < len(self.fields) - 1:

0 commit comments

Comments
 (0)