diff --git a/client/src/jobq/job.py b/client/src/jobq/job.py index 4dcfa8c..9cf347b 100644 --- a/client/src/jobq/job.py +++ b/client/src/jobq/job.py @@ -12,9 +12,8 @@ from collections.abc import Callable from collections.abc import Set as AbstractSet from pathlib import Path -from typing import Any, ClassVar, Generic, ParamSpec, TypedDict, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Generic, ParamSpec, TypedDict, TypeVar -import docker.types from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr from typing_extensions import Self @@ -26,6 +25,9 @@ from jobq.utils.math import to_rational from jobq.utils.processes import run_command +if TYPE_CHECKING: + from docker.types import DeviceRequest + class BuildMode(enum.Enum): YAML = "yaml" @@ -80,7 +82,7 @@ def _is_yaml(path: AnyPath) -> bool: class DockerResourceOptions(TypedDict): mem_limit: str | None nano_cpus: float | None - device_requests: list[docker.types.DeviceRequest] | None + device_requests: list[DeviceRequest] | None # Functional definition of TypedDict to enable special characters in dict keys @@ -131,12 +133,14 @@ def to_str(self) -> str: return pprint.pformat(self.model_dump(by_alias=True)) def to_docker(self) -> DockerResourceOptions: + from docker.types import DeviceRequest + options: DockerResourceOptions = { "mem_limit": str(int(to_rational(self.memory))) if self.memory else None, "nano_cpus": int(to_rational(self.cpu) * 10**9) if self.cpu else None, "device_requests": ( [ - docker.types.DeviceRequest( + DeviceRequest( capabilities=[["gpu"]], count=self.gpu, )