Skip to content

Commit

Permalink
Automatic torch.autocast for IPEXModel (#542)
Browse files Browse the repository at this point in the history
* Handle autocast in IPEXModel.forward

* Handle missing torch_dtype in config
  • Loading branch information
ofirzaf authored Jan 31, 2024
1 parent 398450d commit 8ee487d
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
OptimizedModel.__init__(self, model=model, config=config)
# To do: add XPU support
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
self.model.to(self._device)
self.model_save_dir = model_save_dir

Expand Down Expand Up @@ -188,7 +189,7 @@ def forward(
if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids

outputs = self.model(**inputs)
outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])

def eval(self):
Expand All @@ -199,6 +200,10 @@ def eval(self):
def device(self) -> torch.device:
return self._device

@property
def dtype(self) -> torch.dtype:
return self._dtype

def to(self, device: Union[torch.device, str]):
self._device = device if isinstance(device, torch.device) else torch.device(device)
self.model.to(self._device)
Expand All @@ -207,6 +212,14 @@ def to(self, device: Union[torch.device, str]):
def can_generate(self):
return isinstance(self, GenerationMixin)

def _call_model(self, *args, **kwargs):
try:
with torch.autocast(self.device.type, self.dtype):
out = self.model(*args, **kwargs)
except RuntimeError:
out = self.model(*args, **kwargs)
return out


class IPEXModelForSequenceClassification(IPEXModel):
auto_model_class = AutoModelForSequenceClassification
Expand Down Expand Up @@ -236,7 +249,7 @@ def forward(
"pixel_values": pixel_values,
}

outputs = self.model(**inputs)
outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


Expand All @@ -257,7 +270,7 @@ def forward(
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask

outputs = self.model(**inputs)
outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


Expand All @@ -266,7 +279,7 @@ class IPEXModelForQuestionAnswering(IPEXModel):
export_feature = "question-answering"

def forward(self, *args, **kwargs):
outputs = self.model(*args, **kwargs)
outputs = self._call_model(*args, **kwargs)
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
return ModelOutput(start_logits=start_logits, end_logits=end_logits)
Expand All @@ -287,7 +300,7 @@ def __init__(
super().__init__(model, config, model_save_dir=model_save_dir)

self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)
self.model_dtype = kwargs.get("model_dtype", self.dtype)
self.use_cache = "past_key_values" in self.input_names

if use_cache ^ self.use_cache:
Expand Down Expand Up @@ -377,7 +390,7 @@ def forward(
inputs["past_key_values"] = past_key_values

# 2. Model forward
outputs = self.model(**inputs)
outputs = self._call_model(**inputs)

# 3. Process model outputs
if isinstance(outputs, (list, tuple)):
Expand Down

0 comments on commit 8ee487d

Please sign in to comment.