Skip to content

Commit

Permalink
force precision using --weight-format
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Aug 13, 2024
1 parent c0ef027 commit afc5dc7
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,11 @@ def main_export(
and task.startswith("text-generation")
and getattr(config, "torch_dtype", torch.float32) in [torch.float16, torch.bfloat16]
):
if is_openvino_version(">=", "2024.2") and config.torch_dtype == torch.float16:
if ov_config is not None and ov_config.dtype in {"fp16", "fp32"}:
dtype = torch.float16 if ov_config.dtype == "fp16" else torch.float32
elif is_openvino_version(">=", "2024.2") and config.torch_dtype == torch.float16:
dtype = torch.float16
if is_openvino_version(">=", "2024.3") and config.torch_dtype == torch.bfloat16:
elif is_openvino_version(">=", "2024.3") and config.torch_dtype == torch.bfloat16:
dtype = torch.bfloat16

if dtype is not None:
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def ts_patched_forward(*args, **kwargs):
with patcher:
if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable

__make_16bit_traceable(model)
check_dummy_inputs_are_allowed(model, dummy_inputs)
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@
PersimmonModelPatcher,
Phi3ModelPatcher,
QwenModelPatcher,
UpdateCausalMaskModelPatcher,
RotaryEmbPatcher,
UpdateCausalMaskModelPatcher,
XverseModelPatcher,
)

Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def patch_update_causal_mask(model, transformers_version):
if is_transformers_version(">=", transformers_version):
model.model._update_causal_mask = types.MethodType(_llama_gemma_update_causal_mask, model.model)


# initialization of sin/cos cached in bf16/fp16 leads to accuracy loss
# reinitialize them to save in float32 before export
def _reinitialize_cos_sin_cached_fp32(rotary_emb):
Expand Down
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _from_transformers(
if load_in_8bit is None and not quantization_config:
ov_export_config = None
else:
ov_export_config = OVConfig(dtype="fp32")
ov_export_config = OVConfig(dtype="auto")

stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)

Expand Down

0 comments on commit afc5dc7

Please sign in to comment.