From 38d04b667534f50643c0de083a9ae51a3b1f2e84 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Mon, 15 Apr 2024 08:33:05 -0700 Subject: [PATCH] Update compare_output medhod for tester (#3016) Summary: Method name update Reviewed By: mcr229 Differential Revision: D56072265 --- backends/arm/test/tester/arm_tester.py | 30 +++++++++++++++++++++++++- backends/xnnpack/test/tester/tester.py | 16 ++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 2d0816a294..20283de5ce 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -128,7 +128,7 @@ def to_edge(self, to_edge_stage: Optional[ToEdge] = None): to_edge_stage = ToEdge(EdgeCompileConfig(_check_ir_validity=False)) return super().to_edge(to_edge_stage) - def partition(self, partition_stage: Optional[Partition] = None): + def partition(self, partition_stage: Optional[Partition] = None): # pyre-ignore if partition_stage is None: arm_partitioner = ArmPartitioner(compile_spec=self.compile_spec) partition_stage = Partition(arm_partitioner) @@ -196,6 +196,34 @@ def run_method( return self + def compare_outputs(self, atol=1e-03, rtol=1e-03, qtol=0): + """ + Compares the original of the original nn module with the output of the generated artifact. + This requres calling run_method before calling compare_outputs. As that runs the generated + artifact on the sample inputs and sets the stage output to be compared against the reference. + """ + assert self.reference_output is not None + assert self.stage_output is not None + + # Wrap both outputs as tuple, since executor output is always a tuple even if single tensor + if isinstance(self.reference_output, torch.Tensor): + self.reference_output = (self.reference_output,) + if isinstance(self.stage_output, torch.Tensor): + self.stage_output = (self.stage_output,) + + # If a qtol is provided and we found an dequantization node prior to the output, relax the + # atol by qtol quant units. + if self.quantization_scale is not None: + atol += self.quantization_scale * qtol + + self._assert_outputs_equal( + self.stage_output, + self.reference_output, + atol=atol, + rtol=rtol, + ) + return self + def _get_input_params( self, program: ExportedProgram ) -> Tuple[str, Union[List[QuantizationParams], List[None]]]: diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index e0115a29ee..235558fcf2 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -534,6 +534,22 @@ def check_node_count(self, input: Dict[Any, int]): return self + def run_method( + self, stage: Optional[str] = None, inputs: Optional[Tuple[torch.Tensor]] = None + ): + # This is to avoid accidental ommition of compare_outputs resulting in + # false positive of the test passing. + raise NotImplementedError( + "run_method is deprecated, please use run_method_and_compare_outputs" + ) + + def compare_outputs(self, atol=1e-03, rtol=1e-03, qtol=0): + # This is to avoid accidental ommition of compare_outputs resulting in + # false positive of the test passing. + raise NotImplementedError( + "compare_outputs is deprecated, please use run_method_and_compare_outputs" + ) + def run_method_and_compare_outputs( self, stage: Optional[str] = None,