Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant folding doesn't work for ThunderFX and NeMo #1478

Open
mpatel31415 opened this issue Nov 26, 2024 · 1 comment
Open

Constant folding doesn't work for ThunderFX and NeMo #1478

mpatel31415 opened this issue Nov 26, 2024 · 1 comment
Assignees
Labels
mixology Issues that the mixology team has surfaced nemo Issues needed to support NVIDIA NeMo models. thunderfx for things that could be applicable to the dynamo+thunder frontend transforms

Comments

@mpatel31415
Copy link
Contributor

mpatel31415 commented Nov 26, 2024

🐛 Bug

When running Phi-3.5-mini-instruct,Mistral-Nemo-Base-2407 and Qwen2.5-7B-Instruct with NeMo + ThunderFX and constant folding enabled we get error:

File "<eval_with_key>.1546", line 7, in forward
0: thunder_2 = self.thunder_2(inductor_1, l__self___model_layers_0_input_layernorm_weight); l__self___model_layers_0_input_layernorm_weight = None
0: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
0: return self._call_impl(*args, **kwargs)
0: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1561, in _call_impl
0: return forward_call(*args, **kwargs)
0: File "/usr/local/lib/python3.10/dist-packages/thunder/core/module.py", line 80, in forward
0: res = self.forward_fn(*args, **kwargs)
0: File "/usr/local/lib/python3.10/dist-packages/thunder/init.py", line 774, in wrapped
0: return fn(*args, **kwargs)
0: File "/usr/local/lib/python3.10/dist-packages/thunder/init.py", line 824, in fn

0: cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
0: File "/usr/local/lib/python3.10/dist-packages/thunder/init.py", line 756, in wrapped
0: cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
0: File "/usr/local/lib/python3.10/dist-packages/thunder/core/langctxs.py", line 136, in _fn
0: result = fn(*args, **kwargs)
0: File "/usr/local/lib/python3.10/dist-packages/thunder/init.py", line 236, in cache_info_wrapper
0: res = fn(*args, **kwargs)
0: File "/usr/local/lib/python3.10/dist-packages/thunder/init.py", line 602, in get_computation_and_inputs
0: new_prologue_trc, new_computation_trc, new_epilogue_trc = transform.transform_traces_pre_prologue(
0: File "/usr/local/lib/python3.10/dist-packages/thunder/transforms/constant_folding.py", line 105, in transform_traces_pre_prologue
0: new_concrete_output = compute_with_constant_tensors(bsym, const_values)
0: File "/usr/local/lib/python3.10/dist-packages/thunder/transforms/constant_folding.py", line 63, in compute_with_constant_tensors
0: return torch_fn(*new_args, **new_kwargs)
0: TypeError: arange() received an invalid combination of arguments - got (int, int, int, dtype=NoneType, device=Device), but expected one of:
0: * (Number end, *, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
0: * (Number start, Number end, *, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
0: * (Number start, Number end, Number step = 1, *, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)

To Reproduce

The error is present on 1xH100.

Dockerfile used (I build it yesterday and I'm not sure yet how nemo:dev images are versioned, so I can't provide its detailed version):

FROM nvcr.io/nvidia/nemo:dev
ARG NVFUSER_REPO=git+https://github.com/NVIDIA/Fuser.git
ARG THUNDER_REPO=git+https://github.com/Lightning-AI/lightning-thunder.git

# Add cloned NeMo latest code
RUN git clone --recursive https://github.com/NVIDIA/NeMo.git /NeMo_cloned
RUN (cd /NeMo_cloned && python -m pip install .)


# Install requirements needed for NeMo, Thunder and NVFUser.
# We must install them in such compilated way because otherwise Thunder is not 
# updated and we are not able to use the latest version. 
RUN python -m pip install -r /NeMo_cloned/requirements/requirements_lightning.txt && \
    python -m pip install --upgrade ${NVFUSER_REPO}  && \
    python -m pip install --upgrade ${THUNDER_REPO} && \
    python -m pip install --upgrade --no-deps --force-reinstall ${NVFUSER_REPO} && \
    python -m pip install --upgrade --no-deps --force-reinstall ${THUNDER_REPO}
 
# Install Mixology requirements (this can be skipped, so I'm commenting it out)
# COPY requirements/mixology.txt mixology_requirements.txt
# RUN pip install --upgrade -r mixology_requirements.txt

Inside docker container please run:

model=microsoft/Phi-3.5-mini-instruct
# Download the model (you might need to set HF_TOKEN and agree on the website to terms of use of this model)
huggingface-cli download $model --local-dir checkpoints/$model --cache-dir checkpoints/$model 
# Run benchmark
python bench_targets/llm_peft/_nemo.py --model checkpoints/$model --mbs 1 --seq-length 2048 --jit-backend thunder

Script bench_targets/llm_peft/_nemo.py can be obtained from internal Gitlab from akoumparouli/nemo_bench. You can contact me or @tfogal if you have any questions. In order to use constant folding it's code around line 90 must be modified:

from thunder.transforms.constant_folding import ConstantFolding
xforms: list = [NvtxProfileTransform(), ConstantFolding()]
nvtx.mark("thunder compilation", domain="model")
be = thunder.dynamo.ThunderCompiler(transforms=xforms)
pl_module.model.compile(backend=be)

Expected behavior

No error.

Environment

cc @tfogal

@IvanYashchuk IvanYashchuk added nemo Issues needed to support NVIDIA NeMo models. mixology Issues that the mixology team has surfaced transforms labels Nov 26, 2024
@IvanYashchuk
Copy link
Collaborator

@kshitij12345, would you be available to take a look at what's going on here?

@tfogal tfogal added the thunderfx for things that could be applicable to the dynamo+thunder frontend label Nov 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mixology Issues that the mixology team has surfaced nemo Issues needed to support NVIDIA NeMo models. thunderfx for things that could be applicable to the dynamo+thunder frontend transforms
Projects
None yet
Development

No branches or pull requests

4 participants