Skip to content

Commit

Permalink
Documentation (#78)
Browse files Browse the repository at this point in the history
* update requirements

* Add ConstantOfShape to light API

* add slice

* changelogs

* k
  • Loading branch information
xadupre authored Feb 22, 2024
1 parent 2dd0686 commit a906010
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.2.0
+++++

* :pr:`77`: supports ConcatOfShape and Slice with the light API
* :pr:`76`: add a mode to compare models without execution
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
* :pr:`71`: adds tools to compare two onnx graphs
Expand Down
30 changes: 29 additions & 1 deletion _unittests/ut_light_api/test_light_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,35 @@ def test_constant_of_shape(self):
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got)

def test_constant_of_shape_value(self):
onx = (
start()
.vin("X", TensorProto.INT64, shape=[None, None])
.ConstantOfShape(value=np.array([1], dtype=np.float32))
.vout(shape=[])
.to_onnx()
)
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
self.assertEqualArray(np.ones((2, 3), dtype=np.float32), got)

def test_slice(self):
onx = (
start(opset=18, ir_version=9)
.cst(np.array([1], dtype=np.int64), name="one")
.cst(np.array([2], dtype=np.int64), name="two")
.vin("X", TensorProto.INT64, shape=[None, None])
.ConstantOfShape(value=np.array([1], dtype=np.float32))
.rename("CX")
.bring("CX", "one", "two", "one")
.Slice()
.vout(shape=[])
.to_onnx()
)
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
self.assertEqualArray(np.ones((2, 1), dtype=np.float32), got)


if __name__ == "__main__":
TestLightApi().test_add()
unittest.main(verbosity=2)
7 changes: 7 additions & 0 deletions onnx_array_api/light_api/_op_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,13 @@ def Selu(
def Shrink(self, bias: float = 0.0, lambd: float = 0.5) -> "Var":
return self.make_node("Shrink", self, bias=bias, lambd=lambd)

def Slice(
self, starts: "Var", ends: "Var", axes: "Var", steps: Optional["Var"] = None
) -> "Var":
if steps is None:
return self.make_node("Slice", self, starts, ends, axes)
return self.make_node("Slice", self, starts, ends, axes, steps)

def Softmax(self, axis: int = -1) -> "Var":
return self.make_node("Softmax", self, axis=axis)

Expand Down
2 changes: 1 addition & 1 deletion onnx_array_api/light_api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def to_onnx(self) -> GRAPH_PROTO:
return graph
model = make_model(graph, opset_imports=opsets)
if self.ir_version:
model.ir_version = ir_version
model.ir_version = self.ir_version
if not is_windows() or not is_azure():
# check_model fails sometimes on Windows
check_model(model)
Expand Down

0 comments on commit a906010

Please sign in to comment.