diff --git a/captum/testing/helpers/basic.py b/captum/testing/helpers/basic.py index 129d322aa..a766de7b6 100644 --- a/captum/testing/helpers/basic.py +++ b/captum/testing/helpers/basic.py @@ -91,6 +91,35 @@ def assertTensorTuplesAlmostEqual( assertTensorAlmostEqual(test, actual, expected, delta, mode) +def assertTupleOfListOfTensorsAlmostEqual( + test: unittest.TestCase, + # pyre-fixme[2]: Parameter must be annotated. + actual, + # pyre-fixme[2]: Parameter must be annotated. + expected, + delta: float = 0.0001, + mode: str = "sum", +) -> None: + if isinstance(expected, tuple): + assert isinstance(actual, tuple) and isinstance( + expected, tuple + ), f"Both actual and expected must be tuples, got {type(actual)} and {type(expected)}" + assert len(actual) == len( + expected + ), f"Tuple lengths differ: {len(actual)} != {len(expected)}" + for i, (actual_list, expected_list) in enumerate(zip(actual, expected)): + assert isinstance(actual_list, list) and isinstance( + expected_list, list + ), f"Elements at index {i} must be lists, got {type(actual_list)} and {type(expected_list)}" + assert len(actual_list) == len( + expected_list + ), f"List lengths at tuple index {i} differ: {len(actual_list)} != {len(expected_list)}" + for j, (a_tensor, e_tensor) in enumerate(zip(actual_list, expected_list)): + assertTensorAlmostEqual(test, a_tensor, e_tensor, delta, mode) + else: + assertTensorAlmostEqual(test, actual, expected, delta, mode) + + def assertAttributionComparision( test: unittest.TestCase, attributions1: Union[Tensor, Tuple[Tensor, ...]], @@ -149,3 +178,30 @@ class BaseTest(unittest.TestCase): def setUp(self) -> None: set_all_random_seeds(1234) patch_methods(self) + + +def extracted_features_equal(a: Any, b: Any) -> bool: + """ + Recursively checks if two extracted feature structures are equal. + The structures can be: + - torch.Tensor + - list of torch.Tensor + - tuple of (torch.Tensor or list of torch.Tensor) + Args: + a: First extracted feature (tensor, list, or tuple). + b: Second extracted feature (tensor, list, or tuple). + Returns: + bool: True if the structures are equal, False otherwise. + """ + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + return torch.equal(a, b) + elif isinstance(a, list) and isinstance(b, list): + if len(a) != len(b): + return False + return all(torch.equal(x, y) for x, y in zip(a, b)) + elif isinstance(a, tuple) and isinstance(b, tuple): + if len(a) != len(b): + return False + return all(extracted_features_equal(x, y) for x, y in zip(a, b)) + else: + return False