Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions captum/testing/helpers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,35 @@ def assertTensorTuplesAlmostEqual(
assertTensorAlmostEqual(test, actual, expected, delta, mode)


def assertTupleOfListOfTensorsAlmostEqual(
test: unittest.TestCase,
# pyre-fixme[2]: Parameter must be annotated.
actual,
# pyre-fixme[2]: Parameter must be annotated.
expected,
delta: float = 0.0001,
mode: str = "sum",
) -> None:
if isinstance(expected, tuple):
assert isinstance(actual, tuple) and isinstance(
expected, tuple
), f"Both actual and expected must be tuples, got {type(actual)} and {type(expected)}"
assert len(actual) == len(
expected
), f"Tuple lengths differ: {len(actual)} != {len(expected)}"
for i, (actual_list, expected_list) in enumerate(zip(actual, expected)):
assert isinstance(actual_list, list) and isinstance(
expected_list, list
), f"Elements at index {i} must be lists, got {type(actual_list)} and {type(expected_list)}"
assert len(actual_list) == len(
expected_list
), f"List lengths at tuple index {i} differ: {len(actual_list)} != {len(expected_list)}"
for j, (a_tensor, e_tensor) in enumerate(zip(actual_list, expected_list)):
assertTensorAlmostEqual(test, a_tensor, e_tensor, delta, mode)
else:
assertTensorAlmostEqual(test, actual, expected, delta, mode)


def assertAttributionComparision(
test: unittest.TestCase,
attributions1: Union[Tensor, Tuple[Tensor, ...]],
Expand Down Expand Up @@ -149,3 +178,30 @@ class BaseTest(unittest.TestCase):
def setUp(self) -> None:
set_all_random_seeds(1234)
patch_methods(self)


def extracted_features_equal(a: Any, b: Any) -> bool:
"""
Recursively checks if two extracted feature structures are equal.
The structures can be:
- torch.Tensor
- list of torch.Tensor
- tuple of (torch.Tensor or list of torch.Tensor)
Args:
a: First extracted feature (tensor, list, or tuple).
b: Second extracted feature (tensor, list, or tuple).
Returns:
bool: True if the structures are equal, False otherwise.
"""
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
return torch.equal(a, b)
elif isinstance(a, list) and isinstance(b, list):
if len(a) != len(b):
return False
return all(torch.equal(x, y) for x, y in zip(a, b))
elif isinstance(a, tuple) and isinstance(b, tuple):
if len(a) != len(b):
return False
return all(extracted_features_equal(x, y) for x, y in zip(a, b))
else:
return False
Loading