From fafb3f7beac1936109a6b6dfb5593b30d7205ab4 Mon Sep 17 00:00:00 2001 From: Lingvo Maintenance Date: Wed, 24 Jan 2024 18:04:57 -0800 Subject: [PATCH] Silence some pytype errors. PiperOrigin-RevId: 601291629 --- lingvo/jax/layers/test_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lingvo/jax/layers/test_layers.py b/lingvo/jax/layers/test_layers.py index 70f00d3fd..dcb1139c6 100644 --- a/lingvo/jax/layers/test_layers.py +++ b/lingvo/jax/layers/test_layers.py @@ -232,7 +232,7 @@ def __init__(self, params): def compute_predictions(self, input_batch: NestedMap) -> JTensor: return self.bn.fprop(input_batch.inputs) - def compute_loss(self, predictions: JTensor, + def compute_loss(self, predictions: JTensor, # pytype: disable=signature-mismatch # dataclasses-replace input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]: targets = input_batch.targets error = predictions - targets @@ -260,7 +260,7 @@ def __init__(self, params): def compute_predictions(self, inputs: NestedMap) -> JTensor: return self.ffwd.fprop(inputs) - def compute_loss(self, predictions: JTensor, + def compute_loss(self, predictions: JTensor, # pytype: disable=signature-mismatch # dataclasses-replace input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]: loss = jnp.mean(jnp.square(predictions)) per_example_out = NestedMap(predictions=predictions)