Skip to content

Commit

Permalink
fix model_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Feb 26, 2024
1 parent dd63ee7 commit 16706d3
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _from_transformers(
save_dir_path = Path(save_dir.name)
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME)
config.torchscript = True
config.torch_dtype = torch_dtype

return cls._from_pretrained(
model_id=save_dir_path,
Expand All @@ -134,7 +135,6 @@ def _from_transformers(
cache_dir=cache_dir,
local_files_only=local_files_only,
use_cache=use_cache,
model_dtype=torch_dtype,
)

@classmethod
Expand Down Expand Up @@ -208,6 +208,11 @@ def device(self) -> torch.device:
def dtype(self) -> torch.dtype:
return self._dtype

@property
def model_dtype(self):
logger.warning("model_dtype will be removed after v1.18.0")
return self._dtype

def to(self, device: Union[torch.device, str]):
self._device = device if isinstance(device, torch.device) else torch.device(device)
self.model.to(self._device)
Expand Down Expand Up @@ -322,8 +327,6 @@ def __init__(

model_type = config.model_type.replace("_", "-")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config)
self.model_dtype = kwargs.get("model_dtype", self.dtype)
self._dtype = self.model_dtype
self.use_cache = "past_key_values" in self.input_names

if use_cache ^ self.use_cache:
Expand Down

0 comments on commit 16706d3

Please sign in to comment.