From 8ee487dc2ade5bd0023d1bbe0a0103d6af8821e0 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Wed, 31 Jan 2024 13:42:14 +0200 Subject: [PATCH] Automatic `torch.autocast` for IPEXModel (#542) * Handle autocast in IPEXModel.forward * Handle missing torch_dtype in config --- optimum/intel/ipex/modeling_base.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index b79f720348..901e90a421 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -69,6 +69,7 @@ def __init__( OptimizedModel.__init__(self, model=model, config=config) # To do: add XPU support self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 self.model.to(self._device) self.model_save_dir = model_save_dir @@ -188,7 +189,7 @@ def forward( if "token_type_ids" in self.input_names: inputs["token_type_ids"] = token_type_ids - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) def eval(self): @@ -199,6 +200,10 @@ def eval(self): def device(self) -> torch.device: return self._device + @property + def dtype(self) -> torch.dtype: + return self._dtype + def to(self, device: Union[torch.device, str]): self._device = device if isinstance(device, torch.device) else torch.device(device) self.model.to(self._device) @@ -207,6 +212,14 @@ def to(self, device: Union[torch.device, str]): def can_generate(self): return isinstance(self, GenerationMixin) + def _call_model(self, *args, **kwargs): + try: + with torch.autocast(self.device.type, self.dtype): + out = self.model(*args, **kwargs) + except RuntimeError: + out = self.model(*args, **kwargs) + return out + class IPEXModelForSequenceClassification(IPEXModel): auto_model_class = AutoModelForSequenceClassification @@ -236,7 +249,7 @@ def forward( "pixel_values": pixel_values, } - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) @@ -257,7 +270,7 @@ def forward( if "attention_mask" in self.input_names: inputs["attention_mask"] = attention_mask - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) @@ -266,7 +279,7 @@ class IPEXModelForQuestionAnswering(IPEXModel): export_feature = "question-answering" def forward(self, *args, **kwargs): - outputs = self.model(*args, **kwargs) + outputs = self._call_model(*args, **kwargs) start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] return ModelOutput(start_logits=start_logits, end_logits=end_logits) @@ -287,7 +300,7 @@ def __init__( super().__init__(model, config, model_save_dir=model_save_dir) self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) - self.model_dtype = kwargs.get("model_dtype", None) + self.model_dtype = kwargs.get("model_dtype", self.dtype) self.use_cache = "past_key_values" in self.input_names if use_cache ^ self.use_cache: @@ -377,7 +390,7 @@ def forward( inputs["past_key_values"] = past_key_values # 2. Model forward - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) # 3. Process model outputs if isinstance(outputs, (list, tuple)):