Skip to content

Commit

Permalink
pass pytype checks
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Oct 29, 2024
1 parent 97929f5 commit 99bd1fe
Showing 1 changed file with 9 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ConfigDispatchIndividualResources(BaseModel):
dataset_flow: ConfigDatasetFlow | None

@model_validator(mode="after")
def check_at_least_one_flow(self) -> Self:
def check_at_least_one_flow(self) -> ConfigDispatchIndividualResources:
if self.resource_flow is None and self.dataset_flow is None:
raise ValueError("either resource_flow or dataset_flow must be provided")
return self
Expand Down Expand Up @@ -80,14 +80,13 @@ def _write_chunks(self, chunks: list[Path]) -> None:
yaml.safe_dump(data, f)

def _dispatch_jobs_resource_flow(self, definition: WorkunitDefinition, params: dict[str, Any]) -> list[Path]:
if self._config.resource_flow is None:
config = self._config.resource_flow
if config is None:
raise ValueError("resource_flow is not configured")
resources = Resource.find_all(ids=definition.execution.resources, client=self._client)
paths = []
for resource in sorted(resources.values()):
if self._config.resource_flow.filter_suffix is not None and not resource["relativepath"].endswith(
self._config.resource_flow.filter_suffix
):
if config.filter_suffix is not None and not resource["relativepath"].endswith(config.filter_suffix):
logger.info(
f"Skipping resource {resource['relativepath']!r} as it does not match the extension filter."
)
Expand All @@ -96,16 +95,15 @@ def _dispatch_jobs_resource_flow(self, definition: WorkunitDefinition, params: d
return paths

def _dispatch_jobs_dataset_flow(self, definition: WorkunitDefinition, params: dict[str, Any]) -> list[Path]:
if self._config.dataset_flow is None:
config = self._config.dataset_flow
if config is None:
raise ValueError("dataset_flow is not configured")
dataset = Dataset.find(id=definition.execution.dataset, client=self._client)
dataset_df = dataset.to_polars()
resources = Resource.find_all(
ids=dataset_df[self._config.dataset_flow.resource_column].unique().to_list(), client=self._client
)
resources = Resource.find_all(ids=dataset_df[config.resource_column].unique().to_list(), client=self._client)
paths = []
for row in dataset_df.iter_rows(named=True):
resource_id = int(row[self._config.dataset_flow.resource_column])
row_params = {name: row[dataset_name] for dataset_name, name in self._config.dataset_flow.param_columns}
resource_id = int(row[config.resource_column])
row_params = {name: row[dataset_name] for dataset_name, name in config.param_columns}
paths.append(self.dispatch_job(resource=resources[resource_id], params=params | row_params))
return paths

0 comments on commit 99bd1fe

Please sign in to comment.