From dfbdf3b0c43607fcc9edb31e642e19bf6ded13fe Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Tue, 4 Feb 2025 15:20:37 -0500 Subject: [PATCH] good --- luisa_lang/codegen/cpp.py | 20 ++++++++- luisa_lang/hir.py | 87 ++++++++++++++++++++++++++++++++----- luisa_lang/lang_builtins.py | 33 +++++++++++++- luisa_lang/parse.py | 64 ++++++++++++++++++++++++--- 4 files changed, 184 insertions(+), 20 deletions(-) diff --git a/luisa_lang/codegen/cpp.py b/luisa_lang/codegen/cpp.py index e273509..e2253af 100644 --- a/luisa_lang/codegen/cpp.py +++ b/luisa_lang/codegen/cpp.py @@ -415,8 +415,10 @@ def impl() -> None: ty = self.base.type_cache.gen(expr.type) self.body.writeln( f"{ty} v{vid}{{ {','.join(self.gen_expr(e) for e in expr.args)} }};") - case hir.Intrinsic() as intrin: + case hir.Intrinsic() as intrin: def do(): + assert intrin.type + intrin_ty_s = self.base.type_cache.gen(intrin.type) intrin_name = intrin.name comps = intrin_name.split('.') gened_args = [self.gen_value_or_ref( @@ -426,6 +428,12 @@ def do(): ty = self.base.type_cache.gen(expr.type) self.body.writeln( f"{ty} v{vid}{{ {','.join(gened_args)} }};") + elif comps[0] == 'cast': + self.body.writeln( + f"auto v{vid} = static_cast<{intrin_ty_s}>({gened_args[0]});") + elif comps[0] == 'bitcast': + self.body.writeln( + f"auto v{vid} = lc_bit_cast<{intrin_ty_s}>({gened_args[0]});") elif comps[0] == 'cmp': cmp_dict = { '__eq__': '==', @@ -592,11 +600,19 @@ def gen_node(self, node: hir.Node) -> Optional[hir.BasicBlock]: ty = self.base.type_cache.gen(alloca.type.remove_ref()) self.body.writeln(f"{ty} v{vid}{{}}; // alloca") self.node_map[alloca] = f"v{vid}" - case hir.AggregateInit() | hir.Intrinsic() | hir.Call() | hir.Constant() | hir.Load() | hir.Index() | hir.Member() | hir.TypeValue() | hir.FunctionValue(): + case hir.Print() as print_stmt: + raise NotImplementedError("print statement") + case hir.Assert() as assert_stmt: + raise NotImplementedError("assert statement") + case hir.AggregateInit() | hir.Intrinsic() | hir.Call() | hir.Constant() | hir.Load() | hir.Index() | hir.Member() | hir.TypeValue() | hir.FunctionValue() | hir.VarValue(): if isinstance(node, hir.TypedNode) and node.is_ref(): pass else: self.gen_expr(node) + case hir.VarRef(): + pass + case _: + raise NotImplementedError(f"unsupported node: {node}") return None def gen_bb(self, bb: hir.BasicBlock): diff --git a/luisa_lang/hir.py b/luisa_lang/hir.py index d610a56..a2a77b3 100644 --- a/luisa_lang/hir.py +++ b/luisa_lang/hir.py @@ -142,7 +142,7 @@ def method(self, name: str) -> Optional[Union["Function", FunctionTemplate]]: def is_concrete(self) -> bool: return True - + def is_addressable(self) -> bool: return True @@ -163,8 +163,10 @@ class RefType(Type): def __init__(self, element: Type) -> None: super().__init__() - assert element.is_addressable(), f"RefType element {element} is not addressable" - assert not isinstance(element, (OpaqueType, RefType, FunctionType,TypeConstructorType)) + assert element.is_addressable(), f"RefType element { + element} is not addressable" + assert not isinstance( + element, (OpaqueType, RefType, FunctionType, TypeConstructorType)) self.element = element self.methods = element.methods @@ -189,18 +191,19 @@ def member(self, field: Any) -> Optional['Type']: ty = self.element.member(field) if ty is None: return None - if isinstance(ty,FunctionType): + if isinstance(ty, FunctionType): return ty return RefType(ty) @override def method(self, name: str) -> Optional[Union["Function", FunctionTemplate]]: return self.element.method(name) - + @override def is_addressable(self) -> bool: return False + class LiteralType(Type): value: Any @@ -221,7 +224,7 @@ def is_concrete(self) -> bool: @override def is_addressable(self) -> bool: return False - + def __eq__(self, value: object) -> bool: return isinstance(value, LiteralType) and value.value == self.value @@ -349,6 +352,7 @@ def is_concrete(self) -> bool: def is_addressable(self) -> bool: return False + class GenericIntType(ScalarType): @override def __eq__(self, value: object) -> bool: @@ -382,6 +386,7 @@ def is_concrete(self) -> bool: def is_addressable(self) -> bool: return False + class FloatType(ScalarType): bits: int @@ -695,6 +700,7 @@ def __repr__(self) -> str: def __str__(self) -> str: return f"~{self.name}@{self.ctx_name}" + class OpaqueType(Type): name: str extra_args: List[Any] @@ -722,7 +728,7 @@ def __str__(self) -> str: @override def is_concrete(self) -> bool: return False - + @override def is_addressable(self) -> bool: return False @@ -800,18 +806,19 @@ def __eq__(self, value: object) -> bool: def __hash__(self) -> int: return hash((ParametricType, tuple(self.params), self.body)) - + def __str__(self) -> str: return f"{self.body}[{', '.join(str(p) for p in self.params)}]" @override def is_concrete(self) -> bool: return self.body.is_concrete() - + @override def is_addressable(self) -> bool: return self.body.is_addressable() + class BoundType(Type): """ An instance of a parametric type, e.g. Foo[int] @@ -841,7 +848,7 @@ def __eq__(self, value: object) -> bool: def __hash__(self): return hash((BoundType, self.generic, tuple(self.args))) - + def __str__(self) -> str: return f"{self.generic}[{', '.join(str(a) for a in self.args)}]" @@ -862,11 +869,12 @@ def method(self, name) -> Optional[Union["Function", FunctionTemplate]]: @override def is_addressable(self) -> bool: return self.generic.is_addressable() - + @override def is_concrete(self) -> bool: return self.generic.is_concrete() + class TypeConstructorType(Type): inner: Type @@ -910,6 +918,7 @@ def size(self) -> int: def align(self) -> int: raise RuntimeError("FunctionType has no align") + class Node: """ Base class for all nodes in the HIR. A node could be a value, a reference, or a statement. @@ -999,6 +1008,7 @@ def __init__( self.name = name self.semantic = semantic + class VarValue(Value): var: Var @@ -1006,6 +1016,7 @@ def __init__(self, var: Var, span: Optional[Span]) -> None: super().__init__(var.type, span) self.var = var + class VarRef(Value): var: Var @@ -1155,6 +1166,8 @@ def __str__(self) -> str: return f"Template matching error:\n\t{self.message}" return f"Template matching error at {self.span}:\n\t{self.message}" +class ComptimeCallStack: + pass class SpannedError(Exception): span: Span | None @@ -1200,7 +1213,8 @@ class Assign(Node): value: Value def __init__(self, ref: Value, value: Value, span: Optional[Span] = None) -> None: - assert not isinstance(value.type, (FunctionType, TypeConstructorType, RefType)) + assert not isinstance( + value.type, (FunctionType, TypeConstructorType, RefType)) if not isinstance(ref.type, RefType): raise ParsingError( ref, f"cannot assign to a non-reference variable") @@ -1209,6 +1223,24 @@ def __init__(self, ref: Value, value: Value, span: Optional[Span] = None) -> Non self.value = value +class Assert(Node): + cond: Value + msg: List[Union[Value, str]] + + def __init__(self, cond: Value, msg: List[Union[Value, str]], span: Optional[Span] = None) -> None: + super().__init__(span) + self.cond = cond + self.msg = msg + + +class Print(Node): + args: List[Union[Value, str]] + + def __init__(self, args: List[Union[Value, str]], span: Optional[Span] = None) -> None: + super().__init__(span) + self.args = args + + class Terminator(Node): pass @@ -1559,6 +1591,7 @@ def __init__(self, func: Function, args: List[Value], body: BasicBlock, span: Op self.mapping[param] = arg for v in func.locals: if v in self.mapping: + # skip function parameters continue assert v.type assert v.type.is_addressable() @@ -1631,6 +1664,33 @@ def do(): self.mapping[intrin] = body.append( Intrinsic(intrin.name, args, intrin.type, node.span)) do() + case If(): + cond = self.mapping.get(node.cond) + assert isinstance(cond, Value) + then_body = BasicBlock() + else_body = BasicBlock() + merge = BasicBlock() + body.append(If(cond, then_body, else_body, merge)) + self.do_inline(node.then_body, then_body) + if node.else_body: + self.do_inline(node.else_body, else_body) + body.append(merge) + case Loop(): + prepare = BasicBlock() + if node.cond: + cond = self.mapping.get(node.cond) + else: + cond = None + assert cond is None or isinstance(cond, Value) + body_ = BasicBlock() + update = BasicBlock() + merge = BasicBlock() + body.append(Loop(prepare, cond, body_, update, merge)) + self.do_inline(node.prepare, prepare) + self.do_inline(node.body, body_) + if node.update: + self.do_inline(node.update, update) + body.append(merge) case Return(): if self.ret is not None: raise InlineError(node, "multiple return statement") @@ -1646,6 +1706,9 @@ def do(): @staticmethod def inline(func: Function, args: List[Value], body: BasicBlock, span: Optional[Span] = None) -> Value: inliner = FunctionInliner(func, args, body, span) + assert func.return_type + if func.return_type == UnitType(): + return Unit() assert inliner.ret return inliner.ret diff --git a/luisa_lang/lang_builtins.py b/luisa_lang/lang_builtins.py index 1722e7b..99da23b 100644 --- a/luisa_lang/lang_builtins.py +++ b/luisa_lang/lang_builtins.py @@ -21,7 +21,7 @@ Any, Annotated ) -from luisa_lang._builtin_decor import func, intrinsic, opaque, builtin_generic_type, byref +from luisa_lang._builtin_decor import func, intrinsic, opaque, builtin_generic_type, byref, struct from luisa_lang import parse T = TypeVar("T") @@ -317,6 +317,37 @@ def __sub__(self, offset: i32 | i64 | u32 | u64) -> 'Pointer[T]': return intrinsic("pointer.sub", Pointer[T], self, offset) +@struct +class RtxRay: + o: float3 + d: float3 + tmin: float + tmax: float + + def __init__(self, o: float3, d: float3, tmin: float, tmax: float) -> None: + self.o = o + self.d = d + self.tmin = tmin + self.tmax = tmax + + +@struct +class RtxHit: + inst_id: u32 + prim_id: u32 + bary: float2 + + def __init__(self, inst_id: u32, prim_id: u32, bary: float2) -> None: + self.inst_id = inst_id + self.prim_id = prim_id + self.bary = bary + + +@func +def ray_query_pipeline(ray: RtxRay, on_surface_hit, on_procedural_hit) -> RtxHit: + return intrinsic("ray_query_pipeline", RtxHit, ray, on_surface_hit, on_procedural_hit) + + __all__: List[str] = [ # 'Pointer', 'Buffer', diff --git a/luisa_lang/parse.py b/luisa_lang/parse.py index 37b793c..1718145 100644 --- a/luisa_lang/parse.py +++ b/luisa_lang/parse.py @@ -56,6 +56,7 @@ def mono_func(args: List[hir.Type]) -> hir.Type: TYPE_PARAMETERIC_TYPE: hir.ParametricType = _make_type_parameteric_type() + class TypeParser: ctx_name: str globalns: Dict[str, Any] @@ -229,6 +230,7 @@ def _add_special_function(name: str, f: Callable[..., Any]) -> None: SPECIAL_FUNCTIONS_DICT[name] = f SPECIAL_FUNCTIONS.add(f) +_add_special_function('print', print) NewVarHint = Literal[False, 'dsl', 'comptime'] @@ -238,6 +240,20 @@ def _friendly_error_message_for_unrecognized_type(ty: Any) -> str: return 'expected builtin function range, use lc.range instead' return f"expected DSL type but got {ty}" +class FuncStack: + st: List['FuncParser'] + + def __init__(self) -> None: + self.st = [] + + def push(self, f: 'FuncParser') -> None: + self.st.append(f) + + def pop(self) -> 'FuncParser': + return self.st.pop() + + +FUNC_STACK = FuncStack() class FuncParser: @@ -525,8 +541,7 @@ def do(expr: ast.Attribute): raise NotImplementedError() # unreachable def parse_call_impl(self, span: hir.Span | None, f: hir.Function | hir.FunctionTemplate, args: List[hir.Value]) -> hir.Value | hir.TemplateMatchingError: - - + if isinstance(f, hir.FunctionTemplate): if f.is_generic: template_resolve_args: hir.FunctionTemplateResolvingArgs = [] @@ -555,7 +570,7 @@ def parse_call_impl(self, span: hir.Span | None, f: hir.Function | hir.FunctionT resolved_f = f assert resolved_f.return_type expect_ref = isinstance(resolved_f.return_type, hir.RefType) - inline = expect_ref + inline = expect_ref or resolved_f.inline_hint == 'always' param_tys = [] for p in resolved_f.params: assert p.type, f"Parameter {p.name} has no type" @@ -628,7 +643,11 @@ def handle_intrinsic(self, expr: ast.Call) -> hir.Value: # ret_type.inner_type(), hir.Span.from_ast(expr))) def handle_special_functions(self, f: Callable[..., Any], expr: ast.Call) -> hir.Value | ComptimeValue: - if f is SPECIAL_FUNCTIONS_DICT['intrinsic']: + if f is print: + args = [self.parse_string_element(a) for a in expr.args] + self.cur_bb().append(hir.Print(args, hir.Span.from_ast(expr))) + return hir.Unit() + elif f is SPECIAL_FUNCTIONS_DICT['intrinsic']: intrin_ret = self.handle_intrinsic(expr) assert isinstance(intrin_ret, hir.Value) return intrin_ret @@ -791,7 +810,6 @@ def collect_args() -> List[hir.Value]: span, init, [tmp]+collect_args()) if isinstance(call, hir.TemplateMatchingError): raise hir.ParsingError(expr, call.message) - assert isinstance(call, hir.Call) return self.cur_bb().append(hir.Load(tmp)) assert func.type if isinstance(func.type, hir.FunctionType): @@ -1094,6 +1112,19 @@ def convert_to_value(self, value: hir.Value | ComptimeValue, span: Optional[hir. self.cur_bb().append(value) return value + def parse_string_element(self, v: ast.expr) -> Union[str, hir.Value]: + span = hir.Span.from_ast(v) + if isinstance(v, ast.Constant) and isinstance(v.value, str): + return v.value + return self.convert_to_value(self.parse_expr(v), span) + + def parse_strings(self, expr: ast.JoinedStr | ast.Constant) -> List[Union[str, hir.Value]]: + match expr: + case ast.JoinedStr(): + return [self.parse_string_element(v) for v in expr.values] + case ast.Constant(): + return [self.parse_string_element(expr)] + def parse_stmt(self, stmt: ast.stmt) -> None: span = hir.Span.from_ast(stmt) match stmt: @@ -1260,6 +1291,27 @@ def do(): self.parse_multi_assignment( [target], [], self.parse_expr(stmt.value) ) + case ast.Assert(): + def handle_assert(): + test = self.parse_expr(stmt.test) + msg = stmt.msg + if isinstance(test, ComptimeValue): + if msg: + evaled_msg = f'assertion failed for comptime value { + self.eval_expr(msg)}' + else: + evaled_msg = f'assertion failed for comptime value { + test.value}' + assert test.value, evaled_msg + else: + sep_msg: List[Union[hir.Value, str]] = [] + if msg is not None: + if not isinstance(msg, (ast.Constant, ast.JoinedStr)): + raise hir.ParsingError( + stmt, "assert message must be a string literal") + sep_msg = self.parse_strings(msg) + self.cur_bb().append(hir.Assert(test, sep_msg, span)) + handle_assert() case ast.AugAssign(): method_name = AUG_ASSIGN_TO_METHOD_NAMES[type(stmt.op)] var = self.parse_expr(stmt.target) @@ -1318,6 +1370,7 @@ def parse_anno_ty() -> hir.Type: raise RuntimeError(f"Unsupported statement: {ast.dump(stmt)}") def parse_body(self): + FUNC_STACK.push(self) assert self.parsed_func is not None body = self.func_def.body entry = hir.BasicBlock(hir.Span.from_ast(self.func_def)) @@ -1331,6 +1384,7 @@ def parse_body(self): if not self.parsed_func.return_type: self.parsed_func.return_type = hir.UnitType() self.parsed_func.complete = True + assert FUNC_STACK.pop() is self return self.parsed_func