Skip to content

Commit

Permalink
Add discrepancies when comparing the execution of two models (#79)
Browse files Browse the repository at this point in the history
* update requirements

* add discrepancies figures

* fix command line

* doc
  • Loading branch information
xadupre authored Feb 28, 2024
1 parent a906010 commit dcc2ddd
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Change Logs
+++++

* :pr:`77`: supports ConcatOfShape and Slice with the light API
* :pr:`76`: add a mode to compare models without execution
* :pr:`76`, :pr:`79`: add a mode to compare models without execution
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
* :pr:`71`: adds tools to compare two onnx graphs
* :pr:`61`: adds function to plot onnx model as graphs
Expand Down
25 changes: 25 additions & 0 deletions _unittests/ut_reference/test_evaluator_yield.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,31 @@ def test_compare_execution(self):
self.assertIn("CAAA Constant", text)
self.assertEqual(len(align), 5)

def test_compare_execution_discrepancies(self):
m1 = parse_model(
"""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = Mul(x, x)
}"""
)
m2 = parse_model(
"""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
z = Mul(x, x)
}"""
)
res1, res2, align, dc = compare_onnx_execution(m1, m2, keep_tensor=True)
text = dc.to_str(res1, res2, align)
print(text)
self.assertIn("CAAA Constant", text)
self.assertIn("| a=", text)
self.assertIn(" r=", text)

def test_no_execution(self):
model = make_model(
make_graph(
Expand Down
14 changes: 12 additions & 2 deletions onnx_array_api/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,15 @@ def get_parser_compare() -> ArgumentParser:
parser.add_argument(
"-c",
"--column-size",
default=50,
default=60,
help="column size when displaying the results",
)
parser.add_argument(
"-d",
"--discrepancies",
default=0,
help="show precise discrepancies when mode is execution",
)
return parser


Expand All @@ -120,7 +126,11 @@ def _cmd_compare(argv: List[Any]):
onx1 = onnx.load(args.model1)
onx2 = onnx.load(args.model2)
res1, res2, align, dc = compare_onnx_execution(
onx1, onx2, verbose=args.verbose, mode=args.mode
onx1,
onx2,
verbose=args.verbose,
mode=args.mode,
keep_tensor=args.discrepancies in (1, "1", "True", True),
)
text = dc.to_str(res1, res2, align, column_size=int(args.column_size))
print(text)
Expand Down
49 changes: 43 additions & 6 deletions onnx_array_api/reference/evaluator_yield.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class ResultExecution:
summary: str
op_type: str
name: str
value: Optional[Any] = None

def __len__(self) -> int:
return 6
Expand Down Expand Up @@ -122,9 +123,11 @@ def make_summary(value: Any, length: int = 4, modulo: int = 26) -> str:
else:
value2 = value.flatten().astype(np.float64)
value4 = value2.reshape((4, -1)).sum(axis=1)
value4i = value4.astype(np.int64) % modulo
s = "".join([chr(65 + i) for i in value4i])
return s
value4 = np.where(np.abs(value4) < 1e10, value4, np.nan)
s = []
for v in value4:
s.append("?" if np.isnan(v) else (chr(65 + int(v) % modulo)))
return "".join(s)


class YieldEvaluator:
Expand Down Expand Up @@ -228,6 +231,7 @@ def enumerate_summarized(
output_names: Optional[List[str]] = None,
feed_inputs: Optional[Dict[str, Any]] = None,
raise_exc: bool = True,
keep_tensor: bool = False,
) -> Iterator[ResultExecution]:
"""
Executes the onnx model and enumerate intermediate results without their names.
Expand All @@ -236,17 +240,40 @@ def enumerate_summarized(
:param feed_inputs: dictionary `{ input name: input value }`
:param raise_exc: raises an exception if the execution fails or stop
where it is
:param keep_tensor:keep the tensor in order to compute precise distances
:return: iterator on ResultExecution
"""
for kind, name, value, op_type in self.enumerate_results(
output_names, feed_inputs, raise_exc=raise_exc
):
summary = make_summary(value)
yield ResultExecution(
kind, value.dtype, value.shape, summary, op_type, name
kind,
value.dtype,
value.shape,
summary,
op_type,
name,
value=value if keep_tensor else None,
)


def discrepancies(
expected: np.ndarray, value: np.ndarray, eps: float = 1e-7
) -> Dict[str, float]:
"""
Computes absolute error and relative error between two matrices.
"""
assert (
expected.size == value.size
), f"Incompatible shapes v1.shape={expected.shape}, v2.shape={value.shape}"
expected = expected.ravel().astype(np.float32)
value = value.ravel().astype(np.float32)
diff = np.abs(expected - value)
rel = diff / (np.abs(expected) + eps)
return dict(aerr=float(diff.max()), rerr=float(rel.max()))


class DistanceExecution:
"""
Computes a distance between two results.
Expand Down Expand Up @@ -403,6 +430,14 @@ def to_str(
d = self.distance_pair(d1, d2)
symbol = "=" if d == 0 else "~"
line = f"{symbol} | {_align(str(d1), column_size)} | {_align(str(d2), column_size)}"
if (
d1.value is not None
and d2.value is not None
and d1.value.size == d2.value.size
):
disc = discrepancies(d1.value, d2.value)
a, r = disc["aerr"], disc["rerr"]
line += f" | a={a:.3f} r={r:.3f}"
elif i == last[0]:
d2 = s2[j]
line = (
Expand Down Expand Up @@ -551,6 +586,7 @@ def compare_onnx_execution(
verbose: int = 0,
raise_exc: bool = True,
mode: str = "execute",
keep_tensor: bool = False,
) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]:
"""
Compares the execution of two onnx models.
Expand All @@ -566,6 +602,7 @@ def compare_onnx_execution(
:param raise_exc: raise exception if the execution fails or stop at the error
:param mode: the model should be executed but the function can be executed
but the comparison may append on nodes only
:param keep_tensor: keeps the tensor in order to compute a precise distance
:return: four results, a sequence of results for the first model and the second model,
the alignment between the two, DistanceExecution
"""
Expand All @@ -589,15 +626,15 @@ def compare_onnx_execution(
print("[compare_onnx_execution] execute first model")
res1 = list(
YieldEvaluator(model1).enumerate_summarized(
None, feeds1, raise_exc=raise_exc
None, feeds1, raise_exc=raise_exc, keep_tensor=keep_tensor
)
)
if verbose:
print(f"[compare_onnx_execution] got {len(res1)} results")
print("[compare_onnx_execution] execute second model")
res2 = list(
YieldEvaluator(model2).enumerate_summarized(
None, feeds2, raise_exc=raise_exc
None, feeds2, raise_exc=raise_exc, keep_tensor=keep_tensor
)
)
elif mode == "nodes":
Expand Down

0 comments on commit dcc2ddd

Please sign in to comment.