Skip to content

Commit

Permalink
Silence some pytype errors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 601291629
  • Loading branch information
lingvo-bot authored and copybara-github committed Jan 25, 2024
1 parent 832a520 commit fafb3f7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lingvo/jax/layers/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __init__(self, params):
def compute_predictions(self, input_batch: NestedMap) -> JTensor:
return self.bn.fprop(input_batch.inputs)

def compute_loss(self, predictions: JTensor,
def compute_loss(self, predictions: JTensor, # pytype: disable=signature-mismatch # dataclasses-replace
input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]:
targets = input_batch.targets
error = predictions - targets
Expand Down Expand Up @@ -260,7 +260,7 @@ def __init__(self, params):
def compute_predictions(self, inputs: NestedMap) -> JTensor:
return self.ffwd.fprop(inputs)

def compute_loss(self, predictions: JTensor,
def compute_loss(self, predictions: JTensor, # pytype: disable=signature-mismatch # dataclasses-replace
input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]:
loss = jnp.mean(jnp.square(predictions))
per_example_out = NestedMap(predictions=predictions)
Expand Down

0 comments on commit fafb3f7

Please sign in to comment.