Skip to content

Commit

Permalink
Update compare_output medhod for tester (pytorch#3016)
Browse files Browse the repository at this point in the history
Summary:

Method name update

Reviewed By: mcr229

Differential Revision: D56072265
  • Loading branch information
digantdesai authored and facebook-github-bot committed Apr 15, 2024
1 parent 057e432 commit 38d04b6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
30 changes: 29 additions & 1 deletion backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def to_edge(self, to_edge_stage: Optional[ToEdge] = None):
to_edge_stage = ToEdge(EdgeCompileConfig(_check_ir_validity=False))
return super().to_edge(to_edge_stage)

def partition(self, partition_stage: Optional[Partition] = None):
def partition(self, partition_stage: Optional[Partition] = None): # pyre-ignore
if partition_stage is None:
arm_partitioner = ArmPartitioner(compile_spec=self.compile_spec)
partition_stage = Partition(arm_partitioner)
Expand Down Expand Up @@ -196,6 +196,34 @@ def run_method(

return self

def compare_outputs(self, atol=1e-03, rtol=1e-03, qtol=0):
"""
Compares the original of the original nn module with the output of the generated artifact.
This requres calling run_method before calling compare_outputs. As that runs the generated
artifact on the sample inputs and sets the stage output to be compared against the reference.
"""
assert self.reference_output is not None
assert self.stage_output is not None

# Wrap both outputs as tuple, since executor output is always a tuple even if single tensor
if isinstance(self.reference_output, torch.Tensor):
self.reference_output = (self.reference_output,)
if isinstance(self.stage_output, torch.Tensor):
self.stage_output = (self.stage_output,)

# If a qtol is provided and we found an dequantization node prior to the output, relax the
# atol by qtol quant units.
if self.quantization_scale is not None:
atol += self.quantization_scale * qtol

self._assert_outputs_equal(
self.stage_output,
self.reference_output,
atol=atol,
rtol=rtol,
)
return self

def _get_input_params(
self, program: ExportedProgram
) -> Tuple[str, Union[List[QuantizationParams], List[None]]]:
Expand Down
16 changes: 16 additions & 0 deletions backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,22 @@ def check_node_count(self, input: Dict[Any, int]):

return self

def run_method(
self, stage: Optional[str] = None, inputs: Optional[Tuple[torch.Tensor]] = None
):
# This is to avoid accidental ommition of compare_outputs resulting in
# false positive of the test passing.
raise NotImplementedError(
"run_method is deprecated, please use run_method_and_compare_outputs"
)

def compare_outputs(self, atol=1e-03, rtol=1e-03, qtol=0):
# This is to avoid accidental ommition of compare_outputs resulting in
# false positive of the test passing.
raise NotImplementedError(
"compare_outputs is deprecated, please use run_method_and_compare_outputs"
)

def run_method_and_compare_outputs(
self,
stage: Optional[str] = None,
Expand Down

0 comments on commit 38d04b6

Please sign in to comment.