Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial function is not supported in grad_transform #171

Open
jjsjann123 opened this issue Apr 12, 2024 · 0 comments
Open

Partial function is not supported in grad_transform #171

jjsjann123 opened this issue Apr 12, 2024 · 0 comments
Labels
autograd enhancement New feature or request

Comments

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Apr 12, 2024

🚀 Feature

Hitting this assert below vvv

root@847841b8737c:/opt/pytorch/lightning-thunder# python /volume/pooling.py
Traceback (most recent call last):
  File "/volume/pooling.py", line 36, in <module>
    o = jit_model(image)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 632, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 265, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 574, in get_computation_and_inputs
    computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
  File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 216, in split_forward_backward
    fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3879, in forward_and_backward_from_trace
    forward_trace = construct_trace()(augmented_forward_fn, *trace.args, **trace.kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1292, in fn_
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/common.py", line 528, in _trace
    result = fn(*proxyargs, **proxykwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3850, in augmented_forward_fn
    result, env = augmented_forward_pass(*args, trace=trace, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3461, in augmented_forward_pass
    result, env = eval_trace(
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 1698, in eval_trace
    prim_func = symbol_mapper(symbol)
  File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3385, in vjp_symbol_mapper
    vjp_impl, backward_fn = make_aug_forward_and_backward(symbol)
  File "/opt/pytorch/lightning-thunder/thunder/core/vjp_utils.py", line 63, in make_aug_forward_and_backward
    joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1292, in fn_
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/common.py", line 506, in _trace
    proxyargs, proxykwargs = _unpack_inputs(fn, trace, args, kwargs, rename_proxies=rename_proxies)
  File "/opt/pytorch/lightning-thunder/thunder/common.py", line 273, in _unpack_inputs
    si = get_siginfo(fn, args, kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/codeutils.py", line 313, in get_siginfo
    check(
  File "/opt/pytorch/lightning-thunder/thunder/core/baseutils.py", line 103, in check
    raise exception_type(s())
NotImplementedError: Support for partials with positional args (like ('test',)) is not implemented yet

I was trying to use something like

foo = partial(bar, pos_arg0)
OperatorExecutor.register_operator(..., grad_transform=fn)

This isn't a high priority issue, since we can easily work around it for now. Filing the issue just to keep track of missing feature.

@jjsjann123 jjsjann123 added the enhancement New feature or request label Apr 12, 2024
@jjsjann123 jjsjann123 mentioned this issue Apr 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
autograd enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants