diff --git a/lingvo/jax/layers/test_layers.py b/lingvo/jax/layers/test_layers.py index 70f00d3fd..dcb1139c6 100644 --- a/lingvo/jax/layers/test_layers.py +++ b/lingvo/jax/layers/test_layers.py @@ -232,7 +232,7 @@ def __init__(self, params): def compute_predictions(self, input_batch: NestedMap) -> JTensor: return self.bn.fprop(input_batch.inputs) - def compute_loss(self, predictions: JTensor, + def compute_loss(self, predictions: JTensor, # pytype: disable=signature-mismatch # dataclasses-replace input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]: targets = input_batch.targets error = predictions - targets @@ -260,7 +260,7 @@ def __init__(self, params): def compute_predictions(self, inputs: NestedMap) -> JTensor: return self.ffwd.fprop(inputs) - def compute_loss(self, predictions: JTensor, + def compute_loss(self, predictions: JTensor, # pytype: disable=signature-mismatch # dataclasses-replace input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]: loss = jnp.mean(jnp.square(predictions)) per_example_out = NestedMap(predictions=predictions)