diff --git a/lib/c/src/libc/string.ri b/lib/c/src/libc/string.ri index 26e9bc73d..65cfd06be 100644 --- a/lib/c/src/libc/string.ri +++ b/lib/c/src/libc/string.ri @@ -11,4 +11,5 @@ extern (C) { public func strerror(code: int32) -> ?[&]mut uint8; public func strlen(cs: ?[&]uint8) -> usize; + public func strstr(needle: ?[&]uint8, haystack: ?[&]uint8) -> ?[&]uint8; } diff --git a/lib/rivet/src/checker/types.ri b/lib/rivet/src/checker/types.ri index 07bdfe6eb..071dfdda0 100644 --- a/lib/rivet/src/checker/types.ri +++ b/lib/rivet/src/checker/types.ri @@ -109,7 +109,7 @@ extend Checker { expected_is_ptr := expected.is_pointer(); got_is_ptr := got.is_pointer(); if expected_is_ptr and got_is_ptr { - return self.check_pointer(expected, got); + return check_pointer(expected, got); } else if (expected_is_ptr and !got_is_ptr) or (got_is_ptr and !expected_is_ptr) { return false; } @@ -193,29 +193,6 @@ extend Checker { return expected == got; } - func check_pointer(self, expected: ast.Type, got: ast.Type) -> bool { - _ = self; - if expected is .Anyptr as anyptr_ { - if got is .Pointer as ptr { - if anyptr_.is_mut and !ptr.is_mut { - return false; - } - // anyptr == *T, is valid - return true; - } - return got is .Anyptr; - } else if expected is .Pointer as ptr and got is .Pointer as ptr2 { - if ptr.is_mut and !ptr2.is_mut { - return false; - } - if ptr.is_indexable and !ptr2.is_indexable { - return false; - } - return ptr.inner == ptr2.inner; - } - return false; - } - func promote(self, left_type: ast.Type, right_type: ast.Type) -> ast.Type { if left_type == right_type { return left_type; @@ -275,3 +252,26 @@ extend Checker { }; } } + +func check_pointer(expected: ast.Type, got: ast.Type) -> bool { + _ = self; + if expected is .Anyptr as anyptr_ { + if got is .Pointer as ptr { + if anyptr_.is_mut and !ptr.is_mut { + return false; + } + // anyptr == &T, is valid + return true; + } + return got is .Anyptr; + } else if expected is .Pointer as ptr and got is .Pointer as ptr2 { + if ptr.is_mut and !ptr2.is_mut { + return false; + } + if ptr.is_indexable and !ptr2.is_indexable { + return false; + } + return ptr.inner == ptr2.inner; + } + return false; +} diff --git a/rivetc/src/codegen/__init__.py b/rivetc/src/codegen/__init__.py index e454026ac..cc6e7c6de 100644 --- a/rivetc/src/codegen/__init__.py +++ b/rivetc/src/codegen/__init__.py @@ -647,10 +647,10 @@ def gen_expr_with_cast(self, expected_typ_, expr, custom_tmp = None): elif isinstance(res_expr, ir.FloatLit) and self.comp.is_float(expected_typ_): res_expr.typ = expected_typ - elif self.comp.is_comptime_number( + elif self.comp.is_number( expr.typ - ) and self.comp.is_number(expected_typ_): - res_expr.typ = expected_typ_ + ) and self.comp.is_number(expected_typ_) and expr.typ != expected_typ_: + res_expr = ir.Inst(ir.InstKind.Cast, [res_expr, expected_typ]) if isinstance( res_expr.typ, ir.Pointer diff --git a/rivetc/src/codegen/c.py b/rivetc/src/codegen/c.py index 92d676229..c8aefad0b 100644 --- a/rivetc/src/codegen/c.py +++ b/rivetc/src/codegen/c.py @@ -381,12 +381,17 @@ def gen_expr(self, expr): # overflows `int64`, hence the consecutive subtraction by `1`. self.write("(-9223372036854775807L - 1)") else: + self.write("((") + self.write_type(expr.typ) + self.write(")") + self.write("(") self.write(expr.lit) if expr.typ.name.endswith("64" ) or expr.typ.name.endswith("size"): if expr.typ.name.startswith("u"): self.write("U") self.write("L") + self.write("))") elif isinstance(expr, ir.FloatLit): self.write(expr.lit) if str(expr.typ) == "float32": diff --git a/rivetc/src/codegen/ir.py b/rivetc/src/codegen/ir.py index e0604e1a9..4884456e9 100644 --- a/rivetc/src/codegen/ir.py +++ b/rivetc/src/codegen/ir.py @@ -554,7 +554,7 @@ def __repr__(self): elif self == InstKind.LoadPtr: return "load_ptr" elif self == InstKind.GetElementPtr: return "get_element_ptr" elif self == InstKind.GetRef: return "get_ref" - elif self == InstKind.Cast: return "as" + elif self == InstKind.Cast: return "cast" elif self == InstKind.Cmp: return "cmp" elif self == InstKind.DbgStmtLine: return "dbg_stmt_line" elif self == InstKind.Breakpoint: return "breakpoint"