Skip to content

Commit

Permalink
set actual device
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng committed Dec 18, 2024
1 parent 9a7e931 commit b0cec9c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def _patch_vit_model(model):


def _patch_model(model):
setattr(model.config, "device", model.device)
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching")
if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version(
Expand Down
8 changes: 4 additions & 4 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def __init__(self, module, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device
self.module_device = config.device
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device
Expand Down Expand Up @@ -779,7 +779,7 @@ def __init__(self, module, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device
self.module_device = config.device
if getattr(config, "quantization_config", None) is None:
if self.module_device.type == "cpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
Expand Down Expand Up @@ -812,7 +812,7 @@ def __init__(self, module, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device
self.module_device = config.device
if getattr(config, "quantization_config", None) is None:
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if self.module_device.type == "cpu":
Expand Down Expand Up @@ -911,7 +911,7 @@ class _IPEXIntermediate(nn.Module):
def __init__(self, module, config):
super().__init__()
_setattr_from_module(self, module)
self.module_device = next(module.parameters()).device
self.module_device = config.device
if getattr(config, "quantization_config", None) is None:
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense)
Expand Down

0 comments on commit b0cec9c

Please sign in to comment.