-
Notifications
You must be signed in to change notification settings - Fork 33
Commit 77471a2
metrics: Use Array and ArrayLike types thoughout
Currently the inputs to `from_model_output` are not typed. However,
these functions cannot accept arbitrary inputs, they need to be a value
convertable to a `jax.Array`. This change fixes this so that:
- `from_model_output` takes in types of `Array` or `ArrayLike`
- Removes use of `jnp.array` as a type as it's equivalent to `Any`
- Makes members of Metric classes have type `Array`
- Moves mask checking code into its own function
While we could make everything use `Array` (instead of `ArrayLike`),
this would break code like:
```
@flax.struct.dataclass
class Collection(metrics.Collection):
train_accuracy: metrics.Accuracy
learning_rate: metrics.LastValue.from_output("learning_rate")
Collection.gather_from_model_output(learning_rate=0.02, ...)
```
which seems undesirable.
Note that `count` and `value` for `LastValue` have type `ArrayLike`,
as this code needs to support passing a plain number for `value` or
`count`. Also, the base `Metric.compute()` method has type `Any`,
because some metrics return `Array` while others use `dict[str, Array]`.
PiperOrigin-RevId: 5292272181 parent f8eec70 commit 77471a2Copy full SHA for 77471a2
File tree
Expand file treeCollapse file tree
2 files changed
+114
-87
lines changedFilter options
- clu
Expand file treeCollapse file tree
2 files changed
+114
-87
lines changed
0 commit comments