Skip to content

Commit

Permalink
ipex Page attn xpu support bug fix (#1053)
Browse files Browse the repository at this point in the history
* fix ipex xpu support issues

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* use `device_map`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* small adjust

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* to compatible with openvino

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix format

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* refine code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* Update tests/ipex/test_modeling.py

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
  • Loading branch information
kaixuanliu and IlyasMoutawwakil authored Dec 9, 2024
1 parent a6cb0c0 commit 6ea6b5d
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 79 deletions.
22 changes: 13 additions & 9 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,13 +744,13 @@ def __init__(self, module, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device.type
if self.module_device == "cpu":
self.module_device = next(module.parameters()).device
if self.module_device.type == "cpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = LinearAdd(module.down_proj)
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
elif self.module_device == "xpu":
elif self.module_device.type == "xpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = XPULinearAdd(module.down_proj)
Expand All @@ -777,15 +777,15 @@ def __init__(self, module, config) -> None:
_setattr_from_module(self, module)
self.config = config
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
self.module_device = next(module.parameters()).device.type
if self.module_device == "cpu":
self.module_device = next(module.parameters()).device
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
elif self.module_device == "xpu":
elif self.module_device.type == "xpu":
self.linear_gelu = XPULinearGelu(module.dense_h_to_4h)
if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]:
if self.module_device == "cpu":
if self.module_device.type == "cpu":
self.linear_add_add = LinearAddAdd(module.dense_4h_to_h)
elif self.module_device == "xpu":
elif self.module_device.type == "xpu":
self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h)

def forward(
Expand Down Expand Up @@ -870,7 +870,11 @@ class _IPEXIntermediate(nn.Module):
def __init__(self, module, config):
super().__init__()
_setattr_from_module(self, module)
self.linear_gelu = LinearGelu(module.dense)
self.module_device = next(module.parameters()).device
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense)
elif self.module_device.type == "xpu":
self.linear_gelu = XPULinearGelu(module.dense)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_gelu(hidden_states)
Expand Down
9 changes: 7 additions & 2 deletions optimum/intel/pipelines/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,17 @@ def load_ipex_model(
SUPPORTED_TASKS,
hub_kwargs: Optional[Dict[str, Any]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
device_map: Optional[torch.device] = None,
):
hub_kwargs = hub_kwargs or {}
model_kwargs = model_kwargs or {}
ipex_model_class = SUPPORTED_TASKS[targeted_task]["class"][0]

if model is None:
model_id = SUPPORTED_TASKS[targeted_task]["default"]
model = ipex_model_class.from_pretrained(model_id, export=True, **hub_kwargs, **model_kwargs)
model = ipex_model_class.from_pretrained(
model_id, export=True, **hub_kwargs, **model_kwargs, device_map=device_map
)
elif isinstance(model, str):
model_id = model
try:
Expand All @@ -262,7 +265,9 @@ def load_ipex_model(
except RuntimeError:
logger.warning("We will use IPEXModel with export=True to export the model")
export = True
model = ipex_model_class.from_pretrained(model, export=export, **hub_kwargs, **model_kwargs)
model = ipex_model_class.from_pretrained(
model, export=export, **hub_kwargs, **model_kwargs, device_map=device_map
)
elif isinstance(model, IPEXModel):
model_id = getattr(model.config, "name_or_path", None)
else:
Expand Down
Loading

0 comments on commit 6ea6b5d

Please sign in to comment.