diff --git a/mmf/models/base_model.py b/mmf/models/base_model.py index c10c1d702..984b29033 100644 --- a/mmf/models/base_model.py +++ b/mmf/models/base_model.py @@ -272,23 +272,27 @@ def __call__(self, sample_list, *args, **kwargs): model_output, collections.abc.Mapping ), "A dict must be returned from the forward of the model." + final_output = {"losses": {}} + final_output.update(model_output) + if "losses" in model_output: - if not self._logged_warning["losses_present"]: + assert isinstance( + model_output["losses"], collections.abc.Mapping + ), "'losses' returned from the model must be a dict." + + if hasattr(self, "losses"): + if "losses" in model_output and not self._logged_warning["losses_present"]: warnings.warn( - "'losses' already present in model output. " - "No calculation will be done in base model." + "'losses' already present in model output and 'loss' key " + "was specified. Results from the two will be merged. " + "If this is not expected, either (i) assign unique keys to " + "losses returned from your model (ii) remove 'loss' key from " + "your model output" ) self._logged_warning["losses_present"] = True + final_output["losses"].update(self.losses(sample_list, model_output)) - assert isinstance( - model_output["losses"], collections.abc.Mapping - ), "'losses' must be a dict." - elif hasattr(self, "losses"): - model_output["losses"] = self.losses(sample_list, model_output) - else: - model_output["losses"] = {} - - return model_output + return final_output def load_requirements(self, *args, **kwargs): requirements = self.config.get("zoo_requirements", []) diff --git a/mmf/models/transformers/heads/mlp.py b/mmf/models/transformers/heads/mlp.py index 9546c9a11..76c906e53 100644 --- a/mmf/models/transformers/heads/mlp.py +++ b/mmf/models/transformers/heads/mlp.py @@ -20,6 +20,7 @@ class Config(BaseTransformerHead.Config): hidden_dropout_prob: float = 0.1 layer_norm_eps: float = 1e-6 hidden_act: str = "gelu" + output_key: str = "scores" def __init__(self, config: Config, *args, **kwargs): super().__init__(config, *args, **kwargs) @@ -33,6 +34,7 @@ def __init__(self, config: Config, *args, **kwargs): ) self.num_labels = self.config.num_labels self.hidden_size = self.config.hidden_size + self.output_key = self.config.get("output_key", "scores") def forward( self, @@ -46,5 +48,5 @@ def forward( output_dict = {} pooled_output = self.pooler(sequence_output) prediction = self.classifier(pooled_output) - output_dict["scores"] = prediction.view(-1, self.num_labels) + output_dict[self.output_key] = prediction.view(-1, self.num_labels) return output_dict diff --git a/tests/models/test_base_model.py b/tests/models/test_base_model.py new file mode 100644 index 000000000..29d5f6149 --- /dev/null +++ b/tests/models/test_base_model.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest + +import torch +from mmf.common.sample import SampleList +from mmf.models.base_model import BaseModel +from tests.test_utils import compare_tensors + + +class LocalTestModelWithForwardLoss(BaseModel): + def forward(self, *args, **kwargs): + return {"losses": {"x": torch.tensor(1.0)}} + + +class LocalTestModelWithNoLoss(BaseModel): + def forward(self, *args, **kwargs): + return {} + + +class LocalTestModelWithLossAttribute(BaseModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.losses = lambda x, y: {"x": torch.tensor(2.0)} + + def forward(self, *args, **kwargs): + return {} + + +class TestBaseModel(unittest.TestCase): + def test_forward_loss(self): + sample_list = SampleList() + sample_list.add_field("x", torch.tensor(1)) + model = LocalTestModelWithForwardLoss({}) + with torch.no_grad(): + output = model(sample_list) + self.assertTrue("losses" in output) + self.assertTrue(compare_tensors(output["losses"]["x"], torch.tensor(1.0))) + + model = LocalTestModelWithLossAttribute({}) + with torch.no_grad(): + output = model(sample_list) + self.assertTrue("losses" in output) + self.assertTrue(compare_tensors(output["losses"]["x"], torch.tensor(2.0))) + + model = LocalTestModelWithNoLoss({}) + with torch.no_grad(): + output = model(sample_list) + self.assertTrue("losses" in output) + self.assertEqual(output["losses"], {})