Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions makani/utils/inference/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _set_eval(self):

# shorthand for inference range running over the full dataset
def inference_epoch(
self, rollout_steps: int, compute_metrics: bool = False, output_channels: List[str] = [], output_file: Optional[str] = None, output_memory_buffer_size: Optional[int] = None, bias_file: Optional[str] = None, spectrum_file: Optional[str] = None, zonal_spectrum_file: Optional[str] = None, wb2_compatible: Optional[bool] = False, profiler=None
self, rollout_steps: int, dhours: int, compute_metrics: bool = False, output_channels: List[str] = [], output_file: Optional[str] = None, output_memory_buffer_size: Optional[int] = None, bias_file: Optional[str] = None, spectrum_file: Optional[str] = None, zonal_spectrum_file: Optional[str] = None, wb2_compatible: Optional[bool] = False, profiler=None
):
"""
Runs the model in autoregressive inference mode on the entire validation dataset. Computes metrics and scores the model.
Expand All @@ -226,6 +226,7 @@ def inference_epoch(
end,
1,
rollout_steps=rollout_steps,
dhours=dhours,
batch_size=self.params.batch_size,
compute_metrics=compute_metrics,
output_channels=output_channels,
Expand All @@ -247,6 +248,7 @@ def inference_range(
end: int,
step: int,
rollout_steps: int,
dhours: int,
batch_size: int,
compute_metrics: bool = False,
metrics_file: Optional[str] = None,
Expand All @@ -267,6 +269,7 @@ def inference_range(
logs = self.inference_indexlist(
indices,
rollout_steps=rollout_steps,
dhours=dhours,
batch_size=batch_size,
compute_metrics=compute_metrics,
metrics_file=metrics_file,
Expand All @@ -286,6 +289,7 @@ def inference_indexlist(
self,
indices: Union[List[int], torch.Tensor],
rollout_steps: int,
dhours: int,
batch_size: int,
compute_metrics: bool = False,
metrics_file: Optional[str] = None,
Expand Down Expand Up @@ -343,7 +347,7 @@ def inference_indexlist(
num_samples=len(indices),
batch_size=batch_size,
num_rollout_steps=rollout_steps,
rollout_dt=self.params.dt,
rollout_dt=self.params.dt * dhours,
ensemble_size=self.params.local_ensemble_size,
img_shape=img_shape,
local_shape=local_shape,
Expand Down Expand Up @@ -813,6 +817,7 @@ def score_model(
start,
end,
step,
dhours=self.valid_dataset.dhours,
rollout_steps=rollout_steps,
batch_size=self.params.batch_size,
compute_metrics=True,
Expand Down
Loading