diff --git a/tests/test_adapt.py b/tests/test_adapt.py index fb9fa7f3..b8c9e914 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -4,6 +4,7 @@ import numpy as np import onnx import onnx.parser +import onnxruntime as ort import pytest import spox.opset.ai.onnx.v18 as op18 @@ -71,32 +72,34 @@ def inline_old_identity_twice_graph(old_identity): return results(final=z).with_opset(("ai.onnx", 17)) -@pytest.fixture -def old_squeeze_graph(old_squeeze): - class Squeeze11(StandardNode): - @dataclass - class Attributes(BaseAttributes): - axes: AttrInt64s +class Squeeze11(StandardNode): + @dataclass + class Attributes(BaseAttributes): + axes: AttrInt64s + + @dataclass + class Inputs(BaseInputs): + data: Var + + @dataclass + class Outputs(BaseOutputs): + squeezed: Var - @dataclass - class Inputs(BaseInputs): - data: Var + op_type = OpType("Squeeze", "", 11) - @dataclass - class Outputs(BaseOutputs): - squeezed: Var + attrs: Attributes + inputs: Inputs + outputs: Outputs - op_type = OpType("Squeeze", "", 11) - attrs: Attributes - inputs: Inputs - outputs: Outputs +def squeeze11(_data: Var, _axes: Iterable[int]): + return Squeeze11( + Squeeze11.Attributes(AttrInt64s(_axes, "axes")), Squeeze11.Inputs(_data) + ).outputs.squeezed - def squeeze11(_data: Var, _axes: Iterable[int]): - return Squeeze11( - Squeeze11.Attributes(AttrInt64s(_axes, "axes")), Squeeze11.Inputs(_data) - ).outputs.squeezed +@pytest.fixture +def old_squeeze_graph(old_squeeze): (data,) = arguments( data=Tensor( np.float32, @@ -233,3 +236,35 @@ def test_inline_model_custom_node_nested(old_squeeze: onnx.ModelProto): # Add another node to the model to trigger the adaption logic c = op18.identity(b) build({"a": a}, {"c": c}) + + +def test_if_adapatation_squeeze(): + cond = argument(Tensor(np.bool_, ())) + b = argument(Tensor(np.float32, (1,))) + squeezed = squeeze11(b, [0]) + out = op18.if_( + cond, + then_branch=lambda: [squeezed], + else_branch=lambda: [squeeze11(b, [0])], + ) + model = build({"b": b, "cond": cond}, {"out": out[0]}) + + # predict on model + b = np.array([1.1], dtype=np.float32) + cond = np.array(True, dtype=np.bool_) + out = ort.InferenceSession(model.SerializeToString()).run( + None, {"b": b, "cond": cond} + ) + + +def test_if_adaptation_const(): + sq = op19.const(1.1453, dtype=np.float32) + b = argument(Tensor(np.float32, ("N",))) + cond = op18.equal(sq, b) + out = op18.if_(cond, then_branch=lambda: [sq], else_branch=lambda: [sq]) + model = build({"b": b}, {"out": out[0]}) + assert model.domain == "" or model.domain == "ai.onnx" + assert ( + model.opset_import[0].domain == "ai.onnx" or model.opset_import[0].domain == "" + ) + assert model.opset_import[0].version > 11