diff --git a/dockrice/argparse.py b/dockrice/argparse.py index 773842c..91fdbd5 100644 --- a/dockrice/argparse.py +++ b/dockrice/argparse.py @@ -3,7 +3,7 @@ import pathlib import docker from .dockerpath import DockerPath, DockerPathFactory, MountOption, MountSet -from .utils import get_image, run_image +from .utils import get_image, run_image, resolve_gpu_device import warnings import inspect @@ -145,6 +145,13 @@ def __call__(self, parser, namespace, values, option_string=None): factory_self.dockrice_verbose = True delattr(namespace, "dockrice_verbose") return + if option_string == "--gpus": + factory_self.docker_kwargs.setdefault("device_requests", []) + factory_self.docker_kwargs["device_requests"].extend( + resolve_gpu_device(values) + ) + delattr(namespace, "gpus") + return if option_string is not None: self.run_command.append(option_string) values = self._recursive_resolve_args(values) @@ -231,6 +238,11 @@ def __init__(self, *args, **kwargs): help="Increases verbosity of dockrice.", action="store_true", ) + self.add_argument( + "--gpus", + help="Adds gpu capability to the docker container.", + type=str, + ) def add_argument(self, *args, **kwargs): kwargs["action"] = self._docker_action_factory( diff --git a/dockrice/dockerpath.py b/dockrice/dockerpath.py index af5464a..747b7f8 100644 --- a/dockrice/dockerpath.py +++ b/dockrice/dockerpath.py @@ -11,7 +11,7 @@ def remove_prefix(string, prefix): - return string[(len(prefix) if string.startswith(prefix) else 0):] + return string[(len(prefix) if string.startswith(prefix) else 0) :] class MountOption(Enum): diff --git a/dockrice/utils.py b/dockrice/utils.py index 1e70e98..23d0c23 100644 --- a/dockrice/utils.py +++ b/dockrice/utils.py @@ -204,3 +204,25 @@ def run_image(image, cmd, client=None, return_logs=False, auto_remove=True, **kw for line in container.logs(stream=True): print(line.decode("utf-8"), end="", flush=True) return container.wait()["StatusCode"] + + +def resolve_gpu_device(arg): + """Takes a string, similar to the docker --gpus flag and converts it to dockerpy. + + Parameters + ---------- + args: str + The string that would be passed to the --gpus flag: "all" or "device=...". + + Returns: + -------- + a list of DeviceRequest objects + """ + + if arg == "all": + return [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] + elif arg.startswith("device="): + devices = arg.replace("device=", "").split(",") + return [docker.types.DeviceRequest(device_ids=devices, capabilities=[["gpu"]])] + + raise ValueError(f"Unknown gpu device value: {arg}")