Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions mmf/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,23 +272,27 @@ def __call__(self, sample_list, *args, **kwargs):
model_output, collections.abc.Mapping
), "A dict must be returned from the forward of the model."

final_output = {"losses": {}}
final_output.update(model_output)

if "losses" in model_output:
if not self._logged_warning["losses_present"]:
assert isinstance(
model_output["losses"], collections.abc.Mapping
), "'losses' returned from the model must be a dict."

if hasattr(self, "losses"):
if "losses" in model_output and not self._logged_warning["losses_present"]:
warnings.warn(
"'losses' already present in model output. "
"No calculation will be done in base model."
"'losses' already present in model output and 'loss' key "
"was specified. Results from the two will be merged. "
"If this is not expected, either (i) assign unique keys to "
"losses returned from your model (ii) remove 'loss' key from "
"your model output"
)
self._logged_warning["losses_present"] = True
final_output["losses"].update(self.losses(sample_list, model_output))

assert isinstance(
model_output["losses"], collections.abc.Mapping
), "'losses' must be a dict."
elif hasattr(self, "losses"):
model_output["losses"] = self.losses(sample_list, model_output)
else:
model_output["losses"] = {}

return model_output
return final_output

def load_requirements(self, *args, **kwargs):
requirements = self.config.get("zoo_requirements", [])
Expand Down
4 changes: 3 additions & 1 deletion mmf/models/transformers/heads/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Config(BaseTransformerHead.Config):
hidden_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-6
hidden_act: str = "gelu"
output_key: str = "scores"

def __init__(self, config: Config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
Expand All @@ -33,6 +34,7 @@ def __init__(self, config: Config, *args, **kwargs):
)
self.num_labels = self.config.num_labels
self.hidden_size = self.config.hidden_size
self.output_key = self.config.get("output_key", "scores")

def forward(
self,
Expand All @@ -46,5 +48,5 @@ def forward(
output_dict = {}
pooled_output = self.pooler(sequence_output)
prediction = self.classifier(pooled_output)
output_dict["scores"] = prediction.view(-1, self.num_labels)
output_dict[self.output_key] = prediction.view(-1, self.num_labels)
return output_dict
50 changes: 50 additions & 0 deletions tests/models/test_base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Facebook, Inc. and its affiliates.

import unittest

import torch
from mmf.common.sample import SampleList
from mmf.models.base_model import BaseModel
from tests.test_utils import compare_tensors


class LocalTestModelWithForwardLoss(BaseModel):
def forward(self, *args, **kwargs):
return {"losses": {"x": torch.tensor(1.0)}}


class LocalTestModelWithNoLoss(BaseModel):
def forward(self, *args, **kwargs):
return {}


class LocalTestModelWithLossAttribute(BaseModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.losses = lambda x, y: {"x": torch.tensor(2.0)}

def forward(self, *args, **kwargs):
return {}


class TestBaseModel(unittest.TestCase):
def test_forward_loss(self):
sample_list = SampleList()
sample_list.add_field("x", torch.tensor(1))
model = LocalTestModelWithForwardLoss({})
with torch.no_grad():
output = model(sample_list)
self.assertTrue("losses" in output)
self.assertTrue(compare_tensors(output["losses"]["x"], torch.tensor(1.0)))

model = LocalTestModelWithLossAttribute({})
with torch.no_grad():
output = model(sample_list)
self.assertTrue("losses" in output)
self.assertTrue(compare_tensors(output["losses"]["x"], torch.tensor(2.0)))

model = LocalTestModelWithNoLoss({})
with torch.no_grad():
output = model(sample_list)
self.assertTrue("losses" in output)
self.assertEqual(output["losses"], {})