Skip to content

Commit

Permalink
Fix deprecated shared memory arg (huggingface#499)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanbuphy authored and PenghuiCheng committed Jan 16, 2024
1 parent 009c6ce commit 6178c69
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def forward(
inputs["position_ids"] = position_ids

# Run inference
self.request.start_async(inputs, shared_memory=True)
self.request.start_async(inputs, share_inputs=True)
self.request.wait()
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)

Expand Down
8 changes: 4 additions & 4 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def __call__(self, input_ids: np.ndarray):
inputs = {
"input_ids": input_ids,
}
outputs = self.request(inputs, shared_memory=True)
outputs = self.request(inputs, share_inputs=True)
return list(outputs.values())


Expand Down Expand Up @@ -604,7 +604,7 @@ def __call__(
if timestep_cond is not None:
inputs["timestep_cond"] = timestep_cond

outputs = self.request(inputs, shared_memory=True)
outputs = self.request(inputs, share_inputs=True)
return list(outputs.values())


Expand All @@ -620,7 +620,7 @@ def __call__(self, latent_sample: np.ndarray):
inputs = {
"latent_sample": latent_sample,
}
outputs = self.request(inputs, shared_memory=True)
outputs = self.request(inputs, share_inputs=True)
return list(outputs.values())

def _compile(self):
Expand All @@ -641,7 +641,7 @@ def __call__(self, sample: np.ndarray):
inputs = {
"sample": sample,
}
outputs = self.request(inputs, shared_memory=True)
outputs = self.request(inputs, share_inputs=True)
return list(outputs.values())

def _compile(self):
Expand Down
8 changes: 4 additions & 4 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,18 +304,18 @@ def __call__(self, *args, **kwargs):
data_cache.append(*args)
return self.request(*args, *kwargs)

def infer(self, inputs: Any = None, shared_memory: bool = False):
def infer(self, inputs: Any = None, share_inputs: bool = False):
data_cache.append(inputs)
return self.request.infer(inputs, shared_memory)
return self.request.infer(inputs, share_inputs)

def start_async(
self,
inputs: Any = None,
userdata: Any = None,
shared_memory: bool = False,
share_inputs: bool = False,
):
data_cache.append(inputs)
self.request.infer(inputs, shared_memory)
self.request.infer(inputs, share_inputs)

def wait(self):
pass
Expand Down

0 comments on commit 6178c69

Please sign in to comment.