Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

symbolic values: error in infer_tensor_properties, through reshape or numel #1257

Closed
t-vi opened this issue Oct 3, 2024 · 4 comments
Closed
Assignees
Labels
dynamic constraints program-coverage Requests for model and program coverage

Comments

@t-vi
Copy link
Collaborator

t-vi commented Oct 3, 2024

Running (with lates litgpt to get Llama 3.2, but you can also use llama2-like)

with torch.device('cuda'):
    m = litgpt.GPT.from_name('Llama-3.2-1B').bfloat16().requires_grad_(False)
    m.set_kv_cache(1)
inp1 = torch.ones(1, 16, device="cuda", dtype=torch.int32)
inp_pos1 = torch.arange(16, device="cuda")
jm = thunder.jit(m, cache=thunder.core.options.CACHE_OPTIONS.SYMBOLIC_VALUES)
jm(inp1 inp_pos1)

gives

File ~/data/firma/grid/thunder/lightning-thunder/thunder/torch/__init__.py:4578, in embedding(a, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   4575     return clang.take(weight, a, 0)
   4577 output_shape = list(a.shape) + list(weight.shape[1:])
-> 4578 flatten_indices = reshape(a, [a.numel()])
   4579 flatten_output = clang.take(weight, flatten_indices, 0)
   4580 return reshape(flatten_output, output_shape)

TypeError: _infer_tensor_properties.<locals>.<lambda>() missing 1 required positional argument: 'tp'

@jjsjann123

@t-vi t-vi added dynamic constraints program-coverage Requests for model and program coverage labels Oct 3, 2024
@jjsjann123 jjsjann123 self-assigned this Oct 3, 2024
@jjsjann123
Copy link
Collaborator

mark here that the old assert is gone with #1260

I'm hitting a new assert here

raise NotImplementedError(f"constant of type {type(provenance.value)} {provenance.value}")

I think it's coming from using numel in code logic.

One way to resolve is to add static constraints, which we should do. Meanwhile, I'm unsure why we needed that numel call in the first place, I'll try to remove it in #1451

@jjsjann123
Copy link
Collaborator

With #1260 #1450 #1451

This script seems to be functional, but we don't have a cache hit during second iteration, which is a bit odd. Looks like there's some input with dtype change during second iteration. But I did confirm that the same behavior happens even for static input shapes. So that's just something the model does I guess?!

But cache hit did happen afterwards. 🥳

@jjsjann123
Copy link
Collaborator

import torch

configs = [
    dict(
        name="llama2-like",
        vocab_size=320,
        padding_multiple=64,
        n_layer=2,
        n_head=4,
        n_embd=64,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        norm_class_name="RMSNorm",
        mlp_class_name="LLaMAMLP",
        intermediate_size=1376,
    ),
]

name_to_config = {config["name"]: config for config in configs}

import litgpt

litgpt.config.name_to_config.update(name_to_config)
name_to_config.update(litgpt.config.name_to_config)

# manually expose for backwards compatibility
Config = litgpt.Config
GPT = litgpt.GPT
RMSNorm = litgpt.model.RMSNorm
CausalSelfAttention = litgpt.model.CausalSelfAttention
LLaMAMLP = litgpt.model.LLaMAMLP
build_rope_cache = litgpt.model.build_rope_cache
apply_rope = litgpt.model.apply_rope
Block = litgpt.model.Block

import thunder

with torch.device('cuda'):
    m = GPT.from_name('llama2-like').bfloat16().requires_grad_(False)
    m.set_kv_cache(1)
inp1 = torch.ones(1, 16, device="cuda", dtype=torch.int32)
inp_pos1 = torch.arange(16, device="cuda")
jm = thunder.jit(m, cache=thunder.core.options.CACHE_OPTIONS.SYMBOLIC_VALUES)
#jm = thunder.jit(m)
out = jm(inp1, inp_pos1)
ref_out = m(inp1, inp_pos1)

#print("\n\tprologue:\n", thunder.last_prologue_traces(jm)[-1])
#print("\n\ttrace:\n", thunder.last_traces(jm)[-1])
print(f"\n\t{thunder.cache_hits(jm)=}")

assert out.allclose(ref_out)
inp1 = torch.ones(1, 17, device="cuda", dtype=torch.int32)
inp_pos1 = torch.arange(17, device="cuda")
out = jm(inp1, inp_pos1)
ref_out = m(inp1, inp_pos1)
assert out.allclose(ref_out)

print(f"\n\t{thunder.cache_hits(jm)=}")

inp1 = torch.ones(1, 18, device="cuda", dtype=torch.int32)
inp_pos1 = torch.arange(18, device="cuda")
out = jm(inp1, inp_pos1)
ref_out = m(inp1, inp_pos1)
assert out.allclose(ref_out)
print(f"\n\t{thunder.cache_hits(jm)=}")

gives

/opt/pytorch/lightning-thunder/thunder/core/options.py:78: The 'symbolic values' cache option is highly experimental and for development only.

        thunder.cache_hits(jm)=0
cache miss: _=AssertionError('expected tensor with (1, 4, 4096, 16), cuda:0, torch.float32, requires_grad=False, got (1, 4, 4096, 16), cuda:0, torch.bfloat16, False')

        thunder.cache_hits(jm)=0

        thunder.cache_hits(jm)=1

@jjsjann123
Copy link
Collaborator

I think we are good with closing this issue. Will let @t-vi to do that.

@t-vi t-vi closed this as completed Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamic constraints program-coverage Requests for model and program coverage
Projects
None yet
Development

No branches or pull requests

2 participants