Skip to content

Commit 6bc35b6

Browse files
committed
enter_expr method implemented
1 parent df2bacf commit 6bc35b6

File tree

4 files changed

+201
-157
lines changed

4 files changed

+201
-157
lines changed

jac/jaclang/compiler/passes/main/fuse_typeinfo_pass.py

Lines changed: 107 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,32 @@ class FuseTypeInfoPass(Pass):
2929

3030
node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {}
3131

32+
# Override this to support enter expression.
33+
def enter_node(self, node: ast.AstNode) -> None:
34+
"""Run on entering node."""
35+
if hasattr(self, f"enter_{pascal_to_snake(type(node).__name__)}"):
36+
getattr(self, f"enter_{pascal_to_snake(type(node).__name__)}")(node)
37+
38+
# TODO: Make (AstSymbolNode::name_spec.sym_typ and Expr::expr_type) the same
39+
# TODO: Introduce AstTypedNode to be a common parent for Expr and AstSymbolNode
40+
if isinstance(node, ast.Expr):
41+
self.enter_expr(node)
42+
3243
def __debug_print(self, msg: str) -> None:
3344
if settings.fuse_type_info_debug:
3445
self.log_info("FuseTypeInfo::" + msg)
3546

36-
def __call_type_handler(
37-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Type
38-
) -> None:
47+
def __call_type_handler(self, mypy_type: MypyTypes.Type) -> Optional[str]:
3948
mypy_type_name = pascal_to_snake(mypy_type.__class__.__name__)
4049
type_handler_name = f"get_type_from_{mypy_type_name}"
4150
if hasattr(self, type_handler_name):
42-
getattr(self, type_handler_name)(node, mypy_type)
43-
else:
44-
self.__debug_print(
45-
f'{node.loc}"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet'
46-
)
51+
return getattr(self, type_handler_name)(mypy_type)
52+
self.__debug_print(
53+
f'"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet'
54+
)
55+
return None
4756

48-
def __set_sym_table_link(self, node: ast.AstSymbolNode) -> None:
57+
def __set_type_sym_table_link(self, node: ast.AstSymbolNode) -> None:
4958
typ = node.sym_type.split(".")
5059
typ_sym_table = self.ir.sym_tab
5160

@@ -117,7 +126,7 @@ def node_handler(self: FuseTypeInfoPass, node: T) -> None:
117126
# Jac node has only one mypy node linked to it
118127
if len(node.gen.mypy_ast) == 1:
119128
func(self, node)
120-
self.__set_sym_table_link(node)
129+
self.__set_type_sym_table_link(node)
121130
self.__collect_python_dependencies(node)
122131

123132
# Jac node has multiple mypy nodes linked to it
@@ -141,7 +150,7 @@ def node_handler(self: FuseTypeInfoPass, node: T) -> None:
141150
f"{jac_node_str} has duplicate mypy nodes associated to it"
142151
)
143152
func(self, node)
144-
self.__set_sym_table_link(node)
153+
self.__set_type_sym_table_link(node)
145154
self.__collect_python_dependencies(node)
146155

147156
# Jac node doesn't have mypy nodes linked to it
@@ -164,7 +173,9 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
164173

165174
if isinstance(mypy_node, MypyNodes.MemberExpr):
166175
if mypy_node in self.node_type_hash:
167-
self.__call_type_handler(node, self.node_type_hash[mypy_node])
176+
node.name_spec.sym_type = (
177+
self.__call_type_handler(self.node_type_hash[mypy_node]) or node.name_spec.sym_type
178+
)
168179
else:
169180
self.__debug_print(f"{node.loc} MemberExpr type is not found")
170181

@@ -173,7 +184,9 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
173184
mypy_node = mypy_node.node
174185

175186
if isinstance(mypy_node, (MypyNodes.Var, MypyNodes.FuncDef)):
176-
self.__call_type_handler(node, mypy_node.type)
187+
node.name_spec.sym_type = (
188+
self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type
189+
)
177190

178191
elif isinstance(mypy_node, MypyNodes.MypyFile):
179192
node.name_spec.sym_type = "types.ModuleType"
@@ -182,7 +195,9 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
182195
node.name_spec.sym_type = mypy_node.fullname
183196

184197
elif isinstance(mypy_node, MypyNodes.OverloadedFuncDef):
185-
self.__call_type_handler(node, mypy_node.items[0].func.type)
198+
node.name_spec.sym_type = (
199+
self.__call_type_handler(mypy_node.items[0].func.type) or node.name_spec.sym_type
200+
)
186201

187202
elif mypy_node is None:
188203
node.name_spec.sym_type = "None"
@@ -196,19 +211,60 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
196211
else:
197212
if isinstance(mypy_node, MypyNodes.ClassDef):
198213
node.name_spec.sym_type = mypy_node.fullname
199-
self.__set_sym_table_link(node)
214+
self.__set_type_sym_table_link(node)
200215
elif isinstance(mypy_node, MypyNodes.FuncDef):
201-
self.__call_type_handler(node, mypy_node.type)
216+
node.name_spec.sym_type = self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type
202217
elif isinstance(mypy_node, MypyNodes.Argument):
203-
self.__call_type_handler(node, mypy_node.variable.type)
218+
node.name_spec.sym_type = self.__call_type_handler(mypy_node.variable.type) or node.name_spec.sym_type
204219
elif isinstance(mypy_node, MypyNodes.Decorator):
205-
self.__call_type_handler(node, mypy_node.func.type.ret_type)
220+
node.name_spec.sym_type = (
221+
self.__call_type_handler(mypy_node.func.type.ret_type) or node.name_spec.sym_type
222+
)
206223
else:
207224
self.__debug_print(
208225
f'"{node.loc}::{node.__class__.__name__}" mypy node isn\'t supported'
209226
f"{type(mypy_node)}"
210227
)
211228

229+
collection_types_map = {
230+
ast.ListVal: "builtins.list",
231+
ast.SetVal: "builtins.set",
232+
ast.TupleVal: "builtins.tuple",
233+
ast.DictVal: "builtins.dict",
234+
ast.ListCompr: None,
235+
ast.DictCompr: None,
236+
}
237+
238+
# NOTE (Thakee): Since expression nodes are not AstSymbolNodes, I'm not decorating this with __handle_node
239+
# and IMO instead of checking if it's a symbol node or an expression, we somehow mark expressions as
240+
# valid nodes that can have symbols. At this point I'm leaving this like this and lemme know
241+
# otherwise.
242+
# NOTE (GAMAL): This will be fixed through the AstTypedNode
243+
def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None:
244+
"""Enter an expression node."""
245+
if len(node.gen.mypy_ast) == 0:
246+
return
247+
248+
# If the corrosponding mypy ast node type has stored here, get the values.
249+
mypy_node = node.gen.mypy_ast[0]
250+
if mypy_node in self.node_type_hash:
251+
mytype: MyType = self.node_type_hash[mypy_node]
252+
node.expr_type = self.__call_type_handler(mytype) or ""
253+
254+
# Set they symbol type for collection expression.
255+
#
256+
# GenCompr is an instance of ListCompr but we don't handle it here.
257+
# so the isinstace (node, <classes>) doesn't work, I'm going with type(...) == ...
258+
if type(node) in self.collection_types_map:
259+
assert isinstance(node, ast.AtomExpr) # To make mypy happy.
260+
collection_type = self.collection_types_map[type(node)]
261+
if collection_type is not None:
262+
node.name_spec.sym_type = collection_type
263+
if mypy_node in self.node_type_hash:
264+
node.name_spec.sym_type = (
265+
self.__call_type_handler(mytype) or node.name_spec.sym_type
266+
)
267+
212268
@__handle_node
213269
def enter_name(self, node: ast.NameAtom) -> None:
214270
"""Pass handler for name nodes."""
@@ -248,7 +304,10 @@ def enter_enum_def(self, node: ast.EnumDef) -> None:
248304
def enter_ability(self, node: ast.Ability) -> None:
249305
"""Pass handler for Ability nodes."""
250306
if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
251-
self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type)
307+
node.name_spec.sym_type = (
308+
self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type)
309+
or node.name_spec.sym_type
310+
)
252311
else:
253312
self.__debug_print(
254313
f"{node.loc}: Can't get type of an ability from mypy node other than Ability. "
@@ -259,7 +318,10 @@ def enter_ability(self, node: ast.Ability) -> None:
259318
def enter_ability_def(self, node: ast.AbilityDef) -> None:
260319
"""Pass handler for AbilityDef nodes."""
261320
if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
262-
self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type)
321+
node.name_spec.sym_type = (
322+
self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type)
323+
or node.name_spec.sym_type
324+
)
263325
else:
264326
self.__debug_print(
265327
f"{node.loc}: Can't get type of an AbilityDef from mypy node other than FuncDef. "
@@ -272,7 +334,9 @@ def enter_param_var(self, node: ast.ParamVar) -> None:
272334
if isinstance(node.gen.mypy_ast[0], MypyNodes.Argument):
273335
mypy_node: MypyNodes.Argument = node.gen.mypy_ast[0]
274336
if mypy_node.variable.type:
275-
self.__call_type_handler(node, mypy_node.variable.type)
337+
node.name_spec.sym_type = (
338+
self.__call_type_handler(mypy_node.variable.type) or node.name_spec.sym_type
339+
)
276340
else:
277341
self.__debug_print(
278342
f"{node.loc}: Can't get parameter value from mypyNode other than Argument"
@@ -286,7 +350,9 @@ def enter_has_var(self, node: ast.HasVar) -> None:
286350
if isinstance(mypy_node, MypyNodes.AssignmentStmt):
287351
n = mypy_node.lvalues[0].node
288352
if isinstance(n, (MypyNodes.Var, MypyNodes.FuncDef)):
289-
self.__call_type_handler(node, n.type)
353+
node.name_spec.sym_type = (
354+
self.__call_type_handler(n.type) or node.name_spec.sym_type
355+
)
290356
else:
291357
self.__debug_print(
292358
"Getting type of 'AssignmentStmt' is only supported with Var and FuncDef"
@@ -311,54 +377,6 @@ def enter_f_string(self, node: ast.FString) -> None:
311377
"""Pass handler for FString nodes."""
312378
self.__debug_print(f"Getting type not supported in {type(node)}")
313379

314-
@__handle_node
315-
def enter_list_val(self, node: ast.ListVal) -> None:
316-
"""Pass handler for ListVal nodes."""
317-
mypy_node = node.gen.mypy_ast[0]
318-
if mypy_node in self.node_type_hash:
319-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
320-
else:
321-
node.name_spec.sym_type = "builtins.list"
322-
323-
@__handle_node
324-
def enter_set_val(self, node: ast.SetVal) -> None:
325-
"""Pass handler for SetVal nodes."""
326-
mypy_node = node.gen.mypy_ast[0]
327-
if mypy_node in self.node_type_hash:
328-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
329-
else:
330-
node.name_spec.sym_type = "builtins.set"
331-
332-
@__handle_node
333-
def enter_tuple_val(self, node: ast.TupleVal) -> None:
334-
"""Pass handler for TupleVal nodes."""
335-
mypy_node = node.gen.mypy_ast[0]
336-
if mypy_node in self.node_type_hash:
337-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
338-
else:
339-
node.name_spec.sym_type = "builtins.tuple"
340-
341-
@__handle_node
342-
def enter_dict_val(self, node: ast.DictVal) -> None:
343-
"""Pass handler for DictVal nodes."""
344-
mypy_node = node.gen.mypy_ast[0]
345-
if mypy_node in self.node_type_hash:
346-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
347-
else:
348-
node.name_spec.sym_type = "builtins.dict"
349-
350-
@__handle_node
351-
def enter_list_compr(self, node: ast.ListCompr) -> None:
352-
"""Pass handler for ListCompr nodes."""
353-
mypy_node = node.gen.mypy_ast[0]
354-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
355-
356-
@__handle_node
357-
def enter_dict_compr(self, node: ast.DictCompr) -> None:
358-
"""Pass handler for DictCompr nodes."""
359-
mypy_node = node.gen.mypy_ast[0]
360-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
361-
362380
@__handle_node
363381
def enter_index_slice(self, node: ast.IndexSlice) -> None:
364382
"""Pass handler for IndexSlice nodes."""
@@ -370,10 +388,12 @@ def enter_arch_ref(self, node: ast.ArchRef) -> None:
370388
if isinstance(node.gen.mypy_ast[0], MypyNodes.ClassDef):
371389
mypy_node: MypyNodes.ClassDef = node.gen.mypy_ast[0]
372390
node.name_spec.sym_type = mypy_node.fullname
373-
self.__set_sym_table_link(node)
391+
self.__set_type_sym_table_link(node)
374392
elif isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
375393
mypy_node2: MypyNodes.FuncDef = node.gen.mypy_ast[0]
376-
self.__call_type_handler(node, mypy_node2.type)
394+
node.name_spec.sym_type = (
395+
self.__call_type_handler(mypy_node2.type) or node.name_spec.sym_type
396+
)
377397
else:
378398
self.__debug_print(
379399
f"{node.loc}: Can't get ArchRef value from mypyNode other than ClassDef "
@@ -425,48 +445,36 @@ def enter_builtin_type(self, node: ast.BuiltinType) -> None:
425445
"""Pass handler for BuiltinType nodes."""
426446
self.__collect_type_from_symbol(node)
427447

428-
def get_type_from_instance(
429-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Instance
430-
) -> None:
448+
def get_type_from_instance(self, mypy_type: MypyTypes.Instance) -> None:
431449
"""Get type info from mypy type Instance."""
432-
node.name_spec.sym_type = str(mypy_type)
450+
return str(mypy_type)
433451

434-
def get_type_from_callable_type(
435-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.CallableType
436-
) -> None:
452+
def get_type_from_callable_type(self, mypy_type: MypyTypes.CallableType) -> Optional[str]:
437453
"""Get type info from mypy type CallableType."""
438-
self.__call_type_handler(node, mypy_type.ret_type)
454+
return self.__call_type_handler(mypy_type.ret_type)
439455

440456
# TODO: Which overloaded function to get the return value from?
441457
def get_type_from_overloaded(
442-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Overloaded
443-
) -> None:
458+
self, mypy_type: MypyTypes.Overloaded
459+
) -> Optional[str]:
444460
"""Get type info from mypy type Overloaded."""
445-
self.__call_type_handler(node, mypy_type.items[0])
461+
return self.__call_type_handler(mypy_type.items[0])
446462

447-
def get_type_from_none_type(
448-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.NoneType
449-
) -> None:
463+
def get_type_from_none_type(self, mypy_type: MypyTypes.NoneType) -> Optional[str]:
450464
"""Get type info from mypy type NoneType."""
451-
node.name_spec.sym_type = "None"
465+
return "None"
452466

453-
def get_type_from_any_type(
454-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.AnyType
455-
) -> None:
467+
def get_type_from_any_type(self, mypy_type: MypyTypes.AnyType) -> None:
456468
"""Get type info from mypy type NoneType."""
457-
node.name_spec.sym_type = "Any"
469+
return "Any"
458470

459-
def get_type_from_tuple_type(
460-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.TupleType
461-
) -> None:
471+
def get_type_from_tuple_type(self, mypy_type: MypyTypes.TupleType) -> None:
462472
"""Get type info from mypy type TupleType."""
463-
node.name_spec.sym_type = "builtins.tuple"
473+
return "builtins.tuple"
464474

465-
def get_type_from_type_type(
466-
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.TypeType
467-
) -> None:
475+
def get_type_from_type_type(self, mypy_type: MypyTypes.TypeType) -> None:
468476
"""Get type info from mypy type TypeType."""
469-
node.name_spec.sym_type = str(mypy_type.item)
477+
return str(mypy_type.item)
470478

471479
def exit_assignment(self, node: ast.Assignment) -> None:
472480
"""Add new symbols in the symbol table in case of self."""

jac/jaclang/compiler/passes/main/tests/test_type_check_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,6 @@ def test_type_coverage(self) -> None:
5959
self.assertIn("HasVar - species - Type: builtins.str", out)
6060
self.assertIn("myDog - Type: type_info.Dog", out)
6161
self.assertIn("Body - Type: type_info.Dog.Body", out)
62-
self.assertEqual(out.count("Type: builtins.str"), 28)
62+
self.assertEqual(out.count("Type: builtins.str"), 39)
6363
for i in lis:
6464
self.assertNotIn(i, out)

0 commit comments

Comments
 (0)