diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index c4df810..0e89a9f 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -20,12 +20,16 @@ AUTOCAST = { - "16": torch.float16, "16-mixed": torch.float16, + "16": torch.float16, "32": torch.float32, - "b16": torch.bfloat16, "b16-mixed": torch.bfloat16, + "b16": torch.bfloat16, + "bf16-mixed": torch.bfloat16, + "bf16": torch.bfloat16, "bfloat16": torch.bfloat16, + "f16": torch.float16, + "f32": torch.float32, "float16": torch.float16, "float32": torch.float32, }