-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
issues with IPEXModel.from_pretrained with sentence-transformer models (all-MiniLM-L6-v2, intfloat/e5-mistral-7b-instruct) #810
Comments
Hi @rbrugaro . The 1st problem is from The 2nd problem is from jit trace, you can see that the jit inputs contains import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from optimum.intel import IPEXModel
from transformers import MistralForCausalLM
def last_token_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery: {query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'how much protein should a female eat'),
get_detailed_instruct(task, 'summit define')
]
# No need to add instruction for retrieval documents
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
]
input_texts = queries + documents
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-mistral-7b-instruct')
#model = AutoModel.from_pretrained('intfloat/e5-mistral-7b-instruct') #works fine
model = IPEXModel.from_pretrained('intfloat/e5-mistral-7b-instruct')
max_length = 4096
# Tokenize the input texts
batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt')
position_ids = MistralForCausalLM.prepare_inputs_for_generation(MistralForCausalLM, **batch_dict)["position_ids"]
outputs = model(**batch_dict, position_ids=position_ids)
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:2] @ embeddings[2:].T) * 100
print(scores.tolist()) |
Thanks @jiqing-feng! I verified both issues are fixed |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I tried 2 different sentence transformers models and during the conversion to torchscript model they fail.
I've applied the patch for issue #797 already.
error:
free(): corrupted unsorted chunks
LIBXSMM_VERSION: unconfigured (2147483647)
SPR/SP TRY JIT STA COL
0..13 4 4 0 0
14..23 0 0 0 0
24..64 24 24 0 0
Registry and code: 13 MB + 352 KB (gemm=28 gemv=4 meltw=12)
Command: python /home/rbrugaro/optimum-intel/notebooks/ipex/test_ST.py
Uptime: 0.130897 s
Aborted (core dumped)
second model test with usage code from: https://huggingface.co/intfloat/e5-mistral-7b-instruct
error:
$ python /home/rbrugaro/optimum-intel/notebooks/ipex/test_ST2.py
/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning:
resume_download
is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, useforce_download=True
.warnings.warn(
Detect torchscript is false. Convert to torchscript model!
Framework not specified. Using pt to export the model.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00, 1.10s/it]
Passing the argument
library_name
toget_supported_tasks_for_model_type
is required, but got library_name=None. Defaulting totransformers
. An error will be raised in a future version of Optimum iflibrary_name
is not provided./home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/transformers/modeling_utils.py:4481: FutureWarning:
_is_quantized_training_enabled
is going to be deprecated in transformers 4.39.0. Please usemodel.hf_quantizer.is_trainable
insteadwarnings.warn(
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py:276: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
elif sliding_window is None or key_value_length < sliding_window:
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py:114: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py:162: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if past_key_values_length > 0:
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:119: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if seq_len > self.max_seq_len_cached:
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:662: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
Passing the argument
library_name
toget_supported_tasks_for_model_type
is required, but got library_name=None. Defaulting totransformers
. An error will be raised in a future version of Optimum iflibrary_name
is not provided./home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/amp/autocast_mode.py:267: UserWarning: In CPU autocast, but the target dtype is not supported. Disabling autocast.
CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.
warnings.warn(error_message)
Traceback (most recent call last):
File "/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/optimum/intel/ipex/modeling_base.py", line 329, in _call_model
out = self.model(*args, **kwargs)
File "/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py(965): forward
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/nn/modules/module.py(1522): _slow_forward
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/nn/modules/module.py(1541): _call_impl
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/nn/modules/module.py(1532): _wrapped_call_impl
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/jit/_trace.py(1076): trace_module
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/jit/_trace.py(820): trace
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/intel_extension_for_pytorch/jit/_trace.py(69): wrapper
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/optimum/intel/ipex/modeling_base.py(118): ipex_jit_trace
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/optimum/intel/ipex/modeling_base.py(162): init
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/optimum/intel/ipex/modeling_base.py(248): _from_pretrained
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/optimum/modeling_base.py(427): from_pretrained
/home/rbrugaro/optimum-intel/notebooks/ipex/test_ST2.py(131):
RuntimeError: Expected a proper Tensor but got None (or an undefined Tensor in C++) for argument #0 'self'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/rbrugaro/optimum-intel/notebooks/ipex/test_ST2.py", line 137, in
outputs = model(**batch_dict)
File "/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/optimum/modeling_base.py", line 95, in call
return self.forward(*args, **kwargs)
File "/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/optimum/intel/ipex/modeling_base.py", line 291, in forward
outputs = self._call_model(**inputs)
File "/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/optimum/intel/ipex/modeling_base.py", line 331, in _call_model
out = self.model(*args, **kwargs)
File "/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py(965): forward
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/nn/modules/module.py(1522): _slow_forward
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/nn/modules/module.py(1541): _call_impl
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/nn/modules/module.py(1532): _wrapped_call_impl
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/jit/_trace.py(1076): trace_module
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/torch/jit/_trace.py(820): trace
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/intel_extension_for_pytorch/jit/_trace.py(69): wrapper
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/optimum/intel/ipex/modeling_base.py(118): ipex_jit_trace
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/optimum/intel/ipex/modeling_base.py(162): init
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/optimum/intel/ipex/modeling_base.py(248): _from_pretrained
/home/rbrugaro/anaconda3/envs/opti441LANG/lib/python3.10/site-packages/optimum/modeling_base.py(427): from_pretrained
/home/rbrugaro/optimum-intel/notebooks/ipex/test_ST2.py(131):
RuntimeError: Expected a proper Tensor but got None (or an undefined Tensor in C++) for argument #0 'self'
cc: @jiqing-feng
The text was updated successfully, but these errors were encountered: