Skip to content

Commit

Permalink
Allow name to be None when creating experiments (#257)
Browse files Browse the repository at this point in the history
* allow overloads of method

* ignore line length errors

* ignore

* optional
  • Loading branch information
epwalsh authored Sep 25, 2023
1 parent 53ff02e commit cad6f76
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 10 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ use patch releases for compatibility fixes instead.

## Unreleased

### Changed

- Allow experiment name to be `None` when creating new experiments via `Beaker.experiment.create()`.

## [v1.21.0](https://github.com/allenai/beaker-py/releases/tag/v1.21.0) - 2023-09-08

### Added
Expand Down
59 changes: 51 additions & 8 deletions beaker/services/experiment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -9,6 +10,7 @@
Optional,
Sequence,
Set,
Tuple,
Union,
)

Expand Down Expand Up @@ -64,20 +66,57 @@ def _get(id: str) -> Experiment:
return matches[0]
raise

def _parse_create_args(
self, *args, **kwargs
) -> Tuple[Union[ExperimentSpec, PathOrStr], Optional[str], Optional[Union[Workspace, str]]]:
spec: Optional[Union[ExperimentSpec, PathOrStr]] = kwargs.pop("spec", None)
name: Optional[str] = kwargs.pop("name", None)
workspace: Optional[Union[Workspace, str]] = kwargs.pop("workspace", None)
if len(args) == 2:
if name is not None or spec is not None:
raise TypeError(
"ExperimentClient.create() got an unexpected number of positional arguments"
)
if isinstance(args[0], str) and isinstance(args[1], (ExperimentSpec, Path, str)):
name, spec = args
elif isinstance(args[0], (ExperimentSpec, Path, str)) and isinstance(args[1], str):
spec, name = args
else:
raise TypeError("ExperimentClient.create() got an unexpected positional argument")
elif len(args) == 1:
if spec is None:
spec = args[0]
elif name is None:
name = args[0]
else:
raise TypeError("ExperimentClient.create() got an unexpected positional argument")
else:
raise TypeError(
"ExperimentClient.create() got an unexpected number of positional arguments"
)
if kwargs:
raise TypeError(
f"ExperimentClient.create() got unexpected keyword arguments {tuple(kwargs.keys())}"
)
assert spec is not None
return spec, name, workspace

def create(
self,
name: str,
spec: Union[ExperimentSpec, PathOrStr],
workspace: Optional[Union[Workspace, str]] = None,
*args,
**kwargs,
) -> Experiment:
"""
Create a new Beaker experiment with the given ``spec``.
:param name: The name to assign the experiment.
:param spec: The spec for the Beaker experiment. This can either be an
:class:`~beaker.data_model.experiment_spec.ExperimentSpec` instance or the path to a YAML spec file.
:param workspace: The workspace to create the experiment under. If not specified,
:type spec: :class:`~beaker.data_model.experiment_spec.ExperimentSpec` | :class:`~pathlib.Path` | :class:`str`
:param name: An optional name to assign the experiment. Must be unique.
:type name: :class:`str`, optional
:param workspace: An optional workspace to create the experiment under. If not specified,
:data:`Beaker.config.default_workspace <beaker.Config.default_workspace>` is used.
:type workspace: :class:`~beaker.data_model.workspace.Workspace` | :class:`str`, optional
:raises ValueError: If the name is invalid.
:raises ExperimentConflict: If an experiment with the given name already exists.
Expand All @@ -88,7 +127,11 @@ def create(
Beaker server.
"""
self.validate_beaker_name(name)
spec, name, workspace = self._parse_create_args(*args, **kwargs)
# For backwards compatibility we parse out the arguments like this to allow for `create(name, spec)`
# or just `create(spec)`.
if name is not None:
self.validate_beaker_name(name)
if not isinstance(spec, ExperimentSpec):
spec = ExperimentSpec.from_file(spec)
spec.validate()
Expand All @@ -98,9 +141,9 @@ def create(
experiment_data = self.request(
f"workspaces/{workspace.id}/experiments",
method="POST",
query={"name": name},
query=None if name is None else {"name": name},
data=json_spec,
exceptions_for_status={409: ExperimentConflict(name)},
exceptions_for_status=None if name is None else {409: ExperimentConflict(name)},
).json()
return self.get(experiment_data["id"])

Expand Down
2 changes: 1 addition & 1 deletion beaker/services/service_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def make_request(session: requests.Session) -> requests.Response:
# Raise a BeakerError if we're misusing the API (4xx error code).
raise BeakerError(msg)
elif msg is not None:
raise HTTPError(msg, response=response)
raise HTTPError(msg, response=response) # type: ignore
else:
raise

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ multi_line_output = 3

[tool.ruff]
line-length = 115
ignore = ["F403", "F405"]
ignore = ["E501", "F403", "F405"]

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
Expand Down
22 changes: 22 additions & 0 deletions tests/experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,28 @@
)


def test_parse_create_args(client: Beaker):
spec, name, workspace = client.experiment._parse_create_args(
"my-experiment", ExperimentSpec.new(docker_image="hello-world")
)
assert workspace is None
assert name == "my-experiment"

spec, name, workspace = client.experiment._parse_create_args(
ExperimentSpec.new(docker_image="hello-world")
)
assert workspace is None
assert name is None
assert spec is not None

spec, name, workspace = client.experiment._parse_create_args(
ExperimentSpec.new(docker_image="hello-world"), name="my-experiment", workspace="ai2/petew"
)
assert workspace == "ai2/petew"
assert name == "my-experiment"
assert spec is not None


def test_experiment_get(client: Beaker, hello_world_experiment_id: str):
exp = client.experiment.get(hello_world_experiment_id)
assert exp.id == hello_world_experiment_id
Expand Down

0 comments on commit cad6f76

Please sign in to comment.