Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle HF dynamic cache #1396

Open
t-vi opened this issue Nov 4, 2024 · 0 comments
Open

Handle HF dynamic cache #1396

t-vi opened this issue Nov 4, 2024 · 0 comments
Assignees
Labels
huggingface For supporting HF models interpreter program-coverage Requests for model and program coverage

Comments

@t-vi
Copy link
Collaborator

t-vi commented Nov 4, 2024

Dynamic cache currently is updated during tracing, this is not good.

    from transformers.models.llama import LlamaForCausalLM, LlamaConfig
    from transformers import DynamicCache

    model_id = "meta-llama/Llama-3.2-1B"

    config_args = LlamaConfig.get_config_dict(model_id)[0]
    config_args['num_hidden_layers'] = 1
    with torch.device('cuda'):
        model = LlamaForCausalLM(LlamaConfig(**config_args)).to(torch.bfloat16).requires_grad_(False).eval()

    jm = thunder.jit(model)

    j_past_key_values = DynamicCache()
    args1 = dict(
        cache_position=torch.tensor([0, 1, 2, 3, 4, 5], device='cuda:0'),
        input_ids=torch.tensor([[128000,    791,   1401,    311,   2324,    374]], device='cuda:0'),
        inputs_embeds=None,
        attention_mask=torch.tensor([[1, 1, 1, 1, 1, 1]], device='cuda:0'),
        use_cache=True, return_dict= True
    )
    res = jm(past_key_values=j_past_key_values, **args1)

gives

>       res = jm(past_key_values=j_past_key_values, **args1)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/module.py:80: in forward
    res = self._forward_fn(*args, **kwargs)
thunder/__init__.py:736: in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
thunder/core/langctxs.py:136: in _fn
    result = fn(*args, **kwargs)
thunder/__init__.py:232: in cache_info_wrapper
    res = fn(*args, **kwargs)
thunder/__init__.py:633: in get_computation_and_inputs
    inps, pro_to_epi = pro(*args, **kwargs)
thunder.prologue_0:190: in prologue
    check_number_type_and_value(i192, 0)
thunder/executors/pythonex.py:103: in _check_number_type_and_value_impl
    utils.check(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

cond = False, s = <function _check_number_type_and_value_impl.<locals>.<lambda> at 0x7fc162c6bf60>, exception_type = <class 'RuntimeError'>

    def check(cond: bool, s: Callable[[], str], exception_type: type[Exception] = RuntimeError) -> None:
        """Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
    
        s is a callable producing a string to avoid string construction if the error check is passed.
        """
        if not cond:
>           raise exception_type(s())
E           RuntimeError: Expected 6 to be equal to and have the type of 0

thunder/core/baseutils.py:107: RuntimeError

This is useful for llama inference.

@t-vi t-vi added interpreter program-coverage Requests for model and program coverage labels Nov 4, 2024
@IvanYashchuk IvanYashchuk added the huggingface For supporting HF models label Nov 4, 2024
@t-vi t-vi self-assigned this Nov 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
huggingface For supporting HF models interpreter program-coverage Requests for model and program coverage
Projects
None yet
Development

No branches or pull requests

2 participants