Skip to content

torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {} #664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wprazuch opened this issue Jun 27, 2024 · 7 comments
Labels
mixology Issues that the mixology team has surfaced

Comments

@wprazuch
Copy link
Contributor

wprazuch commented Jun 27, 2024

🐛 Bug

There is unsupported error when running models:

  • Nous-Hermes-13b

for Thunder inductor for fsdp zero2/zero3:

To Reproduce

Steps to reproduce the behavior:

mkdir -p output
docker run --pull=always --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864  -v $PWD/output:/output -it INTERNAL_IMAGE:pjnl-20240621

Run in the container:

torchrun --nproc-per-node=8 /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py --model_name Nous-Hermes-13b --compile thunder_inductor_cat_cudnn --distributed_mode fsdp --shard_mode zero2 

Expected behavior

The model should run or we should get OOM error.

Environment

As in the Docker image

Additional context

We reproduced for fsdp (1/2 nodes, 8 gpus), zero2/zero3.
The traceback is below:

rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 639, in <module>
[rank0]:     CLI(benchmark_main)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/jsonargparse/_cli.py", line 96, in CLI
[rank0]:     return _run_component(components, cfg_init)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/jsonargparse/_cli.py", line 196, in _run_component
[rank0]:     return component(**cfg)
[rank0]:   File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 584, in benchmark_main
[rank0]:     benchmark.train()
[rank0]:   File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 491, in train
[rank0]:     loss.backward()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 522, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 288, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 306, in apply
[rank0]:     return user_fn(self, *args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 599, in wrapper
[rank0]:     outputs = fn(ctx, *args)
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 96, in backward
[rank0]:     grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "thunder.backward_fn_333", line 462, in backward_fn
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/executors/torch_compile.py", line 97, in compiled_func_wrapper
[rank0]:     return compiled_func(*args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
[rank0]:     return self._torchdynamo_orig_callable(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
[rank0]:     return _compile(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function
[rank0]:     return StrobelightCompileTimeProfiler.profile_compile_time(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:     return func(*args, **kwds)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[rank0]:     r = func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
[rank0]:     out_code = transform_code_object(code, transform)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
[rank0]:     transformations(instructions, code_options)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2463, in run
[rank0]:     super().run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
[rank0]:     while self.step():
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1470, in CALL_FUNCTION
[rank0]:     self.call_function(fn, args, {})
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 298, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 95, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2678, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2794, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
[rank0]:     while self.step():
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1470, in CALL_FUNCTION
[rank0]:     self.call_function(fn, args, {})
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 298, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 95, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2678, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2794, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
[rank0]:     while self.step():
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1916, in CONTAINS_OP
[rank0]:     self.push(right.call_method(self, "__contains__", [left], {}))
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/user_defined.py", line 644, in call_method
[rank0]:     return super().call_method(tx, name, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/base.py", line 320, in call_method
[rank0]:     unimplemented(f"call_method {self} {name} {args} {kwargs}")
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 221, in unimplemented
[rank0]:     raise Unsupported(msg)
[rank0]: torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {}
@tfogal tfogal added mixology Issues that the mixology team has surfaced triage review labels Jun 27, 2024
@tfogal
Copy link
Collaborator

tfogal commented Jun 27, 2024

[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 599, in wrapper
[rank0]:     outputs = fn(ctx, *args)
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 96, in backward
[rank0]:     grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)

Looks like we're eventually asking dynamo to do something that it cannot due to our autograd.

triage: is there something we can do to not tickle dynamo or do we need to just report this upstream?

@IvanYashchuk
Copy link
Collaborator

Asking dynamo to do something that it cannot due to our generated backward trace:

File "thunder.backward_fn_333", line 462, in backward_fn

and use of fullgraph=True (added in e0ab648)

compiled_func = torch.compile(trace_callable, fullgraph=True)

setting fullgraph=False might fix this problem.

@tfogal
Copy link
Collaborator

tfogal commented Jun 28, 2024

@wprazuch can I ask you to do a one-off that tests this with fullgraph=False, as Ivan points out above?

(I don't know that this the long-term solution but it will allow us to have a more reasoned discussion on the long-term solution.)

@mpatel31415
Copy link
Contributor

mpatel31415 commented Jul 1, 2024

We can confirm that after the modification in torch_compile.py: compiled_func = torch.compile(trace_callable, fullgraph=False) there is no error :)

@tfogal
Copy link
Collaborator

tfogal commented Jul 1, 2024

Thanks Martyna, Wojciech!

@tfogal
Copy link
Collaborator

tfogal commented Jul 1, 2024

triage review:

  • by default we use fullgraph=True for torch.compile'd regions
  • using fullgraph=True doesn't really affect performance, just gives an indicator when this kind of thing happens
  • we can revert back to fullgraph=False as the default
  • if torch.compile is going to split things, we might as well just run it in eager
  • we should investigate such cases on a case-by-case basis, as we expect that points to other bugs
  • let's expose an option to turn this to false, and then set the option to false explicitly for this benchmark (or all the benchmarks).

@wprazuch
Copy link
Contributor Author

We don't see it anymore in our logs, so I think it is resolved 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mixology Issues that the mixology team has surfaced
Projects
None yet
Development

No branches or pull requests

4 participants