Skip to content

Commit e1c4038

Browse files
committed
zluda hijack torch jit
1 parent c490b9c commit e1c4038

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

modules/zluda_hijacks.py

+6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ def topk(tensor: torch.Tensor, *args, **kwargs):
88
return torch.return_types.topk((values.to(device), indices.to(device),))
99

1010

11+
def jit_script(f, *_, **__): # experiment / provide dummy graph
12+
f.graph = torch._C.Graph() # pylint: disable=protected-access
13+
return f
14+
15+
1116
def do_hijack():
1217
torch.version.hip = "5.7"
1318
torch.topk = topk
19+
torch.jit.script = jit_script

0 commit comments

Comments
 (0)