Skip to content

Commit

Permalink
Use parent_model._device in OVEncoder and OVDecoder
Browse files Browse the repository at this point in the history
  • Loading branch information
helena-intel committed Jan 22, 2024
1 parent bc87f2f commit 218b9a9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
14 changes: 7 additions & 7 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,11 @@ def __init__(
self.device = torch.device("cpu")
self.decoder_with_past = None
enable_compilation = kwargs.get("compile", True)
self.encoder = OVEncoder(self.encoder_model, self._device, parent_model=self)
self.decoder = OVDecoder(self.decoder_model, self._device, parent_model=self)
self.encoder = OVEncoder(self.encoder_model, parent_model=self)
self.decoder = OVDecoder(self.decoder_model, parent_model=self)

if self.use_cache:
self.decoder_with_past = OVDecoder(self.decoder_with_past_model, self._device, parent_model=self)
self.decoder_with_past = OVDecoder(self.decoder_with_past_model, parent_model=self)
if enable_compilation:
self.compile()

Expand Down Expand Up @@ -412,10 +412,10 @@ class OVEncoder:
The OpenVINO inference request associated to the encoder.
"""

def __init__(self, model: openvino.runtime.Model, device: str, parent_model: OVModelForSeq2SeqLM):
def __init__(self, model: openvino.runtime.Model, parent_model: OVModelForSeq2SeqLM):
self.model = model
self._device = device
self.parent_model = parent_model
self._device = self.parent_model._device
self.device = torch.device("cpu")
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.main_input_name = self.parent_model.main_input_name or "input_ids"
Expand Down Expand Up @@ -477,10 +477,10 @@ class OVDecoder:
The device type used by this process.
"""

def __init__(self, model: openvino.runtime.Model, device: str, parent_model: OVModelForSeq2SeqLM):
def __init__(self, model: openvino.runtime.Model, parent_model: OVModelForSeq2SeqLM):
self.model = model
self._device = device
self.parent_model = parent_model
self._device = self.parent_model._device
self.device = torch.device("cpu")
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
Expand Down
2 changes: 1 addition & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,8 @@ def test_compare_with_and_without_past_key_values(self):
gc.collect()

def test_auto_device_loading(self):
OV_MODEL_ID = "echarlaix/distilbert-base-uncased-finetuned-sst-2-english-openvino"
for device in ("AUTO", "AUTO:CPU"):
OV_MODEL_ID = "echarlaix/distilbert-base-uncased-finetuned-sst-2-english-openvino"
model = OVModelForSequenceClassification.from_pretrained(OV_MODEL_ID, device=device)
model.half()
self.assertEqual(model._device, device)
Expand Down

0 comments on commit 218b9a9

Please sign in to comment.