diff --git a/mmf/common/sample.py b/mmf/common/sample.py index bafffa460..cdf0450fe 100644 --- a/mmf/common/sample.py +++ b/mmf/common/sample.py @@ -393,6 +393,61 @@ def to_dict(self) -> Dict[str, Any]: return sample_dict + def __eq__(self,other): + """Compare a sampleList with the current SampleList. + + Returns: + Bool : True or False + """ + + if not isinstance(other,SampleList): + return False + + fields = self.fields() + fields_other = other.fields() + tensor_field = self._get_tensor_field() + tensor_field_other = other._get_tensor_field() + + # Check for tensor fields comparison + if ( + len(fields) != 0 + and len(fields_other) != 0 + and tensor_field is not None + and tensor_field_other is not None + and other[tensor_field_other].size(0) != self[tensor_field].size(0) + ): + return False + + fields_set = set(fields) + fields_set_other = set(fields_other) + + # Comparison between keys and early fail + if fields_set==fields_set_other: + # Compare all the features + for field in fields: + # Compare Tensors + if ( + isinstance(self.get_field(field),torch.Tensor) + and isinstance(other.get_field(field),torch.Tensor) + ): + if not torch.equal(self.get_field(field),other.get_field(field)): + return False + + # Check for same data type + elif ( + type(self.get_field(field)) is not type(other.get_field(field)) + ): + return False + + # Compare Lists + else: + if not self.get_field(field)==other.get_field(field): + return False + + return True + + return False + def convert_batch_to_sample_list( batch: Union[SampleList, Dict[str, Any]] diff --git a/tests/common/test_sample.py b/tests/common/test_sample.py index 1f90f214e..75560c41e 100644 --- a/tests/common/test_sample.py +++ b/tests/common/test_sample.py @@ -75,7 +75,42 @@ def test_to_dict(self): self.assertTrue(all_keys) self.assertTrue(isinstance(sample_dict, dict)) - + + def test_equal(self): + sample_list1 = test_utils.build_random_sample_list() + sample_list2 = sample_list1.copy() + sample_list3 = sample_list1.copy() + + sample_list3.add_field('new',list([1,2,3,4,5])) + sample_list4 = sample_list1.copy() + tensor_size = sample_list1.get_batch_size() + sample_list4.add_field('new',torch.zeros(tensor_size)) + + sample_list5 = SampleList() + sample_list6 = SampleList() + sample_list6.add_field('new',SampleList()) + + sample_list7 = SampleList() + dict_example = {'a':1, 'b':2} + sample_list7.add_field('new',dict_example) + + sample_list8 = sample_list1.copy() + sample_list8.add_field('new',torch.ones(tensor_size)) + + + self.assertTrue(sample_list1 == sample_list2) + self.assertTrue(sample_list1 != sample_list3) + self.assertTrue(sample_list1 != sample_list4) + self.assertTrue(sample_list2 != sample_list4) + self.assertTrue(sample_list1 != sample_list5) + self.assertTrue(sample_list1 != sample_list6) + self.assertTrue(sample_list1 != sample_list7) + self.assertTrue(sample_list5 != sample_list6) + self.assertTrue(sample_list6 != sample_list7) + self.assertTrue(sample_list6 != sample_list5) + self.assertTrue(sample_list6 != sample_list1) + self.assertTrue(sample_list1 != sample_list8) + self.assertFalse(sample_list4 == sample_list8) class TestFunctions(unittest.TestCase): def test_to_device(self):