Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] cannot capture your model as a full graph #1132

Open
sunkun1997 opened this issue Jul 16, 2024 · 6 comments
Open

[BUG] cannot capture your model as a full graph #1132

sunkun1997 opened this issue Jul 16, 2024 · 6 comments

Comments

@sunkun1997
Copy link

torch version: 2.5.0.dev20240616+cu121
python version: python 3.8

I run the llama example with torchrun --nproc-per-node 2 pippy_llama.py. It got an Error

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.26s/it]
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.27s/it]
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
layers_per_rank = 16
layers_per_rank = 16
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/distributed/pipelining/_IR.py", line 1006, in _trace_with_export
[rank0]:     ep = torch.export.export(
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/export/__init__.py", line 174, in export
[rank0]:     return _export(
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/export/_trace.py", line 952, in wrapper
[rank0]:     raise e
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/export/_trace.py", line 935, in wrapper
[rank0]:     ep = fn(*args, **kwargs)
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/export/exported_program.py", line 91, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/export/_trace.py", line 1547, in _export
[rank0]:     exported_program = ExportedProgram(
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/export/exported_program.py", line 248, in __init__
[rank0]:     self.verifier().check(self)
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/_export/verifier.py", line 154, in check
[rank0]:     self._check_graph_module(ep.graph_module)
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/_export/verifier.py", line 220, in _check_graph_module
[rank0]:     _check_val(node)
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/_export/verifier.py", line 62, in _check_val
[rank0]:     raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
[rank0]: torch._export.verifier.SpecViolationError: Node.meta _enter_autocast is missing val field.

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "pippy_llama.py", line 36, in <module>
[rank0]:     pipe = pipeline(llama, mb_args=(mb_inputs["input_ids"],))
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/distributed/pipelining/_IR.py", line 1236, in pipeline
[rank0]:     return Pipe.from_tracing(
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/distributed/pipelining/_IR.py", line 1044, in from_tracing
[rank0]:     exported_program = Pipe._trace_with_export(
[rank0]:   File "/home/ray/anaconda3/lib/python3.8/site-packages/torch/distributed/pipelining/_IR.py", line 1012, in _trace_with_export
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: It seems that we cannot capture your model as a full graph. Typical reasons include graph breaks, data/shape-dependent control flow, or missing meta kernels for custom operators. You can use our manual pipeline interfaces, or try to fix the graph breaks, see https://pytorch.org/docs/stable/export.html
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
@apresunreve
Copy link

Same problem

@ishan-gaur
Copy link

ishan-gaur commented Jul 17, 2024

This can (at least temporarily) be fixed by getting rid of the autocast at transformers/models/llama/modeling_llama.py
And replacing everything from the # Force … comment in the forward pass to instead be:

freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)

(this is basically taking the llama model back to commit 7628b3a0f40212c0f264233fc6da0d9c9cf88853 of the transformers package)

However, after doing this, there seems to still be a problem where the compiled (traced?, split?) model graph seems to not match the original:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/.local/share/code-server/extensions/ms-python.python-2022.4.1-universal/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py", line 3385, in <module>
[rank0]:     main()
[rank0]:   File "/root/.local/share/code-server/extensions/ms-python.python-2022.4.1-universal/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py", line 3378, in main
[rank0]:     globals = debugger.run(setup['file'], None, None, is_module)
[rank0]:   File "/root/.local/share/code-server/extensions/ms-python.python-2022.4.1-universal/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py", line 2446, in run
[rank0]:     return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
[rank0]:   File "/root/.local/share/code-server/extensions/ms-python.python-2022.4.1-universal/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py", line 2453, in _exec
[rank0]:     pydev_imports.execfile(file, globals, locals)  # execute the script
[rank0]:   File "/root/.local/share/code-server/extensions/ms-python.python-2022.4.1-universal/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydev_bundle/_pydev_execfile.py", line 25, in execfile
[rank0]:     exec(compile(contents + "\n", file, 'exec'), glob, loc)
[rank0]:   File ".../pippy_llama.py", line 45, in <module>
[rank0]:     pipe = pipeline(llama, example_args=(mb_inputs["input_ids"],), num_chunks=int(len(full_batch_prompts) / len(mb_inputs)))
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/distributed/pipelining/_IR.py", line 1187, in pipeline
[rank0]:     return Pipe.from_tracing(
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/distributed/pipelining/_IR.py", line 1030, in from_tracing
[rank0]:     pipe = Pipe._from_traced(
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/distributed/pipelining/_IR.py", line 734, in _from_traced
[rank0]:     new_submod = _outline_submodules(submodule.graph)
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/distributed/pipelining/_unflatten.py", line 23, in _outline_submodules
[rank0]:     ).run_outer()
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 862, in run_outer
[rank0]:     self.run_from(node_idx)
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 942, in run_from
[rank0]:     ).run_from(node_idx)
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 942, in run_from
[rank0]:     ).run_from(node_idx)
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 942, in run_from
[rank0]:     ).run_from(node_idx)
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 919, in run_from
[rank0]:     self.finalize_outputs()
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 841, in finalize_outputs
[rank0]:     _verify_graph_equivalence(self.cached_graph_module, self.module)
[rank0]:   File "/root/miniforge3/envs/sequin/lib/python3.10/site-packages/torch/export/unflatten.py", line 567, in _verify_graph_equivalence
[rank0]:     assert graph_dump(x.graph) == graph_dump(y.graph)

@sunkun1997
Copy link
Author

sunkun1997 commented Jul 17, 2024

Was able to solve the above in my case by turning off the kv-cache in the model config. Perhaps this needs to be manually managed by the user outside of the traced module.

Can you tell me exactly which line you're replacing? And the way to turn off the kv-cache is revicing "use_cache": true into false in config.json?

@ishan-gaur
Copy link

Sorry the kv-cache thing was wrong. I was trying out gpt2 earlier to make sure I can at least run something.

Also had the wrong commit number earlier. Was talking about reverting this change in the transformer's library:
huggingface/transformers@d45f47a

@ishan-gaur
Copy link

ishan-gaur commented Jul 20, 2024

Was able to resolve this by reverting transformers to the last December 2023 commit that passes all tests (3b7675b2b844b02d4821b827871a21ad16dd446c) and the PiPPy v0.2.0 tag. If you need batch chat template decoding then you need to go find the updated utils tokenization base file and the init.py file for that folder accordingly as well.

@Noblezhong
Copy link

I encounter the same problem, so is there any solution to fix it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants