diff --git a/mmf/modules/hf_layers.py b/mmf/modules/hf_layers.py index b4e0af64c..d4922bd26 100644 --- a/mmf/modules/hf_layers.py +++ b/mmf/modules/hf_layers.py @@ -280,11 +280,17 @@ def forward( attention_mask: Optional[Tensor], encoder_hidden_states: Optional[Tensor] = None, encoder_attention_mask: Optional[Tensor] = None, - output_attentions: bool = False, - output_hidden_states: bool = False, + output_attentions: bool = None, + output_hidden_states: bool = None, return_dict: bool = False, head_mask: Optional[Tensor] = None, ) -> Tuple[Tensor]: + + if output_attentions is None: + output_attentions = self.output_attentions + if output_hidden_states is None: + output_hidden_states = self.output_hidden_states + all_hidden_states = () all_attentions = () for i, layer_module in enumerate(self.layer):