Skip to content

Commit

Permalink
refact(rivetc.codegen): make sure to always cast numeric values ​​wit…
Browse files Browse the repository at this point in the history
…h promoted types
  • Loading branch information
StunxFS committed Oct 29, 2023
1 parent 19aedc1 commit 74173df
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 28 deletions.
1 change: 1 addition & 0 deletions lib/c/src/libc/string.ri
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ extern (C) {

public func strerror(code: int32) -> ?[&]mut uint8;
public func strlen(cs: ?[&]uint8) -> usize;
public func strstr(needle: ?[&]uint8, haystack: ?[&]uint8) -> ?[&]uint8;
}
48 changes: 24 additions & 24 deletions lib/rivet/src/checker/types.ri
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ extend Checker {
expected_is_ptr := expected.is_pointer();
got_is_ptr := got.is_pointer();
if expected_is_ptr and got_is_ptr {
return self.check_pointer(expected, got);
return check_pointer(expected, got);
} else if (expected_is_ptr and !got_is_ptr) or (got_is_ptr and !expected_is_ptr) {
return false;
}
Expand Down Expand Up @@ -193,29 +193,6 @@ extend Checker {
return expected == got;
}

func check_pointer(self, expected: ast.Type, got: ast.Type) -> bool {
_ = self;
if expected is .Anyptr as anyptr_ {
if got is .Pointer as ptr {
if anyptr_.is_mut and !ptr.is_mut {
return false;
}
// anyptr == *T, is valid
return true;
}
return got is .Anyptr;
} else if expected is .Pointer as ptr and got is .Pointer as ptr2 {
if ptr.is_mut and !ptr2.is_mut {
return false;
}
if ptr.is_indexable and !ptr2.is_indexable {
return false;
}
return ptr.inner == ptr2.inner;
}
return false;
}

func promote(self, left_type: ast.Type, right_type: ast.Type) -> ast.Type {
if left_type == right_type {
return left_type;
Expand Down Expand Up @@ -275,3 +252,26 @@ extend Checker {
};
}
}

func check_pointer(expected: ast.Type, got: ast.Type) -> bool {
_ = self;
if expected is .Anyptr as anyptr_ {
if got is .Pointer as ptr {
if anyptr_.is_mut and !ptr.is_mut {
return false;
}
// anyptr == &T, is valid
return true;
}
return got is .Anyptr;
} else if expected is .Pointer as ptr and got is .Pointer as ptr2 {
if ptr.is_mut and !ptr2.is_mut {
return false;
}
if ptr.is_indexable and !ptr2.is_indexable {
return false;
}
return ptr.inner == ptr2.inner;
}
return false;
}
6 changes: 3 additions & 3 deletions rivetc/src/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,10 +647,10 @@ def gen_expr_with_cast(self, expected_typ_, expr, custom_tmp = None):
elif isinstance(res_expr,
ir.FloatLit) and self.comp.is_float(expected_typ_):
res_expr.typ = expected_typ
elif self.comp.is_comptime_number(
elif self.comp.is_number(
expr.typ
) and self.comp.is_number(expected_typ_):
res_expr.typ = expected_typ_
) and self.comp.is_number(expected_typ_) and expr.typ != expected_typ_:
res_expr = ir.Inst(ir.InstKind.Cast, [res_expr, expected_typ])

if isinstance(
res_expr.typ, ir.Pointer
Expand Down
5 changes: 5 additions & 0 deletions rivetc/src/codegen/c.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,12 +381,17 @@ def gen_expr(self, expr):
# overflows `int64`, hence the consecutive subtraction by `1`.
self.write("(-9223372036854775807L - 1)")
else:
self.write("((")
self.write_type(expr.typ)
self.write(")")
self.write("(")
self.write(expr.lit)
if expr.typ.name.endswith("64"
) or expr.typ.name.endswith("size"):
if expr.typ.name.startswith("u"):
self.write("U")
self.write("L")
self.write("))")
elif isinstance(expr, ir.FloatLit):
self.write(expr.lit)
if str(expr.typ) == "float32":
Expand Down
2 changes: 1 addition & 1 deletion rivetc/src/codegen/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def __repr__(self):
elif self == InstKind.LoadPtr: return "load_ptr"
elif self == InstKind.GetElementPtr: return "get_element_ptr"
elif self == InstKind.GetRef: return "get_ref"
elif self == InstKind.Cast: return "as"
elif self == InstKind.Cast: return "cast"
elif self == InstKind.Cmp: return "cmp"
elif self == InstKind.DbgStmtLine: return "dbg_stmt_line"
elif self == InstKind.Breakpoint: return "breakpoint"
Expand Down

0 comments on commit 74173df

Please sign in to comment.