Skip to content

Commit

Permalink
consider eager model in input_names
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Jan 5, 2024
1 parent dbf4a2f commit 56121f3
Showing 1 changed file with 75 additions and 60 deletions.
135 changes: 75 additions & 60 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ def __init__(
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)
self.input_names = [inputs.debugName().split(".")[0] for inputs in model.graph.inputs()]
self.input_names = (
[inputs.debugName().split(".")[0] for inputs in model.graph.inputs()]
if isinstance(model, torch.jit.RecursiveScriptModule)
else None
)

if is_transformers_version("<=", "4.25.1"):
self.generation_config = None
Expand Down Expand Up @@ -266,65 +270,76 @@ def forward(
**kwargs,
) -> CausalLMOutputWithPast:
# 1. Prepare model inputs
inputs = {}
if "input_ids" in self.input_names:
if input_ids is None:
raise ValueError("input_ids is missing")
inputs["input_ids"] = input_ids

if "attention_mask" in self.input_names:
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
inputs["attention_mask"] = attention_mask

model_type = self.config.model_type.replace("_", "-")
if "past_key_values" in self.input_names and self.use_cache:
if past_key_values is None:
nb_pkv = 2
num_layers = self.normalized_config.num_layers
d_k = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
batch_size = input_ids.shape[0]

if model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads

if model_type == "bloom":
shape_key = (batch_size * num_attention_heads, d_k, 0)
shape_value = (batch_size * num_attention_heads, 0, d_k)
key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device)
value = torch.empty(size=shape_value, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(
tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) for _ in range(num_layers)
)
elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS:
shape = (batch_size, 0, d_k * 2)
pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(pkv for _ in range(num_layers))
else:
shape = (batch_size, num_attention_heads, 0, d_k)
pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(tuple(pkv for _ in range(nb_pkv)) for _ in range(num_layers))

inputs["past_key_values"] = past_key_values

if "position_ids" in self.input_names:
if position_ids is None:
seq_length = input_ids.shape[-1]
if not self.use_cache:
past_key_values_length = 0
else:
past_key_values_length = (
past_key_values[0].shape[-2]
if model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS
else past_key_values[0][1].shape[-2]
)
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=self._device
).unsqueeze(0)

inputs["position_ids"] = position_ids
if self.input_names is not None:
inputs = {}
if "input_ids" in self.input_names:
if input_ids is None:
raise ValueError("input_ids is missing")
inputs["input_ids"] = input_ids

if "attention_mask" in self.input_names:
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
inputs["attention_mask"] = attention_mask

model_type = self.config.model_type.replace("_", "-")
if "past_key_values" in self.input_names and self.use_cache:
if past_key_values is None:
nb_pkv = 2
num_layers = self.normalized_config.num_layers
d_k = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
batch_size = input_ids.shape[0]

if model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads

if model_type == "bloom":
shape_key = (batch_size * num_attention_heads, d_k, 0)
shape_value = (batch_size * num_attention_heads, 0, d_k)
key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device)
value = torch.empty(size=shape_value, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(
tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) for _ in range(num_layers)
)
elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS:
shape = (batch_size, 0, d_k * 2)
pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(pkv for _ in range(num_layers))
else:
shape = (batch_size, num_attention_heads, 0, d_k)
pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(tuple(pkv for _ in range(nb_pkv)) for _ in range(num_layers))

inputs["past_key_values"] = past_key_values

if "position_ids" in self.input_names:
if position_ids is None:
seq_length = input_ids.shape[-1]
if not self.use_cache:
past_key_values_length = 0
else:
past_key_values_length = (
past_key_values[0].shape[-2]
if model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS
else past_key_values[0][1].shape[-2]
)
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=self._device,
).unsqueeze(0)

inputs["position_ids"] = position_ids
else:
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}

# 2. Model forward
outputs = self.model(**inputs)
Expand Down

0 comments on commit 56121f3

Please sign in to comment.