We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
grad_transform
Hitting this assert below vvv
root@847841b8737c:/opt/pytorch/lightning-thunder# python /volume/pooling.py Traceback (most recent call last): File "/volume/pooling.py", line 36, in <module> o = jit_model(image) File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 632, in fn_ cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 265, in cache_info_wrapper res = fn(*args, **kwargs) File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 574, in get_computation_and_inputs computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps) File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 216, in split_forward_backward fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True) File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3879, in forward_and_backward_from_trace forward_trace = construct_trace()(augmented_forward_fn, *trace.args, **trace.kwargs) File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1292, in fn_ return fn(*args, **kwargs) File "/opt/pytorch/lightning-thunder/thunder/common.py", line 528, in _trace result = fn(*proxyargs, **proxykwargs) File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3850, in augmented_forward_fn result, env = augmented_forward_pass(*args, trace=trace, **kwargs) File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3461, in augmented_forward_pass result, env = eval_trace( File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 1698, in eval_trace prim_func = symbol_mapper(symbol) File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3385, in vjp_symbol_mapper vjp_impl, backward_fn = make_aug_forward_and_backward(symbol) File "/opt/pytorch/lightning-thunder/thunder/core/vjp_utils.py", line 63, in make_aug_forward_and_backward joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs) File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1292, in fn_ return fn(*args, **kwargs) File "/opt/pytorch/lightning-thunder/thunder/common.py", line 506, in _trace proxyargs, proxykwargs = _unpack_inputs(fn, trace, args, kwargs, rename_proxies=rename_proxies) File "/opt/pytorch/lightning-thunder/thunder/common.py", line 273, in _unpack_inputs si = get_siginfo(fn, args, kwargs) File "/opt/pytorch/lightning-thunder/thunder/core/codeutils.py", line 313, in get_siginfo check( File "/opt/pytorch/lightning-thunder/thunder/core/baseutils.py", line 103, in check raise exception_type(s()) NotImplementedError: Support for partials with positional args (like ('test',)) is not implemented yet
I was trying to use something like
foo = partial(bar, pos_arg0) OperatorExecutor.register_operator(..., grad_transform=fn)
This isn't a high priority issue, since we can easily work around it for now. Filing the issue just to keep track of missing feature.
The text was updated successfully, but these errors were encountered:
No branches or pull requests
🚀 Feature
Hitting this assert below vvv
I was trying to use something like
This isn't a high priority issue, since we can easily work around it for now. Filing the issue just to keep track of missing feature.
The text was updated successfully, but these errors were encountered: