Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion dockrice/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pathlib
import docker
from .dockerpath import DockerPath, DockerPathFactory, MountOption, MountSet
from .utils import get_image, run_image
from .utils import get_image, run_image, resolve_gpu_device
import warnings
import inspect

Expand Down Expand Up @@ -145,6 +145,13 @@ def __call__(self, parser, namespace, values, option_string=None):
factory_self.dockrice_verbose = True
delattr(namespace, "dockrice_verbose")
return
if option_string == "--gpus":
factory_self.docker_kwargs.setdefault("device_requests", [])
factory_self.docker_kwargs["device_requests"].extend(
resolve_gpu_device(values)
)
delattr(namespace, "gpus")
return
if option_string is not None:
self.run_command.append(option_string)
values = self._recursive_resolve_args(values)
Expand Down Expand Up @@ -231,6 +238,11 @@ def __init__(self, *args, **kwargs):
help="Increases verbosity of dockrice.",
action="store_true",
)
self.add_argument(
"--gpus",
help="Adds gpu capability to the docker container.",
type=str,
)

def add_argument(self, *args, **kwargs):
kwargs["action"] = self._docker_action_factory(
Expand Down
2 changes: 1 addition & 1 deletion dockrice/dockerpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def remove_prefix(string, prefix):
return string[(len(prefix) if string.startswith(prefix) else 0):]
return string[(len(prefix) if string.startswith(prefix) else 0) :]


class MountOption(Enum):
Expand Down
22 changes: 22 additions & 0 deletions dockrice/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,25 @@ def run_image(image, cmd, client=None, return_logs=False, auto_remove=True, **kw
for line in container.logs(stream=True):
print(line.decode("utf-8"), end="", flush=True)
return container.wait()["StatusCode"]


def resolve_gpu_device(arg):
"""Takes a string, similar to the docker --gpus flag and converts it to dockerpy.

Parameters
----------
args: str
The string that would be passed to the --gpus flag: "all" or "device=...".

Returns:
--------
a list of DeviceRequest objects
"""

if arg == "all":
return [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])]
elif arg.startswith("device="):
devices = arg.replace("device=", "").split(",")
return [docker.types.DeviceRequest(device_ids=devices, capabilities=[["gpu"]])]

raise ValueError(f"Unknown gpu device value: {arg}")