Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_vjp_correctness fails with ops that return tensors that do not require grads. #120

Open
nikitaved opened this issue Apr 2, 2024 · 1 comment
Assignees
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@nikitaved
Copy link
Contributor

🐛 Bug

As per title. To reproduce, one could uncomment these tests in these tests in #118 to get:

thunder/tests/test_grad.py:423: in test_vjp_correctness                                                                                                                                                                                       
    result = run_snippet(                                                                                                                                                                                                                     
thunder/tests/framework.py:483: in run_snippet                                                                                                                                                                                                
    raise ex                                                                                                                                                                                                                                  
thunder/tests/framework.py:475: in run_snippet                                                                                                                                                                                                
    snippet(*args, **kwargs)                                                                                                                                                                                                                  
thunder/tests/test_grad.py:394: in snippet_vjp_correctness                                                                                                                                                                                    
    check_vjp(func, *args, executor=executor)                                                                                                                                                                                                 
thunder/tests/test_grad.py:304: in check_vjp                                                                                                                                                                                                  
    _, J_star_v = executor.make_callable_legacy(vjp(f), disable_torch_autograd_support=True)(primals, v)                                                                                                                                      
thunder/common.py:783: in _fn                                                                                                                                                                                                                 
    trc_or_result = trace(compile_data=cd)(processed_function, *args, **kwargs)                                                                                                                                                               
thunder/core/interpreter.py:1298: in fn_                                                                                                                                                                                                      
    return fn(*args, **kwargs)                                                                                                                                                                                                                
thunder/common.py:534: in _trace                                                                                                                                                                                                              
    result = fn(*proxyargs, **proxykwargs)                                                                                                                                                                                                    
thunder/core/transforms.py:3629: in _vjp                                                                                                                                                                                                      
    result, vjp_result = vjp_call(flat_args, cotangents, trace=trace)                                                                                                                                                                         
thunder/core/transforms.py:3603: in vjp_call_metafunc                                                                                                                                                                                         
    result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)                                                                                                                                                                     
thunder/core/transforms.py:3414: in augmented_forward_pass                                                                                                                                                                                    
    result, env = eval_trace(                                                                                                                                                                                                                 
thunder/core/transforms.py:1693: in eval_trace                                                                                                                                                                                                
    prim_func = symbol_mapper(symbol)                                                                                                                                                                                                         
thunder/core/transforms.py:3338: in vjp_symbol_mapper                                                                                                                                                                                         
    vjp_impl, backward_fn = make_aug_forward_and_backward(symbol)                                                                                                                                                                             
thunder/core/vjp_utils.py:99: in make_aug_forward_and_backward                                                                                                                                                                                
    backward_bsyms = utils.find_producer_symbols(joint_trace, flat_bw_outputs, tree_flatten(bw_inputs)[0])                                                                                                                                    
thunder/core/utils.py:1062: in find_producer_symbols                                                                                                                                                                                          
    if arg_name not in map(lambda x: x.name, stop_proxies) and arg_name not in seen:                                                                                                                                                          
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
                                                                                                                                                                                                                                              
x = None                                                                                                                                                                                                                                      
                                                                                                                                                                                                                                              
>   if arg_name not in map(lambda x: x.name, stop_proxies) and arg_name not in seen:                                                                                                                                                          
E   AttributeError: 'NoneType' object has no attribute 'name'                                                                                                                                                                                 
                                                                                                                                                                                                                                              
thunder/core/utils.py:1062: AttributeError     
@nikitaved nikitaved added bug Something isn't working help wanted Extra attention is needed labels Apr 2, 2024
@kshitij12345 kshitij12345 self-assigned this Apr 29, 2024
@kshitij12345
Copy link
Collaborator

The root cause seems to be in vjp itself.

import thunder
import torch

def foo(x):
    return thunder.torch.topk(x, k=2)

x = torch.ones(3, 3) * 2
co_x = torch.ones(3, 3)
outputs = torch.topk(x, k=2)
cotangents = tuple(torch.ones_like(x) for x in outputs)
vjp_foo = thunder.core.transforms.vjp(foo)
jfoo = thunder.compile(vjp_foo, disable_preprocessing=True)
# jfoo = thunder.jit(vjp_foo)  # Doesn't work currently.

# Fails with 
# File "/home/kkalambarkar/lightning-thunder/thunder/core/utils.py", line 1062, in <lambda>
#     if arg_name not in map(lambda x: x.name, stop_proxies) and arg_name not in seen:
# AttributeError: 'NoneType' object has no attribute 'name'
jfoo(primals=(x,), cotangents=cotangents)

NOTE: Currently the test uses make_callable_legacy (which uses thunder.compile). We should probably wait till thunder.jit(vjp(fn)) is supported and then verify. (Related issue: #198)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants