Skip to content

Commit

Permalink
Fix torch._C.Node attribute access (#372)
Browse files Browse the repository at this point in the history
Attribute access with subscripting would previously work
due to patching in pytorch/pytorch#82511
but this has been removed.

This commit uses the fix proposed in pytorch/pytorch#82628
to define a helper method to call the appropriate access method.
  • Loading branch information
jamt9000 authored Jul 8, 2023
1 parent a9b1bf5 commit a1d0717
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]

def _node_get(node: torch._C.Node, key: str):
"""Gets attributes of a node which is polymorphic over return type.
From https://github.com/pytorch/pytorch/pull/82628
"""
sel = node.kindOf(key)
return getattr(node, sel)(key)

def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
Expand All @@ -156,7 +164,7 @@ def patch_device(module):

for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
node.copyAttributes(device_node)

model.apply(patch_device)
Expand All @@ -182,7 +190,7 @@ def patch_float(module):
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
if _node_get(inputs[i].node(), "value") == 5:
inputs[i].node().copyAttributes(float_node)

model.apply(patch_float)
Expand Down

0 comments on commit a1d0717

Please sign in to comment.