diff --git a/.circleci/real_config.yml b/.circleci/real_config.yml index d44e35ce566..ef5c85c34db 100644 --- a/.circleci/real_config.yml +++ b/.circleci/real_config.yml @@ -630,6 +630,7 @@ commands: command: | pytest -vv -s \ -m <> \ + --no-compare-stats \ --det-version="<>" deploy-aws-cluster: @@ -2919,6 +2920,7 @@ workflows: devcluster-config: agent-no-connection.devcluster.yaml target-stage: agent wait-for-master: false + extra-pytest-flags: "--no-compare-stats" - test-perf: name: test-perf diff --git a/.circleci/scripts/wait_for_master.py b/.circleci/scripts/wait_for_master.py index 3047c5c5abb..fe15c029679 100644 --- a/.circleci/scripts/wait_for_master.py +++ b/.circleci/scripts/wait_for_master.py @@ -4,15 +4,16 @@ import requests from determined.common import api -from determined.common.api import certs +from determined.common.api import authentication, certs def _wait_for_master(address: str) -> None: print("Checking for master at", address) cert = certs.Cert(noverify=True) + sess = api.UnauthSession(address, cert) for _ in range(150): try: - r = api.get(address, "info", authenticated=False, cert=cert) + r = sess.get("info") if r.status_code == requests.codes.ok: return except api.errors.MasterNotFoundException: diff --git a/bindings/generate_bindings_py.py b/bindings/generate_bindings_py.py index 2d931c5d5d3..1b57be0aaf6 100644 --- a/bindings/generate_bindings_py.py +++ b/bindings/generate_bindings_py.py @@ -205,7 +205,7 @@ def gen_function(func: swagger_parser.Function) -> Code: out = [f"def {func.operation_name_sc()}("] # Function parameters. - out += [' session: "api.Session",'] + out += [' session: "api.BaseSession",'] if func.params: out += [" *,"] diff --git a/docs/manage/users.rst b/docs/manage/users.rst index 40d653e61c1..aa3dea7cdbc 100644 --- a/docs/manage/users.rst +++ b/docs/manage/users.rst @@ -226,6 +226,19 @@ user and group on an agent can be configured in the ``master.yaml`` file located group: root gid: 0 +.. note:: + + A writable ``HOME`` directory is required by all Determined tasks. By default, all official + Determined images contain a tool called ``libnss_determined`` that injects users into the + container at runtime. If you are building custom images using a base image other than those + provided by Determined, you may need to take one of the following steps: + + - prebuild all users you might need into your custom image, or + - include ``libnss_determined`` in your custom image to ensure user injection works as + expected, or + - find an alternate solution that serves the same purpose of injecting users into the + container at runtime + .. _run-unprivileged-tasks: *********************************** @@ -235,8 +248,8 @@ user and group on an agent can be configured in the ``master.yaml`` file located Some administrators of Determined may wish to run tasks as unprivileged users by default. In Linux, unprivileged processes are sometimes run under the `nobody `_ user, which has very few privileges. However, -the ``nobody`` user does not have a writable ``HOME`` directory, which causes problems for some -common tools like ``gsutil``. +the ``nobody`` user does not have a writable ``HOME`` directory, which is a requirement for tasks in +Determined, so the ``nobody`` user will typically not work in Determined. For convenience, the default Determined environments contain an unprivileged user named ``det-nobody``, which does have a writable ``HOME`` directory. The ``det-nobody`` user is a suitable diff --git a/docs/release-notes/8437-auth-refactor.rst b/docs/release-notes/8437-auth-refactor.rst new file mode 100644 index 00000000000..4d640390f19 --- /dev/null +++ b/docs/release-notes/8437-auth-refactor.rst @@ -0,0 +1,18 @@ +:orphan: + +**Removed Features** + +- **Breaking Change** Removed the accidentally-exposed Session object from the + ``det.experimental.client`` namespace. It was never meant to be a public API and it was not + documented in :ref:`python-sdk`, but was nonetheless exposed in that namespace. It was also + available as a deprecated legacy alias, ``det.experimental.Session``. It is expected that most + users use the Python SDK normally and are unaffected by this change, since the + ``det.experimental.client``'s ``login()`` and ``Determined()`` are unaffected. + +- **Breaking Change** Add a new requirement for runtime configurations that there be a writable + ``$HOME`` directory in every container. Previously, there was limited support for containers + without a writable ``$HOME``, merely by coincidence. This change could impact users in scenarios + where jobs were configured to run as the ``nobody`` user inside a container, instead of the + ``det-nobody`` alternative recommended in :ref:`run-unprivileged-tasks`. Users combining non-root + tasks with custom images not based on Determined's official images may also be affected. Overall, + it is expected that few or no users are affected by this change. diff --git a/e2e_tests/.flake8 b/e2e_tests/.flake8 index 47e8d65a242..0e272b446f6 100644 --- a/e2e_tests/.flake8 +++ b/e2e_tests/.flake8 @@ -7,64 +7,11 @@ max-line-length = 100 # be used by the importer.) per-file-ignores = __init__.py:F401,I2041 - tests/api_utils.py:I2041 - tests/cluster/abstract_cluster.py:I2041 - tests/cluster_log_manager.py:I2041 - tests/cluster/managed_cluster_k8s.py:I2041 - tests/cluster/managed_cluster.py:I2041 - tests/cluster/managed_slurm_cluster.py:I2041 - tests/cluster/test_agent_disable.py:I2041 - tests/cluster/test_agent.py:I2041 - tests/cluster/test_agent_restart.py:I2041 - tests/cluster/test_agent_user_group.py:I2041 - tests/cluster/test_checkpoints.py:I2041 - tests/cluster/test_groups.py:I2041 - tests/cluster/test_job_queue.py:I2041 - tests/cluster/test_master_restart.py:I2041 - tests/cluster/test_master_restart_slurm.py:I2041 - tests/cluster/test_model_registry.py:I2041 - tests/cluster/test_model_registry_rbac.py:I2041 - tests/cluster/test_oauth2_scim_client.py:I2041 - tests/cluster/test_priority_scheduler.py:I2041 - tests/cluster/test_proxy.py:I2041 - tests/cluster/test_rbac_misc.py:I2041 - tests/cluster/test_rbac_ntsc.py:I2041 - tests/cluster/test_rbac.py:I2041 - tests/cluster/test_resource_manager.py:I2041 - tests/cluster/test_slurm.py:I2041 - tests/cluster/test_users.py:I2041 - tests/cluster/test_webhooks.py:I2041 - tests/cluster/test_workspace_org.py:I2041 - tests/cluster/utils.py:I2041 - tests/command/command.py:I2041 - tests/command/test_notebook.py:I2041 - tests/command/test_run.py:I2041 - tests/command/test_shell.py:I2041 - tests/command/test_tensorboard.py:I2041 - tests/config.py:I2041 - tests/conftest.py:I2041 - tests/deploy/test_local.py:I2041 - tests/experiment/experiment.py:I2041 - tests/experiment/record_profiling.py:I2041 - tests/experiment/test_allocation_csv.py:I2041 - tests/experiment/test_api.py:I2041 - tests/experiment/test_core.py:I2041 - tests/experiment/test_hpc_launch.py:I2041 - tests/experiment/test_launch.py:I2041 - tests/experiment/test_noop.py:I2041 - tests/experiment/test_pending_hpc.py:I2041 - tests/experiment/test_port_registry.py:I2041 - tests/experiment/test_profiling.py:I2041 - tests/filetree.py:I2041 tests/fixtures/mnist_pytorch/failable_model_def.py:I2041 tests/fixtures/mnist_pytorch/layers.py:I2041 tests/fixtures/mnist_pytorch/model_def.py:I2041 tests/fixtures/mnist_pytorch/stop_requested_model_def.py:I2041 tests/fixtures/trial_error/model_def.py:I2041 - tests/job/test_rbac.py:I2041 - tests/nightly/compute_stats.py:I2041 - tests/template/test_template.py:I2041 - tests/test_sdk.py:I2041 # Explanations for ignored error codes: # - D1* (no missing docstrings): too much effort to start enforcing diff --git a/e2e_tests/tests/api_utils.py b/e2e_tests/tests/api_utils.py index ea4bb584c0e..163f1f6ca5c 100644 --- a/e2e_tests/tests/api_utils.py +++ b/e2e_tests/tests/api_utils.py @@ -1,5 +1,5 @@ import uuid -from typing import Callable, Optional, Sequence, TypeVar +from typing import Callable, Optional, Sequence, Tuple, TypeVar import pytest @@ -7,41 +7,60 @@ from determined.common.api import authentication, bindings, certs, errors from tests import config as conf +_cert: Optional[certs.Cert] = None -def get_random_string() -> str: - return str(uuid.uuid4()) +def cert() -> certs.Cert: + global _cert + if _cert is None: + _cert = certs.default_load(conf.make_master_url()) + return _cert -def determined_test_session( - credentials: Optional[authentication.Credentials] = None, - admin: Optional[bool] = None, -) -> api.Session: - assert admin is None or credentials is None, "admin and credentials are mutually exclusive" - if credentials is None: - if admin: - credentials = conf.ADMIN_CREDENTIALS - else: - credentials = authentication.Credentials("determined", "") +def make_session(username: str, password: str) -> api.Session: + master_url = conf.make_master_url() + # Use login instead of login_with_cache() to not touch auth.json on the filesystem. + utp = authentication.login(master_url, username, password, cert()) + return api.Session(master_url, utp, cert()) - murl = conf.make_master_url() - certs.cli_cert = certs.default_load(murl) - authentication.cli_auth = authentication.Authentication( - murl, requested_user=credentials.username, password=credentials.password - ) - return api.Session(murl, credentials.username, authentication.cli_auth, certs.cli_cert) + +_user_session: Optional[api.Session] = None + + +def user_session() -> api.Session: + global _user_session + if _user_session is None: + _user_session = make_session("determined", "") + return _user_session + + +_admin_session: Optional[api.Session] = None + + +def admin_session() -> api.Session: + global _admin_session + if _admin_session is None: + _admin_session = make_session("admin", "") + return _admin_session + + +def get_random_string() -> str: + return str(uuid.uuid4()) def create_test_user( - add_password: bool = False, - session: Optional[api.Session] = None, user: Optional[bindings.v1User] = None, -) -> authentication.Credentials: - session = session or determined_test_session(admin=True) - user = user or bindings.v1User(username=get_random_string(), admin=False, active=True) - password = get_random_string() if add_password else "" +) -> Tuple[api.Session, str]: + """ + Returns a tuple of (Session, password). + """ + session = admin_session() + username = get_random_string() + user = user or bindings.v1User(username=username, admin=False, active=True) + password = get_random_string() bindings.post_PostUser(session, body=bindings.v1PostUserRequest(user=user, password=password)) - return authentication.Credentials(user.username, password) + sess = make_session(username, password) + return sess, password def assign_user_role(session: api.Session, user: str, role: str, workspace: Optional[str]) -> None: @@ -62,17 +81,6 @@ def assign_group_role( bindings.post_AssignRoles(session, body=req) -def configure_token_store(credentials: authentication.Credentials) -> None: - """Authenticate the user for CLI usage with the given credentials.""" - token_store = authentication.TokenStore(conf.make_master_url()) - certs.cli_cert = certs.default_load(conf.make_master_url()) - token = authentication.do_login( - conf.make_master_url(), credentials.username, credentials.password, certs.cli_cert - ) - token_store.set_token(credentials.username, token) - token_store.set_active(credentials.username) - - def launch_ntsc( session: api.Session, workspace_id: int, @@ -164,6 +172,38 @@ def list_ntsc( raise ValueError("unknown type") +F = TypeVar("F", bound=Callable) + + +_is_k8s: Optional[bool] = None + + +def _get_is_k8s() -> Optional[bool]: + global _is_k8s + + if _is_k8s is None: + try: + admin = admin_session() + resp = bindings.get_GetMasterConfig(admin) + _is_k8s = resp.config["resource_manager"]["type"] == "kubernetes" + except (errors.APIException, errors.MasterNotFoundException): + pass + + return _is_k8s + + +def skipif_not_k8s(reason: str = "test is k8s-specific") -> Callable[[F], F]: + def decorator(f: F) -> F: + is_k8s = _get_is_k8s() + if is_k8s is None: + return f + if not is_k8s: + return pytest.mark.skipif(True, reason=reason)(f) # type: ignore + return f + + return decorator + + _scheduler_type: Optional[bindings.v1SchedulerType] = None @@ -174,7 +214,7 @@ def _get_scheduler_type() -> Optional[bindings.v1SchedulerType]: global _scheduler_type if _scheduler_type is None: try: - sess = determined_test_session() + sess = user_session() resourcePool = bindings.get_GetResourcePools(sess).resourcePools if not resourcePool: raise ValueError( @@ -186,9 +226,6 @@ def _get_scheduler_type() -> Optional[bindings.v1SchedulerType]: return _scheduler_type -F = TypeVar("F", bound=Callable) - - def skipif_not_hpc(reason: str = "test is hpc-specific") -> Callable[[F], F]: def decorator(f: F) -> F: st = _get_scheduler_type() @@ -240,7 +277,8 @@ def _get_ee() -> Optional[bool]: if _is_ee is None: try: - info = api.get(conf.make_master_url(), "info", authenticated=False).json() + sess = api.UnauthSession(conf.make_master_url(), cert()) + info = sess.get("info").json() _is_ee = "sso_providers" in info except (errors.APIException, errors.MasterNotFoundException): pass @@ -280,7 +318,8 @@ def _get_scim_enabled() -> Optional[bool]: if _scim_enabled is None: try: - info = api.get(conf.make_master_url(), "info", authenticated=False).json() + sess = api.UnauthSession(conf.make_master_url(), cert()) + info = sess.get("info").json() _scim_enabled = bool(info.get("sso_providers") and len(info["sso_providers"]) > 0) except (errors.APIException, errors.MasterNotFoundException): pass @@ -298,3 +337,62 @@ def decorator(f: F) -> F: return f return decorator + + +_rbac_enabled: Optional[bool] = None + + +def _get_rbac_enabled() -> Optional[bool]: + global _rbac_enabled + + if _rbac_enabled is None: + try: + sess = api.UnauthSession(conf.make_master_url(), cert()) + _rbac_enabled = bindings.get_GetMaster(sess).rbacEnabled + except (errors.APIException, errors.MasterNotFoundException): + pass + + return _rbac_enabled + + +def skipif_rbac_not_enabled(reason: str = "ee is required for this test") -> Callable[[F], F]: + def decorator(f: F) -> F: + re = _get_rbac_enabled() + if re is None: + return f + if not re: + return pytest.mark.skipif(True, reason=reason)(f) # type: ignore + return f + + return decorator + + +_strict_q: Optional[bool] = None + + +def _get_strict_q() -> Optional[bool]: + global _strict_q + + if _strict_q is None: + try: + sess = api.UnauthSession(conf.make_master_url(), cert()) + resp = bindings.get_GetMaster(sess) + _strict_q = resp.rbacEnabled and resp.strictJobQueueControl + except (errors.APIException, errors.MasterNotFoundException): + pass + + return _strict_q + + +def skipif_strict_q_control_not_enabled( + reason: str = "rbac and strict queue control are required for this test", +) -> Callable[[F], F]: + def decorator(f: F) -> F: + sq = _get_strict_q() + if sq is None: + return f + if not sq: + return pytest.mark.skipif(True, reason=reason)(f) # type: ignore + return f + + return decorator diff --git a/e2e_tests/tests/cluster/abstract_cluster.py b/e2e_tests/tests/cluster/abstract_cluster.py index dfbd650a205..f021857a2f3 100644 --- a/e2e_tests/tests/cluster/abstract_cluster.py +++ b/e2e_tests/tests/cluster/abstract_cluster.py @@ -1,7 +1,7 @@ import abc -from pathlib import Path +import pathlib -DEVCLUSTER_LOG_PATH = Path("/tmp/devcluster") +DEVCLUSTER_LOG_PATH = pathlib.Path("/tmp/devcluster") class Cluster(metaclass=abc.ABCMeta): diff --git a/e2e_tests/tests/cluster/managed_cluster.py b/e2e_tests/tests/cluster/managed_cluster.py index 3c87dbd7aab..c41f1fc077b 100644 --- a/e2e_tests/tests/cluster/managed_cluster.py +++ b/e2e_tests/tests/cluster/managed_cluster.py @@ -1,4 +1,3 @@ -import json import os import subprocess import time @@ -6,11 +5,11 @@ import pytest +from determined.common import api +from tests import api_utils from tests import config as conf - -from .abstract_cluster import Cluster -from .test_users import logged_in_user -from .utils import now_ts, set_master_port +from tests import detproc +from tests.cluster import abstract_cluster, utils DEVCLUSTER_CONFIG_ROOT_PATH = conf.PROJECT_ROOT_PATH.joinpath(".circleci/devcluster") DEVCLUSTER_REATTACH_OFF_CONFIG_PATH = DEVCLUSTER_CONFIG_ROOT_PATH / "double.devcluster.yaml" @@ -18,22 +17,22 @@ DEVCLUSTER_PRIORITY_SCHEDULER_CONFIG_PATH = DEVCLUSTER_CONFIG_ROOT_PATH / "priority.devcluster.yaml" -def get_agent_data(master_url: str) -> List[Dict[str, Any]]: - command = ["det", "-m", master_url, "agent", "list", "--json"] - output = subprocess.check_output(command).decode() - agent_data = cast(List[Dict[str, Any]], json.loads(output)) +def get_agent_data(sess: api.Session) -> List[Dict[str, Any]]: + command = ["det", "agent", "list", "--json"] + output = detproc.check_json(sess, command) + agent_data = cast(List[Dict[str, Any]], output) return agent_data -class ManagedCluster(Cluster): +class ManagedCluster(abstract_cluster.Cluster): # This utility wrapper uses double agent yaml configurations, # but provides helpers to run/kill a single agent setup. def __init__(self, config: Union[str, Dict[str, Any]]) -> None: # Strategically only import devcluster on demand to avoid having it as a hard dependency. - from devcluster import Devcluster # noqa: I2000 + import devcluster # noqa: I2000 - self.dc = Devcluster(config=config) + self.dc = devcluster.Devcluster(config=config) def __enter__(self) -> "ManagedCluster": self.old_cd = os.getcwd() @@ -58,8 +57,9 @@ def kill_agent(self) -> None: self.dc.kill_stage("agent1") WAIT_FOR_KILL = 5 + sess = api_utils.user_session() for _i in range(WAIT_FOR_KILL): - agent_data = get_agent_data(conf.make_master_url()) + agent_data = get_agent_data(sess) if len(agent_data) == 0: break if len(agent_data) == 1 and agent_data[0]["draining"] is True: @@ -69,7 +69,8 @@ def kill_agent(self) -> None: pytest.fail(f"Agent is still present after {WAIT_FOR_KILL} seconds") def restart_agent(self, wait_for_amnesia: bool = True, wait_for_agent: bool = True) -> None: - agent_data = get_agent_data(conf.make_master_url()) + sess = api_utils.user_session() + agent_data = get_agent_data(sess) if len(agent_data) == 1 and agent_data[0]["enabled"]: return @@ -78,7 +79,7 @@ def restart_agent(self, wait_for_amnesia: bool = True, wait_for_agent: bool = Tr # Currently, we've got to wait for master to "forget" the agent before reconnecting. WAIT_FOR_AMNESIA = 60 for _i in range(WAIT_FOR_AMNESIA): - agent_data = get_agent_data(conf.make_master_url()) + agent_data = get_agent_data(sess) if len(agent_data) == 0: break time.sleep(1) @@ -95,10 +96,11 @@ def kill_proxy(self) -> None: subprocess.run(["killall", "socat"]) def restart_proxy(self, wait_for_reconnect: bool = True) -> None: + sess = api_utils.user_session() self.dc.restart_stage("proxy") if wait_for_reconnect: for _i in range(25): - agent_data = get_agent_data(conf.make_master_url()) + agent_data = get_agent_data(sess) if ( len(agent_data) == 1 and agent_data[0]["enabled"] is True @@ -110,14 +112,16 @@ def restart_proxy(self, wait_for_reconnect: bool = True) -> None: pytest.fail(f"Agent didn't reconnect after {_i} seconds") def ensure_agent_ok(self) -> None: - agent_data = get_agent_data(conf.make_master_url()) + sess = api_utils.user_session() + agent_data = get_agent_data(sess) assert len(agent_data) == 1 assert agent_data[0]["enabled"] is True assert agent_data[0]["draining"] is False def wait_for_agent_ok(self, ticks: int) -> None: + sess = api_utils.user_session() for _i in range(ticks): - agent_data = get_agent_data(conf.make_master_url()) + agent_data = get_agent_data(sess) if ( len(agent_data) == 1 and agent_data[0]["enabled"] is True @@ -129,14 +133,8 @@ def wait_for_agent_ok(self, ticks: int) -> None: pytest.fail(f"Agent didn't restart after {ticks} seconds") def fetch_config(self) -> Dict: - with logged_in_user(conf.ADMIN_CREDENTIALS): - master_config = json.loads( - subprocess.run( - ["det", "-m", conf.make_master_url(), "master", "config", "show", "--json"], - stdout=subprocess.PIPE, - check=True, - ).stdout.decode() - ) + admin = api_utils.admin_session() + master_config = detproc.check_json(admin, ["det", "master", "config", "show", "--json"]) return cast(Dict, master_config) def fetch_config_reattach_wait(self) -> float: @@ -165,12 +163,14 @@ def managed_cluster_priority_scheduler( managed_cluster_session_priority_scheduler: ManagedCluster, request: Any ) -> Iterator[ManagedCluster]: config = str(DEVCLUSTER_PRIORITY_SCHEDULER_CONFIG_PATH) - set_master_port(config) + utils.set_master_port(config) nodeid = request.node.nodeid - managed_cluster_session_priority_scheduler.log_marker(f"pytest [{now_ts()}] {nodeid} setup\n") + managed_cluster_session_priority_scheduler.log_marker( + f"pytest [{utils.now_ts()}] {nodeid} setup\n" + ) yield managed_cluster_session_priority_scheduler managed_cluster_session_priority_scheduler.log_marker( - f"pytest [{now_ts()}] {nodeid} teardown\n" + f"pytest [{utils.now_ts()}] {nodeid} teardown\n" ) @@ -180,11 +180,11 @@ def managed_cluster_restarts( ) -> Iterator[ManagedCluster]: # check if priority scheduler or not using config. config = str(DEVCLUSTER_REATTACH_ON_CONFIG_PATH) # port number is same for both reattach on and off config files so you can use either. - set_master_port(config) + utils.set_master_port(config) nodeid = request.node.nodeid - managed_cluster_session.log_marker(f"pytest [{now_ts()}] {nodeid} setup\n") + managed_cluster_session.log_marker(f"pytest [{utils.now_ts()}] {nodeid} setup\n") yield managed_cluster_session - managed_cluster_session.log_marker(f"pytest [{now_ts()}] {nodeid} teardown\n") + managed_cluster_session.log_marker(f"pytest [{utils.now_ts()}] {nodeid} teardown\n") @pytest.fixture diff --git a/e2e_tests/tests/cluster/managed_cluster_k8s.py b/e2e_tests/tests/cluster/managed_cluster_k8s.py index f67d81a3ca1..5a046645d33 100644 --- a/e2e_tests/tests/cluster/managed_cluster_k8s.py +++ b/e2e_tests/tests/cluster/managed_cluster_k8s.py @@ -4,24 +4,21 @@ import pytest from kubernetes import client, config, watch -from tests import config as conf +from tests import api_utils, detproc +from tests.cluster import abstract_cluster, managed_cluster, utils -from .abstract_cluster import Cluster -from .managed_cluster import get_agent_data -from .test_groups import det_cmd -from .utils import run_command, wait_for_command_state - -class ManagedK8sCluster(Cluster): +class ManagedK8sCluster(abstract_cluster.Cluster): def __init__(self) -> None: + sess = api_utils.user_session() config.load_kube_config() self.v1 = client.AppsV1Api() self._scale_master(up=True) # Verify we have pulled our image. # TODO this won't work if we have multiple nodes. - wait_for_command_state(run_command(0, slots=0), "TERMINATED", 300) - wait_for_command_state(run_command(0, slots=1), "TERMINATED", 300) + utils.wait_for_command_state(sess, utils.run_command(sess, 0, slots=0), "TERMINATED", 300) + utils.wait_for_command_state(sess, utils.run_command(sess, 0, slots=1), "TERMINATED", 300) def kill_master(self) -> None: self._scale_master(up=False) @@ -75,10 +72,11 @@ def _scale_master(self, up: bool) -> None: return # Wait for determined to be up. + sess = api_utils.user_session() WAIT_FOR_UP = 60 for _ in range(WAIT_FOR_UP): try: - assert len(get_agent_data(conf.make_master_url())) > 0 + assert len(managed_cluster.get_agent_data(sess)) > 0 return except Exception as e: print(f"Can't reach master, retrying again {e}") @@ -88,9 +86,10 @@ def _scale_master(self, up: bool) -> None: @pytest.fixture def k8s_managed_cluster() -> Iterator[ManagedK8sCluster]: + sess = api_utils.user_session() cluster = ManagedK8sCluster() cluster._scale_master(up=True) yield cluster cluster._scale_master(up=True) - print("Master logs: ", det_cmd(["master", "logs"], check=True).stdout.decode("utf-8")) + print("Master logs: ", detproc.check_output(sess, ["det", "master", "logs"])) diff --git a/e2e_tests/tests/cluster/managed_slurm_cluster.py b/e2e_tests/tests/cluster/managed_slurm_cluster.py index 4dddb3d0717..0d04df9968b 100644 --- a/e2e_tests/tests/cluster/managed_slurm_cluster.py +++ b/e2e_tests/tests/cluster/managed_slurm_cluster.py @@ -1,21 +1,19 @@ import os +import shlex import subprocess import time -from shlex import split as sh_split from typing import Any, Iterator import pytest from tests import config as conf - -from .abstract_cluster import Cluster -from .utils import now_ts +from tests.cluster import abstract_cluster, utils # ManagedSlurmCluster is an implementation of the abstract class Cluster, to suit a slurm based # devcluster instance. It is used as part of the e2e slurm tests that require the master to be # restarted. -class ManagedSlurmCluster(Cluster): +class ManagedSlurmCluster(abstract_cluster.Cluster): def __init__(self) -> None: self.is_circleci_job = os.getenv("IS_CIRCLECI_JOB") self.dc = None @@ -33,7 +31,7 @@ def kill_master(self) -> None: if self.is_circleci_job: # Use the pre-installed determined master service when running the tests as part of a # CircleCI job. - subprocess.run(sh_split("sudo systemctl stop determined-master")) + subprocess.run(shlex.split("sudo systemctl stop determined-master")) else: # Use the local instance of devcluster. if self.dc: @@ -54,7 +52,7 @@ def _start_devcluster(self) -> None: if self.is_circleci_job: # Use the pre-installed determined master service when running the tests as part # of a CircleCI job. - subprocess.run(sh_split("sudo systemctl start determined-master")) + subprocess.run(shlex.split("sudo systemctl start determined-master")) else: # Use a local instance of the devcluster. master_config_file = os.getenv("MASTER_CONFIG_FILE") @@ -107,6 +105,6 @@ def managed_slurm_cluster_restarts( # Local instance of devcluster is run on port 8081 conf.MASTER_PORT = "8081" nodeid = request.node.nodeid - managed_slurm_cluster_session.log_marker(f"pytest [{now_ts()}] {nodeid} setup\n") + managed_slurm_cluster_session.log_marker(f"pytest [{utils.now_ts()}] {nodeid} setup\n") yield managed_slurm_cluster_session - managed_slurm_cluster_session.log_marker(f"pytest [{now_ts()}] {nodeid} teardown\n") + managed_slurm_cluster_session.log_marker(f"pytest [{utils.now_ts()}] {nodeid} teardown\n") diff --git a/e2e_tests/tests/cluster/test_agent.py b/e2e_tests/tests/cluster/test_agent.py index f90c6f1f91e..55e9cf57357 100644 --- a/e2e_tests/tests/cluster/test_agent.py +++ b/e2e_tests/tests/cluster/test_agent.py @@ -4,25 +4,20 @@ import pytest import determined -from determined.common import api -from determined.common.api import authentication, certs -from tests import config as conf - -from .managed_cluster import ManagedCluster +from tests import api_utils +from tests.cluster import managed_cluster # TODO: This should be marked as a cross-version test, but it can't actually be at the time of # writing, since older agent versions don't report their versions. @pytest.mark.e2e_cpu def test_agent_version() -> None: - # TODO: refactor tests to not use cli singleton auth. - certs.cli_cert = certs.default_load(conf.make_master_url()) - authentication.cli_auth = authentication.Authentication(conf.make_master_url()) # DET_AGENT_VERSION is available and specifies the agent version in cross-version tests; for # other tests, this evaluates to the current version. target_version = os.environ.get("DET_AGENT_VERSION") or determined.__version__ - agents = api.get(conf.make_master_url(), "api/v1/agents").json()["agents"] + sess = api_utils.user_session() + agents = sess.get("api/v1/agents").json()["agents"] assert all(agent["version"] == target_version for agent in agents) @@ -37,7 +32,7 @@ def test_agent_never_connect() -> None: @pytest.mark.managed_devcluster -def test_agent_fail_reconnect(restartable_managed_cluster: ManagedCluster) -> None: +def test_agent_fail_reconnect(restartable_managed_cluster: managed_cluster.ManagedCluster) -> None: restartable_managed_cluster.kill_proxy() for _ in range(150): # ManagedCluster agents try to reconnect for 24 * 5 seconds. TODO: eh. diff --git a/e2e_tests/tests/cluster/test_agent_disable.py b/e2e_tests/tests/cluster/test_agent_disable.py index 8e45453a656..60b4cf6567a 100644 --- a/e2e_tests/tests/cluster/test_agent_disable.py +++ b/e2e_tests/tests/cluster/test_agent_disable.py @@ -1,81 +1,54 @@ import contextlib -import json -import subprocess import time from typing import Any, Dict, Iterator, List, Optional, cast import pytest -from determined.common.api.bindings import experimentv1State, get_GetSlot +from determined.common import api +from determined.common.api import bindings from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp - -from .test_users import logged_in_user -from .utils import assert_command_succeeded, run_zero_slot_command, wait_for_command_state +from tests.cluster import utils @pytest.mark.e2e_cpu def test_disable_and_enable_slots() -> None: - with logged_in_user(conf.ADMIN_CREDENTIALS): - command = [ - "det", - "-m", - conf.make_master_url(), - "slot", - "list", - "--json", - ] - output = subprocess.check_output(command).decode() - slots = json.loads(output) - assert len(slots) == 1 + sess = api_utils.admin_session() + + command = ["det", "slot", "list", "--json"] + slots = detproc.check_json(sess, command) + assert len(slots) == 1 + + slot_id, agent_id = slots[0]["slot_id"], slots[0]["agent_id"] + + command = ["det", "slot", "disable", agent_id, slot_id] + detproc.check_call(sess, command) + + slot = bindings.get_GetSlot(sess, agentId=agent_id, slotId=slot_id).slot + assert slot is not None + assert slot.enabled is False + + command = ["det", "slot", "enable", agent_id, slot_id] + detproc.check_call(sess, command) - slot_id, agent_id = slots[0]["slot_id"], slots[0]["agent_id"] - - command = [ - "det", - "-m", - conf.make_master_url(), - "slot", - "disable", - agent_id, - slot_id, - ] - subprocess.check_call(command) - - slot = get_GetSlot( - api_utils.determined_test_session(), agentId=agent_id, slotId=slot_id - ).slot - assert slot is not None - assert slot.enabled is False - - command = ["det", "-m", conf.make_master_url(), "slot", "enable", agent_id, slot_id] - subprocess.check_call(command) - - slot = get_GetSlot( - api_utils.determined_test_session(), agentId=agent_id, slotId=slot_id - ).slot - assert slot is not None - assert slot.enabled is True - - -def _fetch_slots() -> List[Dict[str, Any]]: - command = [ - "det", - "-m", - conf.make_master_url(), - "slot", - "list", - "--json", - ] - output = subprocess.check_output(command).decode() - slots = cast(List[Dict[str, str]], json.loads(output)) - return slots - - -def _wait_for_slots(min_slots_expected: int, max_ticks: int = 60 * 2) -> List[Dict[str, Any]]: + slot = bindings.get_GetSlot(sess, agentId=agent_id, slotId=slot_id).slot + assert slot is not None + assert slot.enabled is True + + +def _fetch_slots(sess: api.Session) -> List[Dict[str, Any]]: + command = ["det", "slot", "list", "--json"] + slots = detproc.check_json(sess, command) + return cast(List[Dict[str, str]], slots) + + +def _wait_for_slots( + sess: api.Session, min_slots_expected: int, max_ticks: int = 60 * 2 +) -> List[Dict[str, Any]]: for _ in range(max_ticks): - slots = _fetch_slots() + slots = _fetch_slots(sess) if len(slots) >= min_slots_expected: return slots time.sleep(1) @@ -84,21 +57,19 @@ def _wait_for_slots(min_slots_expected: int, max_ticks: int = 60 * 2) -> List[Di @contextlib.contextmanager -def _disable_agent(agent_id: str, drain: bool = False, json: bool = False) -> Iterator[str]: +def _disable_agent( + sess: api.Session, agent_id: str, drain: bool = False, json: bool = False +) -> Iterator[str]: command = ( - ["det", "-m", conf.make_master_url(), "agent", "disable"] + ["det", "agent", "disable"] + (["--drain"] if drain else []) + (["--json"] if json else []) + [agent_id] ) try: - with logged_in_user(conf.ADMIN_CREDENTIALS): - out = subprocess.check_output(command).decode() - yield out + yield detproc.check_output(sess, command) finally: - with logged_in_user(conf.ADMIN_CREDENTIALS): - command = ["det", "-m", conf.make_master_url(), "agent", "enable", agent_id] - subprocess.check_call(command) + detproc.check_call(sess, ["det", "agent", "enable", agent_id]) @pytest.mark.e2e_cpu @@ -108,28 +79,33 @@ def test_disable_agent_experiment_resume() -> None: Start an experiment with max_restarts=0 and ensure that being killed due to an explicit agent disable/enable (without draining) does not count toward the number of restarts. """ - slots = _fetch_slots() + admin = api_utils.admin_session() + sess = api_utils.user_session() + slots = _fetch_slots(admin) assert len(slots) == 1 agent_id = slots[0]["agent_id"] exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), ["--config", "max_restarts=0"], ) - exp.wait_for_experiment_state(exp_id, experimentv1State.RUNNING, max_wait_secs=300) + exp.wait_for_experiment_state( + sess, exp_id, bindings.experimentv1State.RUNNING, max_wait_secs=300 + ) - with _disable_agent(agent_id): + with _disable_agent(admin, agent_id): # Wait for the allocation to go away. for _ in range(20): - slots = _fetch_slots() + slots = _fetch_slots(admin) print(slots) if not any(s["allocation_id"] != "FREE" for s in slots): break time.sleep(1) else: pytest.fail("Experiment stayed scheduled after agent was disabled") - exp.wait_for_experiment_state(exp_id, experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) @pytest.mark.e2e_cpu @@ -139,21 +115,23 @@ def test_disable_agent_zero_slots() -> None: Start a command, disable the agent it's running on. The command should then be terminated promptly. """ - slots = _fetch_slots() + admin = api_utils.admin_session() + sess = api_utils.user_session() + slots = _fetch_slots(admin) assert len(slots) == 1 agent_id = slots[0]["agent_id"] - command_id = run_zero_slot_command(sleep=180) + command_id = utils.run_zero_slot_command(sess, sleep=180) # Wait for it to run. - wait_for_command_state(command_id, "RUNNING", 300) + utils.wait_for_command_state(sess, command_id, "RUNNING", 300) try: - with _disable_agent(agent_id): - wait_for_command_state(command_id, "TERMINATED", 30) + with _disable_agent(admin, agent_id): + utils.wait_for_command_state(sess, command_id, "TERMINATED", 30) finally: # Kill the command before failing so it does not linger. - command = ["det", "-m", conf.make_master_url(), "command", "kill", command_id] - subprocess.check_call(command) + command = ["det", "command", "kill", command_id] + detproc.check_call(sess, command) @pytest.mark.e2e_cpu @@ -163,62 +141,71 @@ def test_drain_agent() -> None: Start an experiment, `disable --drain` the agent once the trial is running, make sure the experiment still finishes, but the new ones won't schedule. """ + admin = api_utils.admin_session() + sess = api_utils.user_session() - slots = _fetch_slots() + slots = _fetch_slots(admin) assert len(slots) == 1 agent_id = slots[0]["agent_id"] experiment_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), ["--config", "hyperparameters.training_batch_seconds=0.15"], # Take 15 seconds. ) - exp.wait_for_experiment_state(experiment_id, experimentv1State.RUNNING, max_wait_secs=300) - exp.wait_for_experiment_active_workload(experiment_id) - exp.wait_for_experiment_workload_progress(experiment_id) + exp.wait_for_experiment_state( + sess, experiment_id, bindings.experimentv1State.RUNNING, max_wait_secs=300 + ) + exp.wait_for_experiment_active_workload(sess, experiment_id) + exp.wait_for_experiment_workload_progress(sess, experiment_id) # Disable and quickly enable it back. - with _disable_agent(agent_id, drain=True): + with _disable_agent(admin, agent_id, drain=True): pass # Try to launch another experiment. It shouldn't get scheduled because the # slot is still busy with the first experiment. experiment_id_no_start = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), None, ) time.sleep(5) - exp.wait_for_experiment_state(experiment_id_no_start, experimentv1State.QUEUED) + exp.wait_for_experiment_state(sess, experiment_id_no_start, bindings.experimentv1State.QUEUED) - with _disable_agent(agent_id, drain=True): + with _disable_agent(admin, agent_id, drain=True): # Ensure the first one has finished with the correct number of workloads. - exp.wait_for_experiment_state(experiment_id, experimentv1State.COMPLETED) - trials = exp.experiment_trials(experiment_id) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.COMPLETED) + trials = exp.experiment_trials(sess, experiment_id) assert len(trials) == 1 assert len(trials[0].workloads) == 7 # Check for 15 seconds it doesn't get scheduled into the same slot. for _ in range(15): - assert exp.experiment_state(experiment_id_no_start) == experimentv1State.QUEUED + assert ( + exp.experiment_state(sess, experiment_id_no_start) + == bindings.experimentv1State.QUEUED + ) time.sleep(1) # Ensure the slot is empty. - slots = _fetch_slots() + slots = _fetch_slots(admin) assert len(slots) == 1 assert slots[0]["enabled"] is False assert slots[0]["draining"] is True assert slots[0]["allocation_id"] == "FREE" # Check agent state. - command = ["det", "-m", conf.make_master_url(), "agent", "list", "--json"] - output = subprocess.check_output(command).decode() - agent_data = cast(List[Dict[str, Any]], json.loads(output))[0] + command = ["det", "agent", "list", "--json"] + output = detproc.check_json(admin, command) + agent_data = cast(List[Dict[str, Any]], output)[0] assert agent_data["id"] == agent_id assert agent_data["enabled"] is False assert agent_data["draining"] is True - exp.kill_single(experiment_id_no_start) + exp.kill_single(sess, experiment_id_no_start) @pytest.mark.e2e_cpu_2a @@ -227,32 +214,36 @@ def test_drain_agent_sched() -> None: Start an experiment, drain it. Start a second one and make sure it schedules on the second agent *before* the first one has finished. """ - slots = _wait_for_slots(2) + admin = api_utils.admin_session() + sess = api_utils.user_session() + slots = _wait_for_slots(admin, 2) assert len(slots) == 2 exp_id1 = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), None, ) - exp.wait_for_experiment_workload_progress(exp_id1) + exp.wait_for_experiment_workload_progress(sess, exp_id1) - slots = _fetch_slots() + slots = _fetch_slots(admin) used_slots = [s for s in slots if s["allocation_id"] != "FREE"] assert len(used_slots) == 1 agent_id1 = used_slots[0]["agent_id"] - with _disable_agent(agent_id1, drain=True): + with _disable_agent(admin, agent_id1, drain=True): exp_id2 = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), None, ) - exp.wait_for_experiment_state(exp_id2, experimentv1State.RUNNING) + exp.wait_for_experiment_state(sess, exp_id2, bindings.experimentv1State.RUNNING) # Wait for a state when *BOTH* experiments are scheduled. for _ in range(20): - slots = _fetch_slots() + slots = _fetch_slots(admin) assert len(slots) == 2 used_slots = [s for s in slots if s["allocation_id"] != "FREE"] if len(used_slots) == 2: @@ -264,24 +255,24 @@ def test_drain_agent_sched() -> None: "while the first agent was draining" ) - exp.wait_for_experiment_state(exp_id1, experimentv1State.COMPLETED) - exp.wait_for_experiment_state(exp_id2, experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id1, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id2, bindings.experimentv1State.COMPLETED) - trials1 = exp.experiment_trials(exp_id1) - trials2 = exp.experiment_trials(exp_id2) + trials1 = exp.experiment_trials(sess, exp_id1) + trials2 = exp.experiment_trials(sess, exp_id2) assert len(trials1) == len(trials2) == 1 assert len(trials1[0].workloads) == len(trials2[0].workloads) == 7 -def _task_data(task_id: str) -> Optional[Dict[str, Any]]: - command = ["det", "-m", conf.make_master_url(), "task", "list", "--json"] - tasks_data: Dict[str, Dict[str, Any]] = json.loads(subprocess.check_output(command).decode()) +def _task_data(sess: api.Session, task_id: str) -> Optional[Dict[str, Any]]: + command = ["det", "task", "list", "--json"] + tasks_data: Dict[str, Dict[str, Any]] = detproc.check_json(sess, command) matches = [t for t in tasks_data.values() if t["taskId"] == task_id] return matches[0] if matches else None -def _task_agents(task_id: str) -> List[str]: - task_data = _task_data(task_id) +def _task_agents(sess: api.Session, task_id: str) -> List[str]: + task_data = _task_data(sess, task_id) if not task_data: return [] return [a for r in task_data.get("resources", []) for a in r["agentDevices"]] @@ -294,24 +285,26 @@ def test_drain_agent_sched_zeroslot() -> None: as well. Wait for them to finish, reenable both agents, and make sure next command schedules and succeeds. """ - slots = _wait_for_slots(2) + admin = api_utils.admin_session() + sess = api_utils.user_session() + slots = _wait_for_slots(admin, 2) assert len(slots) == 2 - command_id1 = run_zero_slot_command(60) - wait_for_command_state(command_id1, "RUNNING", 10) - agent_id1 = _task_agents(command_id1)[0] + command_id1 = utils.run_zero_slot_command(sess, 60) + utils.wait_for_command_state(sess, command_id1, "RUNNING", 10) + agent_id1 = _task_agents(sess, command_id1)[0] - with _disable_agent(agent_id1, drain=True): - command_id2 = run_zero_slot_command(60) - wait_for_command_state(command_id2, "RUNNING", 10) - agent_id2 = _task_agents(command_id2)[0] + with _disable_agent(admin, agent_id1, drain=True): + command_id2 = utils.run_zero_slot_command(sess, 60) + utils.wait_for_command_state(sess, command_id2, "RUNNING", 10) + agent_id2 = _task_agents(sess, command_id2)[0] assert agent_id1 != agent_id2 - with _disable_agent(agent_id2, drain=True): + with _disable_agent(admin, agent_id2, drain=True): for command_id in [command_id1, command_id2]: - wait_for_command_state(command_id, "TERMINATED", 60) - assert_command_succeeded(command_id) + utils.wait_for_command_state(sess, command_id, "TERMINATED", 60) + utils.assert_command_succeeded(sess, command_id) - command_id3 = run_zero_slot_command(1) - wait_for_command_state(command_id3, "TERMINATED", 60) - assert_command_succeeded(command_id3) + command_id3 = utils.run_zero_slot_command(sess, 1) + utils.wait_for_command_state(sess, command_id3, "TERMINATED", 60) + utils.assert_command_succeeded(sess, command_id3) diff --git a/e2e_tests/tests/cluster/test_agent_restart.py b/e2e_tests/tests/cluster/test_agent_restart.py index be1bee93eb1..8b730efd248 100644 --- a/e2e_tests/tests/cluster/test_agent_restart.py +++ b/e2e_tests/tests/cluster/test_agent_restart.py @@ -1,28 +1,30 @@ -import json +import pathlib import subprocess import time import uuid -from pathlib import Path from typing import Any, Dict, Iterator, Tuple import pytest -from determined.common.api.bindings import experimentv1State as EXP_STATE +from determined.common import api +from determined.common.api import bindings +from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp - -from .managed_cluster import ManagedCluster -from .utils import assert_command_succeeded, get_command_info, run_command, wait_for_command_state +from tests.cluster import managed_cluster, utils DEVCLUSTER_CONFIG_ROOT_PATH = conf.PROJECT_ROOT_PATH.joinpath(".circleci/devcluster") DEVCLUSTER_REATTACH_OFF_CONFIG_PATH = DEVCLUSTER_CONFIG_ROOT_PATH / "double.devcluster.yaml" DEVCLUSTER_REATTACH_ON_CONFIG_PATH = DEVCLUSTER_CONFIG_ROOT_PATH / "double-reattach.devcluster.yaml" -DEVCLUSTER_LOG_PATH = Path("/tmp/devcluster") +DEVCLUSTER_LOG_PATH = pathlib.Path("/tmp/devcluster") DEVCLUSTER_MASTER_LOG_PATH = DEVCLUSTER_LOG_PATH / "master.log" @pytest.mark.managed_devcluster -def test_managed_cluster_working(restartable_managed_cluster: ManagedCluster) -> None: +def test_managed_cluster_working( + restartable_managed_cluster: managed_cluster.ManagedCluster, +) -> None: try: restartable_managed_cluster.ensure_agent_ok() restartable_managed_cluster.kill_agent() @@ -54,34 +56,38 @@ def _local_container_ids_for_command(command_id: str) -> Iterator[str]: yield container_id -def _task_list_json(master_url: str) -> Dict[str, Dict[str, Any]]: - command = ["det", "-m", master_url, "task", "list", "--json"] - tasks_data: Dict[str, Dict[str, Any]] = json.loads(subprocess.check_output(command).decode()) +def _task_list_json(sess: api.Session) -> Dict[str, Dict[str, Any]]: + command = ["det", "task", "list", "--json"] + tasks_data: Dict[str, Dict[str, Any]] = detproc.check_json(sess, command) return tasks_data @pytest.mark.managed_devcluster -def test_agent_restart_exp_container_failure(restartable_managed_cluster: ManagedCluster) -> None: +def test_agent_restart_exp_container_failure( + restartable_managed_cluster: managed_cluster.ManagedCluster, +) -> None: + sess = api_utils.user_session() restartable_managed_cluster.ensure_agent_ok() try: exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), None, ) - exp.wait_for_experiment_workload_progress(exp_id) + exp.wait_for_experiment_workload_progress(sess, exp_id) container_ids = list(_local_container_ids_for_experiment(exp_id)) if len(container_ids) != 1: pytest.fail( f"unexpected number of local containers for the experiment: {len(container_ids)}" ) # Get task id / allocation id - tasks_data = _task_list_json(conf.make_master_url()) + tasks_data = _task_list_json(sess) assert len(tasks_data) == 1 exp_task_before = list(tasks_data.values())[0] restartable_managed_cluster.kill_agent() - subprocess.run(["docker", "kill", container_ids[0]], check=True, stdout=subprocess.PIPE) + subprocess.check_call(["docker", "kill", container_ids[0]]) except Exception: restartable_managed_cluster.restart_agent() raise @@ -89,38 +95,39 @@ def test_agent_restart_exp_container_failure(restartable_managed_cluster: Manage restartable_managed_cluster.restart_agent() # As soon as the agent is back, the original allocation should be considered dead, # but the new one should be allocated. - state = exp.experiment_state(exp_id) + state = exp.experiment_state(sess, exp_id) # old STATE_ACTIVE got divided into three states assert state in [ - EXP_STATE.ACTIVE, - EXP_STATE.QUEUED, - EXP_STATE.PULLING, - EXP_STATE.STARTING, - EXP_STATE.RUNNING, + bindings.experimentv1State.ACTIVE, + bindings.experimentv1State.QUEUED, + bindings.experimentv1State.PULLING, + bindings.experimentv1State.STARTING, + bindings.experimentv1State.RUNNING, ] - exp.wait_for_experiment_state(exp_id, EXP_STATE.RUNNING) - tasks_data = _task_list_json(conf.make_master_url()) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.RUNNING) + tasks_data = _task_list_json(sess) assert len(tasks_data) == 1 exp_task_after = list(tasks_data.values())[0] assert exp_task_before["taskId"] == exp_task_after["taskId"] assert exp_task_before["allocationId"] != exp_task_after["allocationId"] - exp.wait_for_experiment_state(exp_id, EXP_STATE.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) @pytest.mark.managed_devcluster @pytest.mark.parametrize("command_duration", [10, 60]) def test_agent_restart_cmd_container_failure( - restartable_managed_cluster: ManagedCluster, command_duration: int + restartable_managed_cluster: managed_cluster.ManagedCluster, command_duration: int ) -> None: + sess = api_utils.user_session() # Launch a cmd, kill agent, wait for reconnect timeout, check it's not marked as success. # Reconnect timeout is ~25 seconds. We'd like to both test tasks that take # longer (60 seconds) and shorter (10 seconds) than that. restartable_managed_cluster.ensure_agent_ok() try: - command_id = run_command(command_duration) - wait_for_command_state(command_id, "RUNNING", 10) + command_id = utils.run_command(sess, command_duration) + utils.wait_for_command_state(sess, command_id, "RUNNING", 10) for _i in range(10): if len(list(_local_container_ids_for_command(command_id))) > 0: @@ -139,8 +146,8 @@ def test_agent_restart_cmd_container_failure( time.sleep(1) else: pytest.fail(f"command container didn't terminate after {_i} ticks") - wait_for_command_state(command_id, "TERMINATED", 30) - assert "success" not in get_command_info(command_id)["exitStatus"] + utils.wait_for_command_state(sess, command_id, "TERMINATED", 30) + assert "success" not in utils.get_command_info(sess, command_id)["exitStatus"] except Exception: restartable_managed_cluster.restart_agent() raise @@ -151,18 +158,19 @@ def test_agent_restart_cmd_container_failure( @pytest.mark.managed_devcluster @pytest.mark.parametrize("downtime, slots", [(0, 0), (20, 1), (60, 1)]) def test_agent_restart_recover_cmd( - restartable_managed_cluster: ManagedCluster, slots: int, downtime: int + restartable_managed_cluster: managed_cluster.ManagedCluster, slots: int, downtime: int ) -> None: + sess = api_utils.user_session() restartable_managed_cluster.ensure_agent_ok() try: - command_id = run_command(30, slots=slots) - wait_for_command_state(command_id, "RUNNING", 10) + command_id = utils.run_command(sess, 30, slots=slots) + utils.wait_for_command_state(sess, command_id, "RUNNING", 10) restartable_managed_cluster.kill_agent() time.sleep(downtime) restartable_managed_cluster.restart_agent(wait_for_amnesia=False) - wait_for_command_state(command_id, "TERMINATED", 30) + utils.wait_for_command_state(sess, command_id, "TERMINATED", 30) # If the reattach_wait <= downtime, master would have considered agent # to be dead marking the experiment fail. We can ignore such scenarios. @@ -171,7 +179,7 @@ def test_agent_restart_recover_cmd( # reconnected in time. reattach_wait = restartable_managed_cluster.fetch_config_reattach_wait() if reattach_wait > downtime: - assert_command_succeeded(command_id) + utils.assert_command_succeeded(sess, command_id) except Exception: restartable_managed_cluster.restart_agent() raise @@ -180,24 +188,26 @@ def test_agent_restart_recover_cmd( @pytest.mark.managed_devcluster @pytest.mark.parametrize("downtime", [0, 20, 60]) def test_agent_restart_recover_experiment( - restartable_managed_cluster: ManagedCluster, downtime: int + restartable_managed_cluster: managed_cluster.ManagedCluster, downtime: int ) -> None: + sess = api_utils.user_session() restartable_managed_cluster.ensure_agent_ok() try: exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), None, ) - exp.wait_for_experiment_workload_progress(exp_id) + exp.wait_for_experiment_workload_progress(sess, exp_id) if downtime >= 0: restartable_managed_cluster.kill_agent() time.sleep(downtime) restartable_managed_cluster.restart_agent(wait_for_amnesia=False) - exp.wait_for_experiment_state(exp_id, EXP_STATE.COMPLETED) - trials = exp.experiment_trials(exp_id) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) + trials = exp.experiment_trials(sess, exp_id) assert len(trials) == 1 train_wls = exp.workloads_with_training(trials[0].workloads) @@ -208,21 +218,25 @@ def test_agent_restart_recover_experiment( @pytest.mark.managed_devcluster -def test_agent_reconnect_keep_experiment(restartable_managed_cluster: ManagedCluster) -> None: +def test_agent_reconnect_keep_experiment( + restartable_managed_cluster: managed_cluster.ManagedCluster, +) -> None: + sess = api_utils.user_session() try: exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), None, ) - exp.wait_for_experiment_workload_progress(exp_id) + exp.wait_for_experiment_workload_progress(sess, exp_id) restartable_managed_cluster.kill_proxy() time.sleep(1) restartable_managed_cluster.restart_proxy() - exp.wait_for_experiment_state(exp_id, EXP_STATE.COMPLETED) - trials = exp.experiment_trials(exp_id) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) + trials = exp.experiment_trials(sess, exp_id) assert len(trials) == 1 train_wls = exp.workloads_with_training(trials[0].workloads) @@ -234,18 +248,21 @@ def test_agent_reconnect_keep_experiment(restartable_managed_cluster: ManagedClu @pytest.mark.managed_devcluster -def test_agent_reconnect_keep_cmd(restartable_managed_cluster: ManagedCluster) -> None: +def test_agent_reconnect_keep_cmd( + restartable_managed_cluster: managed_cluster.ManagedCluster, +) -> None: + sess = api_utils.user_session() try: - command_id = run_command(20, slots=0) - wait_for_command_state(command_id, "RUNNING", 10) + command_id = utils.run_command(sess, 20, slots=0) + utils.wait_for_command_state(sess, command_id, "RUNNING", 10) restartable_managed_cluster.kill_proxy() time.sleep(1) restartable_managed_cluster.restart_proxy() - wait_for_command_state(command_id, "TERMINATED", 30) + utils.wait_for_command_state(sess, command_id, "TERMINATED", 30) - assert_command_succeeded(command_id) + utils.assert_command_succeeded(sess, command_id) except Exception: restartable_managed_cluster.restart_proxy(wait_for_reconnect=False) restartable_managed_cluster.restart_agent() @@ -253,16 +270,19 @@ def test_agent_reconnect_keep_cmd(restartable_managed_cluster: ManagedCluster) - @pytest.mark.managed_devcluster -def test_agent_reconnect_trigger_schedule(restartable_managed_cluster: ManagedCluster) -> None: +def test_agent_reconnect_trigger_schedule( + restartable_managed_cluster: managed_cluster.ManagedCluster, +) -> None: + sess = api_utils.user_session() restartable_managed_cluster.ensure_agent_ok() try: restartable_managed_cluster.kill_proxy() - command_id = run_command(5, slots=1) + command_id = utils.run_command(sess, 5, slots=1) restartable_managed_cluster.restart_proxy() - wait_for_command_state(command_id, "TERMINATED", 10) + utils.wait_for_command_state(sess, command_id, "TERMINATED", 10) - assert_command_succeeded(command_id) + utils.assert_command_succeeded(sess, command_id) except Exception: restartable_managed_cluster.restart_proxy(wait_for_reconnect=False) restartable_managed_cluster.restart_agent() @@ -271,14 +291,16 @@ def test_agent_reconnect_trigger_schedule(restartable_managed_cluster: ManagedCl @pytest.mark.managed_devcluster def test_queued_experiment_restarts_with_correct_allocation_id( - restartable_managed_cluster: ManagedCluster, + restartable_managed_cluster: managed_cluster.ManagedCluster, ) -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), ["--config", "resources.slots_per_trial=9999"], ) - exp.wait_for_experiment_state(exp_id, EXP_STATE.QUEUED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.QUEUED) restartable_managed_cluster.kill_master() log_marker = str(uuid.uuid4()) diff --git a/e2e_tests/tests/cluster/test_agent_user_group.py b/e2e_tests/tests/cluster/test_agent_user_group.py index e7d2f5aad31..740f0330efa 100644 --- a/e2e_tests/tests/cluster/test_agent_user_group.py +++ b/e2e_tests/tests/cluster/test_agent_user_group.py @@ -1,13 +1,14 @@ import re +import time import uuid -from time import sleep import pytest -from determined.common.api import Session, bindings, errors -from determined.common.api.bindings import experimentv1State -from tests import api_utils, command +from determined.common import api +from determined.common.api import bindings, errors +from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp GID, GROUPNAME = 1234, "group1234" @@ -15,14 +16,14 @@ # TODO(ilia): Add this utility to Python SDK. def _delete_workspace_and_check( - sess: Session, w: bindings.v1Workspace, max_ticks: int = 60 + sess: api.Session, w: bindings.v1Workspace, max_ticks: int = 60 ) -> None: resp = bindings.delete_DeleteWorkspace(sess, id=w.id) if resp.completed: return for _ in range(max_ticks): - sleep(1) + time.sleep(1) try: w = bindings.get_GetWorkspace(sess, id=w.id).workspace if w.state == bindings.v1WorkspaceState.DELETE_FAILED: @@ -33,32 +34,29 @@ def _delete_workspace_and_check( break -def _check_test_command(workspace_name: str) -> None: - with command.interactive_command( - "cmd", "run", "-w", workspace_name, "bash", "-c", "echo $(id -g -n):$(id -g)" - ) as cmd: - for line in cmd.stdout: - if f"{GROUPNAME}:{GID}" in line: - break - else: - raise AssertionError(f"Did not find {GROUPNAME}:{GID} in output") +def _check_test_command(sess: api.Session, workspace_name: str) -> None: + cmd = ["det", "cmd", "run", "-w", workspace_name, "bash", "-c", "echo $(id -g -n):$(id -g)"] + output = detproc.check_output(sess, cmd) + assert f"{GROUPNAME}:{GID}" in output -def _check_test_experiment(project_id: int) -> None: +def _check_test_experiment(sess: api.Session, project_id: int) -> None: # Create an experiment in that project. test_exp_id = exp.create_experiment( + sess, conf.fixtures_path("core_api/whoami.yaml"), conf.fixtures_path("core_api"), ["--project_id", str(project_id)], ) exp.wait_for_experiment_state( + sess, test_exp_id, - experimentv1State.COMPLETED, + bindings.experimentv1State.COMPLETED, ) - trials = exp.experiment_trials(test_exp_id) + trials = exp.experiment_trials(sess, test_exp_id) trial_id = trials[0].trial.id - trial_logs = exp.trial_logs(trial_id) + trial_logs = exp.trial_logs(sess, trial_id) marker = "id output: " for line in trial_logs: @@ -78,7 +76,7 @@ def _check_test_experiment(project_id: int) -> None: @pytest.mark.e2e_cpu def test_workspace_post_gid() -> None: - sess = api_utils.determined_test_session(admin=True) + sess = api_utils.admin_session() # Make project with workspace. resp_w = bindings.post_PostWorkspace( @@ -101,15 +99,15 @@ def test_workspace_post_gid() -> None: ) p = resp_p.project - _check_test_experiment(p.id) - _check_test_command(w.name) + _check_test_experiment(sess, p.id) + _check_test_command(sess, w.name) finally: _delete_workspace_and_check(sess, w) @pytest.mark.e2e_cpu def test_workspace_patch_gid() -> None: - sess = api_utils.determined_test_session(admin=True) + sess = api_utils.admin_session() # Make project with workspace. resp_w = bindings.post_PostWorkspace( @@ -140,8 +138,8 @@ def test_workspace_patch_gid() -> None: ) p = resp_p.project - _check_test_experiment(p.id) - _check_test_command(w.name) + _check_test_experiment(sess, p.id) + _check_test_command(sess, w.name) finally: _delete_workspace_and_check(sess, w) @@ -150,7 +148,7 @@ def test_workspace_patch_gid() -> None: def test_workspace_partial_patch() -> None: # TODO(ilia): Implement better partial patch with fieldmasks. # This may need a changes to the way python bindings generate json payloads. - sess = api_utils.determined_test_session(admin=True) + sess = api_utils.admin_session() # Make project with workspace. resp_w = bindings.post_PostWorkspace( diff --git a/e2e_tests/tests/cluster/test_checkpoints.py b/e2e_tests/tests/cluster/test_checkpoints.py index 88c7f251171..cdf67d9e221 100644 --- a/e2e_tests/tests/cluster/test_checkpoints.py +++ b/e2e_tests/tests/cluster/test_checkpoints.py @@ -6,31 +6,25 @@ import time from typing import Any, Dict, List, Set, Tuple -import pexpect import pytest from determined import errors from determined.common import api, storage, util -from determined.common.api import authentication, bindings, certs -from determined.common.api.bindings import checkpointv1State +from determined.common.api import bindings from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp -from .test_users import det_spawn - EXPECT_TIMEOUT = 5 -def wait_for_gc_to_finish(experiment_ids: List[int]) -> None: - certs.cli_cert = certs.default_load(conf.make_master_url()) - authentication.cli_auth = authentication.Authentication(conf.make_master_url()) - +def wait_for_gc_to_finish(sess: api.Session, experiment_ids: List[int]) -> None: seen_gc_experiment_ids = set() done_gc_experiment_ids = set() # Don't wait longer than 5 minutes (as 600 half-seconds to improve our sampling resolution). for _ in range(600): - r = api.get(conf.make_master_url(), "tasks").json() + r = sess.get("tasks").json() names = [task["name"] for task in r.values()] for experiment_id in experiment_ids: @@ -52,7 +46,9 @@ def wait_for_gc_to_finish(experiment_ids: List[int]) -> None: @pytest.mark.e2e_gpu @pytest.mark.e2e_slurm_gpu def test_set_gc_policy() -> None: + sess = api_utils.user_session() exp_id = exp.run_basic_test( + sess, config_file=conf.fixtures_path("no_op/gc_checkpoints_decreasing.yaml"), model_def_file=conf.fixtures_path("no_op"), expected_trials=1, @@ -65,7 +61,7 @@ def test_set_gc_policy() -> None: # Command that uses the same gc policy as initial policy used for the experiment. run_command_gc_policy( - str(save_exp_best), str(save_trial_latest), str(save_trial_best), str(exp_id) + sess, str(save_exp_best), str(save_trial_latest), str(save_trial_best), str(exp_id) ) # Command that uses a diff gc policy from the initial policy used for the experiment. @@ -73,17 +69,19 @@ def test_set_gc_policy() -> None: save_trial_latest = 1 save_trial_best = 1 run_command_gc_policy( - str(save_exp_best), str(save_trial_latest), str(save_trial_best), str(exp_id) + sess, str(save_exp_best), str(save_trial_latest), str(save_trial_best), str(exp_id) ) def run_command_gc_policy( - save_exp_best: str, save_trial_latest: str, save_trial_best: str, exp_id: str + sess: api.Session, save_exp_best: str, save_trial_latest: str, save_trial_best: str, exp_id: str ) -> None: command = [ + "det", "e", "set", "gc-policy", + "--yes", "--save-experiment-best", str(save_exp_best), "--save-trial-best", @@ -92,20 +90,14 @@ def run_command_gc_policy( str(save_trial_latest), str(exp_id), ] - - child = det_spawn(command) - child.expect("Do you wish to " "proceed?", timeout=EXPECT_TIMEOUT) - child.sendline("y") - child.read() - child.wait() - child.close() - assert child.exitstatus == 0 + detproc.check_output(sess, command) -def run_command_master_checkpoint_download(uuid: str) -> None: +def run_command_master_checkpoint_download(sess: api.Session, uuid: str) -> None: with tempfile.TemporaryDirectory() as dirpath: outdir = dirpath + "/checkpoint" command = [ + "det", "checkpoint", "download", "--mode", @@ -115,13 +107,7 @@ def run_command_master_checkpoint_download(uuid: str) -> None: uuid, ] - child = det_spawn(command) - child.expect(pexpect.EOF) - child.wait() - child.close() - if child.exitstatus != 0: - print(child.before.decode("ascii"), file=sys.stderr) - assert child.exitstatus == 0 + detproc.check_call(sess, command) assert os.path.exists(outdir + "/metadata.json") @@ -137,6 +123,7 @@ def test_gc_checkpoints_lfs() -> None: @pytest.mark.e2e_cpu def test_delete_checkpoints() -> None: + sess = api_utils.user_session() base_conf_path = conf.fixtures_path("no_op/single-default-ckpt.yaml") config = conf.load_config(str(base_conf_path)) @@ -149,20 +136,16 @@ def test_delete_checkpoints() -> None: config["min_checkpoint_period"] = {"batches": 10} exp_id_1 = exp.run_basic_test_with_temp_config( - config, model_def_path=conf.fixtures_path("no_op"), expected_trials=1 + sess, config, model_def_path=conf.fixtures_path("no_op"), expected_trials=1 ) exp_id_2 = exp.run_basic_test_with_temp_config( - config, model_def_path=conf.fixtures_path("no_op"), expected_trials=1 + sess, config, model_def_path=conf.fixtures_path("no_op"), expected_trials=1 ) - test_session = api_utils.determined_test_session() - exp_1_checkpoints = bindings.get_GetExperimentCheckpoints( - session=test_session, id=exp_id_1 - ).checkpoints - exp_2_checkpoints = bindings.get_GetExperimentCheckpoints( - session=test_session, id=exp_id_2 - ).checkpoints + sess = api_utils.user_session() + exp_1_checkpoints = bindings.get_GetExperimentCheckpoints(session=sess, id=exp_id_1).checkpoints + exp_2_checkpoints = bindings.get_GetExperimentCheckpoints(session=sess, id=exp_id_2).checkpoints assert len(exp_1_checkpoints) > 0, f"no checkpoints found in experiment with ID:{exp_id_1}" assert len(exp_2_checkpoints) > 0, f"no checkpoints found in experiment with ID:{exp_id_2}" @@ -193,24 +176,22 @@ def test_delete_checkpoints() -> None: pytest.fail(f"checkpoint directory with uuid: {uuid} was not created.") delete_body = bindings.v1DeleteCheckpointsRequest(checkpointUuids=d_checkpoint_uuids) - bindings.delete_DeleteCheckpoints(session=test_session, body=delete_body) + bindings.delete_DeleteCheckpoints(session=sess, body=delete_body) - wait_for_gc_to_finish([exp_id_1, exp_id_2]) + wait_for_gc_to_finish(sess, [exp_id_1, exp_id_2]) for d_c in d_checkpoint_uuids: - ensure_checkpoint_deleted(test_session, d_c, storage_manager) + ensure_checkpoint_deleted(sess, d_c, storage_manager) -def ensure_checkpoint_deleted( - test_session: Any, d_checkpoint_uuid: Any, storage_manager: Any -) -> None: +def ensure_checkpoint_deleted(sess: Any, d_checkpoint_uuid: Any, storage_manager: Any) -> None: d_checkpoint = bindings.get_GetCheckpoint( - session=test_session, checkpointUuid=d_checkpoint_uuid + session=sess, checkpointUuid=d_checkpoint_uuid ).checkpoint if d_checkpoint is not None: assert ( - d_checkpoint.state == checkpointv1State.DELETED + d_checkpoint.state == bindings.checkpointv1State.DELETED ), f"checkpoint with uuid {d_checkpoint_uuid} does not have a deleted state" else: pytest.fail( @@ -223,6 +204,7 @@ def ensure_checkpoint_deleted( def run_gc_checkpoints_test(checkpoint_storage: Dict[str, str]) -> None: + sess = api_utils.user_session() fixtures = [ ( conf.fixtures_path("no_op/gc_checkpoints_decreasing.yaml"), @@ -271,18 +253,18 @@ def run_gc_checkpoints_test(checkpoint_storage: Dict[str, str]) -> None: with open(tf.name, "w") as f: util.yaml_safe_dump(config, f) - experiment_id = exp.create_experiment(tf.name, conf.fixtures_path("no_op")) + experiment_id = exp.create_experiment(sess, tf.name, conf.fixtures_path("no_op")) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.COMPLETED) # In some configurations, checkpoint GC will run on an auxillary machine, which may have to # be spun up still. So we'll wait for it to run. - wait_for_gc_to_finish([experiment_id]) + wait_for_gc_to_finish(sess, [experiment_id]) # Checkpoints are not marked as deleted until gc_checkpoint task starts. retries = 5 for retry in range(retries): - trials = exp.experiment_trials(experiment_id) + trials = exp.experiment_trials(sess, experiment_id) assert len(trials) == 1 cpoints = exp.workloads_with_checkpoint(trials[0].workloads) @@ -356,12 +338,13 @@ def run_gc_checkpoints_test(checkpoint_storage: Dict[str, str]) -> None: cs_type = checkpoint_storage["type"] if cs_type != "azure": assert type(last_checkpoint_uuid) == str - run_command_master_checkpoint_download(str(last_checkpoint_uuid)) + run_command_master_checkpoint_download(sess, str(last_checkpoint_uuid)) @pytest.mark.e2e_gpu def test_s3_no_creds(secrets: Dict[str, str]) -> None: pytest.skip("Temporarily skipping this until we find a more secure way of testing this.") + sess = api_utils.user_session() config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml")) config["checkpoint_storage"] = exp.s3_checkpoint_config_no_creds() config.setdefault("environment", {}) @@ -370,11 +353,12 @@ def test_s3_no_creds(secrets: Dict[str, str]) -> None: f"AWS_ACCESS_KEY_ID={secrets['INTEGRATIONS_S3_ACCESS_KEY']}", f"AWS_SECRET_ACCESS_KEY={secrets['INTEGRATIONS_S3_SECRET_KEY']}", ] - exp.run_basic_test_with_temp_config(config, conf.tutorials_path("mnist_pytorch"), 1) + exp.run_basic_test_with_temp_config(sess, config, conf.tutorials_path("mnist_pytorch"), 1) @pytest.mark.e2e_cpu def test_delete_experiment_with_no_checkpoints() -> None: + sess = api_utils.user_session() # Experiment will intentionally fail. config = conf.load_config(conf.fixtures_path("no_op/single.yaml")) config["checkpoint_storage"] = { @@ -383,18 +367,18 @@ def test_delete_experiment_with_no_checkpoints() -> None: } config["max_restarts"] = 0 exp_id = exp.run_failure_test_with_temp_config( + sess, config, conf.fixtures_path("no_op"), None, ) # Still able to delete this since it will have no checkpoints meaning no checkpoint gc task. - test_session = api_utils.determined_test_session() - bindings.delete_DeleteExperiment(session=test_session, experimentId=exp_id) + bindings.delete_DeleteExperiment(session=sess, experimentId=exp_id) ticks = 60 for i in range(ticks): try: - state = exp.experiment_state(exp_id) + state = exp.experiment_state(sess, exp_id) if i % 5 == 0: print(f"experiment in state {state} waiting to be deleted") time.sleep(1) @@ -406,6 +390,7 @@ def test_delete_experiment_with_no_checkpoints() -> None: @pytest.mark.e2e_cpu def test_checkpoint_partial_delete() -> None: + sess = api_utils.user_session() base_conf_path = conf.fixtures_path("no_op/single-default-ckpt.yaml") host_path = "/tmp" @@ -420,12 +405,11 @@ def test_checkpoint_partial_delete() -> None: config["min_checkpoint_period"] = {"batches": 10} exp_id = exp.run_basic_test_with_temp_config( - config, model_def_path=conf.fixtures_path("no_op"), expected_trials=1 + sess, config, model_def_path=conf.fixtures_path("no_op"), expected_trials=1 ) - test_session = api_utils.determined_test_session() checkpoints = bindings.get_GetExperimentCheckpoints( - session=test_session, + session=sess, id=exp_id, ).checkpoints completed_checkpoints = [] @@ -438,7 +422,7 @@ def test_checkpoint_partial_delete() -> None: pytest.fail("did not find two checkpoints in state completed") s = bindings.get_GetExperiment( - test_session, + sess, experimentId=exp_id, ).experiment.checkpointSize assert s is not None @@ -449,15 +433,15 @@ def assert_checkpoint_state( exp_size: int, trial_size: int, resources: Dict[str, Any], - state: checkpointv1State, + state: bindings.checkpointv1State, ) -> None: s = bindings.get_GetExperiment( - test_session, + sess, experimentId=exp_id, ).experiment.checkpointSize assert s is not None and int(s) == exp_size - trials = bindings.get_GetExperimentTrials(test_session, experimentId=exp_id).trials + trials = bindings.get_GetExperimentTrials(sess, experimentId=exp_id).trials assert len(trials) == 1 assert ( trials[0].totalCheckpointSize is not None @@ -465,7 +449,7 @@ def assert_checkpoint_state( ) ckpt = bindings.get_GetCheckpoint( - test_session, + sess, checkpointUuid=uuid, ).checkpoint assert ckpt.resources == resources @@ -477,8 +461,8 @@ def assert_checkpoint_state( checkpointGlobs=[], checkpointUuids=[completed_checkpoints[0].uuid], ) - bindings.post_CheckpointsRemoveFiles(test_session, body=remove_body) - wait_for_gc_to_finish([exp_id]) + bindings.post_CheckpointsRemoveFiles(sess, body=remove_body) + wait_for_gc_to_finish(sess, [exp_id]) assert_checkpoint_state( completed_checkpoints[0].uuid, @@ -503,8 +487,8 @@ def assert_checkpoint_state( checkpointGlobs=[], checkpointUuids=[completed_checkpoints[0].uuid], ) - bindings.post_CheckpointsRemoveFiles(test_session, body=remove_body) - wait_for_gc_to_finish([exp_id]) + bindings.post_CheckpointsRemoveFiles(sess, body=remove_body) + wait_for_gc_to_finish(sess, [exp_id]) assert_checkpoint_state( completed_checkpoints[0].uuid, @@ -521,8 +505,8 @@ def assert_checkpoint_state( checkpointGlobs=["**/*"], checkpointUuids=[completed_checkpoints[0].uuid], ) - bindings.post_CheckpointsRemoveFiles(test_session, body=remove_body) - wait_for_gc_to_finish([exp_id]) + bindings.post_CheckpointsRemoveFiles(sess, body=remove_body) + wait_for_gc_to_finish(sess, [exp_id]) assert_checkpoint_state( completed_checkpoints[0].uuid, @@ -544,8 +528,8 @@ def assert_checkpoint_state( checkpointGlobs=["**/*.pkl"], checkpointUuids=[completed_checkpoints[1].uuid], ) - bindings.post_CheckpointsRemoveFiles(test_session, body=remove_body) - wait_for_gc_to_finish([exp_id]) + bindings.post_CheckpointsRemoveFiles(sess, body=remove_body) + wait_for_gc_to_finish(sess, [exp_id]) assert_checkpoint_state( completed_checkpoints[1].uuid, @@ -558,10 +542,12 @@ def assert_checkpoint_state( @pytest.mark.e2e_cpu def test_fail_on_chechpoint_save() -> None: + sess = api_utils.user_session() error_log = "failed on checkpoint save" config_obj = conf.load_config(conf.fixtures_path("no_op/single.yaml")) config_obj["hyperparameters"]["fail_on_chechpoint_save"] = error_log exp.run_failure_test_with_temp_config( + sess, config_obj, conf.fixtures_path("no_op"), error_log, @@ -570,6 +556,7 @@ def test_fail_on_chechpoint_save() -> None: @pytest.mark.e2e_cpu def test_fail_on_preclose_chechpoint_save() -> None: + sess = api_utils.user_session() error_log = "failed on checkpoint save" config_obj = conf.load_config(conf.fixtures_path("no_op/single.yaml")) config_obj["hyperparameters"]["fail_on_chechpoint_save"] = error_log @@ -577,6 +564,7 @@ def test_fail_on_preclose_chechpoint_save() -> None: config_obj["min_validation_period"] = {"batches": 1} config_obj["max_restarts"] = 1 exp.run_failure_test_with_temp_config( + sess, config_obj, conf.fixtures_path("no_op"), error_log, diff --git a/e2e_tests/tests/cluster/test_coscheduler.py b/e2e_tests/tests/cluster/test_coscheduler.py index 8fb6d7404bd..ad73628f434 100644 --- a/e2e_tests/tests/cluster/test_coscheduler.py +++ b/e2e_tests/tests/cluster/test_coscheduler.py @@ -3,6 +3,7 @@ import pytest +from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -10,6 +11,7 @@ @pytest.mark.parallel @pytest.mark.timeout(300) def test_gang_scheduling() -> None: + sess = api_utils.user_session() total_slots = os.getenv("TOTAL_SLOTS") if total_slots is None: pytest.skip("test requires a static cluster and TOTAL_SLOTS set in the environment") @@ -19,7 +21,7 @@ def test_gang_scheduling() -> None: model = conf.tutorials_path("mnist_pytorch") def submit_job() -> None: - ret_value = exp.run_basic_test_with_temp_config(config, model, 1) + ret_value = exp.run_basic_test_with_temp_config(sess, config, model, 1) print(ret_value) t = [] diff --git a/e2e_tests/tests/cluster/test_exp_continue.py b/e2e_tests/tests/cluster/test_exp_continue.py index 586d09cd2ac..b3b8a632c7f 100644 --- a/e2e_tests/tests/cluster/test_exp_continue.py +++ b/e2e_tests/tests/cluster/test_exp_continue.py @@ -1,6 +1,5 @@ -import subprocess import tempfile -from typing import Any, List, Tuple +from typing import List, Tuple import pytest @@ -8,44 +7,39 @@ from determined.common.api import bindings from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp -# TODO move this to a package helper. -def det_cmd(cmd: List[str], **kwargs: Any) -> subprocess.CompletedProcess: - return subprocess.run( - ["det", "-m", conf.make_master_url()] + cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - **kwargs, - ) - - @pytest.mark.e2e_cpu def test_continue_config_file_cli() -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), ["--config", "hyperparameters.metrics_sigma=-1.0"], ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: util.yaml_safe_dump({"hyperparameters": {"metrics_sigma": 1.0}}, f) - det_cmd(["e", "continue", str(exp_id), "--config-file", tf.name], check=True) + detproc.check_call(sess, ["det", "e", "continue", str(exp_id), "--config-file", tf.name]) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) @pytest.mark.e2e_cpu def test_continue_config_file_and_args_cli() -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), ["--config", "hyperparameters.metrics_sigma=-1.0"], ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) expected_name = "checkThis" with tempfile.NamedTemporaryFile() as tf: @@ -54,8 +48,10 @@ def test_continue_config_file_and_args_cli() -> None: {"name": expected_name, "hyperparameters": {"metrics_sigma": -1.0}}, f ) - stdout = det_cmd( + stdout = detproc.check_output( + sess, [ + "det", "e", "continue", str(exp_id), @@ -65,13 +61,11 @@ def test_continue_config_file_and_args_cli() -> None: "hyperparameters.metrics_sigma=1.0", "-f", ], - check=True, - ).stdout + ) # Follow works till end of trial. - assert "resources exited successfully with a zero exit code" in stdout.decode("utf-8") + assert "resources exited successfully with a zero exit code" in stdout # Name is also still applied. - sess = api_utils.determined_test_session() resp = bindings.get_GetExperiment(sess, experimentId=exp_id) assert resp.experiment.config["name"] == expected_name assert ( @@ -86,71 +80,79 @@ def test_continue_config_file_and_args_cli() -> None: @pytest.mark.e2e_cpu def test_continue_fixing_broken_config() -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), ["--config", "hyperparameters.metrics_sigma=-1.0"], ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) - det_cmd( - ["e", "continue", str(exp_id), "--config", "hyperparameters.metrics_sigma=1.0"], check=True + detproc.check_call( + sess, ["det", "e", "continue", str(exp_id), "--config", "hyperparameters.metrics_sigma=1.0"] ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) - trials = exp.experiment_trials(exp_id) + trials = exp.experiment_trials(sess, exp_id) assert len(trials) == 1 # Trial logs show both tasks logs with the failure message in it. - trial_logs = "\n".join(exp.trial_logs(trials[0].trial.id)) + trial_logs = "\n".join(exp.trial_logs(sess, trials[0].trial.id)) assert "assert 0 <= self.metrics_sigma" in trial_logs assert "resources exited successfully with a zero exit code" in trial_logs @pytest.mark.e2e_cpu def test_continue_max_restart() -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), ["--config", "hyperparameters.metrics_sigma=-1.0", "--config", "max_restarts=2"], ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) - trials = exp.experiment_trials(exp_id) + trials = exp.experiment_trials(sess, exp_id) assert len(trials) == 1 def count_times_ran() -> int: - return "\n".join(exp.trial_logs(trials[0].trial.id)).count("assert 0 <= self.metrics_sigma") + return "\n".join(exp.trial_logs(sess, trials[0].trial.id)).count( + "assert 0 <= self.metrics_sigma" + ) def get_trial_restarts() -> int: - experiment_trials = exp.experiment_trials(exp_id) + experiment_trials = exp.experiment_trials(sess, exp_id) assert len(experiment_trials) == 1 return experiment_trials[0].trial.restarts assert count_times_ran() == 3 assert get_trial_restarts() == 2 - det_cmd(["e", "continue", str(exp_id)], check=True) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + detproc.check_call(sess, ["det", "e", "continue", str(exp_id)]) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) assert count_times_ran() == 6 assert get_trial_restarts() == 2 - det_cmd(["e", "continue", str(exp_id), "--config", "max_restarts=1"], check=True) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + detproc.check_call(sess, ["det", "e", "continue", str(exp_id), "--config", "max_restarts=1"]) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) assert count_times_ran() == 8 assert get_trial_restarts() == 1 @pytest.mark.e2e_cpu def test_continue_trial_time() -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), ["--config", "hyperparameters.metrics_sigma=-1.0"], ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) - sess = api_utils.determined_test_session() + sess = api_utils.user_session() def exp_start_end_time() -> Tuple[str, str]: e = bindings.get_GetExperiment(sess, experimentId=exp_id).experiment @@ -158,7 +160,7 @@ def exp_start_end_time() -> Tuple[str, str]: return e.startTime, e.endTime def trial_start_end_time() -> Tuple[str, str]: - experiment_trials = exp.experiment_trials(exp_id) + experiment_trials = exp.experiment_trials(sess, exp_id) assert len(experiment_trials) == 1 assert experiment_trials[0].trial.endTime is not None return experiment_trials[0].trial.startTime, experiment_trials[0].trial.endTime @@ -166,8 +168,8 @@ def trial_start_end_time() -> Tuple[str, str]: exp_orig_start, exp_orig_end = exp_start_end_time() trial_orig_start, trial_orig_end = trial_start_end_time() - det_cmd(["e", "continue", str(exp_id)], check=True) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + detproc.check_call(sess, ["det", "e", "continue", str(exp_id)]) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) exp_new_start, exp_new_end = exp_start_end_time() trial_new_start, trial_new_end = trial_start_end_time() @@ -179,7 +181,7 @@ def trial_start_end_time() -> Tuple[str, str]: assert trial_new_end > trial_orig_end # Task times are updated. - experiment_trials = exp.experiment_trials(exp_id) + experiment_trials = exp.experiment_trials(sess, exp_id) assert len(experiment_trials) == 1 task_ids = experiment_trials[0].trial.taskIds assert task_ids is not None @@ -195,21 +197,23 @@ def trial_start_end_time() -> Tuple[str, str]: @pytest.mark.e2e_cpu def test_continue_batches() -> None: + sess = api_utils.user_session() # Experiment fails before first checkpoint. exp_id = exp.create_experiment( + sess, conf.fixtures_path("mnist_pytorch/failable.yaml"), conf.fixtures_path("mnist_pytorch"), ["--config", "environment.environment_variables=['FAIL_AT_BATCH=2']"], ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) - sess = api_utils.determined_test_session() - trials = exp.experiment_trials(exp_id) + sess = api_utils.user_session() + trials = exp.experiment_trials(sess, exp_id) assert len(trials) == 1 trial_id = trials[0].trial.id def assert_exited_at(batch_idx: int) -> None: - assert f"failed at this batch {batch_idx}" in "\n".join(exp.trial_logs(trial_id)) + assert f"failed at this batch {batch_idx}" in "\n".join(exp.trial_logs(sess, trial_id)) assert_exited_at(2) @@ -230,17 +234,18 @@ def get_metric_list() -> List[bindings.v1MetricsReport]: # Experiment has to start over since we didn't checkpoint. # We must invalidate all previous reported metrics. # This time experiment makes it a validation after the first checkpoint. - det_cmd( + detproc.check_call( + sess, [ + "det", "e", "continue", str(exp_id), "--config", "environment.environment_variables=['FAIL_AT_BATCH=5']", ], - check=True, ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) assert_exited_at(5) second_metric_ids = [] @@ -255,17 +260,18 @@ def get_metric_list() -> List[bindings.v1MetricsReport]: # We lose one metric since we are continuing from first checkpoint. # We correctly stop at total_batches. - det_cmd( + detproc.check_call( + sess, [ + "det", "e", "continue", str(exp_id), "--config", "environment.environment_variables=['FAIL_AT_BATCH=-1']", ], - check=True, ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) metrics = get_metric_list() assert len(metrics) == 8 @@ -282,15 +288,19 @@ def get_metric_list() -> List[bindings.v1MetricsReport]: @pytest.mark.e2e_cpu @pytest.mark.parametrize("continue_max_length", [499, 500]) def test_continue_workloads_searcher(continue_max_length: int) -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), [], ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) - det_cmd( + detproc.check_call( + sess, [ + "det", "e", "continue", str(exp_id), @@ -299,24 +309,27 @@ def test_continue_workloads_searcher(continue_max_length: int) -> None: "--config", "searcher.name=single", ], - check=True, ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) @pytest.mark.e2e_cpu @pytest.mark.parametrize("continue_max_length", [2, 3]) def test_continue_pytorch_completed_searcher(continue_max_length: int) -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("mnist_pytorch/failable.yaml"), conf.fixtures_path("mnist_pytorch"), ["--config", "searcher.max_length.batches=3"], ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) # Train for less or the same time has no error. - det_cmd( + detproc.check_call( + sess, [ + "det", "e", "continue", str(exp_id), @@ -325,56 +338,59 @@ def test_continue_pytorch_completed_searcher(continue_max_length: int) -> None: "--config", "searcher.name=single", ], - check=True, ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) @pytest.mark.e2e_cpu @pytest.mark.parametrize("exp_config_path", ["no_op/random-short.yaml", "no_op/grid-short.yaml"]) def test_continue_hp_search_cli(exp_config_path: str) -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path(exp_config_path), conf.fixtures_path("no_op"), [], ) - trials = exp.experiment_trials(exp_id) + trials = exp.experiment_trials(sess, exp_id) for t in trials: if t.trial.id % 2 == 0: - exp.kill_trial(t.trial.id) + exp.kill_trial(sess, t.trial.id) if exp_config_path == "no_op/random-short.yaml": assert len(trials) == 3 else: assert len(trials) == 6 - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) - det_cmd(["e", "continue", str(exp_id)], check=True) + detproc.check_call(sess, ["det", "e", "continue", str(exp_id)]) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) - trials = exp.experiment_trials(exp_id) + trials = exp.experiment_trials(sess, exp_id) for t in trials: assert t.trial.state == bindings.trialv1State.COMPLETED @pytest.mark.e2e_cpu def test_continue_hp_search_single_cli() -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), [], ) - trials = exp.experiment_trials(exp_id) + trials = exp.experiment_trials(sess, exp_id) assert len(trials) == 1 - exp.kill_trial(trials[0].trial.id) + exp.kill_trial(sess, trials[0].trial.id) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.CANCELED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.CANCELED) - det_cmd(["e", "continue", str(exp_id)], check=True) + detproc.check_call(sess, ["det", "e", "continue", str(exp_id)]) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) - trials = exp.experiment_trials(exp_id) + trials = exp.experiment_trials(sess, exp_id) assert trials[0].trial.state == bindings.trialv1State.COMPLETED diff --git a/e2e_tests/tests/cluster/test_experiment_delete.py b/e2e_tests/tests/cluster/test_experiment_delete.py index 9a965f72704..5ffa377b2eb 100644 --- a/e2e_tests/tests/cluster/test_experiment_delete.py +++ b/e2e_tests/tests/cluster/test_experiment_delete.py @@ -1,11 +1,12 @@ import pathlib -import subprocess import time import pytest from determined.common import api +from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp @@ -14,8 +15,11 @@ def test_delete_experiment_removes_tensorboard_files() -> None: """ Start a random experiment, delete the experiment and verify that TensorBoard files are deleted. """ + sess = api_utils.user_session() config_obj = conf.load_config(conf.fixtures_path("no_op/single-medium-train-step.yaml")) - experiment_id = exp.run_basic_test_with_temp_config(config_obj, conf.fixtures_path("no_op"), 1) + experiment_id = exp.run_basic_test_with_temp_config( + sess, config_obj, conf.fixtures_path("no_op"), 1 + ) # Check if Tensorboard files are created path = ( @@ -28,12 +32,12 @@ def test_delete_experiment_removes_tensorboard_files() -> None: assert pathlib.Path(tb_path).exists() command = ["det", "-m", conf.make_master_url(), "e", "delete", str(experiment_id), "--yes"] - subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE, check=True) + detproc.check_call(sess, command) ticks = 60 for i in range(ticks): try: - state = exp.experiment_state(experiment_id) + state = exp.experiment_state(sess, experiment_id) if i % 5 == 0: print(f"experiment in state {state} waiting to be deleted") time.sleep(1) diff --git a/e2e_tests/tests/cluster/test_groups.py b/e2e_tests/tests/cluster/test_groups.py index 8b735467a82..e910617faf9 100644 --- a/e2e_tests/tests/cluster/test_groups.py +++ b/e2e_tests/tests/cluster/test_groups.py @@ -1,152 +1,132 @@ -import json -import subprocess -from typing import Any, List +from typing import List import pytest -from tests import config as conf - -from .test_users import get_random_string, logged_in_user - - -def det_cmd(cmd: List[str], **kwargs: Any) -> subprocess.CompletedProcess: - return subprocess.run( - ["det", "-m", conf.make_master_url()] + cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - **kwargs, - ) - - -def det_cmd_json(cmd: List[str]) -> Any: - res = det_cmd(cmd, check=True) - return json.loads(res.stdout) - - -def det_cmd_expect_error(cmd: List[str], expected: str) -> None: - res = det_cmd(cmd) - assert res.returncode != 0 - assert expected.lower() in res.stderr.decode().lower() +from tests import api_utils, detproc @pytest.mark.e2e_cpu_rbac @pytest.mark.parametrize("add_users", [[], ["admin", "determined"]]) def test_group_creation(add_users: List[str]) -> None: - with logged_in_user(conf.ADMIN_CREDENTIALS): - group_name = get_random_string() - create_group_cmd = ["user-group", "create", group_name] - for add_user in add_users: - create_group_cmd += ["--add-user", add_user] - det_cmd(create_group_cmd, check=True) - - # Can view through list. - group_list = det_cmd_json(["user-group", "list", "--json"]) + admin = api_utils.admin_session() + group_name = api_utils.get_random_string() + create_group_cmd = ["det", "user-group", "create", group_name] + for add_user in add_users: + create_group_cmd += ["--add-user", add_user] + detproc.check_call(admin, create_group_cmd) + + # Can view through list. + group_list = detproc.check_json(admin, ["det", "user-group", "list", "--json"]) + assert ( + len([group for group in group_list["groups"] if group["group"]["name"] == group_name]) == 1 + ) + + # Can view through list with userID filter. + for add_user in add_users: + group_list = detproc.check_json( + admin, ["det", "user-group", "list", "--json", "--groups-user-belongs-to", add_user] + ) assert ( len([group for group in group_list["groups"] if group["group"]["name"] == group_name]) == 1 ) - # Can view through list with userID filter. - for add_user in add_users: - group_list = det_cmd_json( - ["user-group", "list", "--json", "--groups-user-belongs-to", add_user] - ) - assert ( - len( - [ - group - for group in group_list["groups"] - if group["group"]["name"] == group_name - ] - ) - == 1 - ) - - # Can describe properly. - group_desc = det_cmd_json(["user-group", "describe", group_name, "--json"]) - assert group_desc["name"] == group_name - for add_user in add_users: - assert len([u for u in group_desc["users"] if u["username"] == add_user]) == 1 - - # Can delete. - det_cmd(["user-group", "delete", group_name, "--yes"], check=True) - det_cmd_expect_error(["user-group", "describe", group_name], "not find") + # Can describe properly. + group_desc = detproc.check_json(admin, ["det", "user-group", "describe", group_name, "--json"]) + assert group_desc["name"] == group_name + for add_user in add_users: + assert len([u for u in group_desc["users"] if u["username"] == add_user]) == 1 + + # Can delete. + detproc.check_call(admin, ["det", "user-group", "delete", group_name, "--yes"]) + detproc.check_error(admin, ["det", "user-group", "describe", group_name], "not find") @pytest.mark.e2e_cpu_rbac def test_group_updates() -> None: - with logged_in_user(conf.ADMIN_CREDENTIALS): - group_name = get_random_string() - det_cmd(["user-group", "create", group_name], check=True) + admin = api_utils.admin_session() + group_name = api_utils.get_random_string() + detproc.check_call(admin, ["det", "user-group", "create", group_name]) - # Adds admin and determined to our group then remove determined. - det_cmd(["user-group", "add-user", group_name, "admin,determined"], check=True) - det_cmd(["user-group", "remove-user", group_name, "determined"], check=True) + # Adds admin and determined to our group then remove determined. + detproc.check_call(admin, ["det", "user-group", "add-user", group_name, "admin,determined"]) + detproc.check_call(admin, ["det", "user-group", "remove-user", group_name, "determined"]) - group_desc = det_cmd_json(["user-group", "describe", group_name, "--json"]) - assert group_desc["name"] == group_name - assert len(group_desc["users"]) == 1 - assert group_desc["users"][0]["username"] == "admin" + group_desc = detproc.check_json(admin, ["det", "user-group", "describe", group_name, "--json"]) + assert group_desc["name"] == group_name + assert len(group_desc["users"]) == 1 + assert group_desc["users"][0]["username"] == "admin" - # Rename our group. - new_group_name = get_random_string() - det_cmd(["user-group", "change-name", group_name, new_group_name], check=True) + # Rename our group. + new_group_name = api_utils.get_random_string() + detproc.check_call(admin, ["det", "user-group", "change-name", group_name, new_group_name]) - # Old name is gone. - det_cmd_expect_error(["user-group", "describe", group_name, "--json"], "not find") + # Old name is gone. + detproc.check_error(admin, ["det", "user-group", "describe", group_name, "--json"], "not find") - # New name is here. - group_desc = det_cmd_json(["user-group", "describe", new_group_name, "--json"]) - assert group_desc["name"] == new_group_name - assert len(group_desc["users"]) == 1 - assert group_desc["users"][0]["username"] == "admin" + # New name is here. + group_desc = detproc.check_json( + admin, ["det", "user-group", "describe", new_group_name, "--json"] + ) + assert group_desc["name"] == new_group_name + assert len(group_desc["users"]) == 1 + assert group_desc["users"][0]["username"] == "admin" @pytest.mark.parametrize("offset", [0, 2]) @pytest.mark.parametrize("limit", [1, 3]) @pytest.mark.e2e_cpu_rbac def test_group_list_pagination(offset: int, limit: int) -> None: + admin = api_utils.admin_session() # Ensure we have at minimum n groups. n = 5 - group_list = det_cmd_json(["user-group", "list", "--json"])["groups"] + group_list = detproc.check_json(admin, ["det", "user-group", "list", "--json"])["groups"] needed_groups = max(n - len(group_list), 0) - with logged_in_user(conf.ADMIN_CREDENTIALS): - for _ in range(needed_groups): - det_cmd(["user-group", "create", get_random_string()], check=True) + for _ in range(needed_groups): + detproc.check_call(admin, ["det", "user-group", "create", api_utils.get_random_string()]) # Get baseline group list to compare pagination to. - group_list = det_cmd_json(["user-group", "list", "--json"])["groups"] + group_list = detproc.check_json(admin, ["det", "user-group", "list", "--json"])["groups"] expected = group_list[offset : offset + limit] - paged_group_list = det_cmd_json( - ["user-group", "list", "--json", "--offset", f"{offset}", "--limit", f"{limit}"] + paged_group_list = detproc.check_json( + admin, + ["det", "user-group", "list", "--json", "--offset", f"{offset}", "--limit", f"{limit}"], ) assert expected == paged_group_list["groups"] @pytest.mark.e2e_cpu_rbac def test_group_errors() -> None: - with logged_in_user(conf.ADMIN_CREDENTIALS): - fake_group = get_random_string() - group_name = get_random_string() - det_cmd(["user-group", "create", group_name], check=True) - - # Creating group with same name. - det_cmd_expect_error(["user-group", "create", group_name], "already exists") - - # Adding non existent users to groups. - fake_user = get_random_string() - det_cmd_expect_error( - ["user-group", "create", fake_group, "--add-user", fake_user], "not find" - ) - det_cmd_expect_error(["user-group", "add-user", group_name, fake_user], "not find") + admin = api_utils.admin_session() + fake_group = api_utils.get_random_string() + group_name = api_utils.get_random_string() + detproc.check_output(admin, ["det", "user-group", "create", group_name]) + + # Creating group with same name. + detproc.check_error(admin, ["det", "user-group", "create", group_name], "already exists") + + # Adding non existent users to groups. + fake_user = api_utils.get_random_string() + detproc.check_error( + admin, + ["det", "user-group", "create", fake_group, "--add-user", fake_user], + "not find", + ) + detproc.check_error(admin, ["det", "user-group", "add-user", group_name, fake_user], "not find") - # Removing a non existent user from group. - det_cmd_expect_error(["user-group", "remove-user", group_name, fake_user], "not find") + # Removing a non existent user from group. + detproc.check_error( + admin, ["det", "user-group", "remove-user", group_name, fake_user], "not find" + ) - # Removing a user not in a group. - det_cmd_expect_error(["user-group", "remove-user", group_name, "admin"], "Not Found") + # Removing a user not in a group. + detproc.check_error( + admin, ["det", "user-group", "remove-user", group_name, "admin"], "not found" + ) - # Describing a non existent group. - det_cmd_expect_error(["user-group", "describe", get_random_string()], "not find") + # Describing a non existent group. + detproc.check_error( + admin, ["det", "user-group", "describe", api_utils.get_random_string()], "not find" + ) diff --git a/e2e_tests/tests/cluster/test_job_queue.py b/e2e_tests/tests/cluster/test_job_queue.py index 50cab6d50fa..cb4bfe986e7 100644 --- a/e2e_tests/tests/cluster/test_job_queue.py +++ b/e2e_tests/tests/cluster/test_job_queue.py @@ -1,36 +1,40 @@ -import subprocess -from time import sleep +import time from typing import Dict, List, Tuple import pytest +from determined.common import api + # from determined.experimental import Determined, ModelSortBy +from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp @pytest.mark.e2e_cpu def test_job_queue_adjust_weight() -> None: + sess = api_utils.user_session() config = conf.tutorials_path("mnist_pytorch/const.yaml") model = conf.tutorials_path("mnist_pytorch") - exp_ids = [exp.create_experiment(config, model) for _ in range(2)] + exp_ids = [exp.create_experiment(sess, config, model) for _ in range(2)] try: - jobs = JobInfo() + jobs = JobInfo(sess) ok = jobs.refresh_until_populated() assert ok ordered_ids = jobs.get_ids() - subprocess.run(["det", "job", "update", ordered_ids[0], "--weight", "10"]) + detproc.check_call(sess, ["det", "job", "update", ordered_ids[0], "--weight", "10"]) - sleep(2) + time.sleep(2) jobs.refresh() new_weight = jobs.get_job_weight(ordered_ids[0]) assert new_weight == "10" - subprocess.run(["det", "job", "update-batch", f"{ordered_ids[1]}.weight=10"]) + detproc.check_call(sess, ["det", "job", "update-batch", f"{ordered_ids[1]}.weight=10"]) - sleep(2) + time.sleep(2) jobs.refresh() new_weight = jobs.get_job_weight(ordered_ids[1]) assert new_weight == "10" @@ -38,13 +42,13 @@ def test_job_queue_adjust_weight() -> None: # Avoid leaking experiments even if this test fails. # Leaking experiments can block the cluster and other tests from running other tasks # while the experiments finish. - exp.kill_experiments(exp_ids) + exp.kill_experiments(sess, exp_ids) -def get_raw_data() -> Tuple[List[Dict[str, str]], List[str]]: +def get_raw_data(sess: api.Session) -> Tuple[List[Dict[str, str]], List[str]]: data = [] ordered_ids = [] - output = subprocess.check_output(["det", "job", "list"]).decode("utf-8") + output = detproc.check_output(sess, ["det", "job", "list"]) lines = output.split("\n") keys = [line.strip() for line in lines[0].split("|")] @@ -60,18 +64,19 @@ def get_raw_data() -> Tuple[List[Dict[str, str]], List[str]]: class JobInfo: - def __init__(self) -> None: - self.values, self.ids = get_raw_data() + def __init__(self, sess: api.Session) -> None: + self.sess = sess + self.values, self.ids = get_raw_data(self.sess) def refresh(self) -> None: - self.values, self.ids = get_raw_data() + self.values, self.ids = get_raw_data(self.sess) def refresh_until_populated(self, retries: int = 10) -> bool: while retries > 0: retries -= 1 if len(self.ids) > 0: return True - sleep(0.5) + time.sleep(0.5) self.refresh() print("self.ids remains empty") return False diff --git a/e2e_tests/tests/cluster/test_log_policies.py b/e2e_tests/tests/cluster/test_log_policies.py index 2c71a59b8c0..7daebdb2c29 100644 --- a/e2e_tests/tests/cluster/test_log_policies.py +++ b/e2e_tests/tests/cluster/test_log_policies.py @@ -12,6 +12,7 @@ @pytest.mark.e2e_cpu @pytest.mark.parametrize("should_match", [True, False]) def test_log_policy_cancel_retries(should_match: bool) -> None: + sess = api_utils.user_session() regex = r"assert 0 <= self\.metrics_sigma" if not should_match: regex = r"(.*) this should not match (.*)" @@ -31,13 +32,13 @@ def test_log_policy_cancel_retries(should_match: bool) -> None: with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: yaml.dump(config, f) - exp_id = exp.create_experiment(tf.name, conf.fixtures_path("no_op")) + exp_id = exp.create_experiment(sess, tf.name, conf.fixtures_path("no_op")) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) - experiment_trials = exp.experiment_trials(exp_id) + experiment_trials = exp.experiment_trials(sess, exp_id) assert len(experiment_trials) == 1 - trial_logs = "\n".join(exp.trial_logs(experiment_trials[0].trial.id)) + trial_logs = "\n".join(exp.trial_logs(sess, experiment_trials[0].trial.id)) if should_match: assert experiment_trials[0].trial.restarts == 0 @@ -50,6 +51,7 @@ def test_log_policy_cancel_retries(should_match: bool) -> None: @pytest.mark.e2e_k8s @pytest.mark.parametrize("should_match", [True, False]) def test_log_policy_exclude_node_k8s(should_match: bool) -> None: + sess = api_utils.user_session() regex = r"assert 0 <= self\.metrics_sigma" if not should_match: regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b" @@ -66,7 +68,7 @@ def test_log_policy_exclude_node_k8s(should_match: bool) -> None: config["hyperparameters"]["metrics_sigma"] = -1 config["max_restarts"] = 1 - agents = bindings.get_GetAgents(api_utils.determined_test_session()).agents + agents = bindings.get_GetAgents(sess).agents assert len(agents) == 1 assert agents[0].slots is not None config["resources"] = {"slots_per_trial": len(agents[0].slots)} @@ -74,38 +76,41 @@ def test_log_policy_exclude_node_k8s(should_match: bool) -> None: with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: yaml.dump(config, f) - exp_id = exp.create_experiment(tf.name, conf.fixtures_path("no_op")) + exp_id = exp.create_experiment(sess, tf.name, conf.fixtures_path("no_op")) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.RUNNING) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.RUNNING) if should_match: second_exp_id = exp.create_experiment( - conf.fixtures_path("no_op/single-one-short-step.yaml"), conf.fixtures_path("no_op") + sess, + conf.fixtures_path("no_op/single-one-short-step.yaml"), + conf.fixtures_path("no_op"), ) - exp.wait_for_experiment_state(second_exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, second_exp_id, bindings.experimentv1State.COMPLETED) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.QUEUED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.QUEUED) - experiment_trials = exp.experiment_trials(exp_id) + experiment_trials = exp.experiment_trials(sess, exp_id) assert len(experiment_trials) == 1 assert experiment_trials[0].trial.restarts == 1 - trial_logs = "\n".join(exp.trial_logs(experiment_trials[0].trial.id)) + trial_logs = "\n".join(exp.trial_logs(sess, experiment_trials[0].trial.id)) assert "therefore will not schedule on" in trial_logs - exp.kill_experiments([exp_id]) + exp.kill_experiments(sess, [exp_id]) else: - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) - experiment_trials = exp.experiment_trials(exp_id) + experiment_trials = exp.experiment_trials(sess, exp_id) assert len(experiment_trials) == 1 assert experiment_trials[0].trial.restarts == 1 - trial_logs = "\n".join(exp.trial_logs(experiment_trials[0].trial.id)) + trial_logs = "\n".join(exp.trial_logs(sess, experiment_trials[0].trial.id)) assert "therefore will not schedule on" not in trial_logs @pytest.mark.e2e_cpu @pytest.mark.parametrize("should_match", [True, False]) def test_log_policy_exclude_node_single_agent(should_match: bool) -> None: + sess = api_utils.user_session() regex = r"assert 0 <= self\.metrics_sigma" if not should_match: regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b" @@ -122,7 +127,7 @@ def test_log_policy_exclude_node_single_agent(should_match: bool) -> None: config["hyperparameters"]["metrics_sigma"] = -1 config["max_restarts"] = 1 - agents = bindings.get_GetAgents(api_utils.determined_test_session()).agents + agents = bindings.get_GetAgents(sess).agents assert len(agents) == 1 assert agents[0].slots is not None config["resources"] = {"slots_per_trial": len(agents[0].slots)} @@ -130,26 +135,24 @@ def test_log_policy_exclude_node_single_agent(should_match: bool) -> None: with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: yaml.dump(config, f) - exp_id = exp.create_experiment(tf.name, conf.fixtures_path("no_op")) + exp_id = exp.create_experiment(sess, tf.name, conf.fixtures_path("no_op")) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.RUNNING) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.RUNNING) - master_config = bindings.get_GetMasterConfig( - api_utils.determined_test_session(admin=True) - ).config + master_config = bindings.get_GetMasterConfig(api_utils.admin_session()).config if master_config.get("launch_error"): - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) else: if should_match: - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.QUEUED) - exp.kill_experiments([exp_id]) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.QUEUED) + exp.kill_experiments(sess, [exp_id]) else: - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) - experiment_trials = exp.experiment_trials(exp_id) + experiment_trials = exp.experiment_trials(sess, exp_id) assert len(experiment_trials) == 1 assert experiment_trials[0].trial.restarts == 1 - trial_logs = "\n".join(exp.trial_logs(experiment_trials[0].trial.id)) + trial_logs = "\n".join(exp.trial_logs(sess, experiment_trials[0].trial.id)) if should_match: assert "therefore will not schedule on" in trial_logs @@ -162,7 +165,8 @@ def test_log_policy_exclude_node_single_agent(should_match: bool) -> None: @pytest.mark.e2e_slurm @pytest.mark.parametrize("should_match", [True, False]) def test_log_policy_exclude_slurm(should_match: bool) -> None: - agents = bindings.get_GetAgents(api_utils.determined_test_session()).agents + sess = api_utils.user_session() + agents = bindings.get_GetAgents(sess).agents if len(agents) != 1: pytest.skip("can only be run on a single agent cluster") @@ -185,14 +189,16 @@ def test_log_policy_exclude_slurm(should_match: bool) -> None: with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: yaml.dump(config, f) - exp_id = exp.create_experiment(tf.name, conf.fixtures_path("no_op")) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp_id = exp.create_experiment(sess, tf.name, conf.fixtures_path("no_op")) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) - trials = exp.experiment_trials(exp_id) + trials = exp.experiment_trials(sess, exp_id) assert len(trials) == 1 assert trials[0].trial.restarts == 1 - times_ran = "\n".join(exp.trial_logs(trials[0].trial.id)).count("Validating checkpoint storage") + times_ran = "\n".join(exp.trial_logs(sess, trials[0].trial.id)).count( + "Validating checkpoint storage" + ) if should_match: assert ( times_ran == 1 diff --git a/e2e_tests/tests/cluster/test_logging.py b/e2e_tests/tests/cluster/test_logging.py index ad791ae9429..dd5f490976e 100644 --- a/e2e_tests/tests/cluster/test_logging.py +++ b/e2e_tests/tests/cluster/test_logging.py @@ -7,7 +7,8 @@ import pytest from determined.common import api -from determined.common.api import authentication, bindings, certs +from determined.common.api import bindings +from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -26,16 +27,12 @@ @pytest.mark.e2e_pbs @pytest.mark.timeout(10 * 60) def test_trial_logs() -> None: - # TODO: refactor tests to not use cli singleton auth. - master_url = conf.make_master_url() - certs.cli_cert = certs.default_load(conf.make_master_url()) - authentication.cli_auth = authentication.Authentication(conf.make_master_url()) - session = api.Session(master_url, "determined", authentication.cli_auth, certs.cli_cert) + sess = api_utils.user_session() experiment_id = exp.run_basic_test( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 ) - trial = exp.experiment_trials(experiment_id)[0].trial + trial = exp.experiment_trials(sess, experiment_id)[0].trial trial_id = trial.id task_id = trial.taskId assert task_id != "" @@ -45,15 +42,15 @@ def test_trial_logs() -> None: # Trial-specific APIs should work just fine. check_logs( log_regex, - functools.partial(api.trial_logs, session, trial_id), - functools.partial(bindings.get_TrialLogsFields, session, trialId=trial_id), + functools.partial(api.trial_logs, sess, trial_id), + functools.partial(bindings.get_TrialLogsFields, sess, trialId=trial_id), ) # And so should new task log APIs. check_logs( log_regex, - functools.partial(api.task_logs, session, task_id), - functools.partial(bindings.get_TaskLogsFields, session, taskId=task_id), + functools.partial(api.task_logs, sess, task_id), + functools.partial(bindings.get_TaskLogsFields, sess, taskId=task_id), ) @@ -73,39 +70,37 @@ def test_trial_logs() -> None: ], ) def test_task_logs(task_type: str, task_config: Dict[str, Any], log_regex: Any) -> None: - master_url = conf.make_master_url() - certs.cli_cert = certs.default_load(conf.make_master_url()) - authentication.cli_auth = authentication.Authentication(conf.make_master_url()) - session = api.Session(master_url, "determined", authentication.cli_auth, certs.cli_cert) + sess = api_utils.user_session() - rps = bindings.get_GetResourcePools(session) + rps = bindings.get_GetResourcePools(sess) assert rps.resourcePools and len(rps.resourcePools) > 0, "missing resource pool" if task_type == "tensorboard": exp_id = exp.run_basic_test( + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1, ) treq = bindings.v1LaunchTensorboardRequest(config=task_config, experimentIds=[exp_id]) - task_id = bindings.post_LaunchTensorboard(session, body=treq).tensorboard.id + task_id = bindings.post_LaunchTensorboard(sess, body=treq).tensorboard.id elif task_type == "notebook": nreq = bindings.v1LaunchNotebookRequest(config=task_config) - task_id = bindings.post_LaunchNotebook(session, body=nreq).notebook.id + task_id = bindings.post_LaunchNotebook(sess, body=nreq).notebook.id elif task_type == "command": creq = bindings.v1LaunchCommandRequest(config=task_config) - task_id = bindings.post_LaunchCommand(session, body=creq).command.id + task_id = bindings.post_LaunchCommand(sess, body=creq).command.id elif task_type == "shell": sreq = bindings.v1LaunchShellRequest(config=task_config) - task_id = bindings.post_LaunchShell(session, body=sreq).shell.id + task_id = bindings.post_LaunchShell(sess, body=sreq).shell.id else: raise ValueError("unknown task type: {task_type}") def task_logs(**kwargs: Any) -> Iterable[Log]: - return api.task_logs(session, task_id, **kwargs) + return api.task_logs(sess, task_id, **kwargs) def task_log_fields(follow: Optional[bool] = None) -> Iterable[LogFields]: - return bindings.get_TaskLogsFields(session, taskId=task_id, follow=follow) + return bindings.get_TaskLogsFields(sess, taskId=task_id, follow=follow) try: result: Optional[Exception] = None @@ -115,8 +110,8 @@ def do_check_logs() -> None: try: check_logs( log_regex, - functools.partial(api.task_logs, session, task_id), - functools.partial(bindings.get_TaskLogsFields, session, taskId=task_id), + functools.partial(api.task_logs, sess, task_id), + functools.partial(bindings.get_TaskLogsFields, sess, taskId=task_id), ) except Exception as e: result = e @@ -139,13 +134,13 @@ def do_check_logs() -> None: finally: if task_type == "tensorboard": - bindings.post_KillTensorboard(session, tensorboardId=task_id) + bindings.post_KillTensorboard(sess, tensorboardId=task_id) elif task_type == "notebook": - bindings.post_KillNotebook(session, notebookId=task_id) + bindings.post_KillNotebook(sess, notebookId=task_id) elif task_type == "command": - bindings.post_KillCommand(session, commandId=task_id) + bindings.post_KillCommand(sess, commandId=task_id) elif task_type == "shell": - bindings.post_KillShell(session, shellId=task_id) + bindings.post_KillShell(sess, shellId=task_id) def check_logs( diff --git a/e2e_tests/tests/cluster/test_master_restart.py b/e2e_tests/tests/cluster/test_master_restart.py index cfb8884eb49..21a185ef184 100644 --- a/e2e_tests/tests/cluster/test_master_restart.py +++ b/e2e_tests/tests/cluster/test_master_restart.py @@ -1,59 +1,48 @@ import logging -import subprocess import time import docker import pytest import requests -from determined.common import constants -from determined.common.api import authentication, bindings, task_is_ready, task_logs -from determined.common.api.bindings import experimentv1State as EXP_STATE +from determined.common import api +from determined.common.api import bindings from tests import api_utils from tests import command as cmd from tests import config as conf +from tests import detproc from tests import experiment as exp -from tests.cluster.test_users import det_spawn +from tests.cluster import abstract_cluster, managed_cluster, managed_cluster_k8s, utils from tests.task import task -from .abstract_cluster import Cluster -from .managed_cluster import ManagedCluster, get_agent_data -from .managed_cluster_k8s import ManagedK8sCluster -from .test_groups import det_cmd, det_cmd_json -from .utils import ( - assert_command_succeeded, - run_command, - wait_for_command_state, - wait_for_task_state, -) - logger = logging.getLogger(__name__) @pytest.mark.managed_devcluster -def test_master_restart_ok(restartable_managed_cluster: ManagedCluster) -> None: +def test_master_restart_ok(restartable_managed_cluster: managed_cluster.ManagedCluster) -> None: _test_master_restart_ok(restartable_managed_cluster) restartable_managed_cluster.restart_agent(wait_for_amnesia=False) @pytest.mark.e2e_k8s -def test_master_restart_ok_k8s(k8s_managed_cluster: ManagedK8sCluster) -> None: +def test_master_restart_ok_k8s(k8s_managed_cluster: managed_cluster_k8s.ManagedK8sCluster) -> None: _test_master_restart_ok(k8s_managed_cluster) -def _test_master_restart_ok(managed_cluster: Cluster) -> None: +def _test_master_restart_ok(managed_cluster: abstract_cluster.Cluster) -> None: # - Kill master # - Restart master # - Schedule something. # Do it twice to ensure nothing gets stuck. + sess = api_utils.user_session() try: for i in range(3): print("test_master_restart_ok stage %s start" % i) - cmd_ids = [run_command(1, slots) for slots in [0, 1]] + cmd_ids = [utils.run_command(sess, 1, slots) for slots in [0, 1]] for cmd_id in cmd_ids: - wait_for_command_state(cmd_id, "TERMINATED", 300) - assert_command_succeeded(cmd_id) + utils.wait_for_command_state(sess, cmd_id, "TERMINATED", 300) + utils.assert_command_succeeded(sess, cmd_id) managed_cluster.kill_master() managed_cluster.restart_master() @@ -68,7 +57,7 @@ def _test_master_restart_ok(managed_cluster: Cluster) -> None: @pytest.mark.managed_devcluster @pytest.mark.parametrize("downtime", [0, 20, 60]) def test_master_restart_reattach_recover_experiment( - restartable_managed_cluster: ManagedCluster, + restartable_managed_cluster: managed_cluster.ManagedCluster, downtime: int, ) -> None: _test_master_restart_reattach_recover_experiment(restartable_managed_cluster, downtime) @@ -77,15 +66,17 @@ def test_master_restart_reattach_recover_experiment( @pytest.mark.e2e_k8s @pytest.mark.parametrize("downtime", [0, 20, 60]) def test_master_restart_reattach_recover_experiment_k8s( - k8s_managed_cluster: ManagedK8sCluster, + k8s_managed_cluster: managed_cluster_k8s.ManagedK8sCluster, downtime: int, ) -> None: _test_master_restart_reattach_recover_experiment(k8s_managed_cluster, downtime) @pytest.mark.managed_devcluster -def test_master_restart_generic_task(managed_cluster_restarts: ManagedCluster) -> None: - test_session = api_utils.determined_test_session() +def test_master_restart_generic_task( + managed_cluster_restarts: managed_cluster.ManagedCluster, +) -> None: + sess = api_utils.user_session() with open(conf.fixtures_path("generic_task/test_config.yaml"), "r") as config_file: # Create task @@ -100,18 +91,20 @@ def test_master_restart_generic_task(managed_cluster_restarts: ManagedCluster) - inheritContext=False, noPause=False, ) - task_resp = bindings.post_CreateGenericTask(test_session, body=req) + task_resp = bindings.post_CreateGenericTask(sess, body=req) # Wait for task to start - task.wait_for_task_start(test_session, task_resp.taskId) + task.wait_for_task_start(sess, task_resp.taskId) managed_cluster_restarts.kill_master() managed_cluster_restarts.restart_master() - task.wait_for_task_state(test_session, task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) + task.wait_for_task_state(sess, task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) @pytest.mark.managed_devcluster -def test_master_restart_generic_task_pause(managed_cluster_restarts: ManagedCluster) -> None: - test_session = api_utils.determined_test_session() +def test_master_restart_generic_task_pause( + managed_cluster_restarts: managed_cluster.ManagedCluster, +) -> None: + sess = api_utils.user_session() with open(conf.fixtures_path("generic_task/test_config.yaml"), "r") as config_file: # Create task @@ -126,42 +119,46 @@ def test_master_restart_generic_task_pause(managed_cluster_restarts: ManagedClus inheritContext=False, noPause=False, ) - task_resp = bindings.post_CreateGenericTask(test_session, body=req) + task_resp = bindings.post_CreateGenericTask(sess, body=req) # Wait for task to start - task.wait_for_task_start(test_session, task_resp.taskId) + task.wait_for_task_start(sess, task_resp.taskId) # Pause task - bindings.post_PauseGenericTask(test_session, taskId=task_resp.taskId) - task.wait_for_task_state(test_session, task_resp.taskId, bindings.v1GenericTaskState.PAUSED) + bindings.post_PauseGenericTask(sess, taskId=task_resp.taskId) + task.wait_for_task_state(sess, task_resp.taskId, bindings.v1GenericTaskState.PAUSED) managed_cluster_restarts.kill_master() managed_cluster_restarts.restart_master() # Unpause task - bindings.post_UnpauseGenericTask(test_session, taskId=task_resp.taskId) - task.wait_for_task_state(test_session, task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) + bindings.post_UnpauseGenericTask(sess, taskId=task_resp.taskId) + task.wait_for_task_state(sess, task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) @pytest.mark.managed_devcluster def _test_master_restart_reattach_recover_experiment( - restartable_managed_cluster: Cluster, downtime: int + restartable_managed_cluster: abstract_cluster.Cluster, downtime: int ) -> None: + sess = api_utils.user_session() try: exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), None, ) # TODO(ilia): don't wait for progress. - exp.wait_for_experiment_workload_progress(exp_id) + exp.wait_for_experiment_workload_progress(sess, exp_id) if downtime >= 0: restartable_managed_cluster.kill_master() time.sleep(downtime) restartable_managed_cluster.restart_master() - exp.wait_for_experiment_state(exp_id, EXP_STATE.COMPLETED, max_wait_secs=downtime + 60) - trials = exp.experiment_trials(exp_id) + exp.wait_for_experiment_state( + sess, exp_id, bindings.experimentv1State.COMPLETED, max_wait_secs=downtime + 60 + ) + trials = exp.experiment_trials(sess, exp_id) assert len(trials) == 1 train_wls = exp.workloads_with_training(trials[0].workloads) @@ -173,16 +170,22 @@ def _test_master_restart_reattach_recover_experiment( @pytest.mark.managed_devcluster -def test_master_restart_continued_experiment(managed_cluster_restarts: ManagedCluster) -> None: +def test_master_restart_continued_experiment( + managed_cluster_restarts: managed_cluster.ManagedCluster, +) -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), None, ) - exp.wait_for_experiment_state(exp_id, EXP_STATE.COMPLETED) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED) - det_cmd( + detproc.check_output( + sess, [ + "det", "e", "continue", str(exp_id), @@ -191,39 +194,41 @@ def test_master_restart_continued_experiment(managed_cluster_restarts: ManagedCl "--config", "searcher.name=single", ], - check=True, ) managed_cluster_restarts.kill_master() managed_cluster_restarts.restart_master() - exp.wait_for_experiment_state(exp_id, EXP_STATE.COMPLETED, max_wait_secs=60) + exp.wait_for_experiment_state( + sess, exp_id, bindings.experimentv1State.COMPLETED, max_wait_secs=60 + ) # We continued the latest task, not the first one. - experiment_trials = exp.experiment_trials(exp_id) + experiment_trials = exp.experiment_trials(sess, exp_id) assert len(experiment_trials) == 1 task_ids = experiment_trials[0].trial.taskIds assert task_ids is not None assert len(task_ids) == 2 - sess = api_utils.determined_test_session() - logs = task_logs(sess, task_ids[-1]) + logs = api.task_logs(sess, task_ids[-1]) assert "resources exited successfully with a zero exit code" in "".join(log.log for log in logs) @pytest.mark.managed_devcluster @pytest.mark.parametrize("wait_for_amnesia", [True, False]) def test_master_restart_error_missing_docker_container( - managed_cluster_restarts: ManagedCluster, + managed_cluster_restarts: managed_cluster.ManagedCluster, wait_for_amnesia: bool, ) -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("core_api/sleep.yaml"), conf.fixtures_path("core_api"), None, ) try: - exp.wait_for_experiment_state(exp_id, EXP_STATE.RUNNING) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.RUNNING) client = docker.from_env() containers = client.containers.list() @@ -238,8 +243,8 @@ def test_master_restart_error_missing_docker_container( managed_cluster_restarts.restart_master() managed_cluster_restarts.restart_agent(wait_for_amnesia=wait_for_amnesia) - exp.wait_for_experiment_state(exp_id, EXP_STATE.RUNNING) - trials = exp.experiment_trials(exp_id) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.RUNNING) + trials = exp.experiment_trials(sess, exp_id) trial_id = trials[0].trial.id expected_message = ( @@ -255,46 +260,53 @@ def test_master_restart_error_missing_docker_container( ) for _ in range(30): - trial_logs = "\n".join(exp.trial_logs(trial_id)) + trial_logs = "\n".join(exp.trial_logs(sess, trial_id)) if expected_message in trial_logs: break time.sleep(1) assert expected_message in trial_logs finally: - subprocess.check_call(["det", "-m", conf.make_master_url(), "e", "kill", str(exp_id)]) - exp.wait_for_experiment_state(exp_id, EXP_STATE.CANCELED, max_wait_secs=20) + detproc.check_call(sess, ["det", "e", "kill", str(exp_id)]) + exp.wait_for_experiment_state( + sess, exp_id, bindings.experimentv1State.CANCELED, max_wait_secs=20 + ) @pytest.mark.managed_devcluster -def test_master_restart_kill_works_experiment(restartable_managed_cluster: ManagedCluster) -> None: +def test_master_restart_kill_works_experiment( + restartable_managed_cluster: managed_cluster.ManagedCluster, +) -> None: _test_master_restart_kill_works(restartable_managed_cluster) @pytest.mark.e2e_k8s def test_master_restart_kill_works_k8s( - k8s_managed_cluster: ManagedK8sCluster, + k8s_managed_cluster: managed_cluster_k8s.ManagedK8sCluster, ) -> None: _test_master_restart_kill_works(k8s_managed_cluster) -def _test_master_restart_kill_works(managed_cluster_restarts: Cluster) -> None: +def _test_master_restart_kill_works(managed_cluster_restarts: abstract_cluster.Cluster) -> None: + sess = api_utils.user_session() try: exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-many-long-steps.yaml"), conf.fixtures_path("no_op"), ["--config", "searcher.max_length.batches=10000", "--config", "max_restarts=0"], ) - exp.wait_for_experiment_workload_progress(exp_id) + exp.wait_for_experiment_workload_progress(sess, exp_id) managed_cluster_restarts.kill_master() time.sleep(0) managed_cluster_restarts.restart_master() - command = ["det", "-m", conf.make_master_url(), "e", "kill", str(exp_id)] - subprocess.check_call(command) + detproc.check_call(sess, ["det", "e", "kill", str(exp_id)]) - exp.wait_for_experiment_state(exp_id, EXP_STATE.CANCELED, max_wait_secs=30) + exp.wait_for_experiment_state( + sess, exp_id, bindings.experimentv1State.CANCELED, max_wait_secs=30 + ) managed_cluster_restarts.ensure_agent_ok() except Exception: @@ -305,7 +317,7 @@ def _test_master_restart_kill_works(managed_cluster_restarts: Cluster) -> None: @pytest.mark.managed_devcluster @pytest.mark.parametrize("downtime, slots", [(0, 0), (20, 1), (60, 1)]) def test_master_restart_cmd( - restartable_managed_cluster: ManagedCluster, slots: int, downtime: int + restartable_managed_cluster: managed_cluster.ManagedCluster, slots: int, downtime: int ) -> None: _test_master_restart_cmd(restartable_managed_cluster, slots, downtime) @@ -314,16 +326,17 @@ def test_master_restart_cmd( @pytest.mark.parametrize("slots", [0, 1]) @pytest.mark.parametrize("downtime", [0, 20, 60]) def test_master_restart_cmd_k8s( - k8s_managed_cluster: ManagedK8sCluster, slots: int, downtime: int + k8s_managed_cluster: managed_cluster_k8s.ManagedK8sCluster, slots: int, downtime: int ) -> None: _test_master_restart_cmd(k8s_managed_cluster, slots, downtime) -def _test_master_restart_cmd(managed_cluster: Cluster, slots: int, downtime: int) -> None: +def _test_master_restart_cmd( + managed_cluster: abstract_cluster.Cluster, slots: int, downtime: int +) -> None: + sess = api_utils.user_session() command = [ "det", - "-m", - conf.make_master_url(), "command", "run", "-d", @@ -331,11 +344,11 @@ def _test_master_restart_cmd(managed_cluster: Cluster, slots: int, downtime: int f"resources.slots={slots}", "echo weareready && sleep 30", ] - command_id = subprocess.check_output(command).decode().strip() + command_id = detproc.check_output(sess, command).strip() # Don't just check ready. We want to make sure the command's sleep has started. logs = "" - for log in task_logs(api_utils.determined_test_session(), command_id, follow=True): + for log in api.task_logs(sess, command_id, follow=True): if "weareready" in log.log: break logs += log.log @@ -347,61 +360,52 @@ def _test_master_restart_cmd(managed_cluster: Cluster, slots: int, downtime: int time.sleep(downtime) managed_cluster.restart_master() - wait_for_command_state(command_id, "TERMINATED", 30) - assert_command_succeeded(command_id) + utils.wait_for_command_state(sess, command_id, "TERMINATED", 30) + utils.assert_command_succeeded(sess, command_id) @pytest.mark.managed_devcluster @pytest.mark.parametrize("downtime", [5]) -def test_master_restart_shell(restartable_managed_cluster: ManagedCluster, downtime: int) -> None: +def test_master_restart_shell( + restartable_managed_cluster: managed_cluster.ManagedCluster, downtime: int +) -> None: _test_master_restart_shell(restartable_managed_cluster, downtime) @pytest.mark.e2e_k8s @pytest.mark.parametrize("downtime", [5]) -def test_master_restart_shell_k8s(k8s_managed_cluster: ManagedK8sCluster, downtime: int) -> None: +def test_master_restart_shell_k8s( + k8s_managed_cluster: managed_cluster_k8s.ManagedK8sCluster, downtime: int +) -> None: _test_master_restart_shell(k8s_managed_cluster, downtime) -def _test_master_restart_shell(managed_cluster: Cluster, downtime: int) -> None: - with cmd.interactive_command("shell", "start", "--detach") as shell: - task_id = shell.task_id - - assert task_id is not None +def _test_master_restart_shell(managed_cluster: abstract_cluster.Cluster, downtime: int) -> None: + sess = api_utils.user_session() + with cmd.interactive_command(sess, ["shell", "start", "--detach"]) as shell: + assert shell.task_id # Checking running is not correct here, running != ready for shells. - task_is_ready(api_utils.determined_test_session(), task_id) - pre_restart_queue = det_cmd_json(["job", "list", "--json"]) + api.wait_for_task_ready(sess, shell.task_id) + pre_restart_queue = detproc.check_json(sess, ["det", "job", "list", "--json"]) if downtime >= 0: managed_cluster.kill_master() time.sleep(downtime) managed_cluster.restart_master() - wait_for_task_state("shell", task_id, "RUNNING") - post_restart_queue = det_cmd_json(["job", "list", "--json"]) + utils.wait_for_task_state(sess, "shell", shell.task_id, "RUNNING") + post_restart_queue = detproc.check_json(sess, ["det", "job", "list", "--json"]) assert pre_restart_queue == post_restart_queue - child = det_spawn(["shell", "open", task_id]) - child.setecho(True) - child.expect(r".*Permanently added.+([0-9a-f-]{36}).+known hosts\.") - child.sendline("det user whoami") - child.expect("You are logged in as user \\'(.*)\\'", timeout=10) - child.sendline("exit") - child.read() - child.wait() - assert child.exitstatus == 0 - - -def _get_auth_token_for_curl() -> str: - token = authentication.TokenStore(conf.make_master_url()).get_token( - constants.DEFAULT_DETERMINED_USER - ) - assert token is not None - return token + output = detproc.check_output( + sess, ["det", "shell", "open", shell.task_id, "det", "user", "whoami"] + ) + assert "you are logged in as user" in output.lower(), output -def _check_web_url(url: str, name: str) -> None: - token = _get_auth_token_for_curl() +def _check_web_url(sess: api.Session, path: str, name: str) -> None: + token = sess.token + url = sess.master.rstrip("/") + "/" + path.lstrip("/") bad_status_codes = [] for _ in range(30): r = requests.get(url, headers={"Authorization": f"Bearer {token}"}, allow_redirects=True) @@ -420,42 +424,45 @@ def _check_web_url(url: str, name: str) -> None: pytest.fail(f"{name} {url} got error codes: {error_msg}") -def _check_notebook_url(url: str) -> None: - return _check_web_url(url, "JupyterLab") +def _check_notebook_url(sess: api.Session, path: str) -> None: + return _check_web_url(sess, path, "JupyterLab") -def _check_tb_url(url: str) -> None: - return _check_web_url(url, "TensorBoard") +def _check_tb_url(sess: api.Session, path: str) -> None: + return _check_web_url(sess, path, "TensorBoard") @pytest.mark.managed_devcluster @pytest.mark.parametrize("downtime", [5]) def test_master_restart_notebook( - restartable_managed_cluster: ManagedCluster, downtime: int + restartable_managed_cluster: managed_cluster.ManagedCluster, downtime: int ) -> None: return _test_master_restart_notebook(restartable_managed_cluster, downtime) @pytest.mark.e2e_k8s @pytest.mark.parametrize("downtime", [5]) -def test_master_restart_notebook_k8s(k8s_managed_cluster: ManagedK8sCluster, downtime: int) -> None: +def test_master_restart_notebook_k8s( + k8s_managed_cluster: managed_cluster_k8s.ManagedK8sCluster, downtime: int +) -> None: return _test_master_restart_notebook(k8s_managed_cluster, downtime) -def _test_master_restart_notebook(managed_cluster: Cluster, downtime: int) -> None: - with cmd.interactive_command("notebook", "start", "--detach") as notebook: +def _test_master_restart_notebook(managed_cluster: abstract_cluster.Cluster, downtime: int) -> None: + sess = api_utils.user_session() + with cmd.interactive_command(sess, ["notebook", "start", "--detach"]) as notebook: task_id = notebook.task_id assert task_id is not None - wait_for_task_state("notebook", task_id, "RUNNING") - notebook_url = f"{conf.make_master_url()}proxy/{task_id}/" - _check_notebook_url(notebook_url) + utils.wait_for_task_state(sess, "notebook", task_id, "RUNNING") + notebook_path = f"proxy/{task_id}/" + _check_notebook_url(sess, notebook_path) if downtime >= 0: managed_cluster.kill_master() time.sleep(downtime) managed_cluster.restart_master() - _check_notebook_url(notebook_url) + _check_notebook_url(sess, notebook_path) print("notebook ok") @@ -463,7 +470,7 @@ def _test_master_restart_notebook(managed_cluster: Cluster, downtime: int) -> No @pytest.mark.managed_devcluster @pytest.mark.parametrize("downtime", [5]) def test_master_restart_tensorboard( - restartable_managed_cluster: ManagedCluster, downtime: int + restartable_managed_cluster: managed_cluster.ManagedCluster, downtime: int ) -> None: return _test_master_restart_tensorboard(restartable_managed_cluster, downtime) @@ -471,47 +478,53 @@ def test_master_restart_tensorboard( @pytest.mark.e2e_k8s @pytest.mark.parametrize("downtime", [5]) def test_master_restart_tensorboard_k8s( - k8s_managed_cluster: ManagedK8sCluster, downtime: int + k8s_managed_cluster: managed_cluster_k8s.ManagedK8sCluster, downtime: int ) -> None: return _test_master_restart_tensorboard(k8s_managed_cluster, downtime) -def _test_master_restart_tensorboard(managed_cluster: Cluster, downtime: int) -> None: +def _test_master_restart_tensorboard( + managed_cluster: abstract_cluster.Cluster, downtime: int +) -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-default-ckpt.yaml"), conf.fixtures_path("no_op"), None, ) - exp.wait_for_experiment_state(exp_id, EXP_STATE.COMPLETED, max_wait_secs=60) + exp.wait_for_experiment_state( + sess, exp_id, bindings.experimentv1State.COMPLETED, max_wait_secs=60 + ) - with cmd.interactive_command("tensorboard", "start", "--detach", str(exp_id)) as tb: + with cmd.interactive_command(sess, ["tensorboard", "start", "--detach", str(exp_id)]) as tb: task_id = tb.task_id assert task_id is not None - wait_for_task_state("tensorboard", task_id, "RUNNING") + utils.wait_for_task_state(sess, "tensorboard", task_id, "RUNNING") - tb_url = f"{conf.make_master_url()}proxy/{task_id}/" - _check_tb_url(tb_url) + tb_path = f"proxy/{task_id}/" + _check_tb_url(sess, tb_path) if downtime >= 0: managed_cluster.kill_master() time.sleep(downtime) managed_cluster.restart_master() - _check_tb_url(tb_url) + _check_tb_url(sess, tb_path) print("tensorboard ok") @pytest.mark.managed_devcluster -def test_agent_devices_change(restartable_managed_cluster: ManagedCluster) -> None: - managed_cluster = restartable_managed_cluster +def test_agent_devices_change(restartable_managed_cluster: managed_cluster.ManagedCluster) -> None: + admin = api_utils.admin_session() try: - managed_cluster.kill_agent() - managed_cluster.dc.restart_stage("agent10") + restartable_managed_cluster.kill_agent() + restartable_managed_cluster.dc.restart_stage("agent10") for _i in range(5): - agent_data = get_agent_data(conf.make_master_url()) + agent_data = managed_cluster.get_agent_data(admin) if len(agent_data) == 0: # Agent has exploded and been wiped due to device mismatch, as expected. break @@ -520,30 +533,33 @@ def test_agent_devices_change(restartable_managed_cluster: ManagedCluster) -> No f"agent with different devices is still present after {_i} ticks: {agent_data}" ) finally: - managed_cluster.dc.kill_stage("agent10") - managed_cluster.restart_agent() + restartable_managed_cluster.dc.kill_stage("agent10") + restartable_managed_cluster.restart_agent() @pytest.mark.e2e_k8s -def test_master_restart_with_queued(k8s_managed_cluster: ManagedK8sCluster) -> None: - agent_data = get_agent_data(conf.make_master_url()) +def test_master_restart_with_queued( + k8s_managed_cluster: managed_cluster_k8s.ManagedK8sCluster, +) -> None: + sess = api_utils.user_session() + agent_data = managed_cluster.get_agent_data(api_utils.admin_session()) slots = sum([a["num_slots"] for a in agent_data]) - running_command_id = run_command(120, slots) - wait_for_command_state(running_command_id, "RUNNING", 30) + running_command_id = utils.run_command(sess, 120, slots) + utils.wait_for_command_state(sess, running_command_id, "RUNNING", 30) - queued_command_id = run_command(60, slots) - wait_for_command_state(queued_command_id, "QUEUED", 30) + queued_command_id = utils.run_command(sess, 60, slots) + utils.wait_for_command_state(sess, queued_command_id, "QUEUED", 30) - job_list = det_cmd_json(["job", "list", "--json"])["jobs"] + job_list = detproc.check_json(sess, ["det", "job", "list", "--json"])["jobs"] k8s_managed_cluster.kill_master() k8s_managed_cluster.restart_master() - post_restart_job_list = det_cmd_json(["job", "list", "--json"])["jobs"] + post_restart_job_list = detproc.check_json(sess, ["det", "job", "list", "--json"])["jobs"] assert job_list == post_restart_job_list for cmd_id in [running_command_id, queued_command_id]: - wait_for_command_state(cmd_id, "TERMINATED", 90) - assert_command_succeeded(cmd_id) + utils.wait_for_command_state(sess, cmd_id, "TERMINATED", 90) + utils.assert_command_succeeded(sess, cmd_id) diff --git a/e2e_tests/tests/cluster/test_master_restart_slurm.py b/e2e_tests/tests/cluster/test_master_restart_slurm.py index e4b03df96b3..9752b5da1fa 100644 --- a/e2e_tests/tests/cluster/test_master_restart_slurm.py +++ b/e2e_tests/tests/cluster/test_master_restart_slurm.py @@ -3,12 +3,7 @@ import pytest -from .managed_slurm_cluster import ManagedSlurmCluster -from .test_master_restart import ( - _test_master_restart_cmd, - _test_master_restart_ok, - _test_master_restart_reattach_recover_experiment, -) +from tests.cluster import managed_slurm_cluster, test_master_restart logger = logging.getLogger(__name__) @@ -16,8 +11,8 @@ # Create a pytest fixture that returns a restartable instance of ManagedSlurmCluster. @pytest.fixture def restartable_managed_slurm_cluster( - managed_slurm_cluster_restarts: ManagedSlurmCluster, -) -> Iterator[ManagedSlurmCluster]: + managed_slurm_cluster_restarts: managed_slurm_cluster.ManagedSlurmCluster, +) -> Iterator[managed_slurm_cluster.ManagedSlurmCluster]: try: yield managed_slurm_cluster_restarts except Exception: @@ -27,8 +22,10 @@ def restartable_managed_slurm_cluster( # Test to ensure master restarts successfully. @pytest.mark.e2e_slurm_restart -def test_master_restart_ok_slurm(managed_slurm_cluster_restarts: ManagedSlurmCluster) -> None: - _test_master_restart_ok(managed_slurm_cluster_restarts) +def test_master_restart_ok_slurm( + managed_slurm_cluster_restarts: managed_slurm_cluster.ManagedSlurmCluster, +) -> None: + test_master_restart._test_master_restart_ok(managed_slurm_cluster_restarts) # Test to ensure that master can reattach to the experiment and resume it, after the determined @@ -36,9 +33,11 @@ def test_master_restart_ok_slurm(managed_slurm_cluster_restarts: ManagedSlurmClu @pytest.mark.e2e_slurm_restart @pytest.mark.parametrize("downtime", [0, 20, 60]) def test_master_restart_reattach_recover_experiment_slurm( - managed_slurm_cluster_restarts: ManagedSlurmCluster, downtime: int + managed_slurm_cluster_restarts: managed_slurm_cluster.ManagedSlurmCluster, downtime: int ) -> None: - _test_master_restart_reattach_recover_experiment(managed_slurm_cluster_restarts, downtime) + test_master_restart._test_master_restart_reattach_recover_experiment( + managed_slurm_cluster_restarts, downtime + ) # Test to ensure that master can recover and complete a command that was in running state @@ -47,6 +46,8 @@ def test_master_restart_reattach_recover_experiment_slurm( @pytest.mark.parametrize("slots", [0, 1]) @pytest.mark.parametrize("downtime", [0, 20, 60]) def test_master_restart_cmd_slurm( - restartable_managed_slurm_cluster: ManagedSlurmCluster, slots: int, downtime: int + restartable_managed_slurm_cluster: managed_slurm_cluster.ManagedSlurmCluster, + slots: int, + downtime: int, ) -> None: - _test_master_restart_cmd(restartable_managed_slurm_cluster, slots, downtime) + test_master_restart._test_master_restart_cmd(restartable_managed_slurm_cluster, slots, downtime) diff --git a/e2e_tests/tests/cluster/test_model_registry.py b/e2e_tests/tests/cluster/test_model_registry.py index 325ff625350..82d1621126b 100644 --- a/e2e_tests/tests/cluster/test_model_registry.py +++ b/e2e_tests/tests/cluster/test_model_registry.py @@ -1,35 +1,32 @@ -import subprocess +import http import uuid -from http import HTTPStatus import pytest -from determined.common.api.errors import APIException -from determined.experimental import Determined, ModelSortBy +from determined.common.api import errors +from determined.experimental import client from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp -from tests.cluster.test_users import log_out_user - -from .test_workspace_org import setup_workspaces +from tests.cluster import test_workspace_org @pytest.mark.e2e_cpu def test_model_registry() -> None: + sess = api_utils.user_session() exp_id = exp.run_basic_test( + sess, conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"), conf.tutorials_path("mnist_pytorch"), None, ) - - log_out_user() # Ensure that we use determined credentials. - - d = Determined(conf.make_master_url()) + d = client.Determined._from_session(sess) mnist = None objectdetect = None tform = None - existing_models = [m.name for m in d.get_models(sort_by=ModelSortBy.NAME)] + existing_models = [m.name for m in d.get_models(sort_by=client.ModelSortBy.NAME)] try: # Create a model and validate twiddling the metadata. @@ -37,12 +34,12 @@ def test_model_registry() -> None: assert mnist.metadata == {} # Attempt to create model with a duplicate name - with pytest.raises(APIException) as e: + with pytest.raises(errors.APIException) as e: duplicate_model = d.create_model( "mnist", "simple computer vision model", labels=["a", "b"] ) assert duplicate_model is None - assert e.value.status_code == HTTPStatus.CONFLICT + assert e.value.status_code == http.HTTPStatus.CONFLICT mnist.add_metadata({"testing": "metadata"}) db_model = d.get_model(mnist.name) @@ -123,6 +120,7 @@ def test_model_registry() -> None: # Run another basic test and register its checkpoint as a version as well. # Validate the latest has been updated. exp_id = exp.run_basic_test( + sess, conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"), conf.tutorials_path("mnist_pytorch"), None, @@ -154,7 +152,7 @@ def test_model_registry() -> None: tform = d.create_model("transformer", "all you need is attention") objectdetect = d.create_model("ac - Dc", "a test name model") - models = d.get_models(sort_by=ModelSortBy.NAME) + models = d.get_models(sort_by=client.ModelSortBy.NAME) model_names = [m.name for m in models if m.name not in existing_models] assert model_names == ["ac - Dc", "mnist", "transformer"] @@ -167,7 +165,7 @@ def test_model_registry() -> None: # Test deletion of model tform.delete() tform = None - models = d.get_models(sort_by=ModelSortBy.NAME) + models = d.get_models(sort_by=client.ModelSortBy.NAME) model_names = [m.name for m in models if m.name not in existing_models] assert model_names == ["ac - Dc", "mnist"] finally: @@ -183,45 +181,43 @@ def get_random_string() -> str: @pytest.mark.e2e_cpu def test_model_cli() -> None: + sess = api_utils.user_session() test_model_1_name = get_random_string() - master_url = conf.make_master_url() - command = ["det", "-m", master_url, "model", "create", test_model_1_name] - subprocess.run(command, check=True) - d = Determined(master_url) + command = ["det", "model", "create", test_model_1_name] + detproc.check_call(sess, command) + d = client.Determined._from_session(sess) model_1 = d.get_model(identifier=test_model_1_name) assert model_1.workspace_id == 1 # Test det model list and det model describe - command = ["det", "-m", master_url, "model", "list"] - output = str(subprocess.check_output(command)) + command = ["det", "model", "list"] + output = detproc.check_output(sess, command) assert "Workspace ID" in output and "1" in output - command = ["det", "-m", master_url, "model", "describe", test_model_1_name] - output = str(subprocess.check_output(command)) + command = ["det", "model", "describe", test_model_1_name] + output = detproc.check_output(sess, command) assert "Workspace ID" in output and "1" in output # add a test workspace. - admin_session = api_utils.determined_test_session(admin=True) - with setup_workspaces(admin_session) as [test_workspace]: + admin = api_utils.admin_session() + with test_workspace_org.setup_workspaces(admin) as [test_workspace]: test_workspace_name = test_workspace.name # create model in test_workspace test_model_2_name = get_random_string() command = [ "det", - "-m", - master_url, "model", "create", test_model_2_name, "-w", test_workspace_name, ] - subprocess.run(command, check=True) + detproc.check_call(sess, command) model_2 = d.get_model(identifier=test_model_2_name) assert model_2.workspace_id == test_workspace.id # Test det model list -w workspace_name and det model describe - command = ["det", "-m", master_url, "model", "list", "-w", test_workspace.name] - output = str(subprocess.check_output(command)) + command = ["det", "model", "list", "-w", test_workspace.name] + output = detproc.check_output(sess, command) assert ( "Workspace ID" in output and str(test_workspace.id) in output @@ -232,15 +228,13 @@ def test_model_cli() -> None: # move test_model_1 to test_workspace command = [ "det", - "-m", - master_url, "model", "move", test_model_1_name, "-w", test_workspace_name, ] - subprocess.run(command, check=True) + detproc.check_call(sess, command) model_1 = d.get_model(test_model_1_name) assert model_1.workspace_id == test_workspace.id # Delete test models (workspace deleted in setup_workspace) diff --git a/e2e_tests/tests/cluster/test_model_registry_rbac.py b/e2e_tests/tests/cluster/test_model_registry_rbac.py index cc4274a46f9..9cbd181e8d9 100644 --- a/e2e_tests/tests/cluster/test_model_registry_rbac.py +++ b/e2e_tests/tests/cluster/test_model_registry_rbac.py @@ -3,20 +3,13 @@ import pytest -from determined.common import util -from determined.common.api import Session, authentication, bindings, errors -from determined.common.api.bindings import experimentv1State -from determined.common.experimental import model -from determined.experimental import Checkpoint, Determined -from determined.experimental import client as _client +from determined.common import api, util +from determined.common.api import bindings, errors +from determined.experimental import client from tests import api_utils from tests import config as conf -from tests import experiment -from tests.cluster.test_rbac import create_workspaces_with_users -from tests.cluster.test_users import log_in_user_cli, logged_in_user - -from .test_groups import det_cmd -from .test_workspace_org import setup_workspaces +from tests import detproc, experiment +from tests.cluster import test_rbac, test_workspace_org def get_random_string() -> str: @@ -24,10 +17,10 @@ def get_random_string() -> str: def all_operations( - determined_obj: Determined, + determined_obj: client.Determined, test_workspace: bindings.v1Workspace, - checkpoint: Checkpoint, -) -> Tuple[model.Model, str]: + checkpoint: client.Checkpoint, +) -> Tuple[client.Model, str]: test_model_name = get_random_string() determined_obj.create_model(name=test_model_name, workspace_name=test_workspace.name) @@ -62,7 +55,9 @@ def all_operations( return model_obj, "Uncategorized" -def view_operations(determined_obj: Determined, model: model.Model, workspace_name: str) -> None: +def view_operations( + determined_obj: client.Determined, model: client.Model, workspace_name: str +) -> None: db_model = determined_obj.get_model(model.name) assert db_model.name == model.name models = determined_obj.get_models(workspace_names=[workspace_name]) @@ -70,7 +65,7 @@ def view_operations(determined_obj: Determined, model: model.Model, workspace_na def user_with_view_perms_test( - determined_obj: Determined, workspace_name: str, model: model.Model + determined_obj: client.Determined, workspace_name: str, model: client.Model ) -> None: view_operations(determined_obj=determined_obj, model=model, workspace_name=workspace_name) # fail edit model @@ -85,256 +80,202 @@ def user_with_view_perms_test( assert "access denied" in str(e.value) -def create_model_registry(session: Session, model_name: str, workspace_id: int) -> model.Model: +def create_model_registry(session: api.Session, model_name: str, workspace_id: int) -> client.Model: resp = bindings.post_PostModel( session, body=bindings.v1PostModelRequest(name=model_name, workspaceId=workspace_id), ) assert resp.model is not None - return model.Model._from_bindings(resp.model, session) + return client.Model._from_bindings(resp.model, session) def register_model_version( - creds: authentication.Credentials, model_name: str, workspace_id: int -) -> Tuple[model.Model, model.ModelVersion]: + sess: api.Session, model_name: str, workspace_id: int +) -> Tuple[client.Model, client.ModelVersion]: m = None model_version = None - session = api_utils.determined_test_session(creds) - with logged_in_user(creds): - pid = bindings.post_PostProject( - session, - body=bindings.v1PostProjectRequest(name=get_random_string(), workspaceId=workspace_id), - workspaceId=workspace_id, - ).project.id - m = create_model_registry(session, model_name, workspace_id) - experiment_id = experiment.create_experiment( - conf.fixtures_path("no_op/single.yaml"), - conf.fixtures_path("no_op"), - ["--project_id", str(pid)], - ) - experiment.wait_for_experiment_state( - experiment_id, experimentv1State.COMPLETED, credentials=creds - ) - checkpoint = bindings.get_GetExperimentCheckpoints( - id=experiment_id, session=session - ).checkpoints[0] - model_version = m.register_version(checkpoint.uuid) - assert model_version.model_version == 1 + + pid = bindings.post_PostProject( + sess, + body=bindings.v1PostProjectRequest(name=get_random_string(), workspaceId=workspace_id), + workspaceId=workspace_id, + ).project.id + m = create_model_registry(sess, model_name, workspace_id) + experiment_id = experiment.create_experiment( + sess, + conf.fixtures_path("no_op/single.yaml"), + conf.fixtures_path("no_op"), + ["--project_id", str(pid)], + ) + experiment.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.COMPLETED) + checkpoint = bindings.get_GetExperimentCheckpoints(sess, id=experiment_id).checkpoints[0] + model_version = m.register_version(checkpoint.uuid) + assert model_version.model_version == 1 + return m, model_version @pytest.mark.test_model_registry_rbac def test_model_registry_rbac() -> None: - log_in_user_cli(conf.ADMIN_CREDENTIALS) - test_user_editor_creds = api_utils.create_test_user() - test_user_workspace_admin_creds = api_utils.create_test_user() - test_user_viewer_creds = api_utils.create_test_user() - test_user_with_no_perms_creds = api_utils.create_test_user() - test_user_model_registry_viewer_creds = api_utils.create_test_user() - admin_session = api_utils.determined_test_session(admin=True) - with setup_workspaces(admin_session) as [test_workspace]: - with logged_in_user(conf.ADMIN_CREDENTIALS): - # Assign editor role to user in Uncategorized and test_workspace. - det_cmd( + admin = api_utils.admin_session() + editor, _ = api_utils.create_test_user() + wksp_admin, _ = api_utils.create_test_user() + viewer, _ = api_utils.create_test_user() + noperms, _ = api_utils.create_test_user() + model_registry_viewer, _ = api_utils.create_test_user() + + with test_workspace_org.setup_workspaces(admin) as [test_workspace]: + for wksp in ["Uncategorized", test_workspace.name]: + # Assign editor role. + detproc.check_call( + admin, [ + "det", "rbac", "assign-role", "Editor", "--username-to-assign", - test_user_editor_creds.username, + editor.username, "--workspace-name", - "Uncategorized", + wksp, ], - check=True, ) - det_cmd( - [ - "rbac", - "assign-role", - "Editor", - "--username-to-assign", - test_user_editor_creds.username, - "--workspace-name", - test_workspace.name, - ], - check=True, - ) - - # Assign workspace admin to user in Uncategorized and test_workspace. - det_cmd( - [ - "rbac", - "assign-role", - "WorkspaceAdmin", - "--username-to-assign", - test_user_workspace_admin_creds.username, - "--workspace-name", - "Uncategorized", - ], - check=True, - ) - det_cmd( + # Assign workspace admin role. + detproc.check_call( + admin, [ + "det", "rbac", "assign-role", "WorkspaceAdmin", "--username-to-assign", - test_user_workspace_admin_creds.username, + wksp_admin.username, "--workspace-name", - test_workspace.name, + wksp, ], - check=True, ) - # Assign viewer to user in Uncategorized and test_workspace. - det_cmd( + # Assign viewer role. + detproc.check_call( + admin, [ + "det", "rbac", "assign-role", "Viewer", "--username-to-assign", - test_user_viewer_creds.username, + viewer.username, "--workspace-name", - "Uncategorized", + wksp, ], - check=True, - ) - det_cmd( - [ - "rbac", - "assign-role", - "Viewer", - "--username-to-assign", - test_user_viewer_creds.username, - "--workspace-name", - test_workspace.name, - ], - check=True, ) - # Assign model registry viewer to user in Uncategorized and test_workspace. - det_cmd( + # Assign model registry viewer role. + detproc.check_call( + admin, [ + "det", "rbac", "assign-role", "ModelRegistryViewer", "--username-to-assign", - test_user_model_registry_viewer_creds.username, + model_registry_viewer.username, "--workspace-name", - "Uncategorized", + wksp, ], - check=True, - ) - det_cmd( - [ - "rbac", - "assign-role", - "ModelRegistryViewer", - "--username-to-assign", - test_user_model_registry_viewer_creds.username, - "--workspace-name", - test_workspace.name, - ], - check=True, - ) - master_url = conf.make_master_url() - - with logged_in_user(test_user_editor_creds): - # need to get a new determined obj everytime a new user is logged in. - # Same pattern is followed below. - d = Determined(master_url) - with open(conf.fixtures_path("no_op/single-one-short-step.yaml")) as f: - config = util.yaml_safe_load(f) - exp = d.create_experiment(config, conf.fixtures_path("no_op")) - # wait for exp state to be completed - assert exp.wait() == _client.ExperimentState.COMPLETED - checkpoint = d.get_experiment(exp.id).top_checkpoint() - # need to get a new determined obj everytime a new user is logged in. - # Same pattern is followed below. - model_1, current_model_workspace = all_operations( - determined_obj=d, test_workspace=test_workspace, checkpoint=checkpoint ) - with logged_in_user(test_user_model_registry_viewer_creds): - d = Determined(master_url) - user_with_view_perms_test( - determined_obj=d, workspace_name=current_model_workspace, model=model_1 - ) + # Test editor user. + d = client.Determined._from_session(editor) + with open(conf.fixtures_path("no_op/single-one-short-step.yaml")) as f: + config = util.yaml_safe_load(f) + exp = d.create_experiment(config, conf.fixtures_path("no_op")) + # wait for exp state to be completed + assert exp.wait() == client.ExperimentState.COMPLETED + checkpoint = d.get_experiment(exp.id).top_checkpoint() + # need to get a new determined obj everytime a new user is logged in. + # Same pattern is followed below. + model_1, current_model_workspace = all_operations( + determined_obj=d, test_workspace=test_workspace, checkpoint=checkpoint + ) - with logged_in_user(test_user_viewer_creds): - d = Determined(master_url) - user_with_view_perms_test( - determined_obj=d, workspace_name=current_model_workspace, model=model_1 - ) + # Test model_registry_viewer user. + d = client.Determined._from_session(model_registry_viewer) + user_with_view_perms_test( + determined_obj=d, workspace_name=current_model_workspace, model=model_1 + ) + + # Test viewer user. + d = client.Determined._from_session(viewer) + user_with_view_perms_test( + determined_obj=d, workspace_name=current_model_workspace, model=model_1 + ) - with logged_in_user(test_user_with_no_perms_creds): - d = Determined(master_url) - with pytest.raises(Exception) as e: - d.get_models() - assert "doesn't have view permissions" in str(e.value) + # Test noperms user. + d = client.Determined._from_session(noperms) + with pytest.raises(Exception) as e: + d.get_models() + assert "doesn't have view permissions" in str(e.value) # Unassign view permissions to a certain workspace. # List should return models only in workspaces with permissions. - with logged_in_user(conf.ADMIN_CREDENTIALS): - det_cmd( - [ - "rbac", - "unassign-role", - "ModelRegistryViewer", - "--username-to-assign", - test_user_model_registry_viewer_creds.username, - "--workspace-name", - test_workspace.name, - ], - check=True, - ) - with logged_in_user(test_user_model_registry_viewer_creds): - d = Determined(master_url) - models = d.get_models() - assert test_workspace.id not in [m.workspace_id for m in models] - - with logged_in_user(test_user_editor_creds): - d = Determined(master_url) - model = d.get_model(model_1.name) - model.delete() - - with logged_in_user(test_user_workspace_admin_creds): - d = Determined(master_url) - checkpoint = d.get_experiment(exp.id).top_checkpoint() - model_2, current_model_workspace = all_operations( - determined_obj=d, test_workspace=test_workspace, checkpoint=checkpoint - ) + detproc.check_call( + admin, + [ + "det", + "rbac", + "unassign-role", + "ModelRegistryViewer", + "--username-to-assign", + model_registry_viewer.username, + "--workspace-name", + test_workspace.name, + ], + ) + + d = client.Determined._from_session(model_registry_viewer) + models = d.get_models() + assert test_workspace.id not in [m.workspace_id for m in models] + + d = client.Determined._from_session(editor) + model = d.get_model(model_1.name) + model.delete() + + d = client.Determined._from_session(wksp_admin) + checkpoint = d.get_experiment(exp.id).top_checkpoint() + model_2, current_model_workspace = all_operations( + determined_obj=d, test_workspace=test_workspace, checkpoint=checkpoint + ) # Remove workspace admin role for this user from test_workspace. - with logged_in_user(conf.ADMIN_CREDENTIALS): - det_cmd( - [ - "rbac", - "unassign-role", - "WorkspaceAdmin", - "--username-to-assign", - test_user_workspace_admin_creds.username, - "--workspace-name", - test_workspace.name, - ], - check=True, - ) + detproc.check_call( + admin, + [ + "det", + "rbac", + "unassign-role", + "WorkspaceAdmin", + "--username-to-assign", + wksp_admin.username, + "--workspace-name", + test_workspace.name, + ], + ) - with logged_in_user(test_user_workspace_admin_creds): - d = Determined(master_url) - model = d.get_model(model_2.name) - assert current_model_workspace == "Uncategorized" - # move model to test_workspace should fail. - with pytest.raises(errors.ForbiddenException) as e: - model.move_to_workspace(workspace_name=test_workspace.name) - assert "access denied" in str(e.value) - model.delete() + d = client.Determined._from_session(wksp_admin) + model = d.get_model(model_2.name) + assert current_model_workspace == "Uncategorized" + # move model to test_workspace should fail. + with pytest.raises(errors.ForbiddenException) as e: + model.move_to_workspace(workspace_name=test_workspace.name) + assert "access denied" in str(e.value) + model.delete() @pytest.mark.test_model_registry_rbac def test_model_rbac_deletes() -> None: - with create_workspaces_with_users( + with test_rbac.create_workspaces_with_users( [ [ (0, ["Editor"]), @@ -343,41 +284,36 @@ def test_model_rbac_deletes() -> None: ) as (workspaces, creds): workspace_id = workspaces[0].id # create non-cluster admin user - editor_creds = creds[0] - editor_session = api_utils.determined_test_session(editor_creds) + editor_session = creds[0] # create cluster admin user - cluster_admin_creds = api_utils.create_test_user( - add_password=True, + cluster_admin, _ = api_utils.create_test_user( user=bindings.v1User(username=get_random_string(), active=True, admin=False), ) api_utils.assign_user_role( - session=api_utils.determined_test_session(conf.ADMIN_CREDENTIALS), - user=cluster_admin_creds.username, + session=api_utils.admin_session(), + user=cluster_admin.username, role="ClusterAdmin", workspace=None, ) - cluster_admin_session = api_utils.determined_test_session(cluster_admin_creds) # create non-cluster admin user with OSS admin flag - oss_admin_creds = api_utils.create_test_user( - add_password=True, + oss_admin, _ = api_utils.create_test_user( user=bindings.v1User(username=get_random_string(), active=True, admin=True), ) - oss_admin_session = api_utils.determined_test_session(oss_admin_creds) model_num = 0 try: # test deleting model registries tests: List[Dict[str, Any]] = [ { - "create_session": cluster_admin_session, - "delete_session": cluster_admin_session, + "create_session": cluster_admin, + "delete_session": cluster_admin, "should_error": False, }, { "create_session": editor_session, - "delete_session": cluster_admin_session, + "delete_session": cluster_admin, "should_error": False, }, { @@ -386,19 +322,19 @@ def test_model_rbac_deletes() -> None: "should_error": False, }, { - "create_session": cluster_admin_session, + "create_session": cluster_admin, "delete_session": editor_session, "should_error": True, }, { - "create_session": cluster_admin_session, - "delete_session": oss_admin_session, + "create_session": cluster_admin, + "delete_session": oss_admin, "should_error": True, }, ] for t in tests: - create_session: Session = t["create_session"] - delete_session: Session = t["delete_session"] + create_session: api.Session = t["create_session"] + delete_session: api.Session = t["delete_session"] should_error: bool = t["should_error"] model_name = "model_" + str(model_num) @@ -418,41 +354,41 @@ def test_model_rbac_deletes() -> None: # test deleting model versions tests = [ { - "create_creds": cluster_admin_creds, - "delete_session": cluster_admin_session, + "create_session": cluster_admin, + "delete_session": cluster_admin, "should_error": False, }, { - "create_creds": editor_creds, + "create_session": editor_session, "delete_session": editor_session, "should_error": False, }, { - "create_creds": editor_creds, - "delete_session": cluster_admin_session, + "create_session": editor_session, + "delete_session": cluster_admin, "should_error": False, }, { - "create_creds": cluster_admin_creds, + "create_session": cluster_admin, "delete_session": editor_session, "should_error": True, }, { - "create_creds": cluster_admin_creds, - "delete_session": oss_admin_session, + "create_session": cluster_admin, + "delete_session": oss_admin, "should_error": True, }, ] for t in tests: - create_creds: authentication.Credentials = t["create_creds"] + create_session = t["create_session"] delete_session = t["delete_session"] should_error = t["should_error"] model_name = "model_" + str(model_num) model_num += 1 m, ca_model_version = register_model_version( - creds=create_creds, model_name=model_name, workspace_id=workspace_id + sess=create_session, model_name=model_name, workspace_id=workspace_id ) model_version_num = ca_model_version.model_version @@ -470,14 +406,14 @@ def test_model_rbac_deletes() -> None: ) with pytest.raises(errors.NotFoundException) as notFoundErr: bindings.get_GetModelVersion( - api_utils.determined_test_session(create_creds), + create_session, modelName=model_name, modelVersionNum=model_version_num, ) assert "not found" in str(notFoundErr.value).lower() finally: + admin_session = api_utils.admin_session() for i in range(model_num): - admin_session = api_utils.determined_test_session(conf.ADMIN_CREDENTIALS) try: bindings.delete_DeleteModel(admin_session, modelName="model_" + str(i)) # model is has already been cleaned up diff --git a/e2e_tests/tests/cluster/test_oauth2_scim_client.py b/e2e_tests/tests/cluster/test_oauth2_scim_client.py index 161307ca271..ca409f0a496 100644 --- a/e2e_tests/tests/cluster/test_oauth2_scim_client.py +++ b/e2e_tests/tests/cluster/test_oauth2_scim_client.py @@ -1,113 +1,80 @@ import re -import subprocess import pytest import determined as det from determined.common import api from determined.experimental import client as _client -from tests import api_utils -from tests import config as conf +from tests import api_utils, detproc @pytest.mark.e2e_cpu @api_utils.skipif_ee() def test_list_oauth_clients() -> None: - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - det_obj = _client.Determined(master=conf.make_master_url()) - command = [ - "det", - "-m", - conf.make_master_url(), - "oauth", - "client", - "list", - ] - + sess = api_utils.admin_session() + det_obj = _client.Determined._from_session(sess) with pytest.raises(det.errors.EnterpriseOnlyError): det_obj.list_oauth_clients() - with pytest.raises(subprocess.CalledProcessError): - subprocess.run(command, check=True) + + command = ["det", "oauth", "client", "list"] + detproc.check_error(sess, command, "enterprise") @pytest.mark.e2e_cpu @api_utils.skipif_ee() def test_add_client() -> None: - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - - det_obj = _client.Determined(master=conf.make_master_url()) - command = [ - "det", - "-m", - conf.make_master_url(), - "oauth", - "client", - "add", - "XXX", - "cli_test_oauth_client", - ] - + sess = api_utils.admin_session() + det_obj = _client.Determined._from_session(sess) with pytest.raises(det.errors.EnterpriseOnlyError): det_obj.add_oauth_client(domain="XXX", name="sdk_oauth_client_test") - with pytest.raises(subprocess.CalledProcessError): - subprocess.run(command, check=True) + command = ["det", "oauth", "client", "add", "XXX", "cli_test_oauth_client"] + detproc.check_error(sess, command, "enterprise") @pytest.mark.e2e_cpu @api_utils.skipif_ee() def test_remove_client() -> None: - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - det_obj = _client.Determined(master=conf.make_master_url()) + sess = api_utils.admin_session() + det_obj = _client.Determined._from_session(sess) with pytest.raises(det.errors.EnterpriseOnlyError): det_obj.remove_oauth_client(client_id="3") - command = [ - "det", - "-m", - conf.make_master_url(), - "oauth", - "client", - "remove", - "4", - ] - with pytest.raises(subprocess.CalledProcessError): - subprocess.run(command, check=True) + + command = ["det", "oauth", "client", "remove", "4"] + detproc.check_error(sess, command, "enterprise") @pytest.mark.test_oauth @api_utils.skipif_not_ee() def test_list_oauth_clients_ee() -> None: - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) + sess = api_utils.admin_session() # Test SDK - det_obj = _client.Determined(master=conf.make_master_url()) + det_obj = _client.Determined._from_session(sess) det_obj.list_oauth_clients() # Test CLI command = [ "det", - "-m", - conf.make_master_url(), "oauth", "client", "list", ] - subprocess.run(command, check=True) + detproc.check_output(sess, command) # non-admin users are not allowed to call Oauth API. - new_creds = api_utils.create_test_user() - api_utils.configure_token_store(new_creds) + sess = api_utils.user_session() + det_obj = _client.Determined._from_session(sess) with pytest.raises(api.errors.ForbiddenException): - det_obj = _client.Determined(master=conf.make_master_url()) det_obj.list_oauth_clients() @pytest.mark.test_oauth @api_utils.skipif_not_ee() def test_add_remove_client_ee() -> None: - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) + sess = api_utils.admin_session() - # Test SDK. - det_obj = _client.Determined(master=conf.make_master_url()) + # Test SDK + det_obj = _client.Determined._from_session(sess) client = det_obj.add_oauth_client(domain="XXXSDK", name="sdk_oauth_client_test") remove_id = client.id det_obj.remove_oauth_client(client_id=remove_id) @@ -117,15 +84,13 @@ def test_add_remove_client_ee() -> None: # Test CLI. command = [ "det", - "-m", - conf.make_master_url(), "oauth", "client", "add", "XXXCLI", "cli_test_oauth_client", ] - output = str(subprocess.check_output(command)).split("\\n")[0] + output = detproc.check_output(sess, command).split("\\n")[0] assert "ID" in output r = "(.*)ID:(\\s*)(.*)" m = re.match(r, output) @@ -133,23 +98,20 @@ def test_add_remove_client_ee() -> None: remove_id = m.group(3) command = [ "det", - "-m", - conf.make_master_url(), "oauth", "client", "remove", str(remove_id), # only one OAuth client is allowed. ] - subprocess.run(command, check=True) + detproc.check_output(sess, command) list_client_ids = [oclient.id for oclient in det_obj.list_oauth_clients()] assert remove_id not in list_client_ids # non-admin users are not allowed to call Oauth API. - new_creds = api_utils.create_test_user() - api_utils.configure_token_store(new_creds) - det_obj = _client.Determined(master=conf.make_master_url()) + sess = api_utils.user_session() + det_obj = _client.Determined._from_session(sess) with pytest.raises(api.errors.ForbiddenException): - client = det_obj.add_oauth_client(domain="XXXSDK", name="sdk_oauth_client_test") + det_obj.add_oauth_client(domain="XXXSDK", name="sdk_oauth_client_test") # non-admin users are not allowed to call Oauth API. with pytest.raises(api.errors.ForbiddenException): diff --git a/e2e_tests/tests/cluster/test_priority_scheduler.py b/e2e_tests/tests/cluster/test_priority_scheduler.py index c68b653aa7e..cbbd8da1d82 100644 --- a/e2e_tests/tests/cluster/test_priority_scheduler.py +++ b/e2e_tests/tests/cluster/test_priority_scheduler.py @@ -2,61 +2,60 @@ import pytest +from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp - -from .managed_cluster import ManagedCluster -from .utils import ( - assert_command_succeeded, - run_command, - run_command_set_priority, - wait_for_command_state, -) +from tests.cluster import managed_cluster, utils @pytest.mark.managed_devcluster def test_priortity_scheduler_noop_experiment( - managed_cluster_priority_scheduler: ManagedCluster, + managed_cluster_priority_scheduler: managed_cluster.ManagedCluster, ) -> None: + sess = api_utils.user_session() managed_cluster_priority_scheduler.ensure_agent_ok() assert str(conf.MASTER_PORT) == str(8082) # uses the default priority set in cluster config - exp.run_basic_test(conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1) + exp.run_basic_test( + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 + ) # uses explicit priority exp.run_basic_test( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1, priority=50 + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1, priority=50 ) @pytest.mark.managed_devcluster def test_priortity_scheduler_noop_command( - managed_cluster_priority_scheduler: ManagedCluster, + managed_cluster_priority_scheduler: managed_cluster.ManagedCluster, ) -> None: + sess = api_utils.user_session() managed_cluster_priority_scheduler.ensure_agent_ok() assert str(conf.MASTER_PORT) == "8082" # without slots (and default priority) - command_id = run_command(slots=0) - wait_for_command_state(command_id, "TERMINATED", 40) - assert_command_succeeded(command_id) + command_id = utils.run_command(sess, slots=0) + utils.wait_for_command_state(sess, command_id, "TERMINATED", 40) + utils.assert_command_succeeded(sess, command_id) # with slots (and default priority) - command_id = run_command(slots=1) - wait_for_command_state(command_id, "TERMINATED", 60) - assert_command_succeeded(command_id) + command_id = utils.run_command(sess, slots=1) + utils.wait_for_command_state(sess, command_id, "TERMINATED", 60) + utils.assert_command_succeeded(sess, command_id) # explicity priority - command_id = run_command_set_priority(slots=0, priority=60) - wait_for_command_state(command_id, "TERMINATED", 60) - assert_command_succeeded(command_id) + command_id = utils.run_command_set_priority(sess, slots=0, priority=60) + utils.wait_for_command_state(sess, command_id, "TERMINATED", 60) + utils.assert_command_succeeded(sess, command_id) @pytest.mark.managed_devcluster -def test_slots_list_command(managed_cluster_priority_scheduler: ManagedCluster) -> None: +def test_slots_list_command( + managed_cluster_priority_scheduler: managed_cluster.ManagedCluster, +) -> None: + sess = api_utils.user_session() managed_cluster_priority_scheduler.ensure_agent_ok() assert str(conf.MASTER_PORT) == "8082" - command = ["det", "-m", conf.make_master_url(), "slot", "list"] - completed_process = subprocess.run( - command, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - - assert completed_process.returncode == 0, "\nstdout:\n{} \nstderr:\n{}".format( - completed_process.stdout, completed_process.stderr + command = ["det", "slot", "list"] + p = detproc.run( + sess, command, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) + assert p.returncode == 0, f"\nstdout:\n{p.stdout} \nstderr:\n{p.stderr}" diff --git a/e2e_tests/tests/cluster/test_proxy.py b/e2e_tests/tests/cluster/test_proxy.py index 8eee11dccc3..3e77784c330 100644 --- a/e2e_tests/tests/cluster/test_proxy.py +++ b/e2e_tests/tests/cluster/test_proxy.py @@ -1,22 +1,23 @@ import csv +import io import pathlib import re import subprocess import time -from io import StringIO import pytest import requests +from determined.common import api from determined.common.api import bindings from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp from tests import ray_utils -def _experiment_task_id(exp_id: int) -> str: - sess = api_utils.determined_test_session() +def _experiment_task_id(sess: api.Session, exp_id: int) -> str: trials = bindings.get_GetExperimentTrials(sess, experimentId=exp_id).trials assert len(trials) > 0 @@ -59,17 +60,20 @@ def _ray_job_submit(exp_path: pathlib.Path, port: int = 8265) -> None: @pytest.mark.e2e_cpu @pytest.mark.timeout(600) def test_experiment_proxy_ray_tunnel() -> None: + sess = api_utils.user_session() exp_path = conf.EXAMPLES_PATH / "features" / "ports" exp_id = exp.create_experiment( + sess, str(exp_path / "ray_launcher.yaml"), str(exp_path), ["--config", "max_restarts=0", "--config", "resources.slots=1"], ) try: - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.RUNNING) - task_id = _experiment_task_id(exp_id) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.RUNNING) + task_id = _experiment_task_id(sess, exp_id) - proc = subprocess.Popen( + proc = detproc.Popen( + sess, [ "python", "-m", @@ -90,7 +94,6 @@ def test_experiment_proxy_ray_tunnel() -> None: proc.terminate() proc.wait(10) finally: - sess = api_utils.determined_test_session() bindings.post_KillExperiment(sess, id=exp_id) @@ -109,11 +112,11 @@ def _parse_exp_id(proc: "subprocess.Popen[str]") -> int: def _kill_all_ray_experiments() -> None: - proc = subprocess.run( + sess = api_utils.user_session() + proc = detproc.run( + sess, [ "det", - "-m", - conf.make_master_url(), "experiment", "list", "--csv", @@ -122,8 +125,7 @@ def _kill_all_ray_experiments() -> None: text=True, check=True, ) - reader = csv.DictReader(StringIO(proc.stdout)) - sess = api_utils.determined_test_session() + reader = csv.DictReader(io.StringIO(proc.stdout)) for row in reader: if row["name"] == "ray_launcher": if row["state"] not in ["CANCELED", "COMPLETED"]: @@ -134,12 +136,12 @@ def _kill_all_ray_experiments() -> None: @pytest.mark.e2e_cpu @pytest.mark.timeout(600) def test_experiment_proxy_ray_publish() -> None: + sess = api_utils.user_session() exp_path = conf.EXAMPLES_PATH / "features" / "ports" - proc = subprocess.Popen( + proc = detproc.Popen( + sess, [ "det", - "-m", - conf.make_master_url(), "experiment", "create", str(exp_path / "ray_launcher.yaml"), @@ -164,11 +166,10 @@ def test_experiment_proxy_ray_publish() -> None: raise try: - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.RUNNING) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.RUNNING) _probe_tunnel(proc) _ray_job_submit(exp_path) finally: - sess = api_utils.determined_test_session() bindings.post_KillExperiment(sess, id=exp_id) finally: proc.terminate() diff --git a/e2e_tests/tests/cluster/test_rbac.py b/e2e_tests/tests/cluster/test_rbac.py index 23dc18e088f..42da9d67829 100644 --- a/e2e_tests/tests/cluster/test_rbac.py +++ b/e2e_tests/tests/cluster/test_rbac.py @@ -1,89 +1,20 @@ import contextlib -from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple import pytest from determined.common import api -from determined.common.api import authentication, bindings, errors -from tests import api_utils -from tests import config as conf -from tests.cluster.test_workspace_org import setup_workspaces +from determined.common.api import bindings +from tests import api_utils, detproc +from tests.cluster import test_workspace_org -from .test_groups import det_cmd, det_cmd_expect_error, det_cmd_json -from .test_users import logged_in_user - - -def roles_not_implemented() -> bool: - return "Unimplemented" in det_cmd(["rbac", "my-permissions"]).stderr.decode() - - -def rbac_disabled() -> bool: - if roles_not_implemented(): - return True - try: - return not bindings.get_GetMaster(api_utils.determined_test_session()).rbacEnabled - except (errors.APIException, errors.MasterNotFoundException): - return True - - -def strict_q_control_disabled() -> bool: - if roles_not_implemented() or rbac_disabled(): - return True - try: - return not bindings.get_GetMaster(api_utils.determined_test_session()).strictJobQueueControl - except (errors.APIException, errors.MasterNotFoundException): - return True - - -PermCase = NamedTuple( - "PermCase", [("cred", api.authentication.Credentials), ("raises", Optional[Any])] -) - - -def run_permission_tests( - action: Callable[[authentication.Credentials], None], cases: List[PermCase] -) -> None: - for cred, raises in cases: - if raises is None: - action(cred) - else: - with pytest.raises(raises): - action(cred) - - -def create_users_with_gloabl_roles(user_roles: List[List[str]]) -> List[authentication.Credentials]: - """ - Set up users with the provided global role assignments. - user_roles: list of roles to assign to each user, one entry per user. - """ - user_creds: List[authentication.Credentials] = [] - sess = api_utils.determined_test_session(admin=True) - for roles in user_roles: - user = bindings.v1User(username=api_utils.get_random_string(), admin=False, active=True) - creds = api_utils.create_test_user(True, user=user) - for role in roles: - api_utils.assign_user_role( - session=sess, - user=creds.username, - role=role, - workspace=None, - ) - user_creds.append(creds) - return user_creds - - -@pytest.fixture(scope="session") -def cluster_admin_creds() -> authentication.Credentials: - [creds] = create_users_with_gloabl_roles([["ClusterAdmin"]]) - return creds +PermCase = NamedTuple("PermCase", [("sess", api.Session), ("raises", Optional[Any])]) @contextlib.contextmanager def create_workspaces_with_users( assignments_list: List[List[Tuple[int, List[str]]]] -) -> Generator[ - Tuple[List[bindings.v1Workspace], Dict[int, authentication.Credentials]], None, None -]: +) -> Generator[Tuple[List[bindings.v1Workspace], Dict[int, api.Session]], None, None]: """ Set up workspaces and users with the provided role assignments. For example the following sets up 2 workspaces and 2 users referenced @@ -100,26 +31,25 @@ def create_workspaces_with_users( ] ] """ - sess = api_utils.determined_test_session(admin=True) - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - rid_to_creds: Dict[int, authentication.Credentials] = {} - with setup_workspaces(count=len(assignments_list)) as workspaces: + sess = api_utils.admin_session() + rid_to_sess: Dict[int, api.Session] = {} + with test_workspace_org.setup_workspaces(count=len(assignments_list)) as workspaces: for workspace, user_list in zip(workspaces, assignments_list): for rid, roles in user_list: - if rid not in rid_to_creds: - rid_to_creds[rid] = api_utils.create_test_user() + if rid not in rid_to_sess: + rid_to_sess[rid], _ = api_utils.create_test_user() for role in roles: api_utils.assign_user_role( session=sess, - user=rid_to_creds[rid].username, + user=rid_to_sess[rid].username, role=role, workspace=workspace.name, ) - yield workspaces, rid_to_creds + yield workspaces, rid_to_sess @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(roles_not_implemented(), reason="ee is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_user_role_setup() -> None: perm_assigments = [ [ @@ -130,517 +60,540 @@ def test_user_role_setup() -> None: (1, ["Viewer"]), ], ] - with create_workspaces_with_users(perm_assigments) as (workspaces, rid_to_creds): - assert len(rid_to_creds) == 2 + with create_workspaces_with_users(perm_assigments) as (workspaces, rid_to_sess): + assert len(rid_to_sess) == 2 assert len(workspaces) == 2 @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(roles_not_implemented(), reason="ee is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_rbac_permission_assignment() -> None: - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - test_user_creds = api_utils.create_test_user() + admin = api_utils.admin_session() + sess, _ = api_utils.create_test_user() # User has no permissions. - with logged_in_user(test_user_creds): - assert "no permissions" in det_cmd(["rbac", "my-permissions"], check=True).stdout.decode() - json_out = det_cmd_json(["rbac", "my-permissions", "--json"]) - assert json_out["roles"] == [] - assert json_out["assignments"] == [] + assert "no permissions" in detproc.check_output(sess, ["det", "rbac", "my-permissions"]) + json_out = detproc.check_json(sess, ["rbac", "my-permissions", "--json"]) + assert json_out["roles"] == [] + assert json_out["assignments"] == [] group_name = api_utils.get_random_string() - with logged_in_user(conf.ADMIN_CREDENTIALS): - # Assign user to role directly. - det_cmd( - [ - "rbac", - "assign-role", - "WorkspaceCreator", - "--username-to-assign", - test_user_creds.username, - ], - check=True, - ) - det_cmd( - [ - "rbac", - "assign-role", - "Viewer", - "--username-to-assign", - test_user_creds.username, - "--workspace-name", - "Uncategorized", - ], - check=True, - ) - - # Assign user to a group with roles. - det_cmd( - ["user-group", "create", group_name, "--add-user", test_user_creds.username], check=True - ) - det_cmd( - ["rbac", "assign-role", "WorkspaceCreator", "--group-name-to-assign", group_name], - check=True, - ) - det_cmd(["rbac", "assign-role", "Editor", "--group-name-to-assign", group_name], check=True) - det_cmd( - [ - "rbac", - "assign-role", - "Editor", - "--group-name-to-assign", - group_name, - "--workspace-name", - "Uncategorized", - ], - check=True, - ) + # Assign user to role directly. + detproc.check_call( + admin, + [ + "det", + "rbac", + "assign-role", + "WorkspaceCreator", + "--username-to-assign", + sess.username, + ], + ) + detproc.check_call( + admin, + [ + "det", + "rbac", + "assign-role", + "Viewer", + "--username-to-assign", + sess.username, + "--workspace-name", + "Uncategorized", + ], + ) + + # Assign user to a group with roles. + detproc.check_call( + admin, ["det", "user-group", "create", group_name, "--add-user", sess.username] + ) + detproc.check_call( + admin, + ["det", "rbac", "assign-role", "WorkspaceCreator", "--group-name-to-assign", group_name], + ) + detproc.check_call( + admin, ["det", "rbac", "assign-role", "Editor", "--group-name-to-assign", group_name] + ) + detproc.check_call( + admin, + [ + "det", + "rbac", + "assign-role", + "Editor", + "--group-name-to-assign", + group_name, + "--workspace-name", + "Uncategorized", + ], + ) # User has those roles assigned. - with logged_in_user(test_user_creds): - assert ( - "no permissions" not in det_cmd(["rbac", "my-permissions"], check=True).stdout.decode() - ) - json_out = det_cmd_json(["rbac", "my-permissions", "--json"]) - assert len(json_out["roles"]) == 3 - assert len(json_out["assignments"]) == 3 - - creator = [role for role in json_out["roles"] if role["name"] == "WorkspaceCreator"] - assert len(creator) == 1 - creator_assignment = [ - a for a in json_out["assignments"] if a["roleId"] == creator[0]["roleId"] - ] - assert creator_assignment[0]["scopeWorkspaceIds"] == [] - assert creator_assignment[0]["scopeCluster"] - - viewer = [role for role in json_out["roles"] if role["name"] == "Viewer"] - assert len(viewer) == 1 - viewer_assignment = [ - a for a in json_out["assignments"] if a["roleId"] == viewer[0]["roleId"] - ] - assert viewer_assignment[0]["scopeWorkspaceIds"] == [1] - assert not viewer_assignment[0]["scopeCluster"] - - editor = [role for role in json_out["roles"] if role["name"] == "Editor"] - assert len(editor) == 1 - editor_assignment = [ - a for a in json_out["assignments"] if a["roleId"] == editor[0]["roleId"] - ] - assert editor_assignment[0]["scopeWorkspaceIds"] == [1] - assert editor_assignment[0]["scopeCluster"] + assert "no permissions" not in detproc.check_output(sess, ["det", "rbac", "my-permissions"]) + json_out = detproc.check_json(sess, ["det", "rbac", "my-permissions", "--json"]) + assert len(json_out["roles"]) == 3 + assert len(json_out["assignments"]) == 3 + + creator = [role for role in json_out["roles"] if role["name"] == "WorkspaceCreator"] + assert len(creator) == 1 + creator_assignment = [a for a in json_out["assignments"] if a["roleId"] == creator[0]["roleId"]] + assert creator_assignment[0]["scopeWorkspaceIds"] == [] + assert creator_assignment[0]["scopeCluster"] + + viewer = [role for role in json_out["roles"] if role["name"] == "Viewer"] + assert len(viewer) == 1 + viewer_assignment = [a for a in json_out["assignments"] if a["roleId"] == viewer[0]["roleId"]] + assert viewer_assignment[0]["scopeWorkspaceIds"] == [1] + assert not viewer_assignment[0]["scopeCluster"] + + editor = [role for role in json_out["roles"] if role["name"] == "Editor"] + assert len(editor) == 1 + editor_assignment = [a for a in json_out["assignments"] if a["roleId"] == editor[0]["roleId"]] + assert editor_assignment[0]["scopeWorkspaceIds"] == [1] + assert editor_assignment[0]["scopeCluster"] # Remove from the group. - with logged_in_user(conf.ADMIN_CREDENTIALS): - det_cmd(["user-group", "remove-user", group_name, test_user_creds.username], check=True) + detproc.check_call(admin, ["det", "user-group", "remove-user", group_name, sess.username]) # User doesn't have any group roles assigned. - with logged_in_user(test_user_creds): - assert ( - "no permissions" not in det_cmd(["rbac", "my-permissions"], check=True).stdout.decode() - ) - json_out = det_cmd_json(["rbac", "my-permissions", "--json"]) + assert "no permissions" not in detproc.check_output(sess, ["det", "rbac", "my-permissions"]) + json_out = detproc.check_json(sess, ["det", "rbac", "my-permissions", "--json"]) - assert len(json_out["roles"]) == 2 - assert len(json_out["assignments"]) == 2 - assert len([role for role in json_out["roles"] if role["name"] == "Editor"]) == 0 + assert len(json_out["roles"]) == 2 + assert len(json_out["assignments"]) == 2 + assert len([role for role in json_out["roles"] if role["name"] == "Editor"]) == 0 # Remove user assignments. - with logged_in_user(conf.ADMIN_CREDENTIALS): - # Assign user to role directly. - det_cmd( - [ - "rbac", - "unassign-role", - "WorkspaceCreator", - "--username-to-assign", - test_user_creds.username, - ], - check=True, - ) - det_cmd( - [ - "rbac", - "unassign-role", - "Viewer", - "--username-to-assign", - test_user_creds.username, - "--workspace-name", - "Uncategorized", - ], - check=True, - ) + detproc.check_call( + admin, + [ + "det", + "rbac", + "unassign-role", + "WorkspaceCreator", + "--username-to-assign", + sess.username, + ], + ) + detproc.check_call( + admin, + [ + "det", + "rbac", + "unassign-role", + "Viewer", + "--username-to-assign", + sess.username, + "--workspace-name", + "Uncategorized", + ], + ) # User has no permissions. - with logged_in_user(test_user_creds): - assert "no permissions" in det_cmd(["rbac", "my-permissions"], check=True).stdout.decode() - json_out = det_cmd_json(["rbac", "my-permissions", "--json"]) - assert json_out["roles"] == [] - assert json_out["assignments"] == [] + assert "no permissions" in detproc.check_output(sess, ["det", "rbac", "my-permissions"]) + json_out = detproc.check_json(sess, ["det", "rbac", "my-permissions", "--json"]) + assert json_out["roles"] == [] + assert json_out["assignments"] == [] @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(roles_not_implemented(), reason="ee is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_rbac_permission_assignment_errors() -> None: - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - with logged_in_user(conf.ADMIN_CREDENTIALS): - # Specifying args incorrectly. - det_cmd_expect_error(["rbac", "assign-role", "Viewer"], "must provide exactly one of") - det_cmd_expect_error(["rbac", "unassign-role", "Viewer"], "must provide exactly one of") - det_cmd_expect_error( - [ - "rbac", - "assign-role", - "Viewer", - "--username-to-assign", - "u", - "--group-name-to-assign", - "g", - ], - "must provide exactly one of", - ) - det_cmd_expect_error( - [ - "rbac", - "unassign-role", - "Viewer", - "--username-to-assign", - "u", - "--group-name-to-assign", - "g", - ], - "must provide exactly one of", - ) - - # Non existent role. - det_cmd_expect_error( - ["rbac", "assign-role", "fakeRoleNameThatDoesntExist", "--username-to-assign", "admin"], - "could not find role name", - ) - det_cmd_expect_error( - [ - "rbac", - "unassign-role", - "fakeRoleNameThatDoesntExist", - "--username-to-assign", - "admin", - ], - "could not find role name", - ) - - # Non existent user - det_cmd_expect_error( - ["rbac", "assign-role", "Viewer", "--username-to-assign", "fakeUserNotExist"], - "could not find user", - ) - det_cmd_expect_error( - ["rbac", "unassign-role", "Viewer", "--username-to-assign", "fakeUserNotExist"], - "could not find user", - ) - - # Non existent group. - det_cmd_expect_error( - ["rbac", "assign-role", "Viewer", "--group-name-to-assign", "fakeGroupNotExist"], - "could not find user group", - ) - det_cmd_expect_error( - ["rbac", "unassign-role", "Viewer", "--group-name-to-assign", "fakeGroupNotExist"], - "could not find user group", - ) - - # Non existent workspace - det_cmd_expect_error( - [ - "rbac", - "assign-role", - "Viewer", - "--workspace-name", - "fakeWorkspace", - "--username-to-assign", - "admin", - ], - "not found", - ) - det_cmd_expect_error( - [ - "rbac", - "unassign-role", - "Viewer", - "--workspace-name", - "fakeWorkspace", - "--username-to-assign", - "admin", - ], - "not found", - ) - - test_user_creds = api_utils.create_test_user() - group_name = api_utils.get_random_string() - det_cmd(["user-group", "create", group_name], check=True) - det_cmd(["rbac", "assign-role", "Viewer", "--group-name-to-assign", group_name], check=True) - det_cmd( - ["rbac", "assign-role", "Viewer", "--username-to-assign", test_user_creds.username], - check=True, - ) - - # Assign a role multiple times. - det_cmd_expect_error( - ["rbac", "assign-role", "Viewer", "--group-name-to-assign", group_name], - "row already exists", - ) - - # Unassigned role group doesn't have. - det_cmd_expect_error( - ["rbac", "unassign-role", "Editor", "--group-name-to-assign", group_name], "Not Found" - ) - det_cmd_expect_error( - [ - "rbac", - "unassign-role", - "Viewer", - "--group-name-to-assign", - group_name, - "--workspace-name", - "Uncategorized", - ], - "Not Found", - ) - - # Unassigned role user doesn't have. - det_cmd_expect_error( - ["rbac", "unassign-role", "Editor", "--username-to-assign", test_user_creds.username], - "Not Found", - ) - det_cmd_expect_error( - [ - "rbac", - "unassign-role", - "Viewer", - "--username-to-assign", - test_user_creds.username, - "--workspace-name", - "Uncategorized", - ], - "Not Found", - ) + admin = api_utils.admin_session() + + # Specifying args incorrectly. + detproc.check_error( + admin, ["det", "rbac", "assign-role", "Viewer"], "must provide exactly one of" + ) + detproc.check_error( + admin, ["det", "rbac", "unassign-role", "Viewer"], "must provide exactly one of" + ) + detproc.check_error( + admin, + [ + "det", + "rbac", + "assign-role", + "Viewer", + "--username-to-assign", + "u", + "--group-name-to-assign", + "g", + ], + "must provide exactly one of", + ) + detproc.check_error( + admin, + [ + "det", + "rbac", + "unassign-role", + "Viewer", + "--username-to-assign", + "u", + "--group-name-to-assign", + "g", + ], + "must provide exactly one of", + ) + + # Non existent role. + detproc.check_error( + admin, + [ + "det", + "rbac", + "assign-role", + "fakeRoleNameThatDoesntExist", + "--username-to-assign", + "admin", + ], + "could not find role name", + ) + detproc.check_error( + admin, + [ + "det", + "rbac", + "unassign-role", + "fakeRoleNameThatDoesntExist", + "--username-to-assign", + "admin", + ], + "could not find role name", + ) + + # Non existent user + detproc.check_error( + admin, + ["det", "rbac", "assign-role", "Viewer", "--username-to-assign", "fakeUserNotExist"], + "could not find user", + ) + detproc.check_error( + admin, + ["det", "rbac", "unassign-role", "Viewer", "--username-to-assign", "fakeUserNotExist"], + "could not find user", + ) + + # Non existent group. + detproc.check_error( + admin, + ["det", "rbac", "assign-role", "Viewer", "--group-name-to-assign", "fakeGroupNotExist"], + "could not find user group", + ) + detproc.check_error( + admin, + ["det", "rbac", "unassign-role", "Viewer", "--group-name-to-assign", "fakeGroupNotExist"], + "could not find user group", + ) + + # Non existent workspace + detproc.check_error( + admin, + [ + "det", + "rbac", + "assign-role", + "Viewer", + "--workspace-name", + "fakeWorkspace", + "--username-to-assign", + "admin", + ], + "not found", + ) + detproc.check_error( + admin, + [ + "det", + "rbac", + "unassign-role", + "Viewer", + "--workspace-name", + "fakeWorkspace", + "--username-to-assign", + "admin", + ], + "not found", + ) + + sess, _ = api_utils.create_test_user() + group_name = api_utils.get_random_string() + detproc.check_call(admin, ["det", "user-group", "create", group_name]) + detproc.check_call( + admin, + ["det", "rbac", "assign-role", "Viewer", "--group-name-to-assign", group_name], + ) + detproc.check_call( + admin, + ["det", "rbac", "assign-role", "Viewer", "--username-to-assign", sess.username], + ) + + # Assign a role multiple times. + detproc.check_error( + admin, + ["rbac", "assign-role", "Viewer", "--group-name-to-assign", group_name], + "row already exists", + ) + + # Unassigned role group doesn't have. + detproc.check_error( + admin, + ["det", "rbac", "unassign-role", "Editor", "--group-name-to-assign", group_name], + "Not Found", + ) + detproc.check_error( + admin, + [ + "det", + "rbac", + "unassign-role", + "Viewer", + "--group-name-to-assign", + group_name, + "--workspace-name", + "Uncategorized", + ], + "Not Found", + ) + + # Unassigned role user doesn't have. + detproc.check_error( + admin, + ["det", "rbac", "unassign-role", "Editor", "--username-to-assign", sess.username], + "Not Found", + ) + detproc.check_error( + admin, + [ + "det", + "rbac", + "unassign-role", + "Viewer", + "--username-to-assign", + sess.username, + "--workspace-name", + "Uncategorized", + ], + "Not Found", + ) @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(roles_not_implemented(), reason="ee is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_rbac_list_roles() -> None: - with logged_in_user(conf.ADMIN_CREDENTIALS): - det_cmd(["rbac", "list-roles"], check=True) - all_roles = det_cmd_json(["rbac", "list-roles", "--json"])["roles"] - - # Test list-roles excluding global roles properly. - non_excluded_roles = det_cmd_json( - ["rbac", "list-roles", "--exclude-global-roles", "--json"] - )["roles"] - non_excluded_role_ids = {r["roleId"] for r in non_excluded_roles} - for role in all_roles: - is_excluded = role["roleId"] not in non_excluded_role_ids - is_global = any(not p["scopeTypeMask"]["workspace"] for p in role["permissions"]) - assert is_excluded == is_global - - # Test list-roles pagination. - json_out = det_cmd_json(["rbac", "list-roles", "--limit=2", "--json"]) - assert len(json_out["roles"]) == 2 - assert json_out["pagination"]["limit"] == 2 - assert json_out["pagination"]["total"] == len(all_roles) - assert json_out["pagination"]["offset"] == 0 - - json_out = det_cmd_json(["rbac", "list-roles", "--offset=1", "--limit=199", "--json"]) - assert len(json_out["roles"]) == len(all_roles) - 1 - assert json_out["pagination"]["limit"] == 199 - assert json_out["pagination"]["total"] == len(all_roles) - assert json_out["pagination"]["offset"] == 1 - - # Set up group/user to test with. - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - test_user_creds = api_utils.create_test_user() - group_name = api_utils.get_random_string() - det_cmd( - ["user-group", "create", group_name, "--add-user", test_user_creds.username], check=True - ) - - # No roles should be returned since no assignmnets have happened. - list_user_roles = ["rbac", "list-users-roles", test_user_creds.username] - list_group_roles = ["rbac", "list-groups-roles", group_name] - - assert det_cmd_json(list_user_roles + ["--json"])["roles"] == [] - assert ( - "user has no role assignments" in det_cmd(list_user_roles, check=True).stdout.decode() - ) - - assert det_cmd_json(list_group_roles + ["--json"])["roles"] == [] - assert ( - "group has no role assignments" in det_cmd(list_group_roles, check=True).stdout.decode() - ) - - # Assign roles. - det_cmd( - ["rbac", "assign-role", "Viewer", "--username-to-assign", test_user_creds.username], - check=True, - ) - det_cmd( - [ - "rbac", - "assign-role", - "Viewer", - "--username-to-assign", - test_user_creds.username, - "--workspace-name", - "Uncategorized", - ], - check=True, - ) - - det_cmd(["rbac", "assign-role", "Editor", "--group-name-to-assign", group_name], check=True) - det_cmd( - [ - "rbac", - "assign-role", - "Editor", - "--group-name-to-assign", - group_name, - "--workspace-name", - "Uncategorized", - ], - check=True, - ) - - # Test list-users-roles. - det_cmd(list_user_roles, check=True) - json_out = det_cmd_json(list_user_roles + ["--json"]) - assert len(json_out["roles"]) == 2 - json_out["roles"].sort(key=lambda x: -1 if x["role"]["name"] == "Viewer" else 1) - assert json_out["roles"][0]["role"]["name"] == "Viewer" - - assert len(json_out["roles"][0]["groupRoleAssignments"]) == 0 - workspace_ids = [ - a["roleAssignment"]["scopeWorkspaceId"] - for a in json_out["roles"][0]["userRoleAssignments"] - ] - assert len(workspace_ids) == 2 - assert 1 in workspace_ids - assert None in workspace_ids - - assert json_out["roles"][1]["role"]["name"] == "Editor" - assert len(json_out["roles"][1]["groupRoleAssignments"]) == 2 - workspace_ids = [ - a["roleAssignment"]["scopeWorkspaceId"] - for a in json_out["roles"][1]["groupRoleAssignments"] - ] - assert len(workspace_ids) == 2 - assert len(json_out["roles"][1]["userRoleAssignments"]) == 0 + admin = api_utils.admin_session() + detproc.check_call(admin, ["det", "rbac", "list-roles"]) + all_roles = detproc.check_json(admin, ["det", "rbac", "list-roles", "--json"])["roles"] + + # Test list-roles excluding global roles properly. + non_excluded_roles = detproc.check_json( + admin, ["det", "rbac", "list-roles", "--exclude-global-roles", "--json"] + )["roles"] + non_excluded_role_ids = {r["roleId"] for r in non_excluded_roles} + for role in all_roles: + is_excluded = role["roleId"] not in non_excluded_role_ids + is_global = any(not p["scopeTypeMask"]["workspace"] for p in role["permissions"]) + assert is_excluded == is_global + + # Test list-roles pagination. + json_out = detproc.check_json(admin, ["det", "rbac", "list-roles", "--limit=2", "--json"]) + assert len(json_out["roles"]) == 2 + assert json_out["pagination"]["limit"] == 2 + assert json_out["pagination"]["total"] == len(all_roles) + assert json_out["pagination"]["offset"] == 0 + + json_out = detproc.check_json( + admin, ["det", "rbac", "list-roles", "--offset=1", "--limit=199", "--json"] + ) + assert len(json_out["roles"]) == len(all_roles) - 1 + assert json_out["pagination"]["limit"] == 199 + assert json_out["pagination"]["total"] == len(all_roles) + assert json_out["pagination"]["offset"] == 1 + + # Set up group/user to test with. + sess, _ = api_utils.create_test_user() + group_name = api_utils.get_random_string() + detproc.check_call( + admin, ["det", "user-group", "create", group_name, "--add-user", sess.username] + ) + + # No roles should be returned since no assignmnets have happened. + list_user_roles = ["det", "rbac", "list-users-roles", sess.username] + list_group_roles = ["det", "rbac", "list-groups-roles", group_name] + + assert detproc.check_json(admin, list_user_roles + ["--json"])["roles"] == [] + assert "user has no role assignments" in detproc.check_output(admin, list_user_roles) + + assert detproc.check_json(admin, list_group_roles + ["--json"])["roles"] == [] + assert "group has no role assignments" in detproc.check_output(admin, list_group_roles) + + # Assign roles. + detproc.check_call( + admin, + ["det", "rbac", "assign-role", "Viewer", "--username-to-assign", sess.username], + ) + detproc.check_call( + admin, + [ + "det", + "rbac", + "assign-role", + "Viewer", + "--username-to-assign", + sess.username, + "--workspace-name", + "Uncategorized", + ], + ) - # Test list-groups-roles. - det_cmd(list_group_roles, check=True) - json_out = det_cmd_json(list_group_roles + ["--json"]) - assert len(json_out["roles"]) == 1 - assert len(json_out["assignments"]) == 1 - assert json_out["roles"][0]["name"] == "Editor" - assert json_out["assignments"][0]["roleId"] == json_out["roles"][0]["roleId"] - assert json_out["assignments"][0]["scopeWorkspaceIds"] == [1] - assert json_out["assignments"][0]["scopeCluster"] + detproc.check_call( + admin, ["det", "rbac", "assign-role", "Editor", "--group-name-to-assign", group_name] + ) + detproc.check_call( + admin, + [ + "det", + "rbac", + "assign-role", + "Editor", + "--group-name-to-assign", + group_name, + "--workspace-name", + "Uncategorized", + ], + ) + + # Test list-users-roles. + detproc.check_call(admin, list_user_roles) + json_out = detproc.check_json(admin, list_user_roles + ["--json"]) + assert len(json_out["roles"]) == 2 + json_out["roles"].sort(key=lambda x: -1 if x["role"]["name"] == "Viewer" else 1) + assert json_out["roles"][0]["role"]["name"] == "Viewer" + + assert len(json_out["roles"][0]["groupRoleAssignments"]) == 0 + workspace_ids = [ + a["roleAssignment"]["scopeWorkspaceId"] for a in json_out["roles"][0]["userRoleAssignments"] + ] + assert len(workspace_ids) == 2 + assert 1 in workspace_ids + assert None in workspace_ids + + assert json_out["roles"][1]["role"]["name"] == "Editor" + assert len(json_out["roles"][1]["groupRoleAssignments"]) == 2 + workspace_ids = [ + a["roleAssignment"]["scopeWorkspaceId"] + for a in json_out["roles"][1]["groupRoleAssignments"] + ] + assert len(workspace_ids) == 2 + assert len(json_out["roles"][1]["userRoleAssignments"]) == 0 + + # Test list-groups-roles. + detproc.check_call(admin, list_group_roles) + json_out = detproc.check_json(admin, list_group_roles + ["--json"]) + assert len(json_out["roles"]) == 1 + assert len(json_out["assignments"]) == 1 + assert json_out["roles"][0]["name"] == "Editor" + assert json_out["assignments"][0]["roleId"] == json_out["roles"][0]["roleId"] + assert json_out["assignments"][0]["scopeWorkspaceIds"] == [1] + assert json_out["assignments"][0]["scopeCluster"] @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(roles_not_implemented(), reason="ee is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_rbac_describe_role() -> None: - with logged_in_user(conf.ADMIN_CREDENTIALS): - # Role doesn't exist. - det_cmd_expect_error( - ["rbac", "describe-role", "roleDoesntExist"], "could not find role name" - ) - - # Role is assigned to our group and user. - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - test_user_creds = api_utils.create_test_user() - group_name = api_utils.get_random_string() - - det_cmd(["user-group", "create", group_name], check=True) - det_cmd(["rbac", "assign-role", "Viewer", "--group-name-to-assign", group_name], check=True) - det_cmd( - [ - "rbac", - "assign-role", - "Viewer", - "--group-name-to-assign", - group_name, - "--workspace-name", - "Uncategorized", - ], - check=True, - ) - - sess = api_utils.determined_test_session(conf.ADMIN_CREDENTIALS) - user_id = api.usernames_to_user_ids(sess, [test_user_creds.username])[0] - group_id = api.group_name_to_group_id(sess, group_name) - - det_cmd( - ["rbac", "assign-role", "Viewer", "--username-to-assign", test_user_creds.username], - check=True, - ) - det_cmd( - [ - "rbac", - "assign-role", - "Viewer", - "--username-to-assign", - test_user_creds.username, - "--workspace-name", - "Uncategorized", - ], - check=True, - ) - - # No errors printing non-json output. - det_cmd(["rbac", "describe-role", "Viewer"], check=True) - - # Output is returned correctly. - json_out = det_cmd_json(["rbac", "describe-role", "Viewer", "--json"]) - assert json_out["role"]["name"] == "Viewer" - - group_assign = [a for a in json_out["groupRoleAssignments"] if a["groupId"] == group_id] - assert len(group_assign) == 2 - group_assign.sort( - key=lambda x: -1 if x["roleAssignment"]["scopeWorkspaceId"] is None else 1 - ) - assert group_assign[0]["roleAssignment"]["scopeWorkspaceId"] is None - assert group_assign[1]["roleAssignment"]["scopeWorkspaceId"] == 1 - - user_assign = [a for a in json_out["userRoleAssignments"] if a["userId"] == user_id] - assert len(user_assign) == 2 - user_assign.sort(key=lambda x: -1 if x["roleAssignment"]["scopeWorkspaceId"] is None else 1) - assert user_assign[0]["roleAssignment"]["scopeWorkspaceId"] is None - assert user_assign[1]["roleAssignment"]["scopeWorkspaceId"] == 1 + admin = api_utils.admin_session() + # Role doesn't exist. + detproc.check_error( + admin, ["det", "rbac", "describe-role", "roleDoesntExist"], "could not find role name" + ) + + # Role is assigned to our group and user. + sess, _ = api_utils.create_test_user() + group_name = api_utils.get_random_string() + + detproc.check_call(admin, ["det", "user-group", "create", group_name]) + detproc.check_call( + admin, ["det", "rbac", "assign-role", "Viewer", "--group-name-to-assign", group_name] + ) + detproc.check_call( + admin, + [ + "det", + "rbac", + "assign-role", + "Viewer", + "--group-name-to-assign", + group_name, + "--workspace-name", + "Uncategorized", + ], + ) + + user_id = api.usernames_to_user_ids(admin, [sess.username])[0] + group_id = api.group_name_to_group_id(admin, group_name) + + detproc.check_call( + admin, + ["det", "rbac", "assign-role", "Viewer", "--username-to-assign", sess.username], + ) + detproc.check_call( + admin, + [ + "det", + "rbac", + "assign-role", + "Viewer", + "--username-to-assign", + sess.username, + "--workspace-name", + "Uncategorized", + ], + ) + + # No errors printing non-json output. + detproc.check_call(admin, ["det", "rbac", "describe-role", "Viewer"]) + + # Output is returned correctly. + json_out = detproc.check_json(admin, ["det", "rbac", "describe-role", "Viewer", "--json"]) + assert json_out["role"]["name"] == "Viewer" + + group_assign = [a for a in json_out["groupRoleAssignments"] if a["groupId"] == group_id] + assert len(group_assign) == 2 + group_assign.sort(key=lambda x: -1 if x["roleAssignment"]["scopeWorkspaceId"] is None else 1) + assert group_assign[0]["roleAssignment"]["scopeWorkspaceId"] is None + assert group_assign[1]["roleAssignment"]["scopeWorkspaceId"] == 1 + + user_assign = [a for a in json_out["userRoleAssignments"] if a["userId"] == user_id] + assert len(user_assign) == 2 + user_assign.sort(key=lambda x: -1 if x["roleAssignment"]["scopeWorkspaceId"] is None else 1) + assert user_assign[0]["roleAssignment"]["scopeWorkspaceId"] is None + assert user_assign[1]["roleAssignment"]["scopeWorkspaceId"] == 1 @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(roles_not_implemented(), reason="ee is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_group_access() -> None: + admin = api_utils.admin_session() # create relevant workspace and project, with group having access group_name = api_utils.get_random_string() workspace_name = api_utils.get_random_string() - with logged_in_user(conf.ADMIN_CREDENTIALS): - det_cmd(["workspace", "create", workspace_name], check=True) - det_cmd(["user-group", "create", group_name], check=True) - det_cmd( - ["rbac", "assign-role", "WorkspaceAdmin", "-w", workspace_name, "-g", group_name], - check=True, - ) + detproc.check_call(admin, ["det", "workspace", "create", workspace_name]) + detproc.check_call(admin, ["det", "user-group", "create", group_name]) + detproc.check_call( + admin, + ["det", "rbac", "assign-role", "WorkspaceAdmin", "-w", workspace_name, "-g", group_name], + ) # create test user which cannot access workspace - creds1 = api_utils.create_test_user() - with logged_in_user(creds1): - det_cmd_expect_error( - ["workspace", "describe", workspace_name], "Failed to describe workspace" - ) + sess, _ = api_utils.create_test_user() + detproc.check_error( + sess, ["det", "workspace", "describe", workspace_name], "Failed to describe workspace" + ) # add user to group - with logged_in_user(conf.ADMIN_CREDENTIALS): - det_cmd(["user-group", "add-user", group_name, creds1.username], check=True) + detproc.check_call(admin, ["det", "user-group", "add-user", group_name, sess.username]) # with user now in group, access possible - with logged_in_user(creds1): - det_cmd(["workspace", "describe", workspace_name], check=True) + detproc.check_call(sess, ["det", "workspace", "describe", workspace_name]) diff --git a/e2e_tests/tests/cluster/test_rbac_misc.py b/e2e_tests/tests/cluster/test_rbac_misc.py index 5c8cbbb9cec..1f6fb93351a 100644 --- a/e2e_tests/tests/cluster/test_rbac_misc.py +++ b/e2e_tests/tests/cluster/test_rbac_misc.py @@ -3,15 +3,17 @@ import pytest -from determined.common.api import authentication, bindings, errors +from determined.common import api +from determined.common.api import bindings, errors from tests import api_utils -from tests.cluster.test_rbac import create_workspaces_with_users, rbac_disabled +from tests.cluster import test_rbac @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_cluster_admin_only_calls() -> None: - with create_workspaces_with_users( + admin = api_utils.admin_session() + with test_rbac.create_workspaces_with_users( [ [ (1, ["Editor"]), @@ -20,24 +22,20 @@ def test_cluster_admin_only_calls() -> None: ], ] ) as (_, creds): - u_admin_role = api_utils.create_test_user( - add_password=True, + u_admin_role, _ = api_utils.create_test_user( user=bindings.v1User(username=api_utils.get_random_string(), active=True, admin=False), ) - session = api_utils.determined_test_session(admin=True) api_utils.assign_user_role( - session=session, user=u_admin_role.username, role="ClusterAdmin", workspace=None + session=admin, user=u_admin_role.username, role="ClusterAdmin", workspace=None ) # normal determined admins without ClusterAdmin role. - u_det_admin = api_utils.create_test_user( - add_password=True, + u_det_admin, _ = api_utils.create_test_user( user=bindings.v1User(username=api_utils.get_random_string(), active=True, admin=True), ) - def get_agent_slot_ids(creds: authentication.Credentials) -> Tuple[str, str]: - session = api_utils.determined_test_session(creds) - agents = sorted(bindings.get_GetAgents(session).agents, key=lambda a: a.id) + def get_agent_slot_ids(sess: api.Session) -> Tuple[str, str]: + agents = sorted(bindings.get_GetAgents(sess).agents, key=lambda a: a.id) assert len(agents) > 0 agent = agents[0] assert agent.slots is not None @@ -46,76 +44,68 @@ def get_agent_slot_ids(creds: authentication.Credentials) -> Tuple[str, str]: assert slot_id is not None return agent.id, slot_id - def enable_agent(creds: authentication.Credentials) -> None: - session = api_utils.determined_test_session(creds) - agent_id, _ = get_agent_slot_ids(creds) - bindings.post_EnableAgent(session, agentId=agent_id) + def enable_agent(sess: api.Session) -> None: + agent_id, _ = get_agent_slot_ids(sess) + bindings.post_EnableAgent(sess, agentId=agent_id) - def disable_agent(creds: authentication.Credentials) -> None: - session = api_utils.determined_test_session(creds) - agent_id, _ = get_agent_slot_ids(creds) + def disable_agent(sess: api.Session) -> None: + agent_id, _ = get_agent_slot_ids(sess) bindings.post_DisableAgent( - session, agentId=agent_id, body=bindings.v1DisableAgentRequest(agentId=agent_id) + sess, agentId=agent_id, body=bindings.v1DisableAgentRequest(agentId=agent_id) ) - def enable_slot(creds: authentication.Credentials) -> None: - session = api_utils.determined_test_session(creds) - agent_id, slot_id = get_agent_slot_ids(creds) - bindings.post_EnableSlot(session, agentId=agent_id, slotId=slot_id) + def enable_slot(sess: api.Session) -> None: + agent_id, slot_id = get_agent_slot_ids(sess) + bindings.post_EnableSlot(sess, agentId=agent_id, slotId=slot_id) - def disable_slot(creds: authentication.Credentials) -> None: - session = api_utils.determined_test_session(creds) - agent_id, slot_id = get_agent_slot_ids(creds) + def disable_slot(sess: api.Session) -> None: + agent_id, slot_id = get_agent_slot_ids(sess) bindings.post_DisableSlot( - session, agentId=agent_id, slotId=slot_id, body=bindings.v1DisableSlotRequest() + sess, agentId=agent_id, slotId=slot_id, body=bindings.v1DisableSlotRequest() ) - def get_master_logs(creds: authentication.Credentials) -> None: - logs = list(bindings.get_MasterLogs(api_utils.determined_test_session(creds), limit=2)) + def get_master_logs(sess: api.Session) -> None: + logs = list(bindings.get_MasterLogs(sess, limit=2)) assert len(logs) == 2 - def get_allocations_raw(creds: authentication.Credentials) -> None: + def get_allocations_raw(sess: api.Session) -> None: EXPECTED_TIME_FMT = "%Y-%m-%dT%H:%M:%S.000Z" start = datetime.datetime.now() start_str = start.strftime(EXPECTED_TIME_FMT) end_str = (start + datetime.timedelta(seconds=1)).strftime(EXPECTED_TIME_FMT) entries = bindings.get_ResourceAllocationRaw( - api_utils.determined_test_session(creds), + sess, timestampAfter=start_str, timestampBefore=end_str, ).resourceEntries assert isinstance(entries, list) - def get_allocations_aggregated(creds: authentication.Credentials) -> None: + def get_allocations_aggregated(sess: api.Session) -> None: EXPECTED_TIME_FMT = "%Y-%m-%d" start = datetime.datetime.now() end = start + datetime.timedelta(seconds=1) entries = bindings.get_ResourceAllocationAggregated( - api_utils.determined_test_session(creds), - # fmt: off - period=bindings.v1ResourceAllocationAggregationPeriod\ - .DAILY, - # fmt: on + sess, + period=bindings.v1ResourceAllocationAggregationPeriod.DAILY, startDate=start.strftime(EXPECTED_TIME_FMT), endDate=end.strftime(EXPECTED_TIME_FMT), ).resourceEntries assert isinstance(entries, list) - def get_allocations_raw_echo(creds: authentication.Credentials) -> None: + def get_allocations_raw_echo(sess: api.Session) -> None: EXPECTED_TIME_FMT = "%Y-%m-%dT%H:%M:%S.000Z" start = datetime.datetime.now() start_str = start.strftime(EXPECTED_TIME_FMT) end_str = (start + datetime.timedelta(seconds=1)).strftime(EXPECTED_TIME_FMT) url = "/resources/allocation/raw" params = {"timestamp_after": start_str, "timestamp_before": end_str} - session = api_utils.determined_test_session(creds) - response = session.get(url, params=params) + response = sess.get(url, params=params) assert response.status_code == 200 # FIXME: these can potentially affect other tests running against the same cluster. # the targeted agent_id and slot_id are not guaranteed to be the same across checks. - checks: List[Callable[[authentication.Credentials], None]] = [ + checks: List[Callable[[api.Session], None]] = [ get_master_logs, get_allocations_raw, get_allocations_aggregated, diff --git a/e2e_tests/tests/cluster/test_rbac_ntsc.py b/e2e_tests/tests/cluster/test_rbac_ntsc.py index 74dbbec92b3..697e329d2d2 100644 --- a/e2e_tests/tests/cluster/test_rbac_ntsc.py +++ b/e2e_tests/tests/cluster/test_rbac_ntsc.py @@ -5,14 +5,10 @@ import tests.config as conf from determined.common import api -from determined.common.api import authentication, bindings, errors -from tests import api_utils +from determined.common.api import bindings, errors +from tests import api_utils, detproc from tests import experiment as exp -from tests.cluster.test_rbac import create_workspaces_with_users, rbac_disabled -from tests.cluster.test_workspace_org import setup_workspaces - -from .test_groups import det_cmd_json -from .test_users import logged_in_user +from tests.cluster import test_rbac, test_workspace_org DEFAULT_WID = 1 # default workspace ID @@ -42,22 +38,22 @@ def filter_out_ntsc( @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_notebook() -> None: - u_viewer_ws0 = api_utils.create_test_user(add_password=True) - u_editor_ws1 = api_utils.create_test_user(add_password=True) - admin_session = api_utils.determined_test_session(conf.ADMIN_CREDENTIALS) + u_viewer_ws0, _ = api_utils.create_test_user() + u_editor_ws0, _ = api_utils.create_test_user() + admin = api_utils.admin_session() - with setup_workspaces(count=2) as workspaces: + with test_workspace_org.setup_workspaces(count=2) as workspaces: api_utils.assign_user_role( - session=admin_session, + session=admin, user=u_viewer_ws0.username, role="Viewer", workspace=workspaces[0].name, ) api_utils.assign_user_role( - session=admin_session, - user=u_editor_ws1.username, + session=admin, + user=u_editor_ws0.username, role="Editor", workspace=workspaces[1].name, ) @@ -67,32 +63,29 @@ def test_notebook() -> None: bindings.v1LaunchNotebookRequest(workspaceId=workspaces[1].id), bindings.v1LaunchNotebookRequest(), ] - with setup_notebooks(admin_session, nb_reqs) as notebooks: - r = bindings.get_GetNotebooks(admin_session) + with setup_notebooks(admin, nb_reqs) as notebooks: + r = bindings.get_GetNotebooks(admin) assert len(filter_out_ntsc(notebooks, r.notebooks)) == 3 - r = bindings.get_GetNotebooks(admin_session, workspaceId=workspaces[0].id) + r = bindings.get_GetNotebooks(admin, workspaceId=workspaces[0].id) assert len(filter_out_ntsc(notebooks, r.notebooks)) == 1 - r = bindings.get_GetNotebooks(admin_session, workspaceId=workspaces[1].id) + r = bindings.get_GetNotebooks(admin, workspaceId=workspaces[1].id) assert len(filter_out_ntsc(notebooks, r.notebooks)) == 1 - r = bindings.get_GetNotebooks(admin_session, workspaceId=DEFAULT_WID) + r = bindings.get_GetNotebooks(admin, workspaceId=DEFAULT_WID) assert len(filter_out_ntsc(notebooks, r.notebooks)) == 1 - r = bindings.get_GetNotebooks(api_utils.determined_test_session(u_viewer_ws0)) + r = bindings.get_GetNotebooks(u_viewer_ws0) assert len(r.notebooks) == 1 - r = bindings.get_GetNotebooks( - api_utils.determined_test_session(u_viewer_ws0), workspaceId=workspaces[0].id - ) + r = bindings.get_GetNotebooks(u_viewer_ws0, workspaceId=workspaces[0].id) assert len(r.notebooks) == 1 with pytest.raises(errors.APIException) as e: - r = bindings.get_GetNotebooks( - api_utils.determined_test_session(u_viewer_ws0), workspaceId=workspaces[1].id - ) + r = bindings.get_GetNotebooks(u_viewer_ws0, workspaceId=workspaces[1].id) assert e.value.status_code == 404 # User with only view role on first workspace - with logged_in_user(u_viewer_ws0): - json_out = det_cmd_json(["notebook", "ls", "--all", "--json"]) - assert len(json_out) == 1 + json_out = detproc.check_json( + u_viewer_ws0, ["det", "notebook", "ls", "--all", "--json"] + ) + assert len(json_out) == 1 tensorboard_wait_time = 300 @@ -114,12 +107,11 @@ def only_tensorboard_can_launch( @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_ntsc_iface_access() -> None: - def can_access_logs(creds: authentication.Credentials, ntsc_id: str) -> bool: - session = api_utils.determined_test_session(creds) + def can_access_logs(sess: api.Session, ntsc_id: str) -> bool: try: - list(bindings.get_TaskLogs(session, taskId=ntsc_id)) + list(bindings.get_TaskLogs(sess, taskId=ntsc_id)) return True except errors.APIException as e: if e.status_code != 404 and "not found" not in e.message: @@ -127,7 +119,7 @@ def can_access_logs(creds: authentication.Credentials, ntsc_id: str) -> bool: raise e return False - with create_workspaces_with_users( + with test_rbac.create_workspaces_with_users( [ [ (0, ["Viewer", "Editor"]), @@ -147,134 +139,125 @@ def can_access_logs(creds: authentication.Credentials, ntsc_id: str) -> bool: experiment_id = None if typ == api.NTSC_Kind.tensorboard: pid = bindings.post_PostProject( - api_utils.determined_test_session(creds[0]), + creds[0], body=bindings.v1PostProjectRequest(name="test", workspaceId=workspaces[0].id), workspaceId=workspaces[0].id, ).project.id - with logged_in_user(creds[0]): - # experiment for tensorboard - experiment_id = exp.create_experiment( - conf.fixtures_path("no_op/single.yaml"), - conf.fixtures_path("no_op"), - ["--project_id", str(pid)], - ) + # experiment for tensorboard + experiment_id = exp.create_experiment( + creds[0], + conf.fixtures_path("no_op/single.yaml"), + conf.fixtures_path("no_op"), + ["--project_id", str(pid)], + ) - created_id = api_utils.launch_ntsc( - api_utils.determined_test_session(creds[0]), workspaces[0].id, typ, experiment_id - ).id + created_id = api_utils.launch_ntsc(creds[0], workspaces[0].id, typ, experiment_id).id # user 0 assert can_access_logs( creds[0], created_id ), f"user 0 should be able to access {typ} logs" - session = api_utils.determined_test_session(creds[0]) # user 0 should be able to get details. - api.get_ntsc_details(session, typ, created_id) + api.get_ntsc_details(creds[0], typ, created_id) # user 0 should be able to kill. - api_utils.kill_ntsc(session, typ, created_id) + api_utils.kill_ntsc(creds[0], typ, created_id) # user 0 should be able to set priority. - api_utils.set_prio_ntsc(session, typ, created_id, 1) + api_utils.set_prio_ntsc(creds[0], typ, created_id, 1) # user 0 should be able to launch in workspace 0. - api_utils.launch_ntsc(session, workspaces[0].id, typ, experiment_id) + api_utils.launch_ntsc(creds[0], workspaces[0].id, typ, experiment_id) # user 0 should be able to launch tensorboards and not NSCs in workspace 1. - only_tensorboard_can_launch(session, workspaces[1].id, typ, experiment_id) + only_tensorboard_can_launch(creds[0], workspaces[1].id, typ, experiment_id) # user 1 assert can_access_logs( creds[1], created_id ), f"user 1 should be able to access {typ} logs" - session = api_utils.determined_test_session(creds[1]) # user 1 should be able to get details. - api.get_ntsc_details(session, typ, created_id) + api.get_ntsc_details(creds[1], typ, created_id) with pytest.raises(errors.ForbiddenException) as fe: # user 1 should not be able to kill. - api_utils.kill_ntsc(session, typ, created_id) + api_utils.kill_ntsc(creds[1], typ, created_id) assert "access denied" in fe.value.message with pytest.raises(errors.ForbiddenException) as fe: # user 1 should not be able to set priority. - api_utils.set_prio_ntsc(session, typ, created_id, 1) + api_utils.set_prio_ntsc(creds[1], typ, created_id, 1) assert "access denied" in fe.value.message # user 1 should be able to launch tensorboards but not NSCs in workspace 0. - only_tensorboard_can_launch(session, workspaces[0].id, typ, experiment_id) + only_tensorboard_can_launch(creds[1], workspaces[0].id, typ, experiment_id) # tensorboard requires workspace access so returns workspace not found if # the user does not have access to the workspace. if typ == api.NTSC_Kind.tensorboard: with pytest.raises(errors.NotFoundException): - api_utils.launch_ntsc(session, workspaces[1].id, typ, experiment_id) + api_utils.launch_ntsc(creds[1], workspaces[1].id, typ, experiment_id) else: with pytest.raises(errors.ForbiddenException): - api_utils.launch_ntsc(session, workspaces[1].id, typ, experiment_id) + api_utils.launch_ntsc(creds[1], workspaces[1].id, typ, experiment_id) # user 2 assert not can_access_logs( creds[2], created_id ), f"user 2 should not be able to access {typ} logs" - session = api_utils.determined_test_session(creds[2]) with pytest.raises(errors.APIException) as e: # user 2 should not be able to get details. - api.get_ntsc_details(session, typ, created_id) + api.get_ntsc_details(creds[2], typ, created_id) assert e.value.status_code == 404, f"user 2 should not be able to get details for {typ}" with pytest.raises(errors.APIException) as e: # user 2 should not be able to kill or know it exists. - api_utils.kill_ntsc(session, typ, created_id) + api_utils.kill_ntsc(creds[2], typ, created_id) assert e.value.status_code == 404, f"user 2 should not be able to kill {typ}" with pytest.raises(errors.APIException) as e: # user 2 should not be able to set priority or know it exists. - api_utils.set_prio_ntsc(session, typ, created_id, 1) + api_utils.set_prio_ntsc(creds[2], typ, created_id, 1) assert e.value.status_code == 404, f"user 2 should not be able to set priority {typ}" if typ == api.NTSC_Kind.tensorboard: with pytest.raises(errors.NotFoundException): - api_utils.launch_ntsc(session, workspaces[0].id, typ, experiment_id) + api_utils.launch_ntsc(creds[2], workspaces[0].id, typ, experiment_id) else: with pytest.raises(errors.ForbiddenException): - api_utils.launch_ntsc(session, workspaces[0].id, typ, experiment_id) + api_utils.launch_ntsc(creds[2], workspaces[0].id, typ, experiment_id) # user 2 has view access to workspace 1 so gets forbidden instead of not found. with pytest.raises(errors.ForbiddenException): - api_utils.launch_ntsc(session, workspaces[1].id, typ, experiment_id) + api_utils.launch_ntsc(creds[2], workspaces[1].id, typ, experiment_id) # test visibility - created_id2 = api_utils.launch_ntsc( - api_utils.determined_test_session(creds[0]), workspaces[2].id, typ, experiment_id - ).id + created_id2 = api_utils.launch_ntsc(creds[0], workspaces[2].id, typ, experiment_id).id # none of the users should be able to get details for cred in [creds[1], creds[2]]: - session = api_utils.determined_test_session(cred) # exception for creds[1], who can access the experiment and tensorboard if typ != api.NTSC_Kind.tensorboard and cred == creds[2]: with pytest.raises(errors.APIException) as e: - api.get_ntsc_details(session, typ, created_id2) + api.get_ntsc_details(cred, typ, created_id2) assert e.value.status_code == 404 - results = api_utils.list_ntsc(session, typ) + results = api_utils.list_ntsc(cred, typ) for r in results: if r.id == created_id2: pytest.fail(f"should not be able to see {typ} {r.id} in the list results") with pytest.raises(errors.APIException) as e: - api_utils.list_ntsc(session, typ, workspace_id=workspaces[2].id) + api_utils.list_ntsc(cred, typ, workspace_id=workspaces[2].id) # FIXME only notebooks return the correct 404. assert e.value.status_code == 404, f"{typ} should fail with 404" with pytest.raises(errors.APIException) as e: - api_utils.list_ntsc(session, typ, workspace_id=12532459) + api_utils.list_ntsc(cred, typ, workspace_id=12532459) assert e.value.status_code == 404, f"{typ} should fail with 404" # kill the ntsc - api_utils.kill_ntsc(api_utils.determined_test_session(creds[0]), typ, created_id) + api_utils.kill_ntsc(creds[0], typ, created_id) @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_ntsc_proxy() -> None: - def get_proxy(creds: authentication.Credentials, task_id: str) -> Optional[errors.APIException]: - session = api_utils.determined_test_session(creds) + def get_proxy(sess: api.Session, task_id: str) -> Optional[errors.APIException]: try: - session.get(f"proxy/{task_id}/") + sess.get(f"proxy/{task_id}/") return None except errors.APIException as e: return e - with create_workspaces_with_users( + with test_rbac.create_workspaces_with_users( [ [ (0, ["Viewer", "Editor"]), @@ -290,38 +273,32 @@ def get_proxy(creds: authentication.Credentials, task_id: str) -> Optional[error experiment_id = None if typ == api.NTSC_Kind.tensorboard: pid = bindings.post_PostProject( - api_utils.determined_test_session(creds[0]), + creds[0], body=bindings.v1PostProjectRequest(name="test", workspaceId=workspaces[0].id), workspaceId=workspaces[0].id, ).project.id - with logged_in_user(creds[0]): - # experiment for tensorboard - experiment_id = exp.create_experiment( - conf.fixtures_path("no_op/single.yaml"), - conf.fixtures_path("no_op"), - ["--project_id", str(pid)], - ) + # experiment for tensorboard + experiment_id = exp.create_experiment( + creds[0], + conf.fixtures_path("no_op/single.yaml"), + conf.fixtures_path("no_op"), + ["--project_id", str(pid)], + ) - created_id = api_utils.launch_ntsc( - api_utils.determined_test_session(creds[0]), workspaces[0].id, typ, experiment_id - ).id + created_id = api_utils.launch_ntsc(creds[0], workspaces[0].id, typ, experiment_id).id print(f"created {typ} {created_id}") api.wait_for_ntsc_state( - api_utils.determined_test_session(creds[0]), + creds[0], api.NTSC_Kind(typ), created_id, lambda s: s == bindings.taskv1State.RUNNING, timeout=300, ) - deets = api.get_ntsc_details( - api_utils.determined_test_session(creds[0]), typ, created_id - ) + deets = api.get_ntsc_details(creds[0], typ, created_id) assert deets.state == bindings.taskv1State.RUNNING, f"{typ} should be running" - err = api.task_is_ready( - api_utils.determined_test_session(conf.ADMIN_CREDENTIALS), created_id - ) + err = api.wait_for_task_ready(api_utils.admin_session(), created_id) assert err is None, f"{typ} should be ready {err}" assert ( get_proxy(creds[0], created_id) is None @@ -334,13 +311,13 @@ def get_proxy(creds: authentication.Credentials, task_id: str) -> Optional[error assert view_err.status_code == 404, f"user 2 should error out with not found{typ}" # kill the ntsc - api_utils.kill_ntsc(api_utils.determined_test_session(creds[0]), typ, created_id) + api_utils.kill_ntsc(creds[0], typ, created_id) @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_tsb_listed() -> None: - with create_workspaces_with_users( + with test_rbac.create_workspaces_with_users( [ [ (0, ["Editor"]), @@ -348,66 +325,63 @@ def test_tsb_listed() -> None: ], ] ) as ([workspace], creds): + editor_sess = creds[0] + viewer_sess = creds[1] pid = bindings.post_PostProject( - api_utils.determined_test_session(creds[0]), + editor_sess, body=bindings.v1PostProjectRequest(name="test", workspaceId=workspace.id), workspaceId=workspace.id, ).project.id - session = api_utils.determined_test_session(creds[0]) - - with logged_in_user(creds[0]): - # experiment for tensorboard - experiment_id = exp.create_experiment( - conf.fixtures_path("no_op/single.yaml"), - conf.fixtures_path("no_op"), - ["--project_id", str(pid)], - ) + # experiment for tensorboard + experiment_id = exp.create_experiment( + editor_sess, + conf.fixtures_path("no_op/single.yaml"), + conf.fixtures_path("no_op"), + ["--project_id", str(pid)], + ) - created_id = api_utils.launch_ntsc( - session, workspace.id, api.NTSC_Kind.tensorboard, experiment_id - ).id + created_id = api_utils.launch_ntsc( + editor_sess, workspace.id, api.NTSC_Kind.tensorboard, experiment_id + ).id - # list tensorboards and make sure it's included in the response. - tsbs = bindings.get_GetTensorboards(session, workspaceId=workspace.id).tensorboards - assert len(tsbs) == 1, "should be one tensorboard" - assert tsbs[0].id == created_id, "should be the tensorboard we created" + # list tensorboards and make sure it's included in the response. + tsbs = bindings.get_GetTensorboards(editor_sess, workspaceId=workspace.id).tensorboards + assert len(tsbs) == 1, "should be one tensorboard" + assert tsbs[0].id == created_id, "should be the tensorboard we created" - tsbs = bindings.get_GetTensorboards( - api_utils.determined_test_session(credentials=creds[1]), workspaceId=workspace.id - ).tensorboards - assert len(tsbs) == 1, "should be one tensorboard" - assert tsbs[0].id == created_id, "should be the tensorboard we created" + tsbs = bindings.get_GetTensorboards(viewer_sess, workspaceId=workspace.id).tensorboards + assert len(tsbs) == 1, "should be one tensorboard" + assert tsbs[0].id == created_id, "should be the tensorboard we created" @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_tsb_launch_on_trials() -> None: - with create_workspaces_with_users( + with test_rbac.create_workspaces_with_users( [ [ (0, ["Editor"]), ], ] ) as ([workspace], creds): - session = api_utils.determined_test_session(creds[0]) pid = bindings.post_PostProject( - session, + creds[0], body=bindings.v1PostProjectRequest(name="test", workspaceId=workspace.id), workspaceId=workspace.id, ).project.id - with logged_in_user(conf.ADMIN_CREDENTIALS): - experiment_id = exp.create_experiment( - conf.fixtures_path("no_op/single.yaml"), - conf.fixtures_path("no_op"), - ["--project_id", str(pid)], - ) + experiment_id = exp.create_experiment( + api_utils.admin_session(), + conf.fixtures_path("no_op/single.yaml"), + conf.fixtures_path("no_op"), + ["--project_id", str(pid)], + ) - trials = bindings.get_GetExperimentTrials(session, experimentId=experiment_id).trials + trials = bindings.get_GetExperimentTrials(creds[0], experimentId=experiment_id).trials trial_ids = [t.id for t in trials] assert len(trial_ids) == 1, f"we should have 1 trial, but got {trial_ids}" bindings.post_LaunchTensorboard( - session, + creds[0], body=bindings.v1LaunchTensorboardRequest(workspaceId=workspace.id, trialIds=trial_ids), ).tensorboard.id diff --git a/e2e_tests/tests/cluster/test_resource_manager.py b/e2e_tests/tests/cluster/test_resource_manager.py index f0103cd82cc..d6be3c557d0 100644 --- a/e2e_tests/tests/cluster/test_resource_manager.py +++ b/e2e_tests/tests/cluster/test_resource_manager.py @@ -8,12 +8,10 @@ from determined.common import util from determined.common.api import bindings -from determined.common.api.bindings import experimentv1State from tests import api_utils from tests import config as conf from tests import experiment as exp - -from .test_agent_disable import _wait_for_slots +from tests.cluster import test_agent_disable # How long we should for the Nth = 1 rank to free. RANK_ONE_WAIT_TIME = 300 @@ -26,10 +24,12 @@ def test_allocation_resources_incremental_release() -> None: Start an two container experiment and ensure one container exits before the other. Ensure resources are released and schedule-able without the other needing to be released. """ + admin = api_utils.admin_session() + sess = api_utils.user_session() cleanup_exp_ids = [] try: - slots = _wait_for_slots(2) + slots = test_agent_disable._wait_for_slots(admin, 2) assert len(slots) == 2 with tempfile.TemporaryDirectory() as context_dir, open( @@ -48,15 +48,16 @@ def test_allocation_resources_incremental_release() -> None: conf.fixtures_path("no_op/model_def.py"), os.path.join(context_dir, "model_def.py") ) - exp_id = exp.create_experiment(config_file.name, context_dir, None) + exp_id = exp.create_experiment(sess, config_file.name, context_dir, None) cleanup_exp_ids.append(exp_id) # Wait for the experiment to start and run some. exp.wait_for_experiment_state( + sess, exp_id, - experimentv1State.RUNNING, + bindings.experimentv1State.RUNNING, ) - exp.wait_for_experiment_active_workload(exp_id) + exp.wait_for_experiment_active_workload(sess, exp_id) # And wait for exactly one of the resources to free, while one is still in use. confirmations = 0 @@ -79,14 +80,15 @@ def test_allocation_resources_incremental_release() -> None: # Ensure we can schedule on the free slot, not only that the API says its available. exp_id_2 = exp.create_experiment( + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), None, ) cleanup_exp_ids.append(exp_id_2) - exp.wait_for_experiment_workload_progress(exp_id_2) - exp.wait_for_experiment_state(exp_id_2, experimentv1State.COMPLETED) + exp.wait_for_experiment_workload_progress(sess, exp_id_2) + exp.wait_for_experiment_state(sess, exp_id_2, bindings.experimentv1State.COMPLETED) cleanup_exp_ids = cleanup_exp_ids[:-1] # And check the hung experiment still is holding on to its hung slot. @@ -96,12 +98,12 @@ def test_allocation_resources_incremental_release() -> None: finally: for exp_id in cleanup_exp_ids: - bindings.post_KillExperiment(api_utils.determined_test_session(), id=exp_id) - exp.wait_for_experiment_state(exp_id, experimentv1State.CANCELED) + bindings.post_KillExperiment(sess, id=exp_id) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.CANCELED) def list_free_agents() -> List[bindings.v1Agent]: - agents = bindings.get_GetAgents(api_utils.determined_test_session()) + agents = bindings.get_GetAgents(api_utils.user_session()) if not agents.agents: pytest.fail(f"missing agents: {agents}") @@ -111,11 +113,14 @@ def list_free_agents() -> List[bindings.v1Agent]: @pytest.mark.e2e_cpu_2a @pytest.mark.timeout(600) def test_experiment_is_single_node() -> None: - slots = _wait_for_slots(2) + admin = api_utils.admin_session() + sess = api_utils.user_session() + slots = test_agent_disable._wait_for_slots(admin, 2) assert len(slots) == 2 with pytest.raises(AssertionError): exp.create_experiment( + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), [ diff --git a/e2e_tests/tests/cluster/test_resource_pools.py b/e2e_tests/tests/cluster/test_resource_pools.py index 4fbac35cbaf..43be9b88014 100644 --- a/e2e_tests/tests/cluster/test_resource_pools.py +++ b/e2e_tests/tests/cluster/test_resource_pools.py @@ -2,6 +2,7 @@ from determined.common import util from determined.experimental import client +from tests import api_utils from tests import config as conf @@ -14,8 +15,8 @@ def test_default_pool_task_container_defaults() -> None: # task_container_defaults: # environment_variables: # - SOMEVAR=SOMEVAL - determined_master = conf.make_master_url() - d = client.Determined(determined_master) + sess = api_utils.user_session() + d = client.Determined._from_session(sess) config_path = conf.fixtures_path("no_op/single-medium-train-step.yaml") e1 = d.create_experiment( config=config_path, diff --git a/e2e_tests/tests/cluster/test_slurm.py b/e2e_tests/tests/cluster/test_slurm.py index 217bf54c80b..af4ed1c6236 100644 --- a/e2e_tests/tests/cluster/test_slurm.py +++ b/e2e_tests/tests/cluster/test_slurm.py @@ -5,29 +5,33 @@ import pytest import torch +from determined.common import api from determined.common.api import bindings from tests import api_utils, command from tests import config as conf from tests import experiment as exp -def run_failure_test_multiple(config_file: str, model_def_file: str, errors: List[str]) -> int: +def run_failure_test_multiple( + sess: api.Session, config_file: str, model_def_file: str, errors: List[str] +) -> int: # Creates an experiment meant to fail and checks array of error messages # If one of the errors are present, then the assertion passes experiment_id = exp.create_experiment( + sess, config_file, model_def_file, ) exp.wait_for_experiment_state( - experiment_id, bindings.experimentv1State.ERROR, max_wait_secs=600 + sess, experiment_id, bindings.experimentv1State.ERROR, max_wait_secs=600 ) - trials = exp.experiment_trials(experiment_id) + trials = exp.experiment_trials(sess, experiment_id) for t in trials: trial = t.trial if trial.state != bindings.trialv1State.ERROR: continue - logs = exp.trial_logs(trial.id) + logs = exp.trial_logs(sess, trial.id) totalAssertion = False for e in errors: totalAssertion = totalAssertion or any(e in line for line in logs) @@ -49,6 +53,7 @@ def run_failure_test_multiple(config_file: str, model_def_file: str, errors: Lis @pytest.mark.e2e_pbs @api_utils.skipif_not_hpc() def test_unsupported_option() -> None: + sess = api_utils.user_session() # Creates an experiment with a yaml file # It attempts to supply a slurm option that is controlled by Determined # run_failure_test expects the experiment to fail and will assert the log with the string @@ -56,6 +61,7 @@ def test_unsupported_option() -> None: # Waits for experiment to reach a ERROR_STATE. Errors if it does not error exp.run_failure_test( + sess, conf.fixtures_path("failures/unsupported-slurm-option.yaml"), conf.fixtures_path("failures/"), "resources failed with non-zero exit code: unable to launch job: " @@ -67,6 +73,7 @@ def test_unsupported_option() -> None: @pytest.mark.e2e_pbs @api_utils.skipif_not_hpc() def test_docker_image() -> None: + sess = api_utils.user_session() # Creates an experiment with a bad docker image file that will error errors = [ # Singularity message @@ -82,7 +89,7 @@ def test_docker_image() -> None: ] run_failure_test_multiple( - conf.fixtures_path("failures/bad-image.yaml"), conf.fixtures_path("failures/"), errors + sess, conf.fixtures_path("failures/bad-image.yaml"), conf.fixtures_path("failures/"), errors ) @@ -93,6 +100,7 @@ def test_docker_image() -> None: @pytest.mark.e2e_pbs @api_utils.skipif_not_hpc() def test_node_not_available() -> None: + sess = api_utils.user_session() # Creates an experiment with a configuration that cannot be satisfied. # Verifies that the error message includes the SBATCH options of the failed submission. # Only casablanca displays the SBATCH options. Horizon does not upon failure @@ -102,6 +110,7 @@ def test_node_not_available() -> None: error2 = "CPU count per node can not be satisfied" errors = [error1, error2] run_failure_test_multiple( + sess, conf.fixtures_path("failures/slurm-requested-node-not-available.yaml"), conf.fixtures_path("failures/"), errors, @@ -109,7 +118,9 @@ def test_node_not_available() -> None: def bad_option_helper(config_path: str, fixture_path: str, error_string: str) -> None: + sess = api_utils.user_session() exp.run_failure_test( + sess, conf.fixtures_path(config_path), conf.fixtures_path(fixture_path), error_string, @@ -150,7 +161,9 @@ def test_docker_login() -> None: ) errorDocker = "lstat /root/.config/containers/registries.conf.d: permission denied" errors = [errorPermission, errorDocker] + sess = api_utils.user_session() run_failure_test_multiple( + sess, conf.fixtures_path("failures/docker-login-failure.yaml"), conf.fixtures_path("failures/"), errors, @@ -162,8 +175,10 @@ def test_docker_login() -> None: @pytest.mark.e2e_slurm_misconfigured @api_utils.skipif_not_slurm() def test_master_host() -> None: + sess = api_utils.user_session() # Creates an experiment normally, should error if the back communication channel is broken exp.run_failure_test( + sess, conf.fixtures_path("no_op/single-one-short-step.yaml"), conf.fixtures_path("no_op"), "Unable to reach the master at DET_MASTER=http://junkmaster:8080. " @@ -178,12 +193,13 @@ def test_master_host() -> None: @pytest.mark.e2e_pbs @api_utils.skipif_not_hpc() def test_mnist_pytorch_distributed() -> None: + sess = api_utils.user_session() config = conf.load_config(conf.tutorials_path("mnist_pytorch/distributed.yaml")) config["searcher"]["max_length"] = {"epochs": 1} config["records_per_epoch"] = 5000 config["max_restarts"] = 0 - exp.run_basic_test_with_temp_config(config, conf.tutorials_path("mnist_pytorch"), 1) + exp.run_basic_test_with_temp_config(sess, config, conf.tutorials_path("mnist_pytorch"), 1) # Test to ensure that determined is able to handle preemption gracefully when using dispatcher RM. @@ -207,9 +223,11 @@ def test_mnist_pytorch_distributed() -> None: @pytest.mark.e2e_slurm_preemption @api_utils.skipif_not_slurm() def test_slurm_preemption() -> None: + sess = api_utils.user_session() # Launch the cifar10_pytorch_cancellable experiment requesting 8 GPUs on defq_GPU_cancellable # partition cancelable_exp_id = exp.create_experiment( + sess, conf.cv_examples_path( "../legacy/computer_vision/cifar10_pytorch/cifar10_pytorch_cancelable.yaml" ), @@ -217,12 +235,13 @@ def test_slurm_preemption() -> None: None, ) # Wait for the first cancellable experiment to enter RUNNING state. - exp.wait_for_experiment_state(cancelable_exp_id, bindings.experimentv1State.RUNNING) + exp.wait_for_experiment_state(sess, cancelable_exp_id, bindings.experimentv1State.RUNNING) # Wait for the first cancellable experiment to complete at least one checkpoint. - exp.wait_for_at_least_one_checkpoint(cancelable_exp_id, 300) + exp.wait_for_at_least_one_checkpoint(sess, cancelable_exp_id, 300) # Launch the cifar10_pytorch_high_priority experiment requesting 8 GPUs on defq_GPU_hipri # partition high_priority_exp_id = exp.create_experiment( + sess, conf.cv_examples_path( "../legacy/computer_vision/cifar10_pytorch/cifar10_pytorch_high_priority.yaml" ), @@ -232,15 +251,15 @@ def test_slurm_preemption() -> None: # In this scenario, cifar10_pytorch_high_priority experiment will cause the # cifar10_pytorch_cancelable experiment to get requeued. The experiment # cifar10_pytorch_high_priority will execute to completion. - exp.wait_for_experiment_state(cancelable_exp_id, bindings.experimentv1State.QUEUED) - exp.wait_for_experiment_state(high_priority_exp_id, bindings.experimentv1State.RUNNING) - exp.wait_for_experiment_state(high_priority_exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, cancelable_exp_id, bindings.experimentv1State.QUEUED) + exp.wait_for_experiment_state(sess, high_priority_exp_id, bindings.experimentv1State.RUNNING) + exp.wait_for_experiment_state(sess, high_priority_exp_id, bindings.experimentv1State.COMPLETED) # Now, the experiment cifar10_pytorch_cancelable will resume as soon as the requested # resources are available. - exp.wait_for_experiment_state(cancelable_exp_id, bindings.experimentv1State.RUNNING) + exp.wait_for_experiment_state(sess, cancelable_exp_id, bindings.experimentv1State.RUNNING) # Finally, the experiment cifar10_pytorch_cancelable will complete if there are no other # interruptions. - exp.wait_for_experiment_state(cancelable_exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, cancelable_exp_id, bindings.experimentv1State.COMPLETED) @pytest.mark.e2e_slurm @@ -257,10 +276,11 @@ def test_start_and_verify_hpc_home() -> None: user $HOME even when the /etc/nsswitch.conf is mis-configured, so this test uses a shell. """ + sess = api_utils.user_session() foundLineWithUserName = False - with command.interactive_command("shell", "start") as shell: + with command.interactive_command(sess, ["shell", "start"]) as shell: # In order to identify whether we are running a Podman container, we will # check if we're running as "root", because Podman containers run as # "root". Use "$(whoami)" to report the user we running as, because the diff --git a/e2e_tests/tests/cluster/test_support-bundle.py b/e2e_tests/tests/cluster/test_support_bundle.py similarity index 59% rename from e2e_tests/tests/cluster/test_support-bundle.py rename to e2e_tests/tests/cluster/test_support_bundle.py index 93e5ebe5c76..8bd10ad7da6 100644 --- a/e2e_tests/tests/cluster/test_support-bundle.py +++ b/e2e_tests/tests/cluster/test_support_bundle.py @@ -3,28 +3,30 @@ import pytest +from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp @pytest.mark.e2e_cpu def test_support_bundle() -> None: + sess = api_utils.user_session() exp_id = exp.run_basic_test( + sess, config_file=conf.fixtures_path("no_op/single-one-short-step.yaml"), model_def_file=conf.fixtures_path("no_op"), expected_trials=1, ) - trial_id = exp.experiment_first_trial(exp_id) + trial_id = exp.experiment_first_trial(sess, exp_id) output_dir = f"e2etest_trial_{trial_id}" os.mkdir(output_dir) command = ["det", "trial", "support-bundle", str(trial_id), "-o", output_dir] - completed_process = subprocess.run( - command, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + p = detproc.run( + sess, command, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - assert completed_process.returncode == 0, "\nstdout:\n{} \nstderr:\n{}".format( - completed_process.stdout, completed_process.stderr - ) + assert p.returncode == 0, f"\nstdout:\n{p.stdout} \nstderr:\n{p.stderr}" diff --git a/e2e_tests/tests/cluster/test_users.py b/e2e_tests/tests/cluster/test_users.py index 4d8a2dbd126..33eb3ed8560 100644 --- a/e2e_tests/tests/cluster/test_users.py +++ b/e2e_tests/tests/cluster/test_users.py @@ -1,144 +1,36 @@ import contextlib -import json import logging import os import pathlib -import re import shutil import subprocess import time -import uuid -from typing import Dict, Generator, Iterator, List, Optional, Tuple, cast +from typing import Generator, List, Tuple import appdirs -import pexpect import pytest -from pexpect import spawn -from determined.common import api, constants, util -from determined.common.api import authentication, bindings, certs, errors -from determined.experimental import Determined +from determined.common import api, util +from determined.common.api import bindings, errors +from determined.experimental import client from tests import api_utils, command from tests import config as conf +from tests import detproc from tests import experiment as exp -from tests.filetree import FileTree +from tests import filetree EXPECT_TIMEOUT = 5 logger = logging.getLogger(__name__) -@pytest.fixture() -def clean_auth() -> Iterator[None]: - """ - clean_auth is a session-level fixture that ensures that we run tests with no preconfigured - authentication, and that any settings we save during tests are cleaned up afterwards. - """ - authentication.TokenStore(conf.make_master_url()).delete_token_cache() - yield None - authentication.TokenStore(conf.make_master_url()).delete_token_cache() - - -@pytest.fixture() -def login_admin() -> None: - a_username, a_password = conf.ADMIN_CREDENTIALS - child = det_spawn(["user", "login", a_username]) - child.setecho(True) - expected = f"Password for user '{a_username}':" - child.expect(expected, timeout=EXPECT_TIMEOUT) - child.sendline(a_password) - child.read() - child.wait() - child.close() - assert child.exitstatus == 0 - - -@contextlib.contextmanager -def logged_in_user(credentials: authentication.Credentials) -> Generator: - api_utils.configure_token_store(credentials) - yield - log_out_user() - - -@pytest.mark.e2e_cpu -def test_logged_in_user() -> None: - with logged_in_user(conf.ADMIN_CREDENTIALS): - output = det_run(["user", "whoami"]) - assert f"You are logged in as user '{conf.ADMIN_CREDENTIALS.username}'" in output - - -def get_random_string() -> str: - return str(uuid.uuid4()) - - -def det_spawn(args: List[str], env: Optional[Dict[str, str]] = None) -> spawn: - args = ["-m", conf.make_master_url()] + args - return pexpect.spawn("det", args, env=env) - - -def det_run(args: List[str]) -> str: - return cast(str, pexpect.run(f"det -m {conf.make_master_url()} {' '.join(args)}").decode()) - - -def log_in_user_cli(credentials: authentication.Credentials, expectedStatus: int = 0) -> None: - username, password = credentials - child = det_spawn(["user", "login", username]) - child.setecho(True) - expected = re.escape(f"Password for user '{username}': ") - - child.expect(expected, EXPECT_TIMEOUT) - child.sendline(password) - child.read() - child.wait() - child.close() - assert child.exitstatus == expectedStatus - - -def change_user_password( - target_username: str, - target_password: str, - own: bool = False, -) -> int: - cmd = ["user", "change-password"] - if not own: - cmd.append(target_username) - - child = det_spawn(cmd) - expected_new_pword_prompt = f"New password for user '{target_username}':" - confirm_pword_prompt = "Confirm password:" - child.expect(expected_new_pword_prompt, timeout=EXPECT_TIMEOUT) - - child.sendline(target_password) - child.expect(confirm_pword_prompt, timeout=EXPECT_TIMEOUT) - child.sendline(target_password) - child.read() - child.wait() - child.close() - return cast(int, child.exitstatus) - - -def log_out_user(username: Optional[str] = None) -> None: - if username is not None: - args = ["-u", username, "user", "logout"] - else: - args = ["user", "logout"] - - child = det_spawn(args) - child.read() - child.wait() - child.close() - assert child.exitstatus == 0 - - -def activate_deactivate_user(active: bool, target_user: str) -> None: +def activate_deactivate_user(sess: api.Session, active: bool, target_user: str) -> None: command = [ "det", - "-m", - conf.make_master_url(), "user", "activate" if active else "deactivate", target_user, ] - subprocess.run(command, check=True) + detproc.check_output(sess, command) def extract_columns(output: str, column_indices: List[int]) -> List[Tuple[str, ...]]: @@ -161,10 +53,10 @@ def extract_id_and_owner_from_exp_list(output: str) -> List[Tuple[int, str]]: @pytest.mark.e2e_cpu -def test_post_user_api(clean_auth: None, login_admin: None) -> None: - new_username = get_random_string() +def test_post_user_api() -> None: + new_username = api_utils.get_random_string() - sess = api_utils.determined_test_session(admin=True) + sess = api_utils.admin_session() user = bindings.v1User(active=True, admin=False, username=new_username) body = bindings.v1PostUserRequest(password="", user=user) @@ -176,7 +68,7 @@ def test_post_user_api(clean_auth: None, login_admin: None) -> None: user = bindings.v1User( active=True, admin=False, - username=get_random_string(), + username=api_utils.get_random_string(), agentUserGroup=bindings.v1AgentUserGroup( agentUid=1000, agentGid=1001, agentUser="username", agentGroup="groupname" ), @@ -191,7 +83,7 @@ def test_post_user_api(clean_auth: None, login_admin: None) -> None: user = bindings.v1User( active=True, admin=False, - username=get_random_string(), + username=api_utils.get_random_string(), agentUserGroup=bindings.v1AgentUserGroup( agentUid=1000, agentGid=1001, @@ -203,346 +95,199 @@ def test_post_user_api(clean_auth: None, login_admin: None) -> None: @pytest.mark.e2e_cpu -def test_create_user_sdk(clean_auth: None, login_admin: None) -> None: - username = get_random_string() - password = get_random_string() - det_obj = Determined(master=conf.make_master_url()) +def test_create_user_sdk() -> None: + username = api_utils.get_random_string() + password = api_utils.get_random_string() + det_obj = client.Determined._from_session(api_utils.admin_session()) user = det_obj.create_user(username=username, admin=False, password=password) assert user.user_id is not None and user.username == username @pytest.mark.e2e_cpu -def test_logout(clean_auth: None, login_admin: None) -> None: - # Tests fallback to default determined user - creds = api_utils.create_test_user(True) - - # Set Determined password to something in order to disable auto-login. - password = get_random_string() - assert change_user_password(constants.DEFAULT_DETERMINED_USER, password) == 0 - - # Log in as new user. - api_utils.configure_token_store(creds) - # Now we should be able to list experiments. - child = det_spawn(["e", "list"]) - child.read() - child.wait() - child.close() - assert child.status == 0 - - # Exiting the logged_in_user context logs out and asserts that the exit code is 0. - log_out_user() - # Now trying to list experiments should result in an error. - child = det_spawn(["e", "list"]) - expected = "Unauthenticated" - assert expected in str(child.read()) - child.wait() - child.close() - assert child.status != 0 - - # Log in as determined. - api_utils.configure_token_store( - authentication.Credentials(constants.DEFAULT_DETERMINED_USER, password) - ) - - # Log back in as new user. - api_utils.configure_token_store(creds) +def test_logout() -> None: + # Make sure that a logged out session cannot be reused. + sess = api_utils.make_session("determined", "") - # Now log out as determined. - log_out_user(constants.DEFAULT_DETERMINED_USER) - - # Should still be able to list experiments because new user is logged in. - child = det_spawn(["e", "list"]) - child.read() - child.wait() - child.close() - assert child.status == 0 - - # Change Determined password back to "". - change_user_password(constants.DEFAULT_DETERMINED_USER, "") - # Clean up. + bindings.post_Logout(sess) + with pytest.raises(errors.UnauthenticatedException): + bindings.get_GetMe(sess) @pytest.mark.e2e_cpu @pytest.mark.e2e_cpu_postgres -def test_activate_deactivate(clean_auth: None, login_admin: None) -> None: - creds = api_utils.create_test_user(True) - - # Make sure we can log in as the user. - api_utils.configure_token_store(creds) - - # Log out. - log_out_user() - - # login admin again. - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) +def test_activate_deactivate() -> None: + sess, password = api_utils.create_test_user() # Deactivate user. - activate_deactivate_user(False, creds.username) + admin = api_utils.admin_session() + activate_deactivate_user(admin, False, sess.username) - # Attempt to log in again. It should have a non-zero exit status. - log_in_user_cli(creds, 1) + # Attempt to log in again. + with pytest.raises(errors.ForbiddenException): + api_utils.make_session(sess.username, password) # Activate user. - activate_deactivate_user(True, creds.username) + activate_deactivate_user(admin, True, sess.username) - # Now log in again. It should have a non-zero exit status. - api_utils.configure_token_store(creds) + # Now log in again. + api_utils.make_session(sess.username, password) # SDK testing for activating and deactivating. - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - det_obj = Determined(master=conf.make_master_url()) - user = det_obj.get_user_by_name(user_name=creds.username) + det_obj = client.Determined._from_session(admin) + user = det_obj.get_user_by_name(user_name=sess.username) user.deactivate() assert user.active is not True + with pytest.raises(errors.ForbiddenException): + api_utils.make_session(sess.username, password) + user.activate() assert user.active is True - - # Now log in again. - api_utils.configure_token_store(creds) + api_utils.make_session(sess.username, password) @pytest.mark.e2e_cpu @pytest.mark.e2e_cpu_postgres -def test_change_password(clean_auth: None, login_admin: None) -> None: - # Create a user without a password. - creds = api_utils.create_test_user(False) +def test_change_password() -> None: + sess, old_password = api_utils.create_test_user() + d = client.Determined._from_session(api_utils.admin_session()) + userobj = d.get_user_by_name(sess.username) + userobj.change_password("newpass") - # Attempt to log in. - api_utils.configure_token_store(creds) + # Old password does not work anymore. + with pytest.raises(errors.UnauthenticatedException): + api_utils.make_session(sess.username, old_password) - # Log out. - log_out_user() + # New password does work. + api_utils.make_session(sess.username, "newpass") - # login admin - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - new_password = get_random_string() - assert change_user_password(creds.username, new_password) == 0 - api_utils.configure_token_store(authentication.Credentials(creds.username, new_password)) +@pytest.mark.e2e_cpu +def test_change_own_password() -> None: + # Create a user without a password. + sess, old_password = api_utils.create_test_user() - new_password_sdk = get_random_string() - det_obj = Determined(master=conf.make_master_url()) - user = det_obj.get_user_by_name(user_name=creds.username) - user.change_password(new_password=new_password_sdk) - api_utils.configure_token_store(authentication.Credentials(creds.username, new_password_sdk)) + d = client.Determined._from_session(sess) + userobj = d.get_user_by_name(sess.username) + userobj.change_password("newpass") + with pytest.raises(errors.UnauthenticatedException): + api_utils.make_session(sess.username, old_password) -@pytest.mark.e2e_cpu -def test_change_own_password(clean_auth: None, login_admin: None) -> None: - # Create a user without a password. - creds = api_utils.create_test_user(False) - log_in_user_cli(creds) - assert change_user_password(creds.username, creds.password, own=True) == 0 + api_utils.make_session(sess.username, "newpass") @pytest.mark.e2e_cpu -def test_change_username(clean_auth: None, login_admin: None) -> None: - creds = api_utils.create_test_user() +def test_change_username() -> None: + admin = api_utils.admin_session() + sess, _ = api_utils.create_test_user() + old_username = sess.username new_username = "rename-user-64" - command = ["det", "-m", conf.make_master_url(), "user", "rename", creds.username, new_username] - subprocess.run(command, check=True) - det_obj = Determined(master=conf.make_master_url()) - user = det_obj.get_user_by_name(user_name=new_username) + command = ["det", "user", "rename", old_username, new_username] + detproc.check_call(admin, command) + d = client.Determined._from_session(admin) + user = d.get_user_by_name(user_name=new_username) assert user.username == new_username - api_utils.configure_token_store(authentication.Credentials(new_username, "")) # Test SDK new_username = "rename-user-$64" user.rename(new_username) - user = det_obj.get_user_by_name(user_name=new_username) + user = d.get_user_by_name(user_name=new_username) assert user.username == new_username - api_utils.configure_token_store(authentication.Credentials(new_username, "")) @pytest.mark.e2e_cpu @pytest.mark.e2e_cpu_postgres @pytest.mark.e2e_cpu_cross_version -def test_experiment_creation_and_listing(clean_auth: None, login_admin: None) -> None: +def test_experiment_creation_and_listing() -> None: # Create 2 users. - creds1 = api_utils.create_test_user(True) - - creds2 = api_utils.create_test_user(True) - - # Ensure determined creds are the default values. - change_user_password("determined", "") + sess1, _ = api_utils.create_test_user() + sess2, _ = api_utils.create_test_user() # Create an experiment as first user. - with logged_in_user(creds1): - experiment_id1 = exp.run_basic_test( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 - ) + experiment_id1 = exp.run_basic_test( + sess1, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 + ) # Create another experiment, this time as second user. - with logged_in_user(creds2): - experiment_id2 = exp.run_basic_test( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 - ) + experiment_id2 = exp.run_basic_test( + sess2, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 + ) - with logged_in_user(creds1): - # Now it should be the other way around. - output = extract_id_and_owner_from_exp_list(det_run(["e", "list"])) - assert (experiment_id1, creds1.username) in output - assert (experiment_id2, creds2.username) not in output + # user 1 can only see user 1 experiment + output = extract_id_and_owner_from_exp_list(detproc.check_output(sess1, ["det", "e", "list"])) + assert (experiment_id1, sess1.username) in output, output + assert (experiment_id2, sess2.username) not in output, output - # Now use the -a flag to list all experiments. The output should include both experiments. - output = extract_id_and_owner_from_exp_list(det_run(["e", "list", "-a"])) - assert (experiment_id1, creds1.username) in output - assert (experiment_id2, creds2.username) in output + # Now use the -a flag to list all experiments. The output should include both experiments. + output = extract_id_and_owner_from_exp_list( + detproc.check_output(sess1, ["det", "e", "list", "-a"]) + ) + assert (experiment_id1, sess1.username) in output, output + assert (experiment_id2, sess2.username) in output, output - with logged_in_user(conf.ADMIN_CREDENTIALS): - # Clean up. - delete_experiments(experiment_id1, experiment_id2) + # Clean up. + delete_experiments(api_utils.admin_session(), experiment_id1, experiment_id2) @pytest.mark.e2e_cpu -def test_login_wrong_password(clean_auth: None, login_admin: None) -> None: - creds = api_utils.create_test_user(True) - - passwd_prompt = f"Password for user '{creds.username}':" - child = det_spawn(["user", "login", creds.username]) - child.setecho(True) - child.expect(passwd_prompt, timeout=EXPECT_TIMEOUT) - child.sendline("this is the wrong password") - unauth_error = "Unauthenticated" - assert unauth_error in str(child.read()) - child.wait() - child.close() - - assert child.exitstatus != 0 +def test_login_wrong_password() -> None: + sess, password = api_utils.create_test_user() + with pytest.raises(errors.UnauthenticatedException): + api_utils.make_session(sess.username, "wrong" + password) @pytest.mark.e2e_cpu -def test_login_as_non_existent_user(clean_auth: None, login_admin: None) -> None: - username = "doesNotExist" - - passwd_prompt = f"Password for user '{username}':" - unauth_error = "Unauthenticated" - - child = det_spawn(["user", "login", username]) - child.setecho(True) - child.expect(passwd_prompt, timeout=EXPECT_TIMEOUT) - child.sendline("secret") - - assert unauth_error in str(child.read()) - child.wait() - child.close() - - assert child.exitstatus != 0 +def test_login_as_non_existent_user() -> None: + with pytest.raises(errors.UnauthenticatedException): + api_utils.make_session("nOtArEaLuSeR", "password") @pytest.mark.e2e_cpu -def test_auth_inside_shell(clean_auth: None, login_admin: None) -> None: - creds = api_utils.create_test_user(True) - - with logged_in_user(creds): - # start a shell - child = det_spawn(["shell", "start"]) - child.setecho(True) - # shells take time to start; use the default timeout which is longer - child.expect(r".*Permanently added.+([0-9a-f-]{36}).+known hosts\.") - - shell_id = child.match.group(1).decode("utf-8") - - def check_whoami(expected_username: str) -> None: - child.sendline("det user whoami") - child.expect("You are logged in as user \\'(.*)\\'", timeout=EXPECT_TIMEOUT) - username = child.match.group(1).decode("utf-8") - logger.debug(f"They are logged in as user {username}") - assert username == expected_username - - # check the current user - check_whoami(creds.username) - - # log in as admin - child.sendline(f"det user login {conf.ADMIN_CREDENTIALS.username}") - child.expect( - f"Password for user '{conf.ADMIN_CREDENTIALS.username}'", timeout=EXPECT_TIMEOUT - ) - child.sendline(conf.ADMIN_CREDENTIALS.password) +def test_login_as_non_active_user() -> None: + sess, password = api_utils.create_test_user() + admin = api_utils.admin_session() + d = client.Determined._from_session(admin) + userobj = d.get_user_by_name(sess.username) + userobj.deactivate() - # check that whoami responds with the new user - check_whoami(conf.ADMIN_CREDENTIALS.username) - - # log out - child.sendline("det user logout") - child.expect("#", timeout=EXPECT_TIMEOUT) - - # check that we are back to who we were - check_whoami(creds.username) - - child.sendline("exit") - - child = det_spawn(["shell", "kill", shell_id]) - child.read() - child.wait() - assert child.exitstatus == 0 + with pytest.raises(errors.ForbiddenException, match="user is not active"): + api_utils.make_session(sess.username, password) @pytest.mark.e2e_cpu -def test_login_as_non_active_user(clean_auth: None, login_admin: None) -> None: - creds = api_utils.create_test_user(True) - - passwd_prompt = f"Password for user '{creds.username}':" - unauth_error = "user is not active" - command = ["det", "-m", conf.make_master_url(), "user", "deactivate", creds.username] - subprocess.run(command, check=True) - - child = det_spawn(["user", "login", creds.username]) - child.setecho(True) - child.expect(passwd_prompt, timeout=EXPECT_TIMEOUT) - child.sendline(creds.password) - assert unauth_error in str(child.read()) - child.wait() - child.close() - - assert child.exitstatus != 0 - +def test_non_admin_user_link_with_agent_user() -> None: + sess1 = api_utils.user_session() + sess2, _ = api_utils.create_test_user() -@pytest.mark.e2e_cpu -def test_non_admin_user_link_with_agent_user(clean_auth: None, login_admin: None) -> None: - creds = api_utils.create_test_user(True) - unauth_error = r".*Forbidden.*" - - with logged_in_user(creds): - child = det_spawn( - [ - "user", - "link-with-agent-user", - creds.username, - "--agent-uid", - str(1), - "--agent-gid", - str(1), - "--agent-user", - creds.username, - "--agent-group", - creds.username, - ] - ) - child.expect(unauth_error, timeout=EXPECT_TIMEOUT) - child.read() - child.wait() - child.close() + cmd = [ + "det", + "user", + "link-with-agent-user", + sess2.username, + "--agent-uid", + "1", + "--agent-gid", + "1", + "--agent-user", + sess2.username, + "--agent-group", + sess2.username, + ] - assert child.exitstatus != 0 + detproc.check_error(sess1, cmd, "forbidden") @pytest.mark.e2e_cpu -def test_non_admin_commands(clean_auth: None, login_admin: None) -> None: - creds = api_utils.create_test_user() - api_utils.configure_token_store(creds) +def test_non_admin_commands() -> None: + sess = api_utils.user_session() command = [ "det", - "-m", - conf.make_master_url(), "slot", "list", "--json", ] - output = subprocess.check_output(command).decode() + slots = detproc.check_json(sess, command) - slots = json.loads(output) - assert len(slots) == 1 slot_id = slots[0]["slot_id"] agent_id = slots[0]["agent_id"] @@ -552,12 +297,7 @@ def test_non_admin_commands(clean_auth: None, login_admin: None) -> None: disable_agents = ["agent", "disable", agent_id] config = ["master", "config"] for cmd in [disable_slots, disable_agents, enable_slots, enable_agents, config]: - child = det_spawn(["-u", constants.DEFAULT_DETERMINED_USER] + cmd) - not_allowed = "Forbidden" - assert not_allowed in str(child.read()) - child.wait() - child.close() - assert child.exitstatus != 0 + detproc.check_error(sess, ["det", *cmd], "forbidden") def run_command(session: api.Session) -> str: @@ -566,29 +306,19 @@ def run_command(session: api.Session) -> str: return cmd.id -def start_notebook() -> str: - child = det_spawn(["notebook", "start", "-d"]) - notebook_id = cast(str, child.readline().decode().rstrip()) - child.read() - child.wait() - assert child.exitstatus == 0 +def start_notebook(sess: api.Session) -> str: + return detproc.check_output(sess, ["det", "notebook", "start", "-d"]).strip() - return notebook_id +def start_tensorboard(sess: api.Session, experiment_id: int) -> str: + cmd = ["det", "tensorboard", "start", "-d", str(experiment_id)] + return detproc.check_output(sess, cmd).strip() -def start_tensorboard(experiment_id: int) -> str: - child = det_spawn(["tensorboard", "start", "-d", str(experiment_id)]) - tensorboard_id = cast(str, child.readline().decode().rstrip()) - child.read() - child.wait() - assert child.exitstatus == 0 - return tensorboard_id - -def delete_experiments(*experiment_ids: int) -> None: +def delete_experiments(sess: api.Session, *experiment_ids: int) -> None: eids = set(experiment_ids) while eids: - output = extract_columns(det_run(["e", "list", "-a"]), [0, 4]) + output = extract_columns(detproc.check_output(sess, ["det", "e", "list", "-a"]), [0, 4]) running_ids = {int(o[0]) for o in output if o[1] == "COMPLETED"} intersection = eids & running_ids @@ -597,17 +327,16 @@ def delete_experiments(*experiment_ids: int) -> None: continue experiment_id = intersection.pop() - child = det_spawn(["e", "delete", "--yes", str(experiment_id)]) - child.read() - child.wait() - assert child.exitstatus == 0 + detproc.check_output(sess, ["det", "e", "delete", "--yes", str(experiment_id)]) eids.remove(experiment_id) -def kill_notebooks(*notebook_ids: str) -> None: +def kill_notebooks(sess: api.Session, *notebook_ids: str) -> None: nids = set(notebook_ids) while nids: - output = extract_columns(det_run(["notebook", "list", "-a"]), [0, 3]) # id, state + output = extract_columns( + detproc.check_output(sess, ["det", "notebook", "list", "-a"]), [0, 3] + ) # id, state # Get set of running IDs. running_ids = {task_id for task_id, state in output if state == "RUNNING"} @@ -618,17 +347,16 @@ def kill_notebooks(*notebook_ids: str) -> None: continue notebook_id = intersection.pop() - child = det_spawn(["notebook", "kill", notebook_id]) - child.read() - child.wait() - assert child.exitstatus == 0 + detproc.check_output(sess, ["det", "notebook", "kill", notebook_id]) nids.remove(notebook_id) -def kill_tensorboards(*tensorboard_ids: str) -> None: +def kill_tensorboards(sess: api.Session, *tensorboard_ids: str) -> None: tids = set(tensorboard_ids) while tids: - output = extract_columns(det_run(["tensorboard", "list", "-a"]), [0, 3]) + output = extract_columns( + detproc.check_output(sess, ["det", "tensorboard", "list", "-a"]), [0, 3] + ) running_ids = {task_id for task_id, state in output if state == "RUNNING"} @@ -638,173 +366,157 @@ def kill_tensorboards(*tensorboard_ids: str) -> None: continue tensorboard_id = intersection.pop() - child = det_spawn(["tensorboard", "kill", tensorboard_id]) - child.read() - child.wait() - assert child.exitstatus == 0 + detproc.check_output(sess, ["det", "tensorboard", "kill", tensorboard_id]) tids.remove(tensorboard_id) @pytest.mark.e2e_cpu -def test_notebook_creation_and_listing(clean_auth: None, login_admin: None) -> None: - creds1 = api_utils.create_test_user(True) - creds2 = api_utils.create_test_user(True) +def test_notebook_creation_and_listing() -> None: + sess1, _ = api_utils.create_test_user() + sess2, _ = api_utils.create_test_user() - with logged_in_user(creds1): - notebook_id1 = start_notebook() + notebook_id1 = start_notebook(sess1) - with logged_in_user(creds2): - notebook_id2 = start_notebook() + notebook_id2 = start_notebook(sess2) - # Listing should only give us user 2's experiment. - output = extract_columns(det_run(["notebook", "list"]), [0, 1]) + # Listing should only give us user 2's experiment. + output = extract_columns(detproc.check_output(sess2, ["det", "notebook", "list"]), [0, 1]) - with logged_in_user(creds1): - output = extract_columns(det_run(["notebook", "list"]), [0, 1]) - assert (notebook_id1, creds1.username) in output - assert (notebook_id2, creds2.username) not in output + output = extract_columns(detproc.check_output(sess1, ["det", "notebook", "list"]), [0, 1]) + assert (notebook_id1, sess1.username) in output + assert (notebook_id2, sess2.username) not in output - # Now test listing all. - output = extract_columns(det_run(["notebook", "list", "-a"]), [0, 1]) - assert (notebook_id1, creds1.username) in output - assert (notebook_id2, creds2.username) in output + # Now test listing all. + output = extract_columns(detproc.check_output(sess1, ["det", "notebook", "list", "-a"]), [0, 1]) + assert (notebook_id1, sess1.username) in output + assert (notebook_id2, sess2.username) in output # Clean up, killing experiments. - kill_notebooks(notebook_id1, notebook_id2) + kill_notebooks(api_utils.admin_session(), notebook_id1, notebook_id2) @pytest.mark.e2e_cpu -def test_tensorboard_creation_and_listing(clean_auth: None, login_admin: None) -> None: - creds1 = api_utils.create_test_user(True) - creds2 = api_utils.create_test_user(True) - - with logged_in_user(creds1): - # Create an experiment. - experiment_id1 = exp.run_basic_test( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 - ) +def test_tensorboard_creation_and_listing() -> None: + sess1, _ = api_utils.create_test_user() + sess2, _ = api_utils.create_test_user() - with logged_in_user(creds1): - tensorboard_id1 = start_tensorboard(experiment_id1) + # Create an experiment. + experiment_id1 = exp.run_basic_test( + sess1, + conf.fixtures_path("no_op/single-one-short-step.yaml"), + conf.fixtures_path("no_op"), + 1, + ) - with logged_in_user(creds2): - experiment_id2 = exp.run_basic_test( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 - ) + tensorboard_id1 = start_tensorboard(sess1, experiment_id1) - with logged_in_user(creds2): - tensorboard_id2 = start_tensorboard(experiment_id2) + experiment_id2 = exp.run_basic_test( + sess2, + conf.fixtures_path("no_op/single-one-short-step.yaml"), + conf.fixtures_path("no_op"), + 1, + ) - with logged_in_user(creds1): - output = extract_columns(det_run(["tensorboard", "list"]), [0, 1]) - assert (tensorboard_id1, creds1.username) in output - assert (tensorboard_id2, creds2.username) not in output + tensorboard_id2 = start_tensorboard(sess2, experiment_id2) - output = extract_columns(det_run(["tensorboard", "list", "-a"]), [0, 1]) - assert (tensorboard_id1, creds1.username) in output - assert (tensorboard_id2, creds2.username) in output + output = extract_columns(detproc.check_output(sess1, ["det", "tensorboard", "list"]), [0, 1]) + assert (tensorboard_id1, sess1.username) in output + assert (tensorboard_id2, sess2.username) not in output - kill_tensorboards(tensorboard_id1, tensorboard_id2) + output = extract_columns( + detproc.check_output(sess1, ["det", "tensorboard", "list", "-a"]), [0, 1] + ) + assert (tensorboard_id1, sess1.username) in output + assert (tensorboard_id2, sess2.username) in output - with logged_in_user(conf.ADMIN_CREDENTIALS): - delete_experiments(experiment_id1, experiment_id2) + admin = api_utils.admin_session() + kill_tensorboards(admin, tensorboard_id1, tensorboard_id2) + delete_experiments(admin, experiment_id1, experiment_id2) @pytest.mark.e2e_cpu -def test_command_creation_and_listing(clean_auth: None) -> None: - creds1 = api_utils.create_test_user(True) - creds2 = api_utils.create_test_user(True) - session1 = api_utils.determined_test_session(credentials=creds1) - session2 = api_utils.determined_test_session(credentials=creds2) - - command_id1 = run_command(session=session1) +def test_command_creation_and_listing() -> None: + sess1, _ = api_utils.create_test_user() + sess2, _ = api_utils.create_test_user() - command_id2 = run_command(session=session2) + command_id1 = run_command(session=sess1) + command_id2 = run_command(session=sess2) - cmds = bindings.get_GetCommands(session1, users=[creds1.username]).commands + cmds = bindings.get_GetCommands(sess1, users=[sess1.username]).commands output = [(cmd.id, cmd.username) for cmd in cmds] - assert (command_id1, creds1.username) in output - assert (command_id2, creds2.username) not in output + assert (command_id1, sess1.username) in output + assert (command_id2, sess2.username) not in output - cmds = bindings.get_GetCommands(session1).commands + cmds = bindings.get_GetCommands(sess1).commands output = [(cmd.id, cmd.username) for cmd in cmds] - assert (command_id1, creds1.username) in output - assert (command_id2, creds2.username) in output - - -def create_linked_user(uid: int, user: str, gid: int, group: str) -> authentication.Credentials: - user_creds = api_utils.create_test_user(False) - - child = det_spawn( - [ - "user", - "link-with-agent-user", - user_creds.username, - "--agent-uid", - str(uid), - "--agent-gid", - str(gid), - "--agent-user", - user, - "--agent-group", - group, - ] - ) - child.read() - child.wait() - child.close() - assert child.exitstatus == 0 + assert (command_id1, sess1.username) in output + assert (command_id2, sess2.username) in output + + +def create_linked_user(uid: int, user: str, gid: int, group: str) -> api.Session: + admin = api_utils.admin_session() + sess, _ = api_utils.create_test_user() + + cmd = [ + "det", + "user", + "link-with-agent-user", + sess.username, + "--agent-uid", + str(uid), + "--agent-gid", + str(gid), + "--agent-user", + user, + "--agent-group", + group, + ] - return user_creds + detproc.check_call(admin, cmd) + return sess -def create_linked_user_sdk( - uid: int, agent_user: str, gid: int, group: str -) -> authentication.Credentials: - creds = api_utils.create_test_user(False) - det_obj = Determined(master=conf.make_master_url()) - user = det_obj.get_user_by_name(user_name=creds.username) + +def create_linked_user_sdk(uid: int, agent_user: str, gid: int, group: str) -> api.Session: + sess, _ = api_utils.create_test_user() + det_obj = client.Determined._from_session(api_utils.admin_session()) + user = det_obj.get_user_by_name(user_name=sess.username) user.link_with_agent(agent_gid=gid, agent_uid=uid, agent_group=group, agent_user=agent_user) - return creds + return sess -def check_link_with_agent_output(user: authentication.Credentials, expected_output: str) -> None: - with logged_in_user(user), command.interactive_command( - "cmd", "run", "bash", "-c", "echo $(id -u -n):$(id -u):$(id -g -n):$(id -g)" - ) as cmd: - for line in cmd.stdout: - if expected_output in line: - break - else: - raise AssertionError(f"Did not find {expected_output} in output") +def check_link_with_agent_output(sess: api.Session, expected_output: str) -> None: + assert expected_output in detproc.check_output( + sess, + ["det", "cmd", "run", "bash", "-c", "echo $(id -u -n):$(id -u):$(id -g -n):$(id -g)"], + ) @pytest.mark.e2e_cpu -def test_link_with_agent_user(clean_auth: None, login_admin: None) -> None: - user = create_linked_user(200, "someuser", 300, "somegroup") +def test_link_with_agent_user() -> None: + sess = create_linked_user(200, "someuser", 300, "somegroup") expected_output = "someuser:200:somegroup:300" - check_link_with_agent_output(user, expected_output) + check_link_with_agent_output(sess, expected_output) - with logged_in_user(conf.ADMIN_CREDENTIALS): - user_sdk = create_linked_user_sdk(210, "anyuser", 310, "anygroup") - expected_output = "anyuser:210:anygroup:310" - check_link_with_agent_output(user_sdk, expected_output) + sess = create_linked_user_sdk(210, "anyuser", 310, "anygroup") + expected_output = "anyuser:210:anygroup:310" + check_link_with_agent_output(sess, expected_output) @pytest.mark.e2e_cpu -def test_link_with_large_uid(clean_auth: None, login_admin: None) -> None: - user = create_linked_user(2000000000, "someuser", 2000000000, "somegroup") +def test_link_with_large_uid() -> None: + sess = create_linked_user(2000000000, "someuser", 2000000000, "somegroup") expected_output = "someuser:2000000000:somegroup:2000000000" - check_link_with_agent_output(user, expected_output) + check_link_with_agent_output(sess, expected_output) @pytest.mark.e2e_cpu -def test_link_with_existing_agent_user(clean_auth: None, login_admin: None) -> None: - user = create_linked_user(65534, "nobody", 65534, "nogroup") +def test_link_with_existing_agent_user() -> None: + sess = create_linked_user(65533, "det-nobody", 65533, "det-nobody") - expected_output = "nobody:65534:nogroup:65534" - check_link_with_agent_output(user, expected_output) + expected_output = "det-nobody:65533:det-nobody:65533" + check_link_with_agent_output(sess, expected_output) @contextlib.contextmanager @@ -835,177 +547,144 @@ def non_tmp_shared_fs_path() -> Generator: @pytest.mark.e2e_cpu -def test_non_root_experiment(clean_auth: None, login_admin: None, tmp_path: pathlib.Path) -> None: - user = create_linked_user(65534, "nobody", 65534, "nogroup") - - with logged_in_user(user): - with open(conf.fixtures_path("no_op/model_def.py")) as f: - model_def_content = f.read() - - with open(conf.fixtures_path("no_op/single-one-short-step.yaml")) as f: - config = util.yaml_safe_load(f) - - # Use a user-owned path to ensure shared_fs uses the container_path and not host_path. - with non_tmp_shared_fs_path() as host_path: - config["checkpoint_storage"] = { - "type": "shared_fs", - "host_path": host_path, - } - - # Call `det --version` in a startup hook to ensure that det is on the PATH. - with FileTree( - tmp_path, - { - "startup-hook.sh": "det --version || exit 77", - "const.yaml": util.yaml_safe_dump(config), - "model_def.py": model_def_content, - }, - ) as tree: - exp.run_basic_test(str(tree.joinpath("const.yaml")), str(tree), None) +def test_non_root_experiment(tmp_path: pathlib.Path) -> None: + sess = create_linked_user(65533, "det-nobody", 65533, "det-nobody") + with open(conf.fixtures_path("no_op/model_def.py")) as f: + model_def_content = f.read() -@pytest.mark.e2e_cpu -def test_link_without_agent_user(clean_auth: None, login_admin: None) -> None: - user = api_utils.create_test_user(False) - - expected_output = "root:0:root:0" - with logged_in_user(user), command.interactive_command( - "cmd", "run", "bash", "-c", "echo $(id -u -n):$(id -u):$(id -g -n):$(id -g)" - ) as cmd: - recvd = [] - for line in cmd.stdout: - if expected_output in line: - break - recvd.append(line) - else: - output = "".join(recvd) - raise AssertionError(f"Did not find {expected_output} in output:\n{output}") + with open(conf.fixtures_path("no_op/single-one-short-step.yaml")) as f: + config = util.yaml_safe_load(f) + + # Use a user-owned path to ensure shared_fs uses the container_path and not host_path. + with non_tmp_shared_fs_path() as host_path: + config["checkpoint_storage"] = { + "type": "shared_fs", + "host_path": host_path, + } + + # Call `det --version` in a startup hook to ensure that det is on the PATH. + with filetree.FileTree( + tmp_path, + { + "startup-hook.sh": "det --version || exit 77", + "const.yaml": util.yaml_safe_dump(config), + "model_def.py": model_def_content, + }, + ) as tree: + exp.run_basic_test(sess, str(tree.joinpath("const.yaml")), str(tree), None) @pytest.mark.e2e_cpu -def test_non_root_shell(clean_auth: None, login_admin: None, tmp_path: pathlib.Path) -> None: - user = create_linked_user(1234, "someuser", 1234, "somegroup") +def test_link_without_agent_user() -> None: + sess, _ = api_utils.create_test_user() - expected_output = "someuser:1234:somegroup:1234" + check_link_with_agent_output(sess, "root:0:root:0") - with logged_in_user(user), command.interactive_command("shell", "start") as shell: - shell.stdin.write(b"echo $(id -u -n):$(id -u):$(id -g -n):$(id -g)\n") - shell.stdin.close() - for line in shell.stdout: - if expected_output in line: - break - else: - raise AssertionError(f"Did not find {expected_output} in output") +@pytest.mark.e2e_cpu +def test_non_root_shell(tmp_path: pathlib.Path) -> None: + # XXX: failing because prep_conatiner has login_with_cache(), which fails reading /.config + sess = create_linked_user(1234, "someuser", 1234, "somegroup") + exp = "someuser:1234:somegroup:1234" + cmd = "echo; echo $(id -u -n):$(id -u):$(id -g -n):$(id -g)" + with command.interactive_command(sess, ["shell", "start", "--detach"]) as shell: + assert shell.task_id + assert exp in detproc.check_output( + sess, ["det", "shell", "open", shell.task_id, "--", "bash", "-c", cmd] + ) @pytest.mark.e2e_cpu -def test_experiment_delete(clean_auth: None, login_admin: None) -> None: - user = api_utils.create_test_user() - non_owner_user = api_utils.create_test_user() +def test_experiment_delete() -> None: + sess = api_utils.user_session() + other, _ = api_utils.create_test_user() - with logged_in_user(user): - experiment_id = exp.run_basic_test( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 - ) + experiment_id = exp.run_basic_test( + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 + ) - with logged_in_user(non_owner_user): - # "det experiment delete" call should fail, because the user is not an admin and - # doesn't own the experiment. - child = det_spawn(["experiment", "delete", str(experiment_id), "--yes"]) - child.read() - child.wait() - assert child.exitstatus > 0 - - with logged_in_user(user): - child = det_spawn(["experiment", "delete", str(experiment_id), "--yes"]) - child.read() - child.wait() - assert child.exitstatus == 0 - - experiment_delete_deadline = time.time() + 5 * 60 - while 1: - child = det_spawn(["experiment", "describe", str(experiment_id)]) - child.read() - child.wait() - # "det experiment describe" call should fail, because the - # experiment is no longer in the database. - if child.exitstatus > 0: - return - elif time.time() > experiment_delete_deadline: - pytest.fail("experiment didn't delete after timeout") - - -def _fetch_user_by_username(sess: api.Session, username: str) -> bindings.v1User: - # Get API bindings object for the created test user - all_users = bindings.get_GetUsers(sess).users - assert all_users is not None - return next(u for u in all_users if u.username == username) + # "det experiment delete" call should fail, because the other user is not an admin and + # doesn't own the experiment. + cmd = ["det", "experiment", "delete", str(experiment_id), "--yes"] + detproc.check_error(other, cmd, "forbidden") + + # but the owner can delete it + detproc.check_output(sess, cmd) + + experiment_delete_deadline = time.time() + 5 * 60 + while True: + # "det experiment describe" call should fail, because the + # experiment is no longer in the database. + p = detproc.run( + sess, + ["det", "experiment", "describe", str(experiment_id)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if p.returncode != 0: + assert p.stderr and b"not found" in p.stderr, p.stderr + return + elif time.time() > experiment_delete_deadline: + pytest.fail("experiment didn't delete after timeout") @pytest.mark.e2e_cpu @pytest.mark.e2e_cpu_postgres -def test_change_displayname(clean_auth: None, login_admin: None) -> None: - u_patch = api_utils.create_test_user(False) - original_name = u_patch.username - - master_url = conf.make_master_url() - certs.cli_cert = certs.default_load(master_url) - authentication.cli_auth = authentication.Authentication( - conf.make_master_url(), requested_user=original_name, password="" - ) - sess = api.Session(master_url, original_name, authentication.cli_auth, certs.cli_cert) +def test_change_displayname() -> None: + sess, _ = api_utils.create_test_user() + original_name = sess.username - current_user = _fetch_user_by_username(sess, original_name) - assert current_user is not None and current_user.id + det_obj = client.Determined._from_session(api_utils.admin_session()) + current_user = det_obj.get_user_by_name(original_name) + assert current_user is not None and current_user.user_id # Rename user using display name patch_user = bindings.v1PatchUser(displayName="renamed display-name") - bindings.patch_PatchUser(sess, body=patch_user, userId=current_user.id) + bindings.patch_PatchUser(sess, body=patch_user, userId=current_user.user_id) - modded_user = bindings.get_GetUser(sess, userId=current_user.id).user + modded_user = bindings.get_GetUser(sess, userId=current_user.user_id).user assert modded_user is not None assert modded_user.displayName == "renamed display-name" # Rename user display name using SDK - det_obj = Determined(master=conf.make_master_url()) - user = det_obj.get_user_by_id(user_id=current_user.id) + user = det_obj.get_user_by_id(user_id=current_user.user_id) user.change_display_name(display_name="renamedSDK") - modded_user_sdk = det_obj.get_user_by_id(user_id=current_user.id) + modded_user_sdk = det_obj.get_user_by_id(user_id=current_user.user_id) assert modded_user_sdk is not None assert modded_user_sdk.display_name == "renamedSDK" # Avoid display name of 'admin' patch_user.displayName = "Admin" with pytest.raises(errors.APIException): - bindings.patch_PatchUser(sess, body=patch_user, userId=current_user.id) + bindings.patch_PatchUser(sess, body=patch_user, userId=current_user.user_id) # Clear display name (UI will show username) patch_user.displayName = "" - bindings.patch_PatchUser(sess, body=patch_user, userId=current_user.id) + bindings.patch_PatchUser(sess, body=patch_user, userId=current_user.user_id) - modded_user = bindings.get_GetUser(sess, userId=current_user.id).user + modded_user = bindings.get_GetUser(sess, userId=current_user.user_id).user assert modded_user is not None assert modded_user.displayName == "" @pytest.mark.e2e_cpu -def test_patch_agentusergroup(clean_auth: None, login_admin: None) -> None: - test_user_credentials = api_utils.create_test_user(False) - test_username = test_user_credentials.username +def test_patch_agentusergroup() -> None: + sess, _ = api_utils.create_test_user() # Patch - normal. - sess = api_utils.determined_test_session(admin=True) + admin = api_utils.admin_session() + det_obj = client.Determined._from_session(admin) patch_user = bindings.v1PatchUser( agentUserGroup=bindings.v1AgentUserGroup( agentGid=1000, agentUid=1000, agentUser="username", agentGroup="groupname" ) ) - test_user = _fetch_user_by_username(sess, test_username) - assert test_user.id - bindings.patch_PatchUser(sess, body=patch_user, userId=test_user.id) - patched_user = bindings.get_GetUser(sess, userId=test_user.id).user + test_user = det_obj.get_user_by_name(sess.username) + assert test_user.user_id + bindings.patch_PatchUser(admin, body=patch_user, userId=test_user.user_id) + patched_user = bindings.get_GetUser(admin, userId=test_user.user_id).user assert patched_user is not None and patched_user.agentUserGroup is not None assert patched_user.agentUserGroup.agentUser == "username" assert patched_user.agentUserGroup.agentGroup == "groupname" @@ -1014,71 +693,27 @@ def test_patch_agentusergroup(clean_auth: None, login_admin: None) -> None: patch_user = bindings.v1PatchUser( agentUserGroup=bindings.v1AgentUserGroup(agentGid=1000, agentUid=1000) ) - test_user = _fetch_user_by_username(sess, test_username) - assert test_user.id + test_user = det_obj.get_user_by_name(sess.username) + assert test_user.user_id with pytest.raises(errors.APIException): - bindings.patch_PatchUser(sess, body=patch_user, userId=test_user.id) - - -@pytest.mark.e2e_cpu -def test_logout_all(clean_auth: None, login_admin: None) -> None: - creds = api_utils.create_test_user(True) - - # Set Determined password to something in order to disable auto-login. - password = get_random_string() - assert change_user_password(constants.DEFAULT_DETERMINED_USER, password) == 0 - - # Log in as determined. - api_utils.configure_token_store( - authentication.Credentials(constants.DEFAULT_DETERMINED_USER, password) - ) - # login test user. - api_utils.configure_token_store(creds) - child = det_spawn(["user", "logout", "--all"]) - child.wait() - child.close() - assert child.status == 0 - # Trying to list experiments should result in an error. - child = det_spawn(["e", "list"]) - expected = "Unauthenticated" - assert expected in str(child.read()) - child.wait() - child.close() - assert child.status != 0 - - # Log in as determined. - api_utils.configure_token_store( - authentication.Credentials(constants.DEFAULT_DETERMINED_USER, password) - ) - # Change Determined password back to "". - change_user_password(constants.DEFAULT_DETERMINED_USER, "") + bindings.patch_PatchUser(admin, body=patch_user, userId=test_user.user_id) @pytest.mark.e2e_cpu -def test_user_edit(clean_auth: None, login_admin: None) -> None: - u_patch = api_utils.create_test_user(False) - original_name = u_patch.username - - master_url = conf.make_master_url() - certs.cli_cert = certs.default_load(master_url) - authentication.cli_auth = authentication.Authentication( - master_url, requested_user=original_name, password="" - ) - sess = api.Session(master_url, original_name, authentication.cli_auth, certs.cli_cert) +def test_user_edit() -> None: + admin = api_utils.admin_session() + sess, _ = api_utils.create_test_user() + original_name = sess.username - current_user = _fetch_user_by_username(sess, original_name) + det_obj = client.Determined._from_session(admin) + current_user = det_obj.get_user_by_name(original_name) - # Log out. - log_out_user() + new_display_name = api_utils.get_random_string() + new_username = api_utils.get_random_string() - # login admin again. - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - - new_display_name = get_random_string() - new_username = get_random_string() - - assert current_user is not None and current_user.id + assert current_user is not None and current_user.user_id command = [ + "det", "user", "edit", original_name, @@ -1090,13 +725,9 @@ def test_user_edit(clean_auth: None, login_admin: None) -> None: "--remote=false", "--admin=true", ] + detproc.check_output(admin, command) - child = det_spawn(command) - child.wait() - child.close() - assert child.status == 0 - - modded_user = bindings.get_GetUser(sess, userId=current_user.id).user + modded_user = bindings.get_GetUser(admin, userId=current_user.user_id).user assert modded_user is not None assert modded_user.displayName == new_display_name assert modded_user.username == new_username @@ -1106,73 +737,19 @@ def test_user_edit(clean_auth: None, login_admin: None) -> None: @pytest.mark.e2e_cpu -def test_user_edit_no_fields(clean_auth: None, login_admin: None) -> None: - u_patch = api_utils.create_test_user(False) - original_name = u_patch.username - - master_url = conf.make_master_url() - certs.cli_cert = certs.default_load(master_url) - authentication.cli_auth = authentication.Authentication( - master_url, requested_user=original_name, password="" - ) - sess = api.Session(master_url, original_name, authentication.cli_auth, certs.cli_cert) - - current_user = _fetch_user_by_username(sess, original_name) - - # Log out. - log_out_user() - - # login admin again. - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - - assert current_user is not None and current_user.id - command = [ - "user", - "edit", - original_name, - ] - - # No edited field should result in error - child = det_spawn(command) - assert "No field provided" in str(child.read()) - child.wait() - child.close() - assert child.status != 0 - - -@pytest.mark.e2e_cpu -def test_user_list(clean_auth: None, login_admin: None) -> None: - u_patch = api_utils.create_test_user(False) - command = [ - "user", - "ls", - ] +def test_user_list() -> None: + admin = api_utils.admin_session() + sess, _ = api_utils.create_test_user() + output = detproc.check_output(admin, ["det", "user", "ls"]) + assert sess.username in output - child = det_spawn(command) - assert u_patch.username in str(child.read()) # Deactivate user - activate_deactivate_user(active=False, target_user=u_patch.username) - command = [ - "user", - "ls", - ] + activate_deactivate_user(admin, active=False, target_user=sess.username) # User should no longer appear in list - child = det_spawn(command) - assert u_patch.username not in str(child.read()) - - -@pytest.mark.e2e_cpu -def test_user_list_with_inactive(clean_auth: None, login_admin: None) -> None: - u_patch = api_utils.create_test_user(False) - command = ["user", "ls", "--all"] - - child = det_spawn(command) - assert u_patch.username in str(child.read()) - # Deactivate user - activate_deactivate_user(active=False, target_user=u_patch.username) - command = ["user", "ls", "--all"] + output = detproc.check_output(admin, ["det", "user", "ls"]) + assert sess.username not in output - # User should still appear in list - child = det_spawn(command) - assert u_patch.username in str(child.read()) + # User still appears with --all + output = detproc.check_output(admin, ["det", "user", "ls", "--all"]) + assert sess.username in output diff --git a/e2e_tests/tests/cluster/test_users_experiment_api.py b/e2e_tests/tests/cluster/test_users_experiment_api.py deleted file mode 100644 index 043508c834b..00000000000 --- a/e2e_tests/tests/cluster/test_users_experiment_api.py +++ /dev/null @@ -1,38 +0,0 @@ -import pytest - -from determined.common.api import bindings, errors -from determined.experimental import client -from tests import api_utils -from tests import config as conf -from tests import experiment as exp -from tests.cluster import test_users - - -@pytest.mark.e2e_cpu -def test_experiment_api_determined_disabled() -> None: - api_utils.configure_token_store(conf.ADMIN_CREDENTIALS) - - determined_master = conf.make_master_url() - user_creds = api_utils.create_test_user(add_password=True) - - child = test_users.det_spawn(["user", "deactivate", "determined"]) - child.wait() - child.close() - assert child.exitstatus == 0 - try: - d = client.Determined(determined_master, user_creds.username, user_creds.password) - e = d.create_experiment( - config=conf.fixtures_path("no_op/single-medium-train-step.yaml"), - model_dir=conf.fixtures_path("no_op"), - ) - - # Determined shouldn't be able to view the experiment since it is deactivated. - with pytest.raises(errors.ForbiddenException): - exp.wait_for_experiment_state(e.id, bindings.experimentv1State.COMPLETED) - - assert e.wait() == client.ExperimentState.COMPLETED - finally: - child = test_users.det_spawn(["user", "activate", "determined"]) - child.wait() - child.close() - assert child.exitstatus == 0 diff --git a/e2e_tests/tests/cluster/test_webhooks.py b/e2e_tests/tests/cluster/test_webhooks.py index 0da9e5ce202..d68e2266ae4 100644 --- a/e2e_tests/tests/cluster/test_webhooks.py +++ b/e2e_tests/tests/cluster/test_webhooks.py @@ -15,7 +15,7 @@ def test_slack_webhook() -> None: port = 5005 server = utils.WebhookServer(port, allow_dupes=True) - sess = api_utils.determined_test_session(admin=True) + sess = api_utils.admin_session() webhook_trigger = bindings.v1Trigger( triggerType=bindings.v1TriggerType.EXPERIMENT_STATE_CHANGE, @@ -32,15 +32,16 @@ def test_slack_webhook() -> None: assert result.webhook.url == webhook_request.url experiment_id = exp.create_experiment( - conf.fixtures_path("no_op/single-one-short-step.yaml"), conf.fixtures_path("no_op") + sess, conf.fixtures_path("no_op/single-one-short-step.yaml"), conf.fixtures_path("no_op") ) exp.wait_for_experiment_state( + sess, experiment_id, bindings.experimentv1State.COMPLETED, max_wait_secs=conf.DEFAULT_MAX_WAIT_SECS, ) - exp_config = exp.experiment_config_json(experiment_id) + exp_config = exp.experiment_config_json(sess, experiment_id) expected_field = {"type": "mrkdwn", "text": "*Status*: Completed"} expected_payload = { "blocks": [ @@ -84,7 +85,7 @@ def test_slack_webhook() -> None: def test_log_pattern_send_webhook(should_match: bool) -> None: port = 5006 server = utils.WebhookServer(port) - sess = api_utils.determined_test_session(admin=True) + sess = api_utils.admin_session() regex = r"assert 0 <= self\.metrics_sigma" if not should_match: @@ -116,11 +117,12 @@ def test_log_pattern_send_webhook(should_match: bool) -> None: ) exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), ["--config", "hyperparameters.metrics_sigma=-1.0"], ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.ERROR) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR) for _ in range(10): responses = server.return_responses() diff --git a/e2e_tests/tests/cluster/test_workspace_org.py b/e2e_tests/tests/cluster/test_workspace_org.py index f645a56f9bd..67e068f79e6 100644 --- a/e2e_tests/tests/cluster/test_workspace_org.py +++ b/e2e_tests/tests/cluster/test_workspace_org.py @@ -1,36 +1,25 @@ import contextlib +import http import os import tempfile import uuid -from http import HTTPStatus from typing import Generator, List, Optional import pytest from determined.common import api -from determined.common.api import authentication, bindings, errors -from determined.common.api._util import NTSC_Kind, wait_for_ntsc_state -from determined.common.api.errors import APIException +from determined.common.api import bindings, errors from tests import api_utils from tests import config as conf -from tests.cluster.test_users import change_user_password, logged_in_user -from tests.experiment import run_basic_test, wait_for_experiment_state - -from .test_agent_user_group import _delete_workspace_and_check -from .test_groups import det_cmd, det_cmd_json +from tests import detproc +from tests import experiment as exp +from tests.cluster import test_agent_user_group @pytest.mark.e2e_cpu def test_workspace_org() -> None: - with logged_in_user(conf.ADMIN_CREDENTIALS): - change_user_password("determined", "") - master_url = conf.make_master_url() - authentication.cli_auth = authentication.Authentication(master_url) - sess = api.Session(master_url, None, None, None) - admin_auth = authentication.Authentication( - master_url, conf.ADMIN_CREDENTIALS.username, conf.ADMIN_CREDENTIALS.password - ) - admin_sess = api.Session(master_url, conf.ADMIN_CREDENTIALS.username, admin_auth, None) + sess = api_utils.user_session() + admin_sess = api_utils.admin_session() test_experiments: List[bindings.v1Experiment] = [] test_projects: List[bindings.v1Project] = [] @@ -321,12 +310,12 @@ def test_workspace_org() -> None: assert len(returned_notes) == 1 # Create an experiment in the default project. - test_exp_id = run_basic_test( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 + test_exp_id = exp.run_basic_test( + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 ) test_exp = bindings.get_GetExperiment(sess, experimentId=test_exp_id).experiment test_experiments.append(test_exp) - wait_for_experiment_state(test_exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, test_exp_id, bindings.experimentv1State.COMPLETED) assert test_exp.projectId == default_project.id # Move the test experiment into a user-made project @@ -370,7 +359,7 @@ def test_workspace_org() -> None: duplicate_workspace = r7.workspace assert duplicate_workspace is not None test_workspaces.append(duplicate_workspace) - with pytest.raises(APIException) as e: + with pytest.raises(errors.APIException) as e: r8 = bindings.post_PostWorkspace( sess, body=bindings.v1PostWorkspaceRequest(name="_TestDuplicate") ) @@ -378,7 +367,7 @@ def test_workspace_org() -> None: assert failed_duplicate_workspace is None if failed_duplicate_workspace is not None: test_workspaces.append(failed_duplicate_workspace) - assert e.value.status_code == HTTPStatus.CONFLICT + assert e.value.status_code == http.HTTPStatus.CONFLICT # Refuse to change a workspace name to an existing name r9 = bindings.post_PostWorkspace( @@ -389,9 +378,9 @@ def test_workspace_org() -> None: test_workspaces.append(duplicate_patch_workspace) w_patch = bindings.v1PatchWorkspace.from_json(made_workspace.to_json()) w_patch.name = "_TestDuplicate" - with pytest.raises(APIException) as e: + with pytest.raises(errors.APIException) as e: bindings.patch_PatchWorkspace(sess, body=w_patch, id=made_workspace.id) - assert e.value.status_code == HTTPStatus.CONFLICT + assert e.value.status_code == http.HTTPStatus.CONFLICT finally: # Clean out workspaces and all dependencies. @@ -402,7 +391,7 @@ def test_workspace_org() -> None: @pytest.mark.e2e_cpu @pytest.mark.parametrize("file_type", ["json", "yaml"]) def test_workspace_checkpoint_storage_file(file_type: str) -> None: - sess = api_utils.determined_test_session(admin=True) + sess = api_utils.admin_session() w_name = uuid.uuid4().hex[:8] with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "config") @@ -416,23 +405,23 @@ def test_workspace_checkpoint_storage_file(file_type: str) -> None: host_path: /tmp/yaml""" ) - det_cmd( - ["workspace", "create", w_name, "--checkpoint-storage-config-file", path], check=True + detproc.check_call( + sess, ["det", "workspace", "create", w_name, "--checkpoint-storage-config-file", path] ) try: - w_id = det_cmd_json(["workspace", "describe", w_name, "--json"])["id"] + w_id = detproc.check_json(sess, ["det", "workspace", "describe", w_name, "--json"])["id"] w = bindings.get_GetWorkspace(sess, id=w_id).workspace assert w.checkpointStorageConfig is not None assert w.checkpointStorageConfig["type"] == "shared_fs" assert w.checkpointStorageConfig["host_path"] == "/tmp/" + file_type finally: - _delete_workspace_and_check(sess, w) + test_agent_user_group._delete_workspace_and_check(sess, w) @pytest.mark.e2e_cpu def test_reset_workspace_checkpoint_storage_conf() -> None: - sess = api_utils.determined_test_session(admin=True) + sess = api_utils.admin_session() # Make project with checkpoint storage config. resp_w = bindings.post_PostWorkspace( @@ -458,14 +447,14 @@ def test_reset_workspace_checkpoint_storage_conf() -> None: ) assert resp_patch.workspace.checkpointStorageConfig is None finally: - _delete_workspace_and_check(sess, resp_w.workspace) + test_agent_user_group._delete_workspace_and_check(sess, resp_w.workspace) @contextlib.contextmanager def setup_workspaces( session: Optional[api.Session] = None, count: int = 1 ) -> Generator[List[bindings.v1Workspace], None, None]: - session = session or api_utils.determined_test_session(admin=True) + session = session or api_utils.admin_session() assert session workspaces: List[bindings.v1Workspace] = [] try: @@ -498,7 +487,7 @@ def setup_workspaces( # tag: no-cli @pytest.mark.e2e_cpu def test_workspace_delete_notebook() -> None: - admin_session = api_utils.determined_test_session(admin=True) + admin_session = api_utils.admin_session() # create a workspace using bindings @@ -543,9 +532,9 @@ def test_workspace_delete_notebook() -> None: assert outside_notebook.state not in TERMINATING_STATES # check that notebook is terminated or terminating. - wait_for_ntsc_state( + api.wait_for_ntsc_state( admin_session, - NTSC_Kind.notebook, + api.NTSC_Kind.notebook, ntsc_id=created_resp.notebook.id, predicate=lambda state: state in TERMINATING_STATES, ) @@ -567,7 +556,7 @@ def test_workspace_delete_notebook() -> None: # tag: no_cli @pytest.mark.e2e_cpu def test_launch_in_archived() -> None: - admin_session = api_utils.determined_test_session(admin=True) + admin_session = api_utils.admin_session() with setup_workspaces(admin_session) as [workspace]: # archive the workspace @@ -588,7 +577,7 @@ def test_launch_in_archived() -> None: # tag: no_cli @pytest.mark.e2e_cpu def test_workspaceid_set() -> None: - admin_session = api_utils.determined_test_session(admin=True) + admin_session = api_utils.admin_session() with setup_workspaces(admin_session) as [workspace]: # create a command inside the workspace diff --git a/e2e_tests/tests/cluster/utils.py b/e2e_tests/tests/cluster/utils.py index a3809c1cae8..b94b769cd00 100644 --- a/e2e_tests/tests/cluster/utils.py +++ b/e2e_tests/tests/cluster/utils.py @@ -1,27 +1,26 @@ import copy -import json +import datetime +import http.server import subprocess import threading import time -from datetime import datetime, timezone -from http.server import HTTPServer, SimpleHTTPRequestHandler from typing import Any, Dict, Tuple, Type import pytest import requests -from typing_extensions import Literal +from typing_extensions import Literal # noqa:I2041 from determined.common import api -from determined.common.api import authentication, certs +from tests import command from tests import config as conf -from tests.command import print_command_logs +from tests import detproc -class _HTTPServerWithRequest(HTTPServer): +class _HTTPServerWithRequest(http.server.HTTPServer): def __init__( self, server_address: Tuple[str, int], - RequestHandlerClass: Type[SimpleHTTPRequestHandler], + RequestHandlerClass: Type[http.server.SimpleHTTPRequestHandler], allow_dupes: bool, ): super().__init__(server_address, RequestHandlerClass) @@ -30,7 +29,7 @@ def __init__( self.allow_dupes = allow_dupes -class _WebhookRequestHandler(SimpleHTTPRequestHandler): +class _WebhookRequestHandler(http.server.SimpleHTTPRequestHandler): def do_POST(self) -> None: assert isinstance(self.server, _HTTPServerWithRequest) with self.server.url_to_request_body_lock: @@ -66,15 +65,12 @@ def close_and_return_responses(self) -> Dict[str, str]: return self.server.url_to_request_body -def cluster_slots() -> Dict[str, Any]: +def cluster_slots(sess: api.Session) -> Dict[str, Any]: """ cluster_slots returns a dict of slots that each agent has. :return: Dict[AgentID, List[Slot]] """ - # TODO: refactor tests to not use cli singleton auth. - certs.cli_cert = certs.default_load(conf.make_master_url()) - authentication.cli_auth = authentication.Authentication(conf.make_master_url()) - r = api.get(conf.make_master_url(), "api/v1/agents") + r = sess.get("api/v1/agents") assert r.status_code == requests.codes.ok, r.text jvals = r.json() # type: Dict[str, Any] return {agent["id"]: agent["slots"].values() for agent in jvals["agents"]} @@ -90,23 +86,23 @@ def get_master_port(loaded_config: dict) -> str: return "8080" # default value if not explicit in config file -def num_slots() -> int: - return sum(len(agent_slots) for agent_slots in cluster_slots().values()) +def num_slots(sess: api.Session) -> int: + return sum(len(agent_slots) for agent_slots in cluster_slots(sess).values()) -def num_free_slots() -> int: +def num_free_slots(sess: api.Session) -> int: return sum( 0 if slot["container"] else 1 - for agent_slots in cluster_slots().values() + for agent_slots in cluster_slots(sess).values() for slot in agent_slots ) -def run_command_set_priority(sleep: int = 30, slots: int = 1, priority: int = 0) -> str: - command = [ +def run_command_set_priority( + sess: api.Session, sleep: int = 30, slots: int = 1, priority: int = 0 +) -> str: + cmd = [ "det", - "-m", - conf.make_master_url(), "command", "run", "-d", @@ -117,14 +113,12 @@ def run_command_set_priority(sleep: int = 30, slots: int = 1, priority: int = 0) "sleep", str(sleep), ] - return subprocess.check_output(command).decode().strip() + return detproc.check_output(sess, cmd).strip() -def run_command(sleep: int = 30, slots: int = 1) -> str: - command = [ +def run_command(sess: api.Session, sleep: int = 30, slots: int = 1) -> str: + cmd = [ "det", - "-m", - conf.make_master_url(), "command", "run", "-d", @@ -133,37 +127,39 @@ def run_command(sleep: int = 30, slots: int = 1) -> str: "sleep", str(sleep), ] - return subprocess.check_output(command).decode().strip() + return detproc.check_output(sess, cmd).strip() -def run_zero_slot_command(sleep: int = 30) -> str: - return run_command(sleep=sleep, slots=0) +def run_zero_slot_command(sess: api.Session, sleep: int = 30) -> str: + return run_command(sess, sleep=sleep, slots=0) TaskType = Literal["command", "notebook", "tensorboard", "shell"] -def get_task_info(task_type: TaskType, task_id: str) -> Dict[str, Any]: - task = ["det", "-m", conf.make_master_url(), task_type, "list", "--json"] - task_data = json.loads(subprocess.check_output(task).decode()) +def get_task_info(sess: api.Session, task_type: TaskType, task_id: str) -> Dict[str, Any]: + cmd = ["det", task_type, "list", "--json"] + task_data = detproc.check_json(sess, cmd) return next((d for d in task_data if d["id"] == task_id), {}) -def get_command_info(command_id: str) -> Dict[str, Any]: - return get_task_info("command", command_id) +def get_command_info(sess: api.Session, command_id: str) -> Dict[str, Any]: + return get_task_info(sess, "command", command_id) # assert_command_succeded checks if a command succeeded or not. It prints the command logs if the # command failed. -def assert_command_succeeded(command_id: str) -> None: - command_info = get_command_info(command_id) +def assert_command_succeeded(sess: api.Session, command_id: str) -> None: + command_info = get_command_info(sess, command_id) succeeded = "success" in command_info["exitStatus"] - assert succeeded, print_command_logs(command_id) + assert succeeded, command.print_command_logs(sess, command_id) -def wait_for_task_state(task_type: TaskType, task_id: str, state: str, ticks: int = 60) -> None: +def wait_for_task_state( + sess: api.Session, task_type: TaskType, task_id: str, state: str, ticks: int = 60 +) -> None: for _ in range(ticks): - info = get_task_info(task_type, task_id) + info = get_task_info(sess, task_type, task_id) gotten_state = info.get("state") if gotten_state == state: return @@ -173,12 +169,12 @@ def wait_for_task_state(task_type: TaskType, task_id: str, state: str, ticks: in pytest.fail(f"{task_type} expected {state} state got {gotten_state} instead after {ticks} secs") -def wait_for_command_state(command_id: str, state: str, ticks: int = 60) -> None: - return wait_for_task_state("command", command_id, state, ticks) +def wait_for_command_state(sess: api.Session, command_id: str, state: str, ticks: int = 60) -> None: + return wait_for_task_state(sess, "command", command_id, state, ticks) def now_ts() -> str: - return datetime.now(timezone.utc).astimezone().isoformat() + return datetime.datetime.now(datetime.timezone.utc).astimezone().isoformat() def set_master_port(config: str) -> None: diff --git a/e2e_tests/tests/cluster_log_manager.py b/e2e_tests/tests/cluster_log_manager.py index 477017895f9..ede408bde86 100644 --- a/e2e_tests/tests/cluster_log_manager.py +++ b/e2e_tests/tests/cluster_log_manager.py @@ -1,5 +1,5 @@ import multiprocessing -from types import TracebackType +from types import TracebackType # noqa:I2041 from typing import Any, Callable, Optional diff --git a/e2e_tests/tests/command/command.py b/e2e_tests/tests/command/command.py index 780327b2a2f..23d3278abcc 100644 --- a/e2e_tests/tests/command/command.py +++ b/e2e_tests/tests/command/command.py @@ -1,15 +1,14 @@ +import contextlib import os import re import subprocess -from contextlib import contextmanager -from typing import IO, Any, Iterator, Optional +from typing import IO, Any, Iterator, List, Optional import requests from determined.common import api -from determined.common.api import authentication, certs, task_logs -from tests import api_utils from tests import config as conf +from tests import detproc class _InteractiveCommandProcess: @@ -51,47 +50,42 @@ def stdin(self) -> IO: return self.process.stdin -@contextmanager -def interactive_command(*args: str) -> Iterator[_InteractiveCommandProcess]: +@contextlib.contextmanager +def interactive_command(sess: api.Session, args: List[str]) -> Iterator[_InteractiveCommandProcess]: """ Runs a Determined CLI command in a subprocess. On exit, it kills the corresponding Determined task if possible before closing the subprocess. Example usage: - with util.interactive_command("notebook", "start") as notebook: + with util.interactive_command(sess, ["notebook", "start"]) as notebook: for line in notebook.stdout: if "Jupyter Notebook is running" in line: break """ - with subprocess.Popen( - ("det", "-m", conf.make_master_url()) + args, + with detproc.Popen( + sess, + ["det"] + args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, - env={"PYTHONUNBUFFERED": "1", **os.environ}, ) as p: cmd = _InteractiveCommandProcess(p, detach="--detach" in args) if cmd.task_id is None: raise AssertionError( - "Task ID for '{}' could not be found. " + f"Task ID for '{args}' could not be found. " "If it is still active, this command may persist " - "in the Determined test deployment...".format(args) + "in the Determined test deployment..." ) try: yield cmd finally: - subprocess.check_call( - ["det", "-m", conf.make_master_url(), str(args[0]), "kill", cmd.task_id] - ) + detproc.check_call(sess, ["det", str(args[0]), "kill", cmd.task_id]) p.kill() -def get_num_active_commands() -> int: - # TODO: refactor tests to not use cli singleton auth. - certs.cli_cert = certs.default_load(conf.make_master_url()) - authentication.cli_auth = authentication.Authentication(conf.make_master_url()) - r = api.get(conf.make_master_url(), "api/v1/commands") +def get_num_active_commands(sess: api.Session) -> int: + r = sess.get("api/v1/commands") assert r.status_code == requests.codes.ok, r.text return len( @@ -107,33 +101,21 @@ def get_num_active_commands() -> int: ) -def get_command(command_id: str) -> Any: - certs.cli_cert = certs.default_load(conf.make_master_url()) - authentication.cli_auth = authentication.Authentication(conf.make_master_url()) - r = api.get(conf.make_master_url(), "api/v1/commands/" + command_id) +def get_command(sess: api.Session, command_id: str) -> Any: + r = sess.get("api/v1/commands/" + command_id) assert r.status_code == requests.codes.ok, r.text return r.json()["command"] -def get_command_config(command_type: str, task_id: str) -> str: +def get_command_config(sess: api.Session, command_type: str, task_id: str) -> str: assert command_type in ["command", "notebook", "shell"] command = ["det", "-m", conf.make_master_url(), command_type, "config", task_id] env = os.environ.copy() env["DET_DEBUG"] = "true" - completed_process = subprocess.run( - command, - universal_newlines=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - ) - assert completed_process.returncode == 0, "\nstdout:\n{} \nstderr:\n{}".format( - completed_process.stdout, completed_process.stderr - ) - return str(completed_process.stdout) + return detproc.check_output(sess, command, env) -def print_command_logs(task_id: str) -> bool: - for tl in task_logs(api_utils.determined_test_session(), task_id): +def print_command_logs(sess: api.Session, task_id: str) -> bool: + for tl in api.task_logs(sess, task_id): print(tl.message) return True diff --git a/e2e_tests/tests/command/test_notebook.py b/e2e_tests/tests/command/test_notebook.py index a2fb8c51b2b..f5f533385ef 100644 --- a/e2e_tests/tests/command/test_notebook.py +++ b/e2e_tests/tests/command/test_notebook.py @@ -4,10 +4,9 @@ import pytest from determined.common import api -from determined.common.api import NTSC_Kind, bindings, get_ntsc_details, wait_for_ntsc_state +from determined.common.api import bindings +from tests import api_utils from tests import command as cmd -from tests import config as conf -from tests.api_utils import determined_test_session, kill_ntsc, launch_ntsc @pytest.mark.slow @@ -15,8 +14,9 @@ @pytest.mark.e2e_slurm @pytest.mark.e2e_pbs def test_basic_notebook_start_and_kill() -> None: + sess = api_utils.user_session() lines = [] # type: List[str] - with cmd.interactive_command("notebook", "start") as notebook: + with cmd.interactive_command(sess, ["notebook", "start"]) as notebook: for line in notebook.stdout: if re.search("Jupyter Notebook .*is running at", line) is not None: return @@ -27,27 +27,27 @@ def test_basic_notebook_start_and_kill() -> None: @pytest.mark.e2e_cpu def test_notebook_proxy() -> None: - session = determined_test_session(conf.ADMIN_CREDENTIALS) + session = api_utils.user_session() def get_proxy(session: api.Session, task_id: str) -> None: session.get(f"proxy/{task_id}/") - typ = NTSC_Kind.notebook - created_id = launch_ntsc(session, 1, typ).id + typ = api.NTSC_Kind.notebook + created_id = api_utils.launch_ntsc(session, 1, typ).id print(f"created {typ} {created_id}") - wait_for_ntsc_state( + api.wait_for_ntsc_state( session, - NTSC_Kind(typ), + api.NTSC_Kind(typ), created_id, lambda s: s == bindings.taskv1State.RUNNING, timeout=300, ) - deets = get_ntsc_details(session, typ, created_id) + deets = api.get_ntsc_details(session, typ, created_id) assert deets.state == bindings.taskv1State.RUNNING, f"{typ} should be running" - err = api.task_is_ready(determined_test_session(conf.ADMIN_CREDENTIALS), created_id) + err = api.wait_for_task_ready(session, created_id) assert err is None, f"{typ} should be ready {err}" print(deets) try: get_proxy(session, created_id) finally: - kill_ntsc(session, typ, created_id) + api_utils.kill_ntsc(session, typ, created_id) diff --git a/e2e_tests/tests/command/test_run.py b/e2e_tests/tests/command/test_run.py index 3cdaa81d76e..6acb8d16978 100644 --- a/e2e_tests/tests/command/test_run.py +++ b/e2e_tests/tests/command/test_run.py @@ -1,283 +1,113 @@ +import contextlib import copy -import re +import pathlib import subprocess import tempfile -from pathlib import Path +import textwrap from typing import Any, Dict, List, Optional import docker import docker.errors import pytest -from determined.common import util +from determined.common import api, util +from tests import api_utils from tests import command as cmd from tests import config as conf -from tests.filetree import FileTree +from tests import detproc, filetree -@pytest.mark.slow -@pytest.mark.e2e_cpu -def test_cold_and_warm_start(tmp_path: Path) -> None: - for _ in range(3): - subprocess.check_call( - ["det", "-m", conf.make_master_url(), "cmd", "run", "echo", "hello", "world"] - ) - - -def _run_and_return_real_exit_status(args: List[str], **kwargs: Any) -> None: - """ - Wraps subprocess.check_call and extracts exit status from output. - """ - # TODO(#2903): remove this once exit status are propagated through cli - output = subprocess.check_output(args, **kwargs) - if re.search(b"finished command \\S+ task failed with exit code", output): - raise subprocess.CalledProcessError(1, " ".join(args), output=output) - - -def _run_and_verify_exit_code_zero(args: List[str], **kwargs: Any) -> None: - """Wraps subprocess.check_output and verifies a successful exit code.""" - # TODO(#2903): remove this once exit status are propagated through cli - output = subprocess.check_output(args, **kwargs) - assert re.search(b"resources exited successfully", output) is not None, "Output is: {}".format( - output.decode("utf-8") - ) - - -def _run_and_verify_failure(args: List[str], message: str, **kwargs: Any) -> None: - output = subprocess.check_output(args, **kwargs) - if re.search(message.encode(), output): - raise subprocess.CalledProcessError(1, " ".join(args), output=output) - - -def _run_cmd_with_config_expecting_success( - cmd: str, config: Dict[str, Any], context_path: Optional[str] = None -) -> None: - with tempfile.NamedTemporaryFile() as tf: - with open(tf.name, "w") as f: - util.yaml_safe_dump(config, f) - - command = ["det", "-m", conf.make_master_url(), "cmd", "run", "--config-file", tf.name] - if context_path: - command += ["-c", context_path] - command.append(cmd) - - _run_and_verify_exit_code_zero(command) - - -def _run_cmd_with_config_expecting_failure( - cmd: str, expected_failure: str, config: Dict[str, Any] +def _run_cmd( + sess: api.Session, + cmd: List[str], + *, + expect_success: bool, + config: Optional[Dict[str, Any]] = None, + context: Optional[str] = None, ) -> None: - with tempfile.NamedTemporaryFile() as tf: - with open(tf.name, "w") as f: - util.yaml_safe_dump(config, f) - - with pytest.raises(subprocess.CalledProcessError): - _run_and_verify_failure( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--config-file", - tf.name, - cmd, - ], - expected_failure, - ) - - -@pytest.mark.e2e_cpu -@pytest.mark.e2e_slurm -@pytest.mark.e2e_pbs -def test_exit_code_reporting() -> None: - """ - Confirm that failed commands are not reported as successful, and confirm - that our test infrastructure is valid. - """ - with pytest.raises(AssertionError): - _run_and_verify_exit_code_zero(["det", "-m", conf.make_master_url(), "cmd", "run", "false"]) + """Always expect `det cmd run` to succeed, but the command itself might fail.""" + + with contextlib.ExitStack() as es: + det_cmd = ["det", "cmd", "run"] + if config is not None: + tf = es.enter_context(tempfile.NamedTemporaryFile()) + with open(tf.name, "w") as f: + util.yaml_safe_dump(config, f) + det_cmd += ["--config-file", tf.name] + if context: + det_cmd += ["-c", context] + + output = detproc.check_output(sess, det_cmd + cmd) + + if expect_success: + assert "resources exited successfully" in output.lower(), output + else: + assert "resources failed with non-zero exit code" in output.lower(), output @pytest.mark.slow @pytest.mark.e2e_cpu @pytest.mark.e2e_slurm @pytest.mark.e2e_pbs -def test_basic_workflows(tmp_path: Path) -> None: - with FileTree(tmp_path, {"hello.py": "print('hello world')"}) as tree: - _run_and_verify_exit_code_zero( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--context", - str(tree), - "python", - "hello.py", - ] - ) +def test_basic_workflows(tmp_path: pathlib.Path) -> None: + sess = api_utils.user_session() + with filetree.FileTree(tmp_path, {"hello.py": "print('hello world')"}) as tree: + _run_cmd(sess, ["python", "hello.py"], context=str(tree), expect_success=True) - with FileTree(tmp_path, {"hello.py": "print('hello world')"}) as tree: + with filetree.FileTree(tmp_path, {"hello.py": "print('hello world')"}) as tree: link = tree.joinpath("hello-link.py") link.symlink_to(tree.joinpath("hello.py")) - _run_and_verify_exit_code_zero( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--context", - str(tree), - "python", - "hello-link.py", - ] - ) + _run_cmd(sess, ["python", "hello-link.py"], context=str(tree), expect_success=True) - _run_and_verify_exit_code_zero( - ["det", "-m", conf.make_master_url(), "cmd", "run", "python", "-c", "print('hello world')"] + detproc.check_error( + sess, + ["det", "cmd", "run", "--context", "non-existent-path", "true"], + "non-existent-path' doesn't exist", ) - with pytest.raises(subprocess.CalledProcessError): - _run_and_return_real_exit_status( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--context", - "non-existent-path-here", - "python", - "hello.py", - ] - ) - @pytest.mark.slow @pytest.mark.e2e_cpu -def test_large_uploads(tmp_path: Path) -> None: - with pytest.raises(subprocess.CalledProcessError): - with FileTree(tmp_path, {"hello.py": "print('hello world')"}) as tree: - large = tree.joinpath("large-file.bin") - large.touch() - f = large.open(mode="w") +def test_large_uploads(tmp_path: pathlib.Path) -> None: + sess = api_utils.user_session() + + with filetree.FileTree(tmp_path, {"hello.py": "print('hello world')"}) as tree: + large = tree.joinpath("large-file.bin") + with large.open("w") as f: f.seek(1024 * 1024 * 120) f.write("\0") - f.close() - - _run_and_return_real_exit_status( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--context", - str(tree), - "python", - "hello.py", - ] - ) - - with FileTree(tmp_path, {"hello.py": "print('hello world')", ".detignore": "*.bin"}) as tree: - large = tree.joinpath("large-file.bin") - large.touch() - f = large.open(mode="w") - f.seek(1024 * 1024 * 120) - f.write("\0") - f.close() - - _run_and_verify_exit_code_zero( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--context", - str(tree), - "python", - "hello.py", - ] + + # 120MB is too big. + detproc.check_error( + sess, ["det", "cmd", "run", "-c", str(tree), "true"], "maximum allowed size" ) + # .detignore makes it ok though. + with tree.joinpath(".detignore").open("w") as f: + f.write("*.bin\n") + _run_cmd(sess, ["python", "hello.py"], context=str(tree), expect_success=True) + # TODO(DET-9859) we could move this test to nightly or even per release to save CI cost. # It takes around 15 seconds. @pytest.mark.e2e_k8s -def test_context_directory_larger_than_config_map_k8s(tmp_path: Path) -> None: - with FileTree(tmp_path, {"hello.py": "print('hello world')"}) as tree: +def test_context_directory_larger_than_config_map_k8s(tmp_path: pathlib.Path) -> None: + sess = api_utils.user_session() + with filetree.FileTree(tmp_path, {"hello.py": "print('hello world')"}) as tree: large = tree.joinpath("large-file.bin") - large.touch() - f = large.open(mode="w") - f.seek(1024 * 1024 * 10) - f.write("\0") - f.close() - - _run_and_verify_exit_code_zero( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--context", - str(tree), - "python", - "hello.py", - ] - ) - + with large.open("w") as f: + f.seek(1024 * 1024 * 10) + f.write("\0") -@pytest.mark.slow -@pytest.mark.e2e_cpu -def test_configs(tmp_path: Path) -> None: - with FileTree( - tmp_path, - { - "config.yaml": """ -resources: - slots: 1 -environment: - environment_variables: - - TEST=TEST -""" - }, - ) as tree: - config_path = tree.joinpath("config.yaml") - _run_and_verify_exit_code_zero( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--config-file", - str(config_path), - "python", - "-c", - """ -import os -test = os.environ["TEST"] -if test != "TEST": - print("{} != {}".format(test, "TEST")) - sys.exit(1) -""", - ] - ) + _run_cmd(sess, ["python", "hello.py"], context=str(tree), expect_success=True) @pytest.mark.slow @pytest.mark.e2e_cpu -@pytest.mark.e2e_slurm -@pytest.mark.e2e_pbs -def test_singleton_command() -> None: - _run_and_verify_exit_code_zero( - ["det", "-m", conf.make_master_url(), "cmd", "run", "echo hello && echo world"] - ) +def test_configs(tmp_path: pathlib.Path) -> None: + sess = api_utils.user_session() + config = {"environment": {"environment_variables": ["TEST=TEST"]}} + _run_cmd(sess, ["env | grep -q TEST=TEST"], config=config, expect_success=True) @pytest.mark.slow @@ -285,18 +115,9 @@ def test_singleton_command() -> None: @pytest.mark.e2e_slurm @pytest.mark.e2e_pbs def test_environment_variables_command() -> None: - _run_and_verify_exit_code_zero( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--config", - "environment.environment_variables='THISISTRUE=true','WONTCAUSEPANIC'", - 'if [ "$THISISTRUE" != "true" ]; then exit 1; fi', - ] - ) + sess = api_utils.user_session() + config_str = "environment.environment_variables='THISISTRUE=true','WONTCAUSEPANIC'" + _run_cmd(sess, ["--config", config_str, "env | grep -q THISISTRUE=true"], expect_success=True) @pytest.mark.parametrize("actual,expected", [("24576", "24"), ("1.5g", "1572864")]) @@ -304,121 +125,48 @@ def test_environment_variables_command() -> None: @pytest.mark.slow @pytest.mark.e2e_cpu def test_shm_size_command( - tmp_path: Path, actual: str, expected: str, use_config_file: bool + tmp_path: pathlib.Path, actual: str, expected: str, use_config_file: bool ) -> None: - with FileTree( - tmp_path, - { - "config.yaml": f""" -resources: - shm_size: {actual} -""" - }, - ) as tree: - config_path = tree.joinpath("config.yaml") - cmd = ["det", "-m", conf.make_master_url(), "cmd", "run"] - if use_config_file: - cmd += ["--config-file", str(config_path)] - else: - cmd += ["--config", f"resources.shm_size={actual}"] - cmd += [ - f"""df /dev/shm && \ -df /dev/shm | \ -tail -1 | \ -[ "$(awk '{{print $2}}')" = '{expected}' ]""" - ] - _run_and_verify_exit_code_zero(cmd) + sess = api_utils.user_session() + script = textwrap.dedent( + rf""" + set -e + set -o pipefail + df /dev/shm | tail -1 | test "$(awk '{{print $2}}')" = '{expected}' + """ + ) + if use_config_file: + config = {"resources": {"shm_size": actual}} + _run_cmd(sess, ["bash", "-c", script], config=config, expect_success=True) + else: + config_str = f"resources.shm_size={actual}" + _run_cmd(sess, ["--config", config_str, "bash", "-c", script], expect_success=True) @pytest.mark.slow @pytest.mark.e2e_cpu -def test_absolute_bind_mount(tmp_path: Path) -> None: - _run_and_verify_exit_code_zero( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--volume", - "/bin:/foo-bar", - "ls", - "/foo-bar", - ] +def test_absolute_bind_mount(tmp_path: pathlib.Path) -> None: + sess = api_utils.user_session() + config = {"bind_mounts": [{"host_path": "/bin", "container_path": "/foo-bar1"}]} + _run_cmd( + sess, + ["--volume", "/bin:/foo-bar2", "ls", "/foo-bar1", "/foo-bar2"], + config=config, + expect_success=True, ) - with FileTree( - tmp_path, - { - "config.yaml": """ -bind_mounts: -- host_path: /bin - container_path: /foo-bar -""" - }, - ) as tree: - config_path = tree.joinpath("config.yaml") - _run_and_verify_exit_code_zero( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--volume", - "/bin:/foo-bar2", - "--config-file", - str(config_path), - "ls", - "/foo-bar", - "/foo-bar2", - ] - ) - @pytest.mark.slow @pytest.mark.e2e_cpu -def test_relative_bind_mount(tmp_path: Path) -> None: - _run_and_verify_exit_code_zero( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--volume", - "/bin:foo-bar", - "ls", - "foo-bar", - ] +def test_relative_bind_mount(tmp_path: pathlib.Path) -> None: + sess = api_utils.user_session() + config = {"bind_mounts": [{"host_path": "/bin", "container_path": "foo-bar1"}]} + _run_cmd( + sess, + ["--volume", "/bin:foo-bar2", "ls", "foo-bar1", "foo-bar2"], + config=config, + expect_success=True, ) - with FileTree( - tmp_path, - { - "config.yaml": """ -bind_mounts: -- host_path: /bin - container_path: foo-bar -""" - }, - ) as tree: - config_path = tree.joinpath("config.yaml") - _run_and_verify_exit_code_zero( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--volume", - "/bin:foo-bar2", - "--config-file", - str(config_path), - "ls", - "foo-bar", - "foo-bar2", - ] - ) @pytest.mark.slow @@ -427,9 +175,10 @@ def test_relative_bind_mount(tmp_path: Path) -> None: @pytest.mark.e2e_pbs def test_cmd_kill() -> None: """Start a command, extract its task ID, and then kill it.""" + sess = api_utils.user_session() with cmd.interactive_command( - "command", "run", "echo hello world; echo hello world; sleep infinity" + sess, ["command", "run", "echo hello world; sleep infinity"] ) as command: assert command.task_id is not None for line in command.stdout: @@ -438,7 +187,7 @@ def test_cmd_kill() -> None: # every 10 seconds. For example, it is very likely the current job state is # STATE_PULLING when job is actually running on HPC. So instead of checking # for STATE_RUNNING, we check for other active states as well. - assert cmd.get_num_active_commands() == 1 + assert cmd.get_num_active_commands(sess) == 1 break @@ -448,13 +197,22 @@ def test_image_pull_after_remove() -> None: """ Remove pulled image and verify that it will be pulled again with auth. """ + sess = api_utils.user_session() client = docker.from_env() try: client.images.remove("python:3.8.16") except docker.errors.ImageNotFound: pass - _run_and_verify_exit_code_zero( + _run_cmd(sess, ["--config", "environment.image=python:3.8.16", "true"], expect_success=True) + + +@pytest.mark.e2e_cpu +def test_outrageous_command_rejected() -> None: + sess = api_utils.user_session() + # Specify an outrageous number of slots to be sure that it can't be scheduled. + detproc.check_error( + sess, [ "det", "-m", @@ -462,44 +220,26 @@ def test_image_pull_after_remove() -> None: "cmd", "run", "--config", - "environment.image=python:3.8.16", - "sleep 3; echo hello world", - ] + "resources.slots=10485", + "sleep infinity", + ], + "request unfulfillable", ) -@pytest.mark.e2e_cpu -def test_outrageous_command_rejected() -> None: - # Specify an outrageous number of slots to be sure that it can't be scheduled. - with pytest.raises(subprocess.CalledProcessError): - _run_and_verify_failure( - [ - "det", - "-m", - conf.make_master_url(), - "cmd", - "run", - "--config", - "resources.slots=10485", - "sleep infinity", - ], - "request unfulfillable", - ) - - @pytest.mark.e2e_gpu @pytest.mark.parametrize("sidecar", [True, False]) -def test_k8_mount(using_k8s: bool, sidecar: bool) -> None: - if not using_k8s: - pytest.skip("only need to run test on kubernetes") +@api_utils.skipif_not_k8s() +def test_k8s_mount(sidecar: bool) -> None: + sess = api_utils.user_session() mount_path = "/ci/" - with pytest.raises(subprocess.CalledProcessError): - _run_and_verify_failure( - ["det", "-m", conf.make_master_url(), "cmd", "run", f"sleep 3; touch {mount_path}"], - "No such file or directory", - ) + output = detproc.check_output( + sess, + ["det", "cmd", "run", "touch", mount_path], + ) + assert "No such file or directory" in output, output config = { "environment": { @@ -532,13 +272,13 @@ def test_k8_mount(using_k8s: bool, sidecar: bool) -> None: config["environment"]["pod_spec"]["spec"]["containers"][0], # type: ignore ] - _run_cmd_with_config_expecting_success(cmd=f"sleep 3; touch {mount_path}", config=config) + _run_cmd(sess, ["touch", mount_path], config=config, expect_success=True) @pytest.mark.e2e_gpu -def test_k8_init_containers(using_k8s: bool) -> None: - if not using_k8s: - pytest.skip("only need to run test on kubernetes") +@api_utils.skipif_not_k8s() +def test_k8s_init_containers() -> None: + sess = api_utils.user_session() config = { "environment": { @@ -556,19 +296,16 @@ def test_k8_init_containers(using_k8s: bool) -> None: } } } - - _run_cmd_with_config_expecting_failure( - cmd="sleep 3", expected_failure="exit code 1", config=config - ) + _run_cmd(sess, ["echo", "hi"], config=config, expect_success=False) config["environment"]["pod_spec"]["spec"]["initContainers"][0]["args"] = ["-c", "exit 0"] - _run_cmd_with_config_expecting_success(cmd="sleep 3", config=config) + _run_cmd(sess, ["echo", "hi"], config=config, expect_success=True) @pytest.mark.e2e_gpu -def test_k8_sidecars(using_k8s: bool) -> None: - if not using_k8s: - pytest.skip("only need to run test on kubernetes") +@api_utils.skipif_not_k8s() +def test_k8s_sidecars() -> None: + sess = api_utils.user_session() base_config = { "environment": { @@ -591,20 +328,18 @@ def set_arg(arg: str) -> Dict[str, Any]: new_config["environment"]["pod_spec"]["spec"]["containers"][0]["args"] = ["-c", arg] return new_config + # Sidecar failure should not affect command failure. configs = [set_arg("sleep 1; exit 1"), set_arg("sleep 99999999")] for config in configs: - _run_cmd_with_config_expecting_failure( - cmd="sleep 3; exit 1", expected_failure="exit code 1", config=config - ) - - _run_cmd_with_config_expecting_success(cmd="sleep 3", config=config) + _run_cmd(sess, ["false"], config=config, expect_success=False) + _run_cmd(sess, ["sleep", "3"], config=config, expect_success=True) @pytest.mark.e2e_gpu @pytest.mark.parametrize("slots", [0, 1]) -def test_k8_resource_limits(using_k8s: bool, slots: int) -> None: - if not using_k8s: - pytest.skip("only need to run test on kubernetes") +@api_utils.skipif_not_k8s() +def test_k8s_resource_limits(slots: int) -> None: + sess = api_utils.user_session() config = { "environment": { @@ -633,13 +368,15 @@ def test_k8_resource_limits(using_k8s: bool, slots: int) -> None: }, } - _run_cmd_with_config_expecting_success(cmd="sleep 3; echo hello", config=config) + _run_cmd(sess, ["true"], config=config, expect_success=True) @pytest.mark.e2e_cpu @pytest.mark.e2e_slurm @pytest.mark.e2e_pbs -def test_log_wait_timeout(tmp_path: Path, secrets: Dict[str, str]) -> None: +def test_log_wait_timeout(tmp_path: pathlib.Path, secrets: Dict[str, str]) -> None: + sess = api_utils.user_session() + # Start a subshell that prints after 5 and 20 seconds, then exit. cmd = 'sh -c "sleep 5; echo after 5; sleep 15; echo after 20" & echo main shell exiting' @@ -648,13 +385,13 @@ def test_log_wait_timeout(tmp_path: Path, secrets: Dict[str, str]) -> None: with open(tf.name, "w") as f: util.yaml_safe_dump(config, f) - cli = ["det", "-m", conf.make_master_url(), "cmd", "run", "--config-file", tf.name, cmd] - p = subprocess.run(cli, stdout=subprocess.PIPE, check=True) + cli = ["det", "cmd", "run", "--config-file", tf.name, cmd] + p = detproc.run(sess, cli, stdout=subprocess.PIPE, check=True) assert p.stdout is not None stdout = p.stdout.decode("utf8") # Logs should wait for the main process to die, plus 10 seconds, then shut down. - # That should capture the "after 5" but not the "after 60". + # That should capture the "after 5" but not the "after 20". # By making the "after 20" occur before the default DET_LOG_WAIT_TIME of 30, we also are testing # that the escape hatch keeps working. assert "after 5" in stdout, stdout @@ -664,8 +401,7 @@ def test_log_wait_timeout(tmp_path: Path, secrets: Dict[str, str]) -> None: @pytest.mark.parametrize("task_type", ["notebook", "command", "shell", "tensorboard"]) @pytest.mark.e2e_cpu def test_log_argument(task_type: str) -> None: + sess = api_utils.user_session() taskid = "28ad1623-dcf0-47d2-9faa-265aaa05b078" - cmd: List[str] = ["det", "-m", conf.make_master_url(), task_type, "logs", taskid] - p = subprocess.run(cmd, stderr=subprocess.PIPE, check=False) - assert p.stderr is not None - assert "not found" in p.stderr.decode("utf8"), p.stderr.decode("utf8") + cmd: List[str] = ["det", task_type, "logs", taskid] + detproc.check_error(sess, cmd, "not found") diff --git a/e2e_tests/tests/command/test_shell.py b/e2e_tests/tests/command/test_shell.py index b74e7a81c67..b90d264f6f8 100644 --- a/e2e_tests/tests/command/test_shell.py +++ b/e2e_tests/tests/command/test_shell.py @@ -1,18 +1,20 @@ -from pathlib import Path +import pathlib import pytest import determined as det +from tests import api_utils from tests import command as cmd -from tests.cluster import test_users +from tests import detproc @pytest.mark.slow @pytest.mark.e2e_gpu @pytest.mark.e2e_slurm @pytest.mark.e2e_pbs -def test_start_and_write_to_shell(tmp_path: Path) -> None: - with cmd.interactive_command("shell", "start") as shell: +def test_start_and_write_to_shell(tmp_path: pathlib.Path) -> None: + sess = api_utils.user_session() + with cmd.interactive_command(sess, ["shell", "start"]) as shell: # Call our cli to ensure that PATH and PYTHONUSERBASE are properly set. shell.stdin.write(b"COLUMNS=80 det --version\n") # Exit the shell, so we can read output below until EOF instead of timeout @@ -32,16 +34,9 @@ def test_start_and_write_to_shell(tmp_path: Path) -> None: @pytest.mark.e2e_cpu def test_open_shell() -> None: - with cmd.interactive_command("shell", "start", "--detach") as shell: - task_id = shell.task_id - assert task_id is not None - - child = test_users.det_spawn(["shell", "open", task_id]) - child.setecho(True) - child.expect(r".*Permanently added.+([0-9a-f-]{36}).+known hosts\.", timeout=180) - child.sendline("det user whoami") - child.expect("You are logged in as user \\'(.*)\\'", timeout=10) - child.sendline("exit") - child.read() - child.wait() - assert child.exitstatus == 0 + sess = api_utils.user_session() + with cmd.interactive_command(sess, ["shell", "start", "--detach"]) as shell: + assert shell.task_id + command = ["det", "shell", "open", shell.task_id, "det", "user", "whoami"] + output = detproc.check_output(sess, command) + assert "You are logged in as user" in output diff --git a/e2e_tests/tests/command/test_tensorboard.py b/e2e_tests/tests/command/test_tensorboard.py index a58649f082d..d793d37e4b4 100644 --- a/e2e_tests/tests/command/test_tensorboard.py +++ b/e2e_tests/tests/command/test_tensorboard.py @@ -1,15 +1,15 @@ import pathlib -import subprocess from typing import Dict, Optional import pytest from determined.common import api, util +from tests import api_utils from tests import command as cmd from tests import config as conf +from tests import detproc from tests import experiment as exp -from tests.api_utils import determined_test_session -from tests.filetree import FileTree +from tests import filetree num_trials = 1 @@ -66,16 +66,17 @@ def test_start_tensorboard_for_shared_fs_experiment(tmp_path: pathlib.Path) -> N TensorBoard instance pointed to the experiment, and kill the TensorBoard instance. """ - with FileTree(tmp_path, {"config.yaml": shared_fs_config(1)}) as tree: + sess = api_utils.user_session() + with filetree.FileTree(tmp_path, {"config.yaml": shared_fs_config(1)}) as tree: config_path = tree.joinpath("config.yaml") experiment_id = exp.run_basic_test( - str(config_path), conf.fixtures_path("no_op"), num_trials + sess, str(config_path), conf.fixtures_path("no_op"), num_trials ) command = ["tensorboard", "start", str(experiment_id), "--no-browser"] - with cmd.interactive_command(*command) as tensorboard: + with cmd.interactive_command(sess, command) as tensorboard: assert tensorboard.task_id is not None - err = api.task_is_ready(determined_test_session(), tensorboard.task_id) + err = api.wait_for_task_ready(sess, tensorboard.task_id) assert err is None, err @@ -91,16 +92,17 @@ def test_start_tensorboard_for_s3_experiment( TensorBoard instance pointed to the experiment, and kill the TensorBoard instance. """ - with FileTree(tmp_path, {"config.yaml": s3_config(1, secrets, prefix)}) as tree: + sess = api_utils.user_session() + with filetree.FileTree(tmp_path, {"config.yaml": s3_config(1, secrets, prefix)}) as tree: config_path = tree.joinpath("config.yaml") experiment_id = exp.run_basic_test( - str(config_path), conf.fixtures_path("no_op"), num_trials + sess, str(config_path), conf.fixtures_path("no_op"), num_trials ) command = ["tensorboard", "start", str(experiment_id), "--no-browser"] - with cmd.interactive_command(*command) as tensorboard: + with cmd.interactive_command(sess, command) as tensorboard: assert tensorboard.task_id is not None - err = api.task_is_ready(determined_test_session(), tensorboard.task_id) + err = api.wait_for_task_ready(sess, tensorboard.task_id) assert err is None, err @@ -117,7 +119,8 @@ def test_start_tensorboard_for_multi_experiment( start a TensorBoard instance pointed to the experiments and some select trials, and kill the TensorBoard instance. """ - with FileTree( + sess = api_utils.user_session() + with filetree.FileTree( tmp_path, { "shared_fs_config.yaml": shared_fs_config(1), @@ -127,18 +130,20 @@ def test_start_tensorboard_for_multi_experiment( ) as tree: shared_conf_path = tree.joinpath("shared_fs_config.yaml") shared_fs_exp_id = exp.run_basic_test( - str(shared_conf_path), conf.fixtures_path("no_op"), num_trials + sess, str(shared_conf_path), conf.fixtures_path("no_op"), num_trials ) s3_conf_path = tree.joinpath("s3_config.yaml") - s3_exp_id = exp.run_basic_test(str(s3_conf_path), conf.fixtures_path("no_op"), num_trials) + s3_exp_id = exp.run_basic_test( + sess, str(s3_conf_path), conf.fixtures_path("no_op"), num_trials + ) multi_trial_config = tree.joinpath("multi_trial_config.yaml") multi_trial_exp_id = exp.run_basic_test( - str(multi_trial_config), conf.fixtures_path("no_op"), 3 + sess, str(multi_trial_config), conf.fixtures_path("no_op"), 3 ) - trial_ids = [str(t.trial.id) for t in exp.experiment_trials(multi_trial_exp_id)] + trial_ids = [str(t.trial.id) for t in exp.experiment_trials(sess, multi_trial_exp_id)] command = [ "tensorboard", @@ -151,9 +156,9 @@ def test_start_tensorboard_for_multi_experiment( "--no-browser", ] - with cmd.interactive_command(*command) as tensorboard: + with cmd.interactive_command(sess, command) as tensorboard: assert tensorboard.task_id is not None - err = api.task_is_ready(determined_test_session(), tensorboard.task_id) + err = api.wait_for_task_ready(sess, tensorboard.task_id) assert err is None, err @@ -163,15 +168,15 @@ def test_start_tensorboard_with_custom_image() -> None: Start a random experiment, start a TensorBoard instance pointed to the experiment with custom image, verify the image has been set. """ + sess = api_utils.user_session() experiment_id = exp.run_basic_test( + sess, conf.fixtures_path("no_op/single-one-short-step.yaml"), conf.fixtures_path("no_op"), 1, ) command = [ "det", - "-m", - conf.make_master_url(), "tensorboard", "start", str(experiment_id), @@ -180,11 +185,10 @@ def test_start_tensorboard_with_custom_image() -> None: "--config", "environment.image=python:3.8.16", ] - res = subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE, check=True) - t_id = res.stdout.strip("\n") - command = ["det", "-m", conf.make_master_url(), "tensorboard", "config", t_id] - res = subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE, check=True) - config = util.yaml_safe_load(res.stdout) + t_id = detproc.check_output(sess, command).strip() + command = ["det", "tensorboard", "config", t_id] + res = detproc.check_output(sess, command) + config = util.yaml_safe_load(res) assert ( config["environment"]["image"]["cpu"] == "python:3.8.16" and config["environment"]["image"]["cuda"] == "python:3.8.16" @@ -198,27 +202,27 @@ def test_tensorboard_inherit_image_pull_secrets() -> None: Start a random experiment with image_pull_secrets, start a TensorBoard instance pointed to the experiment, verify the secrets are inherited. """ + sess = api_utils.user_session() exp_secrets = [{"name": "ips"}] config_obj = conf.load_config(conf.fixtures_path("no_op/single-one-short-step.yaml")) pod = config_obj.setdefault("environment", {}).setdefault("pod_spec", {}) pod.setdefault("spec", {})["imagePullSecrets"] = [{"name": "ips"}] - experiment_id = exp.run_basic_test_with_temp_config(config_obj, conf.fixtures_path("no_op"), 1) + experiment_id = exp.run_basic_test_with_temp_config( + sess, config_obj, conf.fixtures_path("no_op"), 1 + ) command = [ "det", - "-m", - conf.make_master_url(), "tensorboard", "start", str(experiment_id), "--no-browser", "--detach", ] - res = subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE, check=True) - t_id = res.stdout.strip("\n") - command = ["det", "-m", conf.make_master_url(), "tensorboard", "config", t_id] - res = subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE, check=True) - config = util.yaml_safe_load(res.stdout) + t_id = detproc.check_output(sess, command).strip() + command = ["det", "tensorboard", "config", t_id] + res = detproc.check_output(sess, command) + config = util.yaml_safe_load(res) ips = config["environment"]["pod_spec"]["spec"]["imagePullSecrets"] @@ -231,13 +235,14 @@ def test_delete_tensorboard_for_experiment() -> None: Start a random experiment, start a TensorBoard instance pointed to the experiment, delete tensorboard and verify deletion. """ + sess = api_utils.user_session() config_obj = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml")) experiment_id = exp.run_basic_test_with_temp_config( - config_obj, conf.tutorials_path("mnist_pytorch"), 1 + sess, config_obj, conf.tutorials_path("mnist_pytorch"), 1 ) command = ["det", "e", "delete-tb-files", str(experiment_id)] - subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE, check=True) + detproc.check_output(sess, command) # Check if Tensorboard files are deleted tb_path = sorted(pathlib.Path("/tmp/determined-cp/").glob("*/tensorboard"))[0] @@ -247,6 +252,7 @@ def test_delete_tensorboard_for_experiment() -> None: @pytest.mark.e2e_cpu def test_tensorboard_directory_storage(tmp_path: pathlib.Path) -> None: + sess = api_utils.user_session() config_obj = conf.load_config(conf.fixtures_path("no_op/single-one-short-step.yaml")) config_obj["checkpoint_storage"] = { "type": "directory", @@ -264,7 +270,9 @@ def test_tensorboard_directory_storage(tmp_path: pathlib.Path) -> None: with tb_config_path.open("w") as fout: util.yaml_safe_dump(tb_config, fout) - experiment_id = exp.run_basic_test_with_temp_config(config_obj, conf.fixtures_path("no_op"), 1) + experiment_id = exp.run_basic_test_with_temp_config( + sess, config_obj, conf.fixtures_path("no_op"), 1 + ) command = [ "tensorboard", @@ -275,7 +283,7 @@ def test_tensorboard_directory_storage(tmp_path: pathlib.Path) -> None: str(tb_config_path), ] - with cmd.interactive_command(*command) as tensorboard: + with cmd.interactive_command(sess, command) as tensorboard: assert tensorboard.task_id is not None - err = api.task_is_ready(determined_test_session(), tensorboard.task_id) + err = api.wait_for_task_ready(sess, tensorboard.task_id) assert err is None, err diff --git a/e2e_tests/tests/config.py b/e2e_tests/tests/config.py index 8ca49c02694..0b5a13a8e3d 100644 --- a/e2e_tests/tests/config.py +++ b/e2e_tests/tests/config.py @@ -1,9 +1,8 @@ import os -from pathlib import Path +import pathlib from typing import Any, Dict, List, Union from determined.common import api, util -from determined.common.api import authentication MASTER_SCHEME = "http" MASTER_IP = "localhost" @@ -33,11 +32,9 @@ PT2_GPU_IMAGE = os.environ.get("PT2_GPU_IMAGE") or DEFAULT_PT2_GPU_IMAGE GPU_ENABLED = os.environ.get("DET_TEST_GPU_ENABLED", "1") not in ("0", "false") -PROJECT_ROOT_PATH = Path(__file__).resolve().parents[2] +PROJECT_ROOT_PATH = pathlib.Path(__file__).resolve().parents[2] EXAMPLES_PATH = PROJECT_ROOT_PATH / "examples" -ADMIN_CREDENTIALS = authentication.Credentials("admin", "") - SCIM_USERNAME = "determined" SCIM_PASSWORD = "password" @@ -84,7 +81,7 @@ def load_config(config_path: str) -> Any: def make_master_url(suffix: str = "") -> str: - return "{}://{}:{}/{}".format(MASTER_SCHEME, MASTER_IP, MASTER_PORT, suffix) + return f"{MASTER_SCHEME}://{MASTER_IP}:{MASTER_PORT}/{suffix}" def set_global_batch_size(config: Dict[Any, Any], batch_size: int) -> Dict[Any, Any]: diff --git a/e2e_tests/tests/conftest.py b/e2e_tests/tests/conftest.py index 85a12e7f698..4bcb236c328 100644 --- a/e2e_tests/tests/conftest.py +++ b/e2e_tests/tests/conftest.py @@ -1,22 +1,20 @@ import json +import pathlib import subprocess import time -from pathlib import Path from typing import Any, Callable, Dict, Iterator, Optional, cast +import _pytest.config.argparsing +import _pytest.fixtures import boto3 import pytest -from _pytest.config.argparsing import Parser -from _pytest.fixtures import SubRequest from botocore import exceptions as boto_exc -from determined.experimental import client as _client -from tests import config -from tests.experiment import profile_test -from tests.nightly.compute_stats import compare_stats - -from .cluster.test_users import logged_in_user -from .cluster_log_manager import ClusterLogManager +from tests import api_utils, cluster_log_manager +from tests import config as conf +from tests import detproc +from tests.experiment import record_profiling +from tests.nightly import compute_stats _INTEG_MARKERS = { "tensorflow1_cpu", @@ -54,7 +52,7 @@ } -def pytest_addoption(parser: Parser) -> None: +def pytest_addoption(parser: _pytest.config.argparsing.Parser) -> None: parser.addoption( "--master-config-path", action="store", default=None, help="Path to master config path" ) @@ -83,7 +81,7 @@ def pytest_addoption(parser: Parser) -> None: "--require-secrets", action="store_true", help="fail tests when s3 access fails" ) path = ( - Path(__file__) + pathlib.Path(__file__) .parents[2] .joinpath("deploy", "determined.deploy", "local", "docker-compose.yaml") ) @@ -100,23 +98,30 @@ def pytest_addoption(parser: Parser) -> None: parser.addoption("--no-compare-stats", action="store_true", help="Disable usage stats check") +def pytest_configure(config: _pytest.config.Config) -> None: + """ + pytest_configure is a pytest hook which runs before all fixtures and test decorators. + + It is important we use this hook to capture information related to accessing the master, so that + our various skipif decorators can access the master. + """ + + conf.MASTER_SCHEME = config.getoption("--master-scheme") + conf.MASTER_IP = config.getoption("--master-host") + conf.MASTER_PORT = config.getoption("--master-port") + conf.DET_VERSION = config.getoption("--det-version") + + @pytest.fixture(scope="session", autouse=True) -def cluster_log_manager(request: SubRequest) -> Iterator[Optional[ClusterLogManager]]: - master_scheme = request.config.getoption("--master-scheme") - master_host = request.config.getoption("--master-host") - master_port = request.config.getoption("--master-port") - det_version = request.config.getoption("--det-version") +def cluster_log_manager_fixture( + request: _pytest.fixtures.SubRequest, +) -> Iterator[Optional[cluster_log_manager.ClusterLogManager]]: follow_local_logs = request.config.getoption("--follow-local-logs") compare_stats_enabled = not request.config.getoption("--no-compare-stats") - config.MASTER_SCHEME = master_scheme - config.MASTER_IP = master_host - config.MASTER_PORT = master_port - config.DET_VERSION = det_version - - if master_host == "localhost" and follow_local_logs: + if conf.MASTER_IP == "localhost" and follow_local_logs: project_name = request.config.getoption("--compose-project-name") - with ClusterLogManager( + with cluster_log_manager.ClusterLogManager( lambda: subprocess.run( ["det", "deploy", "local", "logs", "--cluster-name", project_name] ) @@ -129,7 +134,8 @@ def cluster_log_manager(request: SubRequest) -> Iterator[Optional[ClusterLogMana yield None if compare_stats_enabled: - compare_stats() + sess = api_utils.admin_session() + compute_stats.compare_stats(sess) def pytest_itemcollected(item: Any) -> None: @@ -137,7 +143,7 @@ def pytest_itemcollected(item: Any) -> None: pytest.exit(f"{item.nodeid} is missing an integration test mark (any of {_INTEG_MARKERS})") -def s3_secrets(request: SubRequest) -> Dict[str, str]: +def s3_secrets(request: _pytest.fixtures.SubRequest) -> Dict[str, str]: """ Connect to S3 secretsmanager to get the secret values used in integrations tests. """ @@ -153,7 +159,7 @@ def s3_secrets(request: SubRequest) -> Dict[str, str]: @pytest.fixture(scope="session") -def secrets(request: SubRequest) -> Dict[str, str]: +def secrets(request: _pytest.fixtures.SubRequest) -> Dict[str, str]: response = {} try: @@ -167,20 +173,13 @@ def secrets(request: SubRequest) -> Dict[str, str]: @pytest.fixture(scope="session") -def checkpoint_storage_config(request: SubRequest) -> Dict[str, Any]: - command = [ - "det", - "-m", - config.make_master_url(), - "master", - "config", - "--json", - ] +def checkpoint_storage_config(request: _pytest.fixtures.SubRequest) -> Dict[str, Any]: + command = ["det", "master", "config", "--json"] - with logged_in_user(config.ADMIN_CREDENTIALS): - output = subprocess.check_output(command, universal_newlines=True, stderr=subprocess.PIPE) + sess = api_utils.admin_session() + output = detproc.check_json(sess, command) - checkpoint_config = json.loads(output)["checkpoint_storage"] + checkpoint_config = output["checkpoint_storage"] if checkpoint_config["type"] == "s3": secret_conf = s3_secrets(request) @@ -191,26 +190,8 @@ def checkpoint_storage_config(request: SubRequest) -> Dict[str, Any]: return cast(Dict[str, Any], checkpoint_config) -@pytest.fixture(scope="session") -def using_k8s(request: SubRequest) -> bool: - command = [ - "det", - "-m", - config.make_master_url(), - "master", - "config", - "--json", - ] - - with logged_in_user(config.ADMIN_CREDENTIALS): - output = subprocess.check_output(command, universal_newlines=True, stderr=subprocess.PIPE) - - rp = json.loads(output)["resource_manager"]["type"] - return bool(rp == "kubernetes") - - @pytest.fixture(autouse=True) -def test_start_timer(request: SubRequest) -> Iterator[None]: +def test_start_timer(request: _pytest.fixtures.SubRequest) -> Iterator[None]: # If pytest is run with minimal verbosity, individual test names are not printed and the output # of this would look funny. if request.config.option.verbose >= 1: @@ -227,14 +208,8 @@ def collect_trial_profiles(record_property: Callable[[str, object], None]) -> Ca Currently retrieves metrics by trial (assumes one trial per experiment) using profiler API. - """ - - return profile_test(record_property=record_property) - -@pytest.fixture(scope="session") -def client() -> _client.Determined: + Note: this must be a fixture in order to use the record_property fixture provided by pytest. """ - Reduce logins by having one session-level fixture do the login. - """ - return _client.Determined(config.make_master_url()) + + return record_profiling.profile_test(record_property=record_property) diff --git a/e2e_tests/tests/deploy/test_local.py b/e2e_tests/tests/deploy/test_local.py index 67c26b16a0f..5b0f48e762d 100644 --- a/e2e_tests/tests/deploy/test_local.py +++ b/e2e_tests/tests/deploy/test_local.py @@ -1,19 +1,28 @@ -import json import os +import pathlib import random import subprocess import time -from pathlib import Path from typing import List import docker import pytest -from determined.common.api import bindings +from determined.common import api +from determined.common.api import authentication, bindings from tests import config as conf +from tests import detproc from tests import experiment as exp -from ..cluster.test_users import logged_in_user + +def mksess(host: str, port: int, username: str = "determined", password: str = "") -> api.Session: + """ + Since this file frequently creates new masters, always create a fresh Session. + """ + + master_url = f"http://{host}:{port}" + utp = authentication.login(master_url, username=username, password=password) + return api.Session(master_url, utp, cert=None) def det_deploy(subcommand: List) -> None: @@ -74,23 +83,17 @@ def agent_down(arguments: List) -> None: det_deploy(command) -def agent_enable(arguments: List) -> None: - with logged_in_user(conf.ADMIN_CREDENTIALS): - subprocess.run(["det", "-m", conf.make_master_url(), "agent", "enable"] + arguments) +def agent_enable(sess: api.Session, arguments: List) -> None: + detproc.check_output(sess, ["det", "agent", "enable"] + arguments) -def agent_disable(arguments: List) -> None: - with logged_in_user(conf.ADMIN_CREDENTIALS): - subprocess.run(["det", "-m", conf.make_master_url(), "agent", "disable"] + arguments) +def agent_disable(sess: api.Session, arguments: List) -> None: + detproc.check_output(sess, ["det", "agent", "disable"] + arguments) @pytest.mark.det_deploy_local def test_cluster_down() -> None: - master_host = "localhost" - master_port = "8080" name = "test_cluster_down" - conf.MASTER_IP = master_host - conf.MASTER_PORT = master_port cluster_up(["--cluster-name", name]) @@ -108,14 +111,12 @@ def test_cluster_down() -> None: @pytest.mark.det_deploy_local def test_custom_etc() -> None: - master_host = "localhost" - master_port = "8080" name = "test_custom_etc" - conf.MASTER_IP = master_host - conf.MASTER_PORT = master_port - etc_path = str(Path(__file__).parent.joinpath("etc/master.yaml").resolve()) + etc_path = str(pathlib.Path(__file__).parent.joinpath("etc/master.yaml").resolve()) cluster_up(["--master-config-path", etc_path, "--cluster-name", name]) + sess = mksess("localhost", 8080) exp.run_basic_test( + sess, conf.fixtures_path("no_op/single-default-ckpt.yaml"), conf.fixtures_path("no_op"), 1, @@ -126,16 +127,12 @@ def test_custom_etc() -> None: @pytest.mark.det_deploy_local def test_agent_config_path() -> None: - master_host = "localhost" - master_port = "8080" cluster_name = "test_agent_config_path" master_name = f"{cluster_name}_determined-master_1" - conf.MASTER_IP = master_host - conf.MASTER_PORT = master_port master_up(["--master-name", master_name]) # Config makes it unmodified. - etc_path = str(Path(__file__).parent.joinpath("etc/agent.yaml").resolve()) + etc_path = str(pathlib.Path(__file__).parent.joinpath("etc/agent.yaml").resolve()) agent_name = "test-path-agent" agent_up(["--agent-config-path", etc_path]) @@ -150,7 +147,8 @@ def test_agent_config_path() -> None: # Validate CLI flags overwrite config file options. agent_name += "-2" agent_up(["--agent-name", agent_name, "--agent-config-path", etc_path]) - agent_list = json.loads(subprocess.check_output(["det", "a", "list", "--json"]).decode()) + sess = mksess("localhost", 8080) + agent_list = detproc.check_json(sess, ["det", "a", "list", "--json"]) agent_list = [el for el in agent_list if el["id"] == agent_name] assert len(agent_list) == 1 agent_down(["--agent-name", agent_name]) @@ -161,18 +159,19 @@ def test_agent_config_path() -> None: @pytest.mark.det_deploy_local def test_custom_port() -> None: name = "port_test" - master_host = "localhost" - master_port = "12321" - conf.MASTER_IP = master_host - conf.MASTER_PORT = master_port + custom_port = 12321 arguments = [ "--cluster-name", name, "--master-port", - f"{master_port}", + f"{custom_port}", ] cluster_up(arguments) + + sess = mksess("localhost", custom_port) + exp.run_basic_test( + sess, conf.fixtures_path("no_op/single-one-short-step.yaml"), conf.fixtures_path("no_op"), 1, @@ -182,12 +181,8 @@ def test_custom_port() -> None: @pytest.mark.det_deploy_local def test_agents_made() -> None: - master_host = "localhost" - master_port = "8080" name = "agents_test" num_agents = 2 - conf.MASTER_IP = master_host - conf.MASTER_PORT = master_port arguments = [ "--cluster-name", name, @@ -207,12 +202,8 @@ def test_agents_made() -> None: @pytest.mark.det_deploy_local def test_master_up_down() -> None: - master_host = "localhost" - master_port = "8080" cluster_name = "test_master_up_down" master_name = f"{cluster_name}_determined-master_1" - conf.MASTER_IP = master_host - conf.MASTER_PORT = master_port master_up(["--master-name", master_name]) @@ -229,11 +220,7 @@ def test_master_up_down() -> None: @pytest.mark.det_deploy_local def test_agent_up_down() -> None: - master_host = "localhost" - master_port = "8080" agent_name = "test_agent-determined-agent" - conf.MASTER_IP = master_host - conf.MASTER_PORT = master_port cluster_name = "test_agent_up_down" master_name = f"{cluster_name}_determined-master_1" @@ -257,14 +244,12 @@ def test_agent_up_down() -> None: @pytest.mark.stress_test def test_stress_agents_reconnect(steps: int, num_agents: int, should_disconnect: bool) -> None: random.seed(42) - master_host = "localhost" - master_port = "8080" cluster_name = "test_stress_agents_reconnect" master_name = f"{cluster_name}_determined-master_1" - conf.MASTER_IP = master_host - conf.MASTER_PORT = master_port master_up(["--master-name", master_name]) + sess = mksess("localhost", 8080, "admin") + # Start all agents. agents_are_up = [True] * num_agents for i in range(num_agents): @@ -286,53 +271,37 @@ def test_stress_agents_reconnect(steps: int, num_agents: int, should_disconnect: agents_are_up[agent_id] = not agents_are_up[agent_id] else: if random.choice([True, False]): - agent_disable([f"agent-{agent_id}"]) + agent_disable(sess, [f"agent-{agent_id}"]) agents_are_up[agent_id] = False else: - agent_enable([f"agent-{agent_id}"]) + agent_enable(sess, [f"agent-{agent_id}"]) agents_are_up[agent_id] = True print("agents_are_up:", agents_are_up) time.sleep(10) # Validate that our master kept track of the agent reconnect spam. - agent_list = json.loads( - subprocess.check_output( - [ - "det", - "agent", - "list", - "--json", - ] - ).decode() - ) + agent_list = detproc.check_json(sess, ["det", "agent", "list", "--json"]) print("agent_list:", agent_list) assert sum(agents_are_up) <= len(agent_list) for agent in agent_list: print("agent:", agent) agent_id = int(agent["id"].replace("agent-", "")) if agents_are_up[agent_id] != agent["enabled"]: - p = subprocess.run( - [ - "det", - "deploy", - "local", - "logs", - ] - ) - print(p.stdout) - print(p.stderr) + subprocess.check_call(["det", "deploy", "local", "logs"]) assert ( agents_are_up[agent_id] == agent["enabled"] ), f"agent is up: {agents_are_up[agent_id]}, agent status: {agent}" # Can we still schedule something? if any(agents_are_up): + mksess("localhost", 8080) experiment_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-one-short-step.yaml"), conf.fixtures_path("no_op"), None, ) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.COMPLETED) for agent_id in range(num_agents): agent_down(["--agent-name", f"agent-{agent_id}"]) diff --git a/e2e_tests/tests/detproc.py b/e2e_tests/tests/detproc.py new file mode 100644 index 00000000000..b7ef620d766 --- /dev/null +++ b/e2e_tests/tests/detproc.py @@ -0,0 +1,142 @@ +""" +detproc is a subprocess-like tool for calling our CLI with explicit session management. + +e2e tests shouldn't really be relying on the persistence of cached api credentials in order to work; +they should be explicit about which login session should be used to make the test pass. + +However, lots of e2e functionality is exercised today through the CLI. Also, it's unfortunately +true that almost all of the CLI functionality is tested only with e2e tests. So migrating the whole +e2e test suite to api.bindings or the SDK might be nice for e2e_tests but it would probably result +in huge parts of the CLI having no test coverage at all. + +So the detprocess module avoids the dilemma by continuing to use the CLI in e2e tests but offering +a mechansism for explicit session management through the CLI subprocess boundary. +""" + +import json +import os +import subprocess +from typing import Any, Dict, List, Optional + +from determined.common import api + + +class CalledProcessError(subprocess.CalledProcessError): + """ + Subclass subprocess.CalledProcessError in order to have a __str__ method that includes the + stderr of the det cli call that failed. + + That way, the actual failure surfaces in test logs in the pytest summary info section at the + bottom of the logs. + """ + + def __str__(self) -> str: + return ( + f"Command '{self.cmd}' returned non-zero exit status {self.returncode}, " + f"stderr={self.stderr}" + ) + + +def mkenv(sess: api.Session, env: Optional[Dict[str, str]]) -> Dict[str, str]: + env = env or {**os.environ} + assert "DET_USER" not in env, "if you set DET_USER you probably want to use normal subprocess" + assert ( + "DET_USER_TOKEN" not in env + ), "if you set DET_USER_TOKEN you probably want to use normal subprocess" + # Point at the same master as the session. + env["DET_MASTER"] = sess.master + # Configure the username and token directly through the environment, via the codepath normally + # designed for on-cluster auto-config situations. + env["DET_USER"] = sess.username + env["DET_USER_TOKEN"] = sess.token + # Disable the authentication cache, which, by design, is allowed to override that on-cluster + # auto-config situation. + env["DET_DEBUG_CONFIG_PATH"] = "/tmp/disable-e2e-auth-cache" + # Disable python's stdio buffering. + env["PYTHONUNBUFFERED"] = "1" + return env + + +def forbid_user_setting(cmd: List[str]) -> None: + if "-u" in cmd or "--user" in cmd: + raise ValueError( + "you should never be passing -u or --user to detproc; that is for setting the user " + "and that functionality belongs to the cli unit tests. If you want to run as a " + "different user, either use the sdk or pass in a different Session that is " + f"authenticated as that user. Command was: {cmd}" + ) + + +class Popen(subprocess.Popen): + def __init__( + self, + sess: api.Session, + cmd: List[str], + *args: Any, + env: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> None: + forbid_user_setting(cmd) + super().__init__(cmd, *args, env=mkenv(sess, env), **kwargs) # type: ignore + + +def run( + sess: api.Session, + cmd: List[str], + *args: Any, + env: Optional[Dict[str, str]] = None, + **kwargs: Any, +) -> subprocess.CompletedProcess: + forbid_user_setting(cmd) + p = subprocess.run(cmd, *args, env=mkenv(sess, env), **kwargs) # type: ignore + assert isinstance(p, subprocess.CompletedProcess) + return p + + +def check_call( + sess: api.Session, + cmd: List[str], + env: Optional[Dict[str, str]] = None, +) -> subprocess.CompletedProcess: + p = run(sess, cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if p.returncode != 0: + assert p.stdout is not None and p.stderr is not None + stdout = p.stdout.decode("utf8") + stderr = p.stderr.decode("utf8") + raise CalledProcessError(p.returncode, cmd, output=stdout, stderr=stderr) + return p + + +def check_output( + sess: api.Session, + cmd: List[str], + env: Optional[Dict[str, str]] = None, +) -> str: + p = run(sess, cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if p.returncode != 0: + assert p.stderr is not None + stderr = p.stderr.decode("utf8") + raise CalledProcessError(p.returncode, cmd, stderr=stderr) + out = p.stdout.decode() + assert isinstance(out, str) + return out + + +def check_error( + sess: api.Session, + cmd: List[str], + errmsg: str, +) -> subprocess.CompletedProcess: + p = run(sess, cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + assert p.returncode != 0 + assert p.stderr is not None + stderr = p.stderr.decode("utf8") + assert errmsg.lower() in stderr.lower(), f"did not find '{errmsg}' in '{stderr}'" + return p + + +def check_json( + sess: api.Session, + cmd: List[str], +) -> Any: + return json.loads(check_output(sess, cmd)) diff --git a/e2e_tests/tests/experiment/__init__.py b/e2e_tests/tests/experiment/__init__.py index 751c945c4cb..40b1e10068c 100644 --- a/e2e_tests/tests/experiment/__init__.py +++ b/e2e_tests/tests/experiment/__init__.py @@ -45,7 +45,3 @@ has_at_least_one_checkpoint, wait_for_at_least_one_checkpoint, ) - -from .record_profiling import ( - profile_test, -) diff --git a/e2e_tests/tests/experiment/experiment.py b/e2e_tests/tests/experiment/experiment.py index fe17df261b8..e3cb8af665d 100644 --- a/e2e_tests/tests/experiment/experiment.py +++ b/e2e_tests/tests/experiment/experiment.py @@ -10,20 +10,21 @@ import pytest from determined.common import api, util -from determined.common.api import authentication, bindings, certs -from determined.common.api.bindings import experimentv1State, trialv1State +from determined.common.api import bindings from tests import api_utils from tests import config as conf +from tests import detproc from tests.cluster import utils as cluster_utils def maybe_create_experiment( - config_file: str, model_def_file: Optional[str] = None, create_args: Optional[List[str]] = None + sess: api.Session, + config_file: str, + model_def_file: Optional[str] = None, + create_args: Optional[List[str]] = None, ) -> subprocess.CompletedProcess: command = [ "det", - "-m", - conf.make_master_url(), "experiment", "create", config_file, @@ -38,24 +39,31 @@ def maybe_create_experiment( env = os.environ.copy() env["DET_DEBUG"] = "true" - return subprocess.run( - command, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + return detproc.run( + sess, + command, + universal_newlines=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, ) def create_experiment( - config_file: str, model_def_file: Optional[str] = None, create_args: Optional[List[str]] = None + sess: api.Session, + config_file: str, + model_def_file: Optional[str] = None, + create_args: Optional[List[str]] = None, ) -> int: - completed_process = maybe_create_experiment(config_file, model_def_file, create_args) - assert completed_process.returncode == 0, "\nstdout:\n{} \nstderr:\n{}".format( - completed_process.stdout, completed_process.stderr - ) - m = re.search(r"Created experiment (\d+)\n", str(completed_process.stdout)) + p = maybe_create_experiment(sess, config_file, model_def_file, create_args) + assert p.returncode == 0, f"\nstdout:\n{p.stdout} \nstderr:\n{p.stderr}" + m = re.search(r"Created experiment (\d+)\n", str(p.stdout)) assert m is not None return int(m.group(1)) def maybe_run_autotuning_experiment( + sess: api.Session, config_file: str, model_def_file: str, create_args: Optional[List[str]] = None, @@ -80,98 +88,115 @@ def maybe_run_autotuning_experiment( env["DET_DEBUG"] = "true" env["DET_MASTER"] = conf.make_master_url() - return subprocess.run( - command, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + return detproc.run( + sess, + command, + universal_newlines=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, ) def run_autotuning_experiment( + sess: api.Session, config_file: str, model_def_file: str, create_args: Optional[List[str]] = None, search_method_name: str = "_test", max_trials: int = 4, ) -> int: - completed_process = maybe_run_autotuning_experiment( - config_file, model_def_file, create_args, search_method_name, max_trials + p = maybe_run_autotuning_experiment( + sess, config_file, model_def_file, create_args, search_method_name, max_trials ) - assert completed_process.returncode == 0, "\nstdout:\n{} \nstderr:\n{}".format( - completed_process.stdout, completed_process.stderr - ) - m = re.search(r"Created experiment (\d+)\n", str(completed_process.stdout)) + assert p.returncode == 0, f"\nstdout:\n{p.stdout}\nstderr:\n{p.stderr}" + m = re.search(r"Created experiment (\d+)\n", str(p.stdout)) assert m is not None return int(m.group(1)) -def archive_experiments(experiment_ids: List[int], name: Optional[str] = None) -> None: +def archive_experiments( + sess: api.Session, experiment_ids: List[int], name: Optional[str] = None +) -> None: body = bindings.v1ArchiveExperimentsRequest(experimentIds=experiment_ids) if name is not None: filters = bindings.v1BulkExperimentFilters(name=name) body = bindings.v1ArchiveExperimentsRequest(experimentIds=[], filters=filters) - bindings.post_ArchiveExperiments(api_utils.determined_test_session(), body=body) + bindings.post_ArchiveExperiments(sess, body=body) -def pause_experiment(experiment_id: int) -> None: - command = ["det", "-m", conf.make_master_url(), "experiment", "pause", str(experiment_id)] - subprocess.check_call(command) +def pause_experiment(sess: api.Session, experiment_id: int) -> None: + command = ["det", "experiment", "pause", str(experiment_id)] + detproc.check_call(sess, command) -def pause_experiments(experiment_ids: List[int], name: Optional[str] = None) -> None: +def pause_experiments( + sess: api.Session, + experiment_ids: List[int], + name: Optional[str] = None, +) -> None: body = bindings.v1PauseExperimentsRequest(experimentIds=experiment_ids) if name is not None: filters = bindings.v1BulkExperimentFilters(name=name) body = bindings.v1PauseExperimentsRequest(experimentIds=[], filters=filters) - bindings.post_PauseExperiments(api_utils.determined_test_session(), body=body) + bindings.post_PauseExperiments(sess, body=body) -def activate_experiment(experiment_id: int) -> None: - command = ["det", "-m", conf.make_master_url(), "experiment", "activate", str(experiment_id)] - subprocess.check_call(command) +def activate_experiment(sess: api.Session, experiment_id: int) -> None: + command = ["det", "experiment", "activate", str(experiment_id)] + detproc.check_call(sess, command) -def activate_experiments(experiment_ids: List[int], name: Optional[str] = None) -> None: +def activate_experiments( + sess: api.Session, experiment_ids: List[int], name: Optional[str] = None +) -> None: if name is None: body = bindings.v1ActivateExperimentsRequest(experimentIds=experiment_ids) else: filters = bindings.v1BulkExperimentFilters(name=name) body = bindings.v1ActivateExperimentsRequest(experimentIds=[], filters=filters) - bindings.post_ActivateExperiments(api_utils.determined_test_session(), body=body) + bindings.post_ActivateExperiments(sess, body=body) -def cancel_experiment(experiment_id: int) -> None: - bindings.post_CancelExperiment(api_utils.determined_test_session(), id=experiment_id) - wait_for_experiment_state(experiment_id, experimentv1State.CANCELED) +def cancel_experiment(sess: api.Session, experiment_id: int) -> None: + bindings.post_CancelExperiment(sess, id=experiment_id) + wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.CANCELED) -def kill_experiment(experiment_id: int) -> None: - bindings.post_KillExperiment(api_utils.determined_test_session(), id=experiment_id) - wait_for_experiment_state(experiment_id, experimentv1State.CANCELED) +def kill_experiment(sess: api.Session, experiment_id: int) -> None: + bindings.post_KillExperiment(sess, id=experiment_id) + wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.CANCELED) -def cancel_experiments(experiment_ids: List[int], name: Optional[str] = None) -> None: +def cancel_experiments( + sess: api.Session, experiment_ids: List[int], name: Optional[str] = None +) -> None: if name is None: body = bindings.v1CancelExperimentsRequest(experimentIds=experiment_ids) else: filters = bindings.v1BulkExperimentFilters(name=name) body = bindings.v1CancelExperimentsRequest(experimentIds=[], filters=filters) - bindings.post_CancelExperiments(api_utils.determined_test_session(), body=body) + bindings.post_CancelExperiments(sess, body=body) -def kill_experiments(experiment_ids: List[int], name: Optional[str] = None) -> None: +def kill_experiments( + sess: api.Session, experiment_ids: List[int], name: Optional[str] = None +) -> None: if name is None: body = bindings.v1KillExperimentsRequest(experimentIds=experiment_ids) else: filters = bindings.v1BulkExperimentFilters(name=name) body = bindings.v1KillExperimentsRequest(experimentIds=[], filters=filters) - bindings.post_KillExperiments(api_utils.determined_test_session(), body=body) + bindings.post_KillExperiments(sess, body=body) -def kill_trial(trial_id: int) -> None: - bindings.post_KillTrial(api_utils.determined_test_session(), id=trial_id) - wait_for_trial_state(trial_id, trialv1State.CANCELED) +def kill_trial(sess: api.Session, trial_id: int) -> None: + bindings.post_KillTrial(sess, id=trial_id) + wait_for_trial_state(sess, trial_id, bindings.trialv1State.CANCELED) def wait_for_experiment_by_name_is_active( + sess: api.Session, experiment_name: str, min_trials: int = 1, max_wait_secs: int = conf.DEFAULT_MAX_WAIT_SECS, @@ -179,9 +204,7 @@ def wait_for_experiment_by_name_is_active( ) -> int: for seconds_waited in range(max_wait_secs): try: - response = bindings.get_GetExperiments( - api_utils.determined_test_session(), name=experiment_name - ).experiments + response = bindings.get_GetExperiments(sess, name=experiment_name).experiments if len(response) == 0: time.sleep(1) continue @@ -194,8 +217,7 @@ def wait_for_experiment_by_name_is_active( experiment_id = experiment.id except api.errors.NotFoundException: logging.warning( - "Experiment not yet available to check state: " - "experiment {}".format(experiment_name) + f"Experiment not yet available to check state: experiment {experiment_name}" ) time.sleep(0.25) continue @@ -207,18 +229,18 @@ def wait_for_experiment_by_name_is_active( continue if is_terminal_experiment_state(experiment.state): - report_failed_experiment(experiment_id) + report_failed_experiment(sess, experiment_id) pytest.fail( f"Experiment {experiment_id} terminated in {experiment.state.value} state, " - f"expected {experimentv1State.ACTIVE}" + f"expected {bindings.experimentv1State.ACTIVE}" ) if seconds_waited > 0 and seconds_waited % log_every == 0: print( f"Waited {seconds_waited} seconds for experiment {experiment_name} " f"(currently {experiment.state.value}) to reach " - f"{experimentv1State.ACTIVE}" + f"{bindings.experimentv1State.ACTIVE}" ) time.sleep(1) @@ -227,26 +249,26 @@ def wait_for_experiment_by_name_is_active( pytest.fail(f"Experiment {experiment_name} did not start any trial {max_wait_secs} seconds") -def _is_experiment_active(exp_state: experimentv1State) -> bool: +def _is_experiment_active(exp_state: bindings.experimentv1State) -> bool: return exp_state in ( - experimentv1State.ACTIVE, - experimentv1State.RUNNING, - experimentv1State.QUEUED, - experimentv1State.PULLING, - experimentv1State.STARTING, + bindings.experimentv1State.ACTIVE, + bindings.experimentv1State.RUNNING, + bindings.experimentv1State.QUEUED, + bindings.experimentv1State.PULLING, + bindings.experimentv1State.STARTING, ) def wait_for_experiment_state( + sess: api.Session, experiment_id: int, - target_state: experimentv1State, + target_state: bindings.experimentv1State, max_wait_secs: int = conf.DEFAULT_MAX_WAIT_SECS, log_every: int = 60, - credentials: Optional[authentication.Credentials] = None, ) -> None: for seconds_waited in range(max_wait_secs): try: - state = experiment_state(experiment_id, credentials) + state = experiment_state(sess, experiment_id) except api.errors.NotFoundException: logging.warning( "Experiment not yet available to check state: " @@ -260,7 +282,7 @@ def wait_for_experiment_state( if is_terminal_experiment_state(state): if state != target_state: - report_failed_experiment(experiment_id) + report_failed_experiment(sess, experiment_id) pytest.fail( f"Experiment {experiment_id} terminated in {state.value} state, " @@ -276,9 +298,9 @@ def wait_for_experiment_state( time.sleep(1) else: - if target_state == experimentv1State.COMPLETED: - kill_experiment(experiment_id) - report_failed_experiment(experiment_id) + if target_state == bindings.experimentv1State.COMPLETED: + kill_experiment(sess, experiment_id) + report_failed_experiment(sess, experiment_id) pytest.fail( "Experiment did not reach target state {} after {} seconds".format( target_state.value, max_wait_secs @@ -287,14 +309,15 @@ def wait_for_experiment_state( def wait_for_trial_state( + sess: api.Session, trial_id: int, - target_state: trialv1State, + target_state: bindings.trialv1State, max_wait_secs: int = conf.DEFAULT_MAX_WAIT_SECS, log_every: int = 60, ) -> None: for seconds_waited in range(max_wait_secs): try: - state = trial_state(trial_id) + state = trial_state(sess, trial_id) except api.errors.NotFoundException: logging.warning("Trial not yet available to check state: " "trial {}".format(trial_id)) time.sleep(0.25) @@ -305,7 +328,7 @@ def wait_for_trial_state( if is_terminal_state(state): if state != target_state: - report_failed_trial(trial_id, target_state=target_state, state=state) + report_failed_trial(sess, trial_id, target_state=target_state, state=state) pytest.fail( f"Trial {trial_id} terminated in {state.value} state, " @@ -321,10 +344,10 @@ def wait_for_trial_state( time.sleep(1) else: - state = trial_state(trial_id) - if target_state == trialv1State.COMPLETED: - kill_trial(trial_id) - report_failed_trial(trial_id, target_state=target_state, state=state) + state = trial_state(sess, trial_id) + if target_state == bindings.trialv1State.COMPLETED: + kill_trial(sess, trial_id) + report_failed_trial(sess, trial_id, target_state=target_state, state=state) pytest.fail( "Trial did not reach target state {} after {} seconds".format( target_state.value, max_wait_secs @@ -332,22 +355,20 @@ def wait_for_trial_state( ) -def experiment_has_active_workload(experiment_id: int) -> bool: - certs.cli_cert = certs.default_load(conf.make_master_url()) - authentication.cli_auth = authentication.Authentication(conf.make_master_url()) - r = api.get(conf.make_master_url(), "tasks").json() +def experiment_has_active_workload(sess: api.Session, experiment_id: int) -> bool: + r = sess.get("tasks").json() for task in r.values(): - if "Experiment {}".format(experiment_id) in task["name"] and len(task["resources"]) > 0: + if f"Experiment {experiment_id}" in task["name"] and len(task["resources"]) > 0: return True return False def wait_for_experiment_active_workload( - experiment_id: int, max_ticks: int = conf.MAX_TASK_SCHEDULED_SECS + sess: api.Session, experiment_id: int, max_ticks: int = conf.MAX_TASK_SCHEDULED_SECS ) -> None: for _ in range(conf.MAX_TASK_SCHEDULED_SECS): - if experiment_has_active_workload(experiment_id): + if experiment_has_active_workload(sess, experiment_id): return time.sleep(1) @@ -358,6 +379,7 @@ def wait_for_experiment_active_workload( def wait_for_at_least_n_trials( + sess: api.Session, experiment_id: int, n: int, timeout: int = 30, @@ -365,7 +387,7 @@ def wait_for_at_least_n_trials( """Wait for enough trials to start, then return the trials found.""" deadline = time.time() + timeout while True: - trials = experiment_trials(experiment_id) + trials = experiment_trials(sess, experiment_id) if len(trials) >= n: return trials if time.time() > deadline: @@ -373,10 +395,10 @@ def wait_for_at_least_n_trials( def wait_for_experiment_workload_progress( - experiment_id: int, max_ticks: int = conf.MAX_TRIAL_BUILD_SECS + sess: api.Session, experiment_id: int, max_ticks: int = conf.MAX_TRIAL_BUILD_SECS ) -> None: for _ in range(conf.MAX_TRIAL_BUILD_SECS): - trials = experiment_trials(experiment_id) + trials = experiment_trials(sess, experiment_id) if len(trials) > 0: only_trial = trials[0] if len(only_trial.workloads) > 1: @@ -388,10 +410,8 @@ def wait_for_experiment_workload_progress( ) -def experiment_has_completed_workload(experiment_id: int) -> bool: - certs.cli_cert = certs.default_load(conf.make_master_url()) - authentication.cli_auth = authentication.Authentication(conf.make_master_url()) - trials = experiment_trials(experiment_id) +def experiment_has_completed_workload(sess: api.Session, experiment_id: int) -> bool: + trials = experiment_trials(sess, experiment_id) if not any(trials): return False @@ -403,9 +423,8 @@ def experiment_has_completed_workload(experiment_id: int) -> bool: return False -def experiment_first_trial(exp_id: int) -> int: - session = api_utils.determined_test_session() - trials = bindings.get_GetExperimentTrials(session, experimentId=exp_id).trials +def experiment_first_trial(sess: api.Session, exp_id: int) -> int: + trials = bindings.get_GetExperimentTrials(sess, experimentId=exp_id).trials assert len(trials) > 0 trial = trials[0] @@ -413,23 +432,19 @@ def experiment_first_trial(exp_id: int) -> int: return trial_id -def experiment_config_json(experiment_id: int) -> Dict[str, Any]: - r = bindings.get_GetExperiment(api_utils.determined_test_session(), experimentId=experiment_id) +def experiment_config_json(sess: api.Session, experiment_id: int) -> Dict[str, Any]: + r = bindings.get_GetExperiment(api_utils.user_session(), experimentId=experiment_id) assert r.experiment and r.experiment.config return r.experiment.config -def experiment_state( - experiment_id: int, credentials: Optional[authentication.Credentials] = None -) -> experimentv1State: - r = bindings.get_GetExperiment( - api_utils.determined_test_session(credentials), experimentId=experiment_id - ) +def experiment_state(sess: api.Session, experiment_id: int) -> bindings.experimentv1State: + r = bindings.get_GetExperiment(sess, experimentId=experiment_id) return r.experiment.state -def trial_state(trial_id: int) -> trialv1State: - r = bindings.get_GetTrial(api_utils.determined_test_session(), trialId=trial_id) +def trial_state(sess: api.Session, trial_id: int) -> bindings.trialv1State: + r = bindings.get_GetTrial(sess, trialId=trial_id) return r.trial.state @@ -441,8 +456,7 @@ def __init__( self.workloads = workloads -def experiment_trials(experiment_id: int) -> List[TrialPlusWorkload]: - sess = api_utils.determined_test_session() +def experiment_trials(sess: api.Session, experiment_id: int) -> List[TrialPlusWorkload]: r1 = bindings.get_GetExperimentTrials(sess, experimentId=experiment_id) src_trials = r1.trials trials = [] @@ -453,75 +467,78 @@ def experiment_trials(experiment_id: int) -> List[TrialPlusWorkload]: return trials -def cancel_single(experiment_id: int, should_have_trial: bool = False) -> None: - cancel_experiment(experiment_id) +def cancel_single(sess: api.Session, experiment_id: int, should_have_trial: bool = False) -> None: + cancel_experiment(sess, experiment_id) if should_have_trial: - trials = experiment_trials(experiment_id) + trials = experiment_trials(sess, experiment_id) assert len(trials) == 1, len(trials) trial = trials[0].trial - assert trial.state == trialv1State.CANCELED + assert trial.state == bindings.trialv1State.CANCELED -def kill_single(experiment_id: int, should_have_trial: bool = False) -> None: - kill_experiment(experiment_id) +def kill_single(sess: api.Session, experiment_id: int, should_have_trial: bool = False) -> None: + kill_experiment(sess, experiment_id) if should_have_trial: - trials = experiment_trials(experiment_id) + trials = experiment_trials(sess, experiment_id) assert len(trials) == 1, len(trials) trial = trials[0].trial - assert trial.state == trialv1State.CANCELED + assert trial.state == bindings.trialv1State.CANCELED -def is_terminal_experiment_state(state: experimentv1State) -> bool: +def is_terminal_experiment_state(state: bindings.experimentv1State) -> bool: return state in ( - experimentv1State.CANCELED, - experimentv1State.COMPLETED, - experimentv1State.ERROR, + bindings.experimentv1State.CANCELED, + bindings.experimentv1State.COMPLETED, + bindings.experimentv1State.ERROR, ) -def is_terminal_state(state: trialv1State) -> bool: +def is_terminal_state(state: bindings.trialv1State) -> bool: return state in ( - trialv1State.CANCELED, - trialv1State.COMPLETED, - trialv1State.ERROR, + bindings.trialv1State.CANCELED, + bindings.trialv1State.COMPLETED, + bindings.trialv1State.ERROR, ) -def num_trials(experiment_id: int) -> int: - return len(experiment_trials(experiment_id)) +def num_trials(sess: api.Session, experiment_id: int) -> int: + return len(experiment_trials(sess, experiment_id)) -def num_active_trials(experiment_id: int) -> int: +def num_active_trials(sess: api.Session, experiment_id: int) -> int: return sum( 1 - if t.trial.state in [trialv1State.RUNNING, trialv1State.STARTING, trialv1State.PULLING] + if t.trial.state + in [ + bindings.trialv1State.RUNNING, + bindings.trialv1State.STARTING, + bindings.trialv1State.PULLING, + ] else 0 - for t in experiment_trials(experiment_id) + for t in experiment_trials(sess, experiment_id) ) -def num_completed_trials(experiment_id: int) -> int: +def num_completed_trials(sess: api.Session, experiment_id: int) -> int: return sum( - 1 if t.trial.state == trialv1State.COMPLETED else 0 - for t in experiment_trials(experiment_id) + 1 if t.trial.state == bindings.trialv1State.COMPLETED else 0 + for t in experiment_trials(sess, experiment_id) ) -def num_error_trials(experiment_id: int) -> int: +def num_error_trials(sess: api.Session, experiment_id: int) -> int: return sum( - 1 if t.trial.state == trialv1State.ERROR else 0 for t in experiment_trials(experiment_id) + 1 if t.trial.state == bindings.trialv1State.ERROR else 0 + for t in experiment_trials(sess, experiment_id) ) -def trial_logs(trial_id: int, follow: bool = False) -> List[str]: - return [ - tl.message - for tl in api.trial_logs(api_utils.determined_test_session(), trial_id, follow=follow) - ] +def trial_logs(sess: api.Session, trial_id: int, follow: bool = False) -> List[str]: + return [tl.message for tl in api.trial_logs(sess, trial_id, follow=follow)] def workloads_with_training( @@ -554,8 +571,10 @@ def workloads_with_checkpoint( return ret -def check_if_string_present_in_trial_logs(trial_id: int, target_string: str) -> bool: - logs = trial_logs(trial_id, follow=True) +def check_if_string_present_in_trial_logs( + sess: api.Session, trial_id: int, target_string: str +) -> bool: + logs = trial_logs(sess, trial_id, follow=True) for log_line in logs: if target_string in log_line: return True @@ -563,12 +582,12 @@ def check_if_string_present_in_trial_logs(trial_id: int, target_string: str) -> return False -def assert_patterns_in_trial_logs(trial_id: int, patterns: List[str]) -> None: +def assert_patterns_in_trial_logs(sess: api.Session, trial_id: int, patterns: List[str]) -> None: """Match each regex pattern in the list to the logs, one-at-a-time, in order.""" assert patterns, "must provide at least one pattern" patterns_iter = iter(patterns) p = re.compile(next(patterns_iter)) - logs = trial_logs(trial_id, follow=True) + logs = trial_logs(sess, trial_id, follow=True) for log_line in logs: if p.search(log_line) is None: continue @@ -585,8 +604,8 @@ def assert_patterns_in_trial_logs(trial_id: int, patterns: List[str]) -> None: ) -def assert_performed_initial_validation(exp_id: int) -> None: - trials = experiment_trials(exp_id) +def assert_performed_initial_validation(sess: api.Session, exp_id: int) -> None: + trials = experiment_trials(sess, exp_id) assert len(trials) > 0 workloads = trials[0].workloads @@ -619,17 +638,17 @@ def last_workload_matches_last_checkpoint( assert last_checkpoint_detail.state == bindings.checkpointv1State.COMPLETED -def assert_performed_final_checkpoint(exp_id: int) -> None: - trials = experiment_trials(exp_id) +def assert_performed_final_checkpoint(sess: api.Session, exp_id: int) -> None: + trials = experiment_trials(sess, exp_id) assert len(trials) > 0 last_workload_matches_last_checkpoint(trials[0].workloads) -def run_cmd_and_print_on_error(cmd: List[str]) -> None: +def run_cmd_and_print_on_error(sess: api.Session, cmd: List[str]) -> None: """ We run some commands to make sure they work, but we don't need their output polluting the logs. """ - p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p = detproc.Popen(sess, cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = p.communicate() ret = p.wait() if ret != 0: @@ -643,7 +662,7 @@ def run_cmd_and_print_on_error(cmd: List[str]) -> None: raise ValueError(f"cmd failed: {cmd} exited {ret}") -def run_describe_cli_tests(experiment_id: int) -> None: +def run_describe_cli_tests(sess: api.Session, experiment_id: int) -> None: """ Runs `det experiment describe` CLI command on a finished experiment. Will raise an exception if `det experiment describe` @@ -652,16 +671,15 @@ def run_describe_cli_tests(experiment_id: int) -> None: # "det experiment describe" without metrics. with tempfile.TemporaryDirectory() as tmpdir: run_cmd_and_print_on_error( + sess, [ "det", - "-m", - conf.make_master_url(), "experiment", "describe", str(experiment_id), "--outdir", tmpdir, - ] + ], ) assert os.path.exists(os.path.join(tmpdir, "experiments.csv")) @@ -671,17 +689,16 @@ def run_describe_cli_tests(experiment_id: int) -> None: # "det experiment describe" with metrics. with tempfile.TemporaryDirectory() as tmpdir: run_cmd_and_print_on_error( + sess, [ "det", - "-m", - conf.make_master_url(), "experiment", "describe", str(experiment_id), "--metrics", "--outdir", tmpdir, - ] + ], ) assert os.path.exists(os.path.join(tmpdir, "experiments.csv")) @@ -689,43 +706,47 @@ def run_describe_cli_tests(experiment_id: int) -> None: assert os.path.exists(os.path.join(tmpdir, "trials.csv")) -def run_list_cli_tests(experiment_id: int) -> None: +def run_list_cli_tests(sess: api.Session, experiment_id: int) -> None: """ Runs list-related CLI commands on a finished experiment. Will raise an exception if the CLI command encounters a traceback failure. """ + run_cmd_and_print_on_error(sess, ["det", "experiment", "list-trials", str(experiment_id)]) run_cmd_and_print_on_error( - ["det", "-m", conf.make_master_url(), "experiment", "list-trials", str(experiment_id)] - ) - run_cmd_and_print_on_error( - ["det", "-m", conf.make_master_url(), "experiment", "list-checkpoints", str(experiment_id)] + sess, + ["det", "experiment", "list-checkpoints", str(experiment_id)], ) run_cmd_and_print_on_error( + sess, [ "det", - "-m", - conf.make_master_url(), "experiment", "list-checkpoints", "--best", str(1), str(experiment_id), - ] + ], ) -def report_failed_experiment(experiment_id: int) -> None: - trials = experiment_trials(experiment_id) - active = sum(1 for t in trials if t.trial.state == trialv1State.RUNNING) - paused = sum(1 for t in trials if t.trial.state == trialv1State.PAUSED) - stopping_completed = sum(1 for t in trials if t.trial.state == trialv1State.STOPPING_COMPLETED) - stopping_canceled = sum(1 for t in trials if t.trial.state == trialv1State.STOPPING_CANCELED) - stopping_error = sum(1 for t in trials if t.trial.state == trialv1State.STOPPING_ERROR) - completed = sum(1 for t in trials if t.trial.state == trialv1State.COMPLETED) - canceled = sum(1 for t in trials if t.trial.state == trialv1State.CANCELED) - errored = sum(1 for t in trials if t.trial.state == trialv1State.ERROR) - stopping_killed = sum(1 for t in trials if t.trial.state == trialv1State.STOPPING_KILLED) +def report_failed_experiment(sess: api.Session, experiment_id: int) -> None: + trials = experiment_trials(sess, experiment_id) + active = sum(1 for t in trials if t.trial.state == bindings.trialv1State.RUNNING) + paused = sum(1 for t in trials if t.trial.state == bindings.trialv1State.PAUSED) + stopping_completed = sum( + 1 for t in trials if t.trial.state == bindings.trialv1State.STOPPING_COMPLETED + ) + stopping_canceled = sum( + 1 for t in trials if t.trial.state == bindings.trialv1State.STOPPING_CANCELED + ) + stopping_error = sum(1 for t in trials if t.trial.state == bindings.trialv1State.STOPPING_ERROR) + completed = sum(1 for t in trials if t.trial.state == bindings.trialv1State.COMPLETED) + canceled = sum(1 for t in trials if t.trial.state == bindings.trialv1State.CANCELED) + errored = sum(1 for t in trials if t.trial.state == bindings.trialv1State.ERROR) + stopping_killed = sum( + 1 for t in trials if t.trial.state == bindings.trialv1State.STOPPING_KILLED + ) print( f"Experiment {experiment_id}: {len(trials)} trials, {completed} completed, " @@ -736,21 +757,27 @@ def report_failed_experiment(experiment_id: int) -> None: ) for trial in trials: - print_trial_logs(trial.trial.id) + print_trial_logs(sess, trial.trial.id) -def report_failed_trial(trial_id: int, target_state: trialv1State, state: trialv1State) -> None: +def report_failed_trial( + sess: api.Session, + trial_id: int, + target_state: bindings.trialv1State, + state: bindings.trialv1State, +) -> None: print(f"Trial {trial_id} was not {target_state.value} but {state.value}", file=sys.stderr) - print_trial_logs(trial_id) + print_trial_logs(sess, trial_id) -def print_trial_logs(trial_id: int) -> None: - print("******** Start of logs for trial {} ********".format(trial_id), file=sys.stderr) - print("".join(trial_logs(trial_id)), file=sys.stderr) - print("******** End of logs for trial {} ********".format(trial_id), file=sys.stderr) +def print_trial_logs(sess: api.Session, trial_id: int) -> None: + print(f"******** Start of logs for trial {trial_id} ********", file=sys.stderr) + print("".join(trial_logs(sess, trial_id)), file=sys.stderr) + print(f"******** End of logs for trial {trial_id} ********", file=sys.stderr) def run_basic_test( + sess: api.Session, config_file: str, model_def_file: str, expected_trials: Optional[int], @@ -761,23 +788,25 @@ def run_basic_test( priority: int = -1, ) -> int: assert os.path.isdir(model_def_file) - experiment_id = create_experiment(config_file, model_def_file, create_args) + experiment_id = create_experiment(sess, config_file, model_def_file, create_args) if priority != -1: - set_priority(experiment_id=experiment_id, priority=priority) + set_priority(sess, experiment_id=experiment_id, priority=priority) wait_for_experiment_state( + sess, experiment_id, - experimentv1State.COMPLETED, + bindings.experimentv1State.COMPLETED, max_wait_secs=max_wait_secs, ) - assert num_active_trials(experiment_id) == 0 + assert num_active_trials(sess, experiment_id) == 0 verify_completed_experiment_metadata( - experiment_id, expected_trials, expect_workloads, expect_checkpoints + sess, experiment_id, expected_trials, expect_workloads, expect_checkpoints ) return experiment_id def run_basic_autotuning_test( + sess: api.Session, config_file: str, model_def_file: str, expected_trials: Optional[int], @@ -792,56 +821,61 @@ def run_basic_autotuning_test( ) -> int: assert os.path.isdir(model_def_file) orchestrator_exp_id = run_autotuning_experiment( - config_file, model_def_file, create_args, search_method_name, max_trials + sess, config_file, model_def_file, create_args, search_method_name, max_trials ) if priority != -1: - set_priority(experiment_id=orchestrator_exp_id, priority=priority) + set_priority(sess, experiment_id=orchestrator_exp_id, priority=priority) # Wait for the Autotuning Single Searcher ("Orchestrator") to finish wait_for_experiment_state( + sess, orchestrator_exp_id, - experimentv1State.COMPLETED, + bindings.experimentv1State.COMPLETED, max_wait_secs=max_wait_secs, ) - assert num_active_trials(orchestrator_exp_id) == 0 + assert num_active_trials(sess, orchestrator_exp_id) == 0 verify_completed_experiment_metadata( - orchestrator_exp_id, expected_trials, expect_workloads, expect_checkpoints + sess, orchestrator_exp_id, expected_trials, expect_workloads, expect_checkpoints ) - client_exp_id = fetch_autotuning_client_experiment(orchestrator_exp_id) + client_exp_id = fetch_autotuning_client_experiment(sess, orchestrator_exp_id) # Wait for the Autotuning Custom Searcher Experiment ("Client Experiment") to finish wait_for_experiment_state( + sess, client_exp_id, - experimentv1State.COMPLETED if not expect_client_failed else experimentv1State.ERROR, + bindings.experimentv1State.COMPLETED + if not expect_client_failed + else bindings.experimentv1State.ERROR, max_wait_secs=max_wait_secs, ) - assert num_active_trials(orchestrator_exp_id) == 0 + assert num_active_trials(sess, orchestrator_exp_id) == 0 verify_completed_experiment_metadata( - orchestrator_exp_id, expected_trials, expect_workloads, expect_checkpoints + sess, orchestrator_exp_id, expected_trials, expect_workloads, expect_checkpoints ) return client_exp_id -def fetch_autotuning_client_experiment(exp_id: int) -> int: - command = ["det", "-m", conf.make_master_url(), "experiment", "logs", str(exp_id)] +def fetch_autotuning_client_experiment(sess: api.Session, exp_id: int) -> int: + command = ["det", "experiment", "logs", str(exp_id)] env = os.environ.copy() env["DET_DEBUG"] = "true" - completed_process = subprocess.run( - command, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + p = detproc.run( + sess, + command, + universal_newlines=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, ) - assert completed_process.returncode == 0, "\nstdout:\n{} \nstderr:\n{}".format( - completed_process.stdout, completed_process.stderr - ) - m = re.search(r"Created experiment (\d+)\n", str(completed_process.stdout)) + assert p.returncode == 0, f"\nstdout:\n{p.stdout} \nstderr:\n{p.stderr}" + m = re.search(r"Created experiment (\d+)\n", str(p.stdout)) assert m is not None return int(m.group(1)) -def set_priority(experiment_id: int, priority: int) -> None: +def set_priority(sess: api.Session, experiment_id: int, priority: int) -> None: command = [ "det", - "-m", - conf.make_master_url(), "experiment", "set", "priority", @@ -849,16 +883,15 @@ def set_priority(experiment_id: int, priority: int) -> None: str(priority), ] - completed_process = subprocess.run( - command, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + p = detproc.run( + sess, command, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - assert completed_process.returncode == 0, "\nstdout:\n{} \nstderr:\n{}".format( - completed_process.stdout, completed_process.stderr - ) + assert p.returncode == 0, f"\nstdout:\n{p.stdout} \nstderr:\n{p.stderr}" def verify_completed_experiment_metadata( + sess: api.Session, experiment_id: int, num_expected_trials: Optional[int], expect_workloads: bool = True, @@ -867,19 +900,20 @@ def verify_completed_experiment_metadata( # If `expected_trials` is None, the expected number of trials is # non-deterministic. if num_expected_trials is not None: - assert num_trials(experiment_id) == num_expected_trials - assert num_completed_trials(experiment_id) == num_expected_trials + assert num_trials(sess, experiment_id) == num_expected_trials + assert num_completed_trials(sess, experiment_id) == num_expected_trials # Check that every trial and step is COMPLETED. - trials = experiment_trials(experiment_id) + trials = experiment_trials(sess, experiment_id) assert len(trials) > 0 for t in trials: trial = t.trial - if trial.state != trialv1State.COMPLETED: + if trial.state != bindings.trialv1State.COMPLETED: report_failed_trial( + sess, trial.id, - target_state=trialv1State.COMPLETED, + target_state=bindings.trialv1State.COMPLETED, state=trial.state, ) pytest.fail(f"Trial {trial.id} was not STATE_COMPLETED but {trial.state.value}") @@ -888,7 +922,7 @@ def verify_completed_experiment_metadata( continue if len(t.workloads) == 0: - print_trial_logs(trial.id) + print_trial_logs(sess, trial.id) raise AssertionError( f"trial {trial.id} is in {trial.state.value} state but has 0 steps/workloads" ) @@ -920,23 +954,25 @@ def verify_completed_experiment_metadata( # polling is longer). max_secs_to_free_slots = 300 if api_utils.is_hpc() else 30 for _ in range(max_secs_to_free_slots): - if cluster_utils.num_free_slots() == cluster_utils.num_slots(): + if cluster_utils.num_free_slots(sess) == cluster_utils.num_slots(sess): break time.sleep(1) else: - raise AssertionError("Slots failed to free after experiment {}".format(experiment_id)) + raise AssertionError(f"Slots failed to free after experiment {experiment_id}") # Run a series of CLI tests on the finished experiment, to sanity check # that basic CLI commands don't raise errors. - run_describe_cli_tests(experiment_id) - run_list_cli_tests(experiment_id) + run_describe_cli_tests(sess, experiment_id) + run_list_cli_tests(sess, experiment_id) # Use Determined to run an experiment that we expect to fail. -def run_failure_test(config_file: str, model_def_file: str, error_str: Optional[str] = None) -> int: - experiment_id = create_experiment(config_file, model_def_file) +def run_failure_test( + sess: api.Session, config_file: str, model_def_file: str, error_str: Optional[str] = None +) -> int: + experiment_id = create_experiment(sess, config_file, model_def_file) - wait_for_experiment_state(experiment_id, experimentv1State.ERROR) + wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.ERROR) # The searcher is configured with a `max_trials` of 8. Since the # first step of each trial results in an error, there should be no @@ -947,31 +983,32 @@ def run_failure_test(config_file: str, model_def_file: str, error_str: Optional[ # might start a trial but cancel it before we hit the error in the # model definition. - assert num_active_trials(experiment_id) == 0 - assert num_completed_trials(experiment_id) == 0 - assert num_error_trials(experiment_id) >= 1 + assert num_active_trials(sess, experiment_id) == 0 + assert num_completed_trials(sess, experiment_id) == 0 + assert num_error_trials(sess, experiment_id) >= 1 # For each failed trial, check for the expected error in the logs. - trials = experiment_trials(experiment_id) + trials = experiment_trials(sess, experiment_id) for t in trials: trial = t.trial - if trial.state != trialv1State.ERROR: + if trial.state != bindings.trialv1State.ERROR: continue - logs = trial_logs(trial.id) + logs = trial_logs(sess, trial.id) if error_str is not None: try: assert any(error_str in line for line in logs) except AssertionError: # Display error log for triage of this failure print(f"Trial {trial.id} log did not contain expected message: {error_str}") - print_trial_logs(trial.id) + print_trial_logs(sess, trial.id) raise return experiment_id def run_basic_test_with_temp_config( + sess: api.Session, config: Dict[Any, Any], model_def_path: str, expected_trials: Optional[int], @@ -982,6 +1019,7 @@ def run_basic_test_with_temp_config( with open(tf.name, "w") as f: util.yaml_safe_dump(config, f) experiment_id = run_basic_test( + sess, tf.name, model_def_path, expected_trials, @@ -992,6 +1030,7 @@ def run_basic_test_with_temp_config( def run_failure_test_with_temp_config( + sess: api.Session, config: Dict[Any, Any], model_def_path: str, error_str: Optional[str] = None, @@ -999,7 +1038,7 @@ def run_failure_test_with_temp_config( with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: util.yaml_safe_dump(config, f) - return run_failure_test(tf.name, model_def_path, error_str=error_str) + return run_failure_test(sess, tf.name, model_def_path, error_str=error_str) def shared_fs_checkpoint_config() -> Dict[str, str]: @@ -1031,17 +1070,19 @@ def root_user_home_bind_mount() -> Dict[str, str]: return {"host_path": "/tmp", "container_path": "/root"} -def has_at_least_one_checkpoint(experiment_id: int) -> bool: - for trial in experiment_trials(experiment_id): +def has_at_least_one_checkpoint(sess: api.Session, experiment_id: int) -> bool: + for trial in experiment_trials(sess, experiment_id): if len(workloads_with_checkpoint(trial.workloads)) > 0: return True return False -def wait_for_at_least_one_checkpoint(experiment_id: int, timeout: int = 120) -> None: +def wait_for_at_least_one_checkpoint( + sess: api.Session, experiment_id: int, timeout: int = 120 +) -> None: for _ in range(timeout): - if has_at_least_one_checkpoint(experiment_id): + if has_at_least_one_checkpoint(sess, experiment_id): return else: time.sleep(1) - pytest.fail("Experiment did not reach at least one checkpoint after {} seconds".format(timeout)) + pytest.fail(f"Experiment did not reach at least one checkpoint after {timeout} seconds") diff --git a/e2e_tests/tests/experiment/record_profiling.py b/e2e_tests/tests/experiment/record_profiling.py index a3ac507ce58..2de0595fd2e 100644 --- a/e2e_tests/tests/experiment/record_profiling.py +++ b/e2e_tests/tests/experiment/record_profiling.py @@ -1,22 +1,20 @@ import json -from statistics import mean +import statistics from typing import Callable, Dict, List, Optional, Tuple -from urllib.parse import urlencode -from determined.common import api -from determined.profiler import SysMetricName -from tests import config as conf +from determined import profiler +from tests import api_utils -summary_methods: Dict[str, Callable] = {"avg": mean, "max": max, "min": min} +summary_methods: Dict[str, Callable] = {"avg": statistics.mean, "max": max, "min": min} default_metrics: Dict[str, List[str]] = { - SysMetricName.GPU_UTIL_METRIC: ["avg", "max"], - SysMetricName.SIMPLE_CPU_UTIL_METRIC: ["avg", "max"], - SysMetricName.DISK_IOPS_METRIC: ["avg", "max"], - SysMetricName.DISK_THRU_READ_METRIC: ["avg", "max"], - SysMetricName.DISK_THRU_WRITE_METRIC: ["avg", "max"], - SysMetricName.NET_THRU_SENT_METRIC: ["avg"], - SysMetricName.NET_THRU_RECV_METRIC: ["avg"], + profiler.SysMetricName.GPU_UTIL_METRIC: ["avg", "max"], + profiler.SysMetricName.SIMPLE_CPU_UTIL_METRIC: ["avg", "max"], + profiler.SysMetricName.DISK_IOPS_METRIC: ["avg", "max"], + profiler.SysMetricName.DISK_THRU_READ_METRIC: ["avg", "max"], + profiler.SysMetricName.DISK_THRU_WRITE_METRIC: ["avg", "max"], + profiler.SysMetricName.NET_THRU_SENT_METRIC: ["avg"], + profiler.SysMetricName.NET_THRU_RECV_METRIC: ["avg"], } @@ -59,18 +57,14 @@ def get_profiling_metrics(trial_id: int, metric_type: str) -> List[float]: """ Calls profiler API to return a list of metric values given trial ID and metric type """ - with api.get( - conf.make_master_url(), - "api/v1/trials/{}/profiler/metrics?{}".format( - trial_id, - urlencode( - { - "labels.name": metric_type, - "labels.metricType": "PROFILER_METRIC_TYPE_SYSTEM", - "follow": "true", - } - ), - ), + sess = api_utils.user_session() + with sess.get( + f"api/v1/trials/{trial_id}/profiler/metrics", + params={ + "labels.name": metric_type, + "labels.metricType": "PROFILER_METRIC_TYPE_SYSTEM", + "follow": "true", + }, stream=True, ) as r: return [ diff --git a/e2e_tests/tests/experiment/test_allocation_csv.py b/e2e_tests/tests/experiment/test_allocation_csv.py index 2036281b025..0f08a3f3b7d 100644 --- a/e2e_tests/tests/experiment/test_allocation_csv.py +++ b/e2e_tests/tests/experiment/test_allocation_csv.py @@ -1,13 +1,13 @@ import csv +import datetime +import io import re -from datetime import datetime, timedelta, timezone -from io import StringIO import pytest import requests -from determined.common import api -from determined.common.api.bindings import experimentv1State +from determined.common.api import bindings +from tests import api_utils from tests import cluster as clu from tests import command as cmd from tests import config as conf @@ -19,32 +19,31 @@ # Create a No_Op experiment and Check training/validation times @pytest.mark.e2e_cpu def test_experiment_capture() -> None: - start_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + sess = api_utils.user_session() + start_time = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") experiment_id = exp.create_experiment( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op") + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op") ) - exp.wait_for_experiment_state(experiment_id, experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.COMPLETED) - end_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - r = api.get( - conf.make_master_url(), - f"{API_URL}timestamp_after={start_time}×tamp_before={end_time}", - ) + end_time = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + r = sess.get(f"{API_URL}timestamp_after={start_time}×tamp_before={end_time}") assert r.status_code == requests.codes.ok, r.text # Check if an entry exists for experiment that just ran - reader = csv.DictReader(StringIO(r.text)) + reader = csv.DictReader(io.StringIO(r.text)) matches = [row for row in reader if int(row["experiment_id"]) == experiment_id] assert len(matches) >= 1, f"could not find any rows for experiment {experiment_id}" @pytest.mark.e2e_cpu def test_notebook_capture() -> None: - start_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + sess = api_utils.user_session() + start_time = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") task_id = None - with cmd.interactive_command("notebook", "start") as notebook: + with cmd.interactive_command(sess, ["notebook", "start"]) as notebook: task_id = notebook.task_id for line in notebook.stdout: @@ -52,11 +51,9 @@ def test_notebook_capture() -> None: return assert task_id is not None - end_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - r = api.get( - conf.make_master_url(), - f"{API_URL}timestamp_after={start_time}×tamp_before={end_time}", - ) + end_time = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + sess = api_utils.user_session() + r = sess.get(f"{API_URL}timestamp_after={start_time}×tamp_before={end_time}") assert r.status_code == requests.codes.ok, r.text assert re.search(f"{task_id},NOTEBOOK", r.text) is not None @@ -65,37 +62,34 @@ def test_notebook_capture() -> None: # Create a No_Op Experiment/Tensorboard & Confirm Tensorboard task is captured @pytest.mark.e2e_cpu def test_tensorboard_experiment_capture() -> None: - start_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + sess = api_utils.user_session() + start_time = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") experiment_id = exp.create_experiment( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op") + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op") ) - exp.wait_for_experiment_state(experiment_id, experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.COMPLETED) - task_id = None - with cmd.interactive_command("tensorboard", "start", "--detach", str(experiment_id)) as tb: - task_id = tb.task_id - for line in tb.stdout: - if "TensorBoard is running at: http" in line: - break - if "TensorBoard is awaiting metrics" in line: - raise AssertionError("Tensorboard did not find metrics") - assert task_id is not None - clu.utils.wait_for_task_state("tensorboard", task_id, "TERMINATED") + with cmd.interactive_command( + sess, + ["tensorboard", "start", "--detach", str(experiment_id)], + ) as tb: + assert tb.task_id + clu.utils.wait_for_task_state(sess, "tensorboard", tb.task_id, "RUNNING") + clu.utils.wait_for_task_state(sess, "tensorboard", tb.task_id, "TERMINATED") # Ensure that end_time captures tensorboard - end_time = (datetime.now(timezone.utc) + timedelta(minutes=1)).strftime("%Y-%m-%dT%H:%M:%SZ") - r = api.get( - conf.make_master_url(), - f"{API_URL}timestamp_after={start_time}×tamp_before={end_time}", - ) + end_time = ( + datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=1) + ).strftime("%Y-%m-%dT%H:%M:%SZ") + r = sess.get(f"{API_URL}timestamp_after={start_time}×tamp_before={end_time}") assert r.status_code == requests.codes.ok, r.text # Confirm Experiment is captured and valid - reader = csv.DictReader(StringIO(r.text)) + reader = csv.DictReader(io.StringIO(r.text)) matches = [row for row in reader if int(row["experiment_id"]) == experiment_id] assert len(matches) >= 1 # Confirm Tensorboard task is captured - assert re.search(f"{task_id}.*,TENSORBOARD", r.text) is not None + assert re.search(f"{tb.task_id}.*,TENSORBOARD", r.text) is not None diff --git a/e2e_tests/tests/experiment/test_api.py b/e2e_tests/tests/experiment/test_api.py index b3dd8fc71ea..5f54ebf347c 100644 --- a/e2e_tests/tests/experiment/test_api.py +++ b/e2e_tests/tests/experiment/test_api.py @@ -4,7 +4,6 @@ import pytest from determined.common.api import bindings -from determined.common.api.bindings import experimentv1State from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -12,13 +11,13 @@ @pytest.mark.e2e_cpu def test_archived_proj_exp_list() -> None: - session = api_utils.determined_test_session(admin=True) + admin = api_utils.admin_session() workspaces: List[bindings.v1Workspace] = [] count = 2 for _ in range(count): body = bindings.v1PostWorkspaceRequest(name=f"workspace_{uuid.uuid4().hex[:8]}") - workspaces.append(bindings.post_PostWorkspace(session, body=body).workspace) + workspaces.append(bindings.post_PostWorkspace(admin, body=body).workspace) projects = [] experiments = [] @@ -27,7 +26,7 @@ def test_archived_proj_exp_list() -> None: name=f"p_{uuid.uuid4().hex[:8]}", workspaceId=wrkspc.id ) pid1 = bindings.post_PostProject( - session, + admin, body=body1, workspaceId=wrkspc.id, ).project.id @@ -36,7 +35,7 @@ def test_archived_proj_exp_list() -> None: name=f"p_{uuid.uuid4().hex[:8]}", workspaceId=wrkspc.id ) pid2 = bindings.post_PostProject( - session, + admin, body=body2, workspaceId=wrkspc.id, ).project.id @@ -46,6 +45,7 @@ def test_archived_proj_exp_list() -> None: experiments.append( exp.create_experiment( + admin, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), ["--project_id", str(pid1), ("--paused")], @@ -53,6 +53,7 @@ def test_archived_proj_exp_list() -> None: ) experiments.append( exp.create_experiment( + admin, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), ["--project_id", str(pid1), ("--paused")], @@ -60,6 +61,7 @@ def test_archived_proj_exp_list() -> None: ) experiments.append( exp.create_experiment( + admin, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), ["--project_id", str(pid2), ("--paused")], @@ -67,6 +69,7 @@ def test_archived_proj_exp_list() -> None: ) experiments.append( exp.create_experiment( + admin, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), ["--project_id", str(pid2), ("--paused")], @@ -74,47 +77,47 @@ def test_archived_proj_exp_list() -> None: ) bindings.post_KillExperiments( - session, body=bindings.v1KillExperimentsRequest(experimentIds=experiments) + admin, body=bindings.v1KillExperimentsRequest(experimentIds=experiments) ) for x in experiments: - exp.wait_for_experiment_state(experiment_id=x, target_state=experimentv1State.CANCELED) + exp.wait_for_experiment_state(admin, x, bindings.experimentv1State.CANCELED) archived_exp = [experiments[0], experiments[3], experiments[5], experiments[6]] for arch_exp in archived_exp: - bindings.post_ArchiveExperiment(session, id=arch_exp) + bindings.post_ArchiveExperiment(admin, id=arch_exp) # test1: GetExperiments shouldn't return archived experiments when archived flag is False - r1 = bindings.get_GetExperiments(session, archived=False) + r1 = bindings.get_GetExperiments(admin, archived=False) for e in r1.experiments: assert e.id not in archived_exp - bindings.post_ArchiveProject(session, id=projects[1]) - bindings.post_ArchiveProject(session, id=projects[2]) + bindings.post_ArchiveProject(admin, id=projects[1]) + bindings.post_ArchiveProject(admin, id=projects[2]) archived_exp.append(experiments[2]) archived_exp.append(experiments[4]) # test2: GetExperiments shouldn't return experiements from archived projects when # archived flag is False - r2 = bindings.get_GetExperiments(session, archived=False) + r2 = bindings.get_GetExperiments(admin, archived=False) for e in r2.experiments: assert e.id not in archived_exp - bindings.post_ArchiveWorkspace(session, id=workspaces[1].id) + bindings.post_ArchiveWorkspace(admin, id=workspaces[1].id) archived_exp.append(experiments[7]) # test3: GetExperiments shouldn't return experiements from archived workspaces when # archived flag is False - r3 = bindings.get_GetExperiments(session, archived=False) + r3 = bindings.get_GetExperiments(admin, archived=False) for e in r3.experiments: assert e.id not in archived_exp # test4: GetExperiments should return only unarchived experiments within an # archived project when archived flag is false - r4 = bindings.get_GetExperiments(session, archived=False, projectId=projects[2]) + r4 = bindings.get_GetExperiments(admin, archived=False, projectId=projects[2]) r4_correct_exp = [experiments[4]] assert len(r4.experiments) == len(r4_correct_exp) for e in r4.experiments: @@ -122,7 +125,7 @@ def test_archived_proj_exp_list() -> None: # test5: GetExperiments should return both archived and unarchived experiments when # archived flag is unspecified - r5 = bindings.get_GetExperiments(session) + r5 = bindings.get_GetExperiments(admin) returned_e_id = [] for e in r5.experiments: returned_e_id.append(e.id) @@ -131,4 +134,4 @@ def test_archived_proj_exp_list() -> None: assert e_id in returned_e_id for w in workspaces: - bindings.delete_DeleteWorkspace(session, id=w.id) + bindings.delete_DeleteWorkspace(admin, id=w.id) diff --git a/e2e_tests/tests/experiment/test_core.py b/e2e_tests/tests/experiment/test_core.py index 7880ec06fff..5ba42e43434 100644 --- a/e2e_tests/tests/experiment/test_core.py +++ b/e2e_tests/tests/experiment/test_core.py @@ -1,4 +1,3 @@ -import json import subprocess import tempfile @@ -6,18 +5,20 @@ from determined.common import api, util from determined.common.api import bindings -from determined.experimental import Determined +from determined.experimental import client from tests import api_utils from tests import command as cmd from tests import config as conf +from tests import detproc from tests import experiment as exp -from tests.api_utils import determined_test_session -from tests.cluster.test_checkpoints import wait_for_gc_to_finish +from tests.cluster import test_checkpoints @pytest.mark.e2e_cpu def test_trial_error() -> None: + sess = api_utils.user_session() exp.run_failure_test( + sess, conf.fixtures_path("trial_error/const.yaml"), conf.fixtures_path("trial_error"), "NotImplementedError", @@ -26,22 +27,22 @@ def test_trial_error() -> None: @pytest.mark.e2e_cpu def test_invalid_experiment() -> None: + sess = api_utils.user_session() completed_process = exp.maybe_create_experiment( - conf.fixtures_path("invalid_experiment/const.yaml"), conf.cv_examples_path("mnist_tf") + sess, conf.fixtures_path("invalid_experiment/const.yaml"), conf.cv_examples_path("mnist_tf") ) assert completed_process.returncode != 0 @pytest.mark.e2e_cpu def test_experiment_archive_unarchive() -> None: + sess = api_utils.user_session() experiment_id = exp.create_experiment( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), ["--paused"] + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), ["--paused"] ) describe_args = [ "det", - "-m", - conf.make_master_url(), "experiment", "describe", "--json", @@ -49,121 +50,114 @@ def test_experiment_archive_unarchive() -> None: ] # Check that the experiment is initially unarchived. - infos = json.loads(subprocess.check_output(describe_args)) + infos = detproc.check_json(sess, describe_args) assert len(infos) == 1 assert not infos[0]["experiment"]["archived"] # Check that archiving a non-terminal experiment fails, then terminate it. with pytest.raises(subprocess.CalledProcessError): - subprocess.check_call( - ["det", "-m", conf.make_master_url(), "experiment", "archive", str(experiment_id)] - ) - subprocess.check_call( - ["det", "-m", conf.make_master_url(), "experiment", "cancel", str(experiment_id)] - ) + detproc.check_call(sess, ["det", "experiment", "archive", str(experiment_id)]) + detproc.check_call(sess, ["det", "experiment", "cancel", str(experiment_id)]) # Check that we can archive and unarchive the experiment and see the expected effects. - subprocess.check_call( - ["det", "-m", conf.make_master_url(), "experiment", "archive", str(experiment_id)] - ) - infos = json.loads(subprocess.check_output(describe_args)) + detproc.check_call(sess, ["det", "experiment", "archive", str(experiment_id)]) + infos = detproc.check_json(sess, describe_args) assert len(infos) == 1 assert infos[0]["experiment"]["archived"] - subprocess.check_call( - ["det", "-m", conf.make_master_url(), "experiment", "unarchive", str(experiment_id)] - ) - infos = json.loads(subprocess.check_output(describe_args)) + detproc.check_call(sess, ["det", "experiment", "unarchive", str(experiment_id)]) + infos = detproc.check_json(sess, describe_args) assert len(infos) == 1 assert not infos[0]["experiment"]["archived"] @pytest.mark.e2e_cpu def test_create_test_mode() -> None: + sess = api_utils.user_session() # test-mode should succeed with a valid experiment. command = [ "det", - "-m", - conf.make_master_url(), "experiment", "create", "--test-mode", conf.fixtures_path("mnist_pytorch/adaptive_short.yaml"), conf.tutorials_path("mnist_pytorch"), ] - output = subprocess.check_output(command, universal_newlines=True) - assert "Model definition test succeeded" in output + output = detproc.check_output(sess, command) + assert "Model definition test succeeded" in output, output # test-mode should fail when an error is introduced into the trial # implementation. command = [ "det", - "-m", - conf.make_master_url(), "experiment", "create", "--test-mode", conf.fixtures_path("trial_error/const.yaml"), conf.fixtures_path("trial_error"), ] - with pytest.raises(subprocess.CalledProcessError): - subprocess.check_call(command) + # We expect a failing exit code, but --test-mode doesn't actually emit to stderr. + p = detproc.check_error(sess, command, "") + assert p.stdout + stdout = p.stdout.decode("utf8") + assert "resources failed with non-zero exit code" in stdout, stdout @pytest.mark.e2e_cpu def test_trial_logs() -> None: + sess = api_utils.user_session() experiment_id = exp.run_basic_test( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 ) - trial_id = exp.experiment_trials(experiment_id)[0].trial.id - subprocess.check_call(["det", "-m", conf.make_master_url(), "trial", "logs", str(trial_id)]) - subprocess.check_call( - ["det", "-m", conf.make_master_url(), "trial", "logs", "--head", "10", str(trial_id)], + trial_id = exp.experiment_trials(sess, experiment_id)[0].trial.id + detproc.check_call(sess, ["det", "trial", "logs", str(trial_id)]) + detproc.check_call( + sess, + ["det", "trial", "logs", "--head", "10", str(trial_id)], ) - subprocess.check_call( - ["det", "-m", conf.make_master_url(), "trial", "logs", "--tail", "10", str(trial_id)], + detproc.check_call( + sess, + ["det", "trial", "logs", "--tail", "10", str(trial_id)], ) @pytest.mark.e2e_cpu def test_labels() -> None: + sess = api_utils.user_session() experiment_id = exp.create_experiment( - conf.fixtures_path("no_op/single-one-short-step.yaml"), conf.fixtures_path("no_op"), None + sess, + conf.fixtures_path("no_op/single-one-short-step.yaml"), + conf.fixtures_path("no_op"), + None, ) label = "__det_test_dummy_label__" # Add a label and check that it shows up. - subprocess.check_call( - ["det", "-m", conf.make_master_url(), "e", "label", "add", str(experiment_id), label] - ) - output = subprocess.check_output( - ["det", "-m", conf.make_master_url(), "e", "describe", str(experiment_id)] - ).decode() + detproc.check_call(sess, ["det", "e", "label", "add", str(experiment_id), label]) + output = detproc.check_output(sess, ["det", "e", "describe", str(experiment_id)]) assert label in output # Remove the label and check that it doesn't show up. - subprocess.check_call( - ["det", "-m", conf.make_master_url(), "e", "label", "remove", str(experiment_id), label] - ) - output = subprocess.check_output( - ["det", "-m", conf.make_master_url(), "e", "describe", str(experiment_id)] - ).decode() + detproc.check_call(sess, ["det", "e", "label", "remove", str(experiment_id), label]) + output = detproc.check_output(sess, ["det", "e", "describe", str(experiment_id)]) assert label not in output @pytest.mark.e2e_cpu def test_end_to_end_adaptive() -> None: + sess = api_utils.user_session() exp_id = exp.run_basic_test( + sess, conf.fixtures_path("mnist_pytorch/adaptive_short.yaml"), conf.tutorials_path("mnist_pytorch"), None, ) - wait_for_gc_to_finish(experiment_ids=[exp_id]) + test_checkpoints.wait_for_gc_to_finish(sess, experiment_ids=[exp_id]) # Check that validation accuracy look sane (more than 93% on MNIST). - trials = exp.experiment_trials(exp_id) + trials = exp.experiment_trials(sess, exp_id) best = None for trial in trials: assert len(trial.workloads) > 0 @@ -178,7 +172,7 @@ def test_end_to_end_adaptive() -> None: # Check that the Experiment returns a sorted order of top checkpoints # without gaps. The top 2 checkpoints should be the first 2 of the top k # checkpoints if sorting is stable. - d = Determined(conf.make_master_url()) + d = client.Determined._from_session(sess) exp_ref = d.get_experiment(exp_id) top_2 = exp_ref.top_n_checkpoints(2) @@ -252,46 +246,54 @@ def test_end_to_end_adaptive() -> None: @pytest.mark.e2e_cpu def test_log_null_bytes() -> None: + sess = api_utils.user_session() config_obj = conf.load_config(conf.fixtures_path("no_op/single.yaml")) config_obj["hyperparameters"]["write_null"] = True config_obj["max_restarts"] = 0 config_obj["searcher"]["max_length"] = {"batches": 1} - experiment_id = exp.run_basic_test_with_temp_config(config_obj, conf.fixtures_path("no_op"), 1) + experiment_id = exp.run_basic_test_with_temp_config( + sess, config_obj, conf.fixtures_path("no_op"), 1 + ) - trials = exp.experiment_trials(experiment_id) + trials = exp.experiment_trials(sess, experiment_id) assert len(trials) == 1 - logs = exp.trial_logs(trials[0].trial.id) + logs = exp.trial_logs(sess, trials[0].trial.id) assert len(logs) > 0 @pytest.mark.e2e_cpu def test_graceful_trial_termination() -> None: + sess = api_utils.user_session() config_obj = conf.load_config(conf.fixtures_path("no_op/grid-graceful-trial-termination.yaml")) - exp.run_basic_test_with_temp_config(config_obj, conf.fixtures_path("no_op"), 2) + exp.run_basic_test_with_temp_config(sess, config_obj, conf.fixtures_path("no_op"), 2) @pytest.mark.e2e_cpu def test_kill_experiment_ignoring_preemption() -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("core_api/sleep.yaml"), conf.fixtures_path("core_api"), None, ) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.RUNNING) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.RUNNING) - bindings.post_CancelExperiment(api_utils.determined_test_session(), id=exp_id) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.STOPPING_CANCELED) + bindings.post_CancelExperiment(sess, id=exp_id) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.STOPPING_CANCELED) - bindings.post_KillExperiment(api_utils.determined_test_session(), id=exp_id) - exp.wait_for_experiment_state(exp_id, bindings.experimentv1State.CANCELED) + bindings.post_KillExperiment(sess, id=exp_id) + exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.CANCELED) @pytest.mark.e2e_cpu def test_fail_on_first_validation() -> None: + sess = api_utils.user_session() error_log = "failed on first validation" config_obj = conf.load_config(conf.fixtures_path("no_op/single.yaml")) config_obj["hyperparameters"]["fail_on_first_validation"] = error_log exp.run_failure_test_with_temp_config( + sess, config_obj, conf.fixtures_path("no_op"), error_log, @@ -300,11 +302,12 @@ def test_fail_on_first_validation() -> None: @pytest.mark.e2e_cpu def test_perform_initial_validation() -> None: + sess = api_utils.user_session() config = conf.load_config(conf.fixtures_path("no_op/single.yaml")) config = conf.set_max_length(config, {"batches": 1}) config = conf.set_perform_initial_validation(config, True) - exp_id = exp.run_basic_test_with_temp_config(config, conf.fixtures_path("no_op"), 1) - exp.assert_performed_initial_validation(exp_id) + exp_id = exp.run_basic_test_with_temp_config(sess, config, conf.fixtures_path("no_op"), 1) + exp.assert_performed_initial_validation(sess, exp_id) @pytest.mark.e2e_cpu_2a @@ -338,6 +341,7 @@ def test_perform_initial_validation() -> None: ], ) def test_max_concurrent_trials(name: str, searcher_cfg: str) -> None: + sess = api_utils.user_session() config_obj = conf.load_config(conf.fixtures_path("no_op/single-very-many-long-steps.yaml")) config_obj["name"] = f"{name} searcher max concurrent trials test" config_obj["searcher"] = searcher_cfg @@ -351,35 +355,37 @@ def test_max_concurrent_trials(name: str, searcher_cfg: str) -> None: with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: util.yaml_safe_dump(config_obj, f) - experiment_id = exp.create_experiment(tf.name, conf.fixtures_path("no_op"), []) + experiment_id = exp.create_experiment(sess, tf.name, conf.fixtures_path("no_op"), []) try: - exp.wait_for_experiment_active_workload(experiment_id) - trials = exp.wait_for_at_least_n_trials(experiment_id, 1) + exp.wait_for_experiment_active_workload(sess, experiment_id) + trials = exp.wait_for_at_least_n_trials(sess, experiment_id, 1) assert len(trials) == 1, trials for t in trials: - exp.kill_trial(t.trial.id) + exp.kill_trial(sess, t.trial.id) # Give the experiment time to refill max_concurrent_trials. - trials = exp.wait_for_at_least_n_trials(experiment_id, 2) + trials = exp.wait_for_at_least_n_trials(sess, experiment_id, 2) # The experiment handling the cancel message and waiting for it to be cancelled slyly # (hackishly) allows us to synchronize with the experiment state after after canceling # the first two trials. - exp.cancel_single(experiment_id) + exp.cancel_single(sess, experiment_id) # Make sure that there were never more than 2 total trials created. - trials = exp.wait_for_at_least_n_trials(experiment_id, 2) + trials = exp.wait_for_at_least_n_trials(sess, experiment_id, 2) assert len(trials) == 2, trials finally: - exp.kill_single(experiment_id) + exp.kill_single(sess, experiment_id) @pytest.mark.e2e_cpu def test_experiment_list_columns() -> None: + sess = api_utils.user_session() exp.create_experiment( + sess, conf.fixtures_path("no_op/single-nested-hps.yaml"), conf.fixtures_path("no_op"), ["--project", "1"], @@ -396,7 +402,7 @@ def test_experiment_list_columns() -> None: "validation.validation_error.last", "validation.validation_error.mean", ] - columns = bindings.get_GetProjectColumns(api_utils.determined_test_session(), id=1) + columns = bindings.get_GetProjectColumns(sess, id=1) column_values = {c.column for c in columns.columns} for hp in exp_hyperparameters: @@ -407,14 +413,16 @@ def test_experiment_list_columns() -> None: @pytest.mark.e2e_cpu def test_metrics_range_by_project() -> None: + sess = api_utils.user_session() exp.run_basic_test( + sess, conf.fixtures_path("core_api/arbitrary_workload_order.yaml"), conf.fixtures_path("core_api"), 1, expect_workloads=True, expect_checkpoints=True, ) - ranges = bindings.get_GetProjectNumericMetricsRange(api_utils.determined_test_session(), id=1) + ranges = bindings.get_GetProjectNumericMetricsRange(sess, id=1) assert ranges.ranges is not None for r in ranges.ranges: @@ -423,7 +431,9 @@ def test_metrics_range_by_project() -> None: @pytest.mark.e2e_cpu def test_core_api_arbitrary_workload_order() -> None: + sess = api_utils.user_session() experiment_id = exp.run_basic_test( + sess, conf.fixtures_path("core_api/arbitrary_workload_order.yaml"), conf.fixtures_path("core_api"), 1, @@ -431,7 +441,7 @@ def test_core_api_arbitrary_workload_order() -> None: expect_checkpoints=True, ) - trials = exp.experiment_trials(experiment_id) + trials = exp.experiment_trials(sess, experiment_id) assert len(trials) == 1 trial = trials[0] @@ -456,7 +466,9 @@ def test_core_api_arbitrary_workload_order() -> None: def test_core_api_tutorials( stage: str, ntrials: int, expect_workloads: bool, expect_checkpoints: bool ) -> None: + sess = api_utils.user_session() exp.run_basic_test( + sess, conf.tutorials_path(f"core_api/{stage}.yaml"), conf.tutorials_path("core_api"), ntrials, @@ -467,8 +479,9 @@ def test_core_api_tutorials( @pytest.mark.parallel def test_core_api_distributed_tutorial() -> None: + sess = api_utils.user_session() exp.run_basic_test( - conf.tutorials_path("core_api/4_distributed.yaml"), conf.tutorials_path("core_api"), 1 + sess, conf.tutorials_path("core_api/4_distributed.yaml"), conf.tutorials_path("core_api"), 1 ) @@ -477,7 +490,9 @@ def test_core_api_pytorch_profiler_tensorboard() -> None: # Ensure tensorboard will load for an experiment which runs pytorch profiler, # and doesn't report metrics or checkpoints. # If the profiler trace file is not synced, the tensorboard will not load. + sess = api_utils.user_session() exp_id = exp.run_basic_test( + sess, conf.fixtures_path("core_api/pytorch_profiler_sync.yaml"), conf.fixtures_path("core_api"), 1, @@ -492,7 +507,7 @@ def test_core_api_pytorch_profiler_tensorboard() -> None: "--no-browser", ] - with cmd.interactive_command(*command) as tensorboard: + with cmd.interactive_command(sess, command) as tensorboard: assert tensorboard.task_id is not None - err = api.task_is_ready(determined_test_session(), tensorboard.task_id) + err = api.wait_for_task_ready(sess, tensorboard.task_id) assert err is None, err diff --git a/e2e_tests/tests/experiment/test_custom_searcher.py b/e2e_tests/tests/experiment/test_custom_searcher.py index 1dc0d698c86..3668f149506 100644 --- a/e2e_tests/tests/experiment/test_custom_searcher.py +++ b/e2e_tests/tests/experiment/test_custom_searcher.py @@ -1,6 +1,5 @@ import logging import pathlib -import subprocess import tempfile import time from typing import Iterator, List, Optional @@ -9,22 +8,25 @@ from urllib3 import connectionpool from determined import searcher -from determined.common import util +from determined.common import api, util from determined.common.api import bindings from determined.experimental import client from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp from tests.fixtures.custom_searcher import searchers TIMESTAMP = int(time.time()) -def check_trial_state(trial: bindings.trialv1Trial, expect: bindings.trialv1State) -> bool: +def check_trial_state( + sess: api.Session, trial: bindings.trialv1Trial, expect: bindings.trialv1State +) -> bool: """If the trial is in an unexpected state, dump logs and return False.""" if trial.state == expect: return True - exp.print_trial_logs(trial.id) + exp.print_trial_logs(sess, trial.id) return False @@ -37,6 +39,7 @@ def client_login() -> Iterator[None]: @pytest.mark.e2e_cpu def test_run_custom_searcher_experiment(tmp_path: pathlib.Path) -> None: + sess = api_utils.user_session() # example searcher script config = conf.load_config(conf.fixtures_path("no_op/single.yaml")) config["searcher"] = { @@ -52,13 +55,13 @@ def test_run_custom_searcher_experiment(tmp_path: pathlib.Path) -> None: experiment_id = search_runner.run(config, model_dir=conf.fixtures_path("no_op")) assert client._determined is not None - session = client._determined._session - response = bindings.get_GetExperiment(session, experimentId=experiment_id) + response = bindings.get_GetExperiment(sess, experimentId=experiment_id) assert response.experiment.numTrials == 1 @pytest.mark.e2e_cpu_2a def test_run_random_searcher_exp() -> None: + sess = api_utils.user_session() config = conf.load_config(conf.fixtures_path("no_op/single.yaml")) config["searcher"] = { "name": "custom", @@ -80,9 +83,7 @@ def test_run_random_searcher_exp() -> None: search_runner = searcher.LocalSearchRunner(search_method, pathlib.Path(searcher_dir)) experiment_id = search_runner.run(config, model_dir=conf.fixtures_path("no_op")) - assert client._determined is not None - session = client._determined._session - response = bindings.get_GetExperiment(session, experimentId=experiment_id) + response = bindings.get_GetExperiment(sess, experimentId=experiment_id) assert response.experiment.numTrials == 5 assert search_method.created_trials == 5 assert search_method.pending_trials == 0 @@ -134,6 +135,7 @@ def test_run_random_searcher_exp_core_api( exception_points: List[str], metric_as_dict: bool, ) -> None: + sess = api_utils.user_session() config = conf.load_config(conf.fixtures_path("custom_searcher/core_api_searcher_random.yaml")) config["entrypoint"] += " --exp-name " + exp_name config["entrypoint"] += " --config-name " + config_name @@ -144,39 +146,33 @@ def test_run_random_searcher_exp_core_api( config["max_restarts"] = len(exception_points) experiment_id = exp.run_basic_test_with_temp_config( - config, conf.fixtures_path("custom_searcher"), 1 + sess, config, conf.fixtures_path("custom_searcher"), 1 ) - session = api_utils.determined_test_session() - # searcher experiment - searcher_exp = bindings.get_GetExperiment(session, experimentId=experiment_id).experiment + searcher_exp = bindings.get_GetExperiment(sess, experimentId=experiment_id).experiment assert searcher_exp.state == bindings.experimentv1State.COMPLETED # actual experiment - response = bindings.get_GetExperiments(session, name=exp_name) + response = bindings.get_GetExperiments(sess, name=exp_name) experiments = response.experiments assert len(experiments) == 1 experiment = experiments[0] assert experiment.numTrials == 5 - trials = bindings.get_GetExperimentTrials(session, experimentId=experiment.id).trials + trials = bindings.get_GetExperimentTrials(sess, experimentId=experiment.id).trials ok = True for trial in trials: - ok = ok and check_trial_state(trial, bindings.trialv1State.COMPLETED) + ok = ok and check_trial_state(sess, trial, bindings.trialv1State.COMPLETED) assert ok, "some trials failed" for trial in trials: assert trial.totalBatchesProcessed == 500 # check logs to ensure failures actually happened - logs = str( - subprocess.check_output( - ["det", "-m", conf.make_master_url(), "experiment", "logs", str(experiment_id)] - ) - ) + logs = detproc.check_output(sess, ["det", "experiment", "logs", str(experiment_id)]) failures = logs.count("Max retries exceeded with url: http://dummyurl (Caused by None)") assert failures == len(exception_points) @@ -187,6 +183,7 @@ def test_run_random_searcher_exp_core_api( @pytest.mark.e2e_cpu_2a def test_pause_multi_trial_random_searcher_core_api() -> None: + sess = api_utils.user_session() config = conf.load_config(conf.fixtures_path("custom_searcher/core_api_searcher_random.yaml")) exp_name = f"random-pause-{TIMESTAMP}" config["entrypoint"] += " --exp-name " + exp_name @@ -198,44 +195,39 @@ def test_pause_multi_trial_random_searcher_core_api() -> None: with open(tf.name, "w") as f: util.yaml_safe_dump(config, f) - searcher_exp_id = exp.create_experiment(tf.name, model_def_path, None) + searcher_exp_id = exp.create_experiment(sess, tf.name, model_def_path, None) exp.wait_for_experiment_state( + sess, searcher_exp_id, bindings.experimentv1State.RUNNING, ) # make sure both experiments have started by checking # that multi-trial experiment has at least 1 running trials - multi_trial_exp_id = exp.wait_for_experiment_by_name_is_active(exp_name, 1) + multi_trial_exp_id = exp.wait_for_experiment_by_name_is_active(sess, exp_name, 1) # pause multi-trial experiment - exp.pause_experiment(multi_trial_exp_id) - exp.wait_for_experiment_state(multi_trial_exp_id, bindings.experimentv1State.PAUSED) + exp.pause_experiment(sess, multi_trial_exp_id) + exp.wait_for_experiment_state(sess, multi_trial_exp_id, bindings.experimentv1State.PAUSED) # activate multi-trial experiment - exp.activate_experiment(multi_trial_exp_id) + exp.activate_experiment(sess, multi_trial_exp_id) # wait for searcher to complete - exp.wait_for_experiment_state(searcher_exp_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, searcher_exp_id, bindings.experimentv1State.COMPLETED) # searcher experiment - searcher_exp = bindings.get_GetExperiment( - api_utils.determined_test_session(), experimentId=searcher_exp_id - ).experiment + searcher_exp = bindings.get_GetExperiment(sess, experimentId=searcher_exp_id).experiment assert searcher_exp.state == bindings.experimentv1State.COMPLETED # actual experiment - experiment = bindings.get_GetExperiment( - api_utils.determined_test_session(), experimentId=multi_trial_exp_id - ).experiment + experiment = bindings.get_GetExperiment(sess, experimentId=multi_trial_exp_id).experiment assert experiment.numTrials == 5 - trials = bindings.get_GetExperimentTrials( - api_utils.determined_test_session(), experimentId=experiment.id - ).trials + trials = bindings.get_GetExperimentTrials(sess, experimentId=experiment.id).trials ok = True for trial in trials: - ok = ok and check_trial_state(trial, bindings.trialv1State.COMPLETED) + ok = ok and check_trial_state(sess, trial, bindings.trialv1State.COMPLETED) assert ok, "some trials failed" for trial in trials: @@ -262,6 +254,7 @@ def test_pause_multi_trial_random_searcher_core_api() -> None: ], ) def test_resume_random_searcher_exp(exceptions: List[str]) -> None: + sess = api_utils.user_session() config = conf.load_config(conf.fixtures_path("no_op/single.yaml")) config["searcher"] = { "name": "custom", @@ -310,9 +303,7 @@ def test_resume_random_searcher_exp(exceptions: List[str]) -> None: assert search_runner.state.last_event_id == 41 assert search_runner.state.experiment_completed is True - assert client._determined is not None - session = client._determined._session - response = bindings.get_GetExperiment(session, experimentId=experiment_id) + response = bindings.get_GetExperiment(sess, experimentId=experiment_id) assert response.experiment.numTrials == 5 assert search_method.created_trials == 5 assert search_method.pending_trials == 0 @@ -325,6 +316,7 @@ def test_resume_random_searcher_exp(exceptions: List[str]) -> None: @pytest.mark.nightly def test_run_asha_batches_exp(tmp_path: pathlib.Path, client_login: None) -> None: + sess = api_utils.user_session() config = conf.load_config(conf.fixtures_path("no_op/adaptive.yaml")) config["searcher"] = { "name": "custom", @@ -347,8 +339,7 @@ def test_run_asha_batches_exp(tmp_path: pathlib.Path, client_login: None) -> Non experiment_id = search_runner.run(config, model_dir=conf.fixtures_path("no_op")) assert client._determined is not None - session = client._determined._session - response = bindings.get_GetExperiment(session, experimentId=experiment_id) + response = bindings.get_GetExperiment(sess, experimentId=experiment_id) assert response.experiment.numTrials == 16 assert search_method.asha_search_state.pending_trials == 0 @@ -357,7 +348,7 @@ def test_run_asha_batches_exp(tmp_path: pathlib.Path, client_login: None) -> Non search_method.asha_search_state.closed_trials ) - response_trials = bindings.get_GetExperimentTrials(session, experimentId=experiment_id).trials + response_trials = bindings.get_GetExperimentTrials(sess, experimentId=experiment_id).trials # 16 trials in rung 1 (#batches = 125) assert sum(t.totalBatchesProcessed >= 125 for t in response_trials) == 16 @@ -368,7 +359,7 @@ def test_run_asha_batches_exp(tmp_path: pathlib.Path, client_login: None) -> Non ok = True for trial in response_trials: - ok = ok and check_trial_state(trial, bindings.trialv1State.COMPLETED) + ok = ok and check_trial_state(sess, trial, bindings.trialv1State.COMPLETED) assert ok, "some trials failed" @@ -399,6 +390,7 @@ def test_run_asha_batches_exp(tmp_path: pathlib.Path, client_login: None) -> Non ], ) def test_resume_asha_batches_exp(exceptions: List[str], client_login: None) -> None: + sess = api_utils.user_session() config = conf.load_config(conf.fixtures_path("no_op/adaptive.yaml")) config["searcher"] = { "name": "custom", @@ -446,9 +438,7 @@ def test_resume_asha_batches_exp(exceptions: List[str], client_login: None) -> N experiment_id = search_runner.run(config, model_dir=conf.fixtures_path("no_op")) assert search_runner.state.experiment_completed is True - assert client._determined is not None - session = client._determined._session - response = bindings.get_GetExperiment(session, experimentId=experiment_id) + response = bindings.get_GetExperiment(sess, experimentId=experiment_id) assert response.experiment.numTrials == 16 # asha search method state @@ -462,7 +452,7 @@ def test_resume_asha_batches_exp(exceptions: List[str], client_login: None) -> N search_method.asha_search_state.closed_trials ) - response_trials = bindings.get_GetExperimentTrials(session, experimentId=experiment_id).trials + response_trials = bindings.get_GetExperimentTrials(sess, experimentId=experiment_id).trials # 16 trials in rung 1 (#batches = 125) assert sum(t.totalBatchesProcessed >= 125 for t in response_trials) == 16 diff --git a/e2e_tests/tests/experiment/test_custom_searcher_asha_2a.py b/e2e_tests/tests/experiment/test_custom_searcher_asha_2a.py index de05aa04fea..44c54a4c574 100644 --- a/e2e_tests/tests/experiment/test_custom_searcher_asha_2a.py +++ b/e2e_tests/tests/experiment/test_custom_searcher_asha_2a.py @@ -1,4 +1,3 @@ -import subprocess import time from typing import List @@ -7,6 +6,7 @@ from determined.common.api import bindings from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp TIMESTAMP = int(time.time()) @@ -45,6 +45,7 @@ def test_run_asha_searcher_exp_core_api( config_name: str, exp_name: str, exception_points: List[str] ) -> None: + sess = api_utils.user_session() config = conf.load_config(conf.fixtures_path("custom_searcher/core_api_searcher_asha.yaml")) config["entrypoint"] += " --exp-name " + exp_name config["entrypoint"] += " --config-name " + config_name @@ -53,23 +54,22 @@ def test_run_asha_searcher_exp_core_api( config["max_restarts"] = len(exception_points) experiment_id = exp.run_basic_test_with_temp_config( - config, conf.fixtures_path("custom_searcher"), 1 + sess, config, conf.fixtures_path("custom_searcher"), 1 ) - session = api_utils.determined_test_session() # searcher experiment - searcher_exp = bindings.get_GetExperiment(session, experimentId=experiment_id).experiment + searcher_exp = bindings.get_GetExperiment(sess, experimentId=experiment_id).experiment assert searcher_exp.state == bindings.experimentv1State.COMPLETED # actual experiment - response = bindings.get_GetExperiments(session, name=exp_name) + response = bindings.get_GetExperiments(sess, name=exp_name) experiments = response.experiments assert len(experiments) == 1 experiment = experiments[0] assert experiment.numTrials == 16 - response_trials = bindings.get_GetExperimentTrials(session, experimentId=experiment.id).trials + response_trials = bindings.get_GetExperimentTrials(sess, experimentId=experiment.id).trials # 16 trials in rung 1 (#batches = 150) assert sum(t.totalBatchesProcessed >= 150 for t in response_trials) == 16 @@ -82,11 +82,7 @@ def test_run_asha_searcher_exp_core_api( assert trial.state == bindings.trialv1State.COMPLETED # check logs to ensure failures actually happened - logs = str( - subprocess.check_output( - ["det", "-m", conf.make_master_url(), "experiment", "logs", str(experiment_id)] - ) - ) + logs = detproc.check_output(sess, ["det", "experiment", "logs", str(experiment_id)]) failures = logs.count("Max retries exceeded with url: http://dummyurl (Caused by None)") assert failures == len(exception_points) diff --git a/e2e_tests/tests/experiment/test_hpc_launch.py b/e2e_tests/tests/experiment/test_hpc_launch.py index c0a304f3488..46d75dbca1b 100644 --- a/e2e_tests/tests/experiment/test_hpc_launch.py +++ b/e2e_tests/tests/experiment/test_hpc_launch.py @@ -1,42 +1,42 @@ import sys -from typing import Callable import pytest -from determined.common.api.bindings import experimentv1State +from determined.common import api +from determined.common.api import bindings +from tests import api_utils from tests import config as conf from tests import experiment as exp -def run_test_case(testcase: str, message: str) -> None: +def run_test_case(sess: api.Session, testcase: str, message: str) -> None: experiment_id = exp.create_experiment( + sess, conf.fixtures_path(testcase), conf.fixtures_path("hpc"), ) try: - exp.wait_for_experiment_state(experiment_id, experimentv1State.COMPLETED, max_wait_secs=600) + exp.wait_for_experiment_state( + sess, experiment_id, bindings.experimentv1State.COMPLETED, max_wait_secs=600 + ) - trials = exp.experiment_trials(experiment_id) + trials = exp.experiment_trials(sess, experiment_id) assert exp.check_if_string_present_in_trial_logs( + sess, trials[0].trial.id, message, ) except AssertionError: # On failure print the log for triage - logs = exp.trial_logs(trials[0].trial.id, follow=True) - print( - "******** Start of logs for trial {} ********".format(trials[0].trial.id), - file=sys.stderr, - ) + logs = exp.trial_logs(sess, trials[0].trial.id, follow=True) + tid = trials[0].trial.id + print(f"******** Start of logs for trial {tid} ********", file=sys.stderr) print("".join(logs), file=sys.stderr) + print(f"******** End of logs for trial {tid} ********", file=sys.stderr) print( - "******** End of logs for trial {} ********".format(trials[0].trial.id), file=sys.stderr - ) - print( - f"Trial {trials[0].trial.id} log did not contain any of the " - + f"expected message: {message}", + f"Trial {tid} log did not contain any of the expected message: {message}", file=sys.stderr, ) raise @@ -45,16 +45,20 @@ def run_test_case(testcase: str, message: str) -> None: # This test should succeed with Slurm plus all container types # it does not yet succeed with PBS+Singularity. @pytest.mark.e2e_slurm -def test_launch_embedded_quotes(collect_trial_profiles: Callable[[int], None]) -> None: +def test_launch_embedded_quotes() -> None: + sess = api_utils.user_session() run_test_case( + sess, conf.fixtures_path("hpc/embedded-quotes.yaml"), 'DATA: user_defined_key=datakey="datavalue with embedded "', ) @pytest.mark.e2e_slurm -def test_launch_embedded_single_quote(collect_trial_profiles: Callable[[int], None]) -> None: +def test_launch_embedded_single_quote() -> None: + sess = api_utils.user_session() run_test_case( + sess, conf.fixtures_path("hpc/embedded-single-quote.yaml"), 'DATA: user_defined_key=datakey="datavalue with \' embedded "', ) diff --git a/e2e_tests/tests/experiment/test_launch.py b/e2e_tests/tests/experiment/test_launch.py index 86aeaea71e7..ea10a76432a 100644 --- a/e2e_tests/tests/experiment/test_launch.py +++ b/e2e_tests/tests/experiment/test_launch.py @@ -2,7 +2,8 @@ import pytest -from determined.experimental import Determined +from determined.experimental import client +from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -11,6 +12,7 @@ @pytest.mark.e2e_slurm @pytest.mark.e2e_pbs def test_launch_layer_mnist(collect_trial_profiles: Callable[[int], None]) -> None: + sess = api_utils.user_session() config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml")) config = conf.set_max_length(config, {"batches": 200}) config = conf.set_slots_per_trial(config, 1) @@ -20,12 +22,13 @@ def test_launch_layer_mnist(collect_trial_profiles: Callable[[int], None]) -> No ) experiment_id = exp.run_basic_test_with_temp_config( - config, conf.tutorials_path("mnist_pytorch"), 1 + sess, config, conf.tutorials_path("mnist_pytorch"), 1 ) - trials = exp.experiment_trials(experiment_id) + trials = exp.experiment_trials(sess, experiment_id) collect_trial_profiles(trials[0].trial.id) assert exp.check_if_string_present_in_trial_logs( + sess, trials[0].trial.id, "resources exited successfully with a zero exit code", ) @@ -35,24 +38,27 @@ def test_launch_layer_mnist(collect_trial_profiles: Callable[[int], None]) -> No @pytest.mark.e2e_slurm @pytest.mark.e2e_pbs def test_launch_layer_exit(collect_trial_profiles: Callable[[int], None]) -> None: + sess = api_utils.user_session() config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml")) config = conf.set_entrypoint(config, "python3 -m nonexistent_launch_module python3 train.py") config["max_restarts"] = 0 experiment_id = exp.run_failure_test_with_temp_config( - config, conf.tutorials_path("mnist_pytorch") + sess, config, conf.tutorials_path("mnist_pytorch") ) - trials = exp.experiment_trials(experiment_id) - Determined(conf.make_master_url()).get_trial(trials[0].trial.id) + trials = exp.experiment_trials(sess, experiment_id) + client.Determined._from_session(sess).get_trial(trials[0].trial.id) collect_trial_profiles(trials[0].trial.id) slurm_run = exp.check_if_string_present_in_trial_logs( - trials[0].trial.id, "Exited with exit code 1" + sess, trials[0].trial.id, "Exited with exit code 1" + ) + pbs_run = exp.check_if_string_present_in_trial_logs( + sess, trials[0].trial.id, "exited with status 1" ) - pbs_run = exp.check_if_string_present_in_trial_logs(trials[0].trial.id, "exited with status 1") cpu_run = exp.check_if_string_present_in_trial_logs( - trials[0].trial.id, "container failed with non-zero exit code: 1" + sess, trials[0].trial.id, "container failed with non-zero exit code: 1" ) assert cpu_run or slurm_run or pbs_run diff --git a/e2e_tests/tests/experiment/test_metrics.py b/e2e_tests/tests/experiment/test_metrics.py index 66f99c3f4ad..e972e925337 100644 --- a/e2e_tests/tests/experiment/test_metrics.py +++ b/e2e_tests/tests/experiment/test_metrics.py @@ -1,28 +1,24 @@ import json import multiprocessing as mp -import multiprocessing.pool -import subprocess from typing import Dict, List, Set, Union import pytest -from determined.common import api -from determined.common.api import authentication, bindings, certs +from determined.common.api import bindings from tests import api_utils from tests import config as conf +from tests import detproc from tests import experiment as exp @pytest.mark.e2e_cpu @pytest.mark.timeout(600) def test_streaming_metrics_api() -> None: - # TODO: refactor tests to not use cli singleton auth. - certs.cli_cert = certs.default_load(conf.make_master_url()) - authentication.cli_auth = authentication.Authentication(conf.make_master_url()) - + sess = api_utils.user_session() pool = mp.pool.ThreadPool(processes=7) experiment_id = exp.create_experiment( + sess, conf.fixtures_path("mnist_pytorch/adaptive_short.yaml"), conf.tutorials_path("mnist_pytorch"), ) @@ -66,9 +62,9 @@ def test_streaming_metrics_api() -> None: def request_metric_names(experiment_id): # type: ignore - response = api.get( - conf.make_master_url(), - "api/v1/experiments/metrics-stream/metric-names?ids={}".format(experiment_id), + sess = api_utils.user_session() + response = sess.get( + f"api/v1/experiments/metrics-stream/metric-names?ids={experiment_id}", params={"period_seconds": 1}, ) results = [message["result"] for message in map(json.loads, response.text.splitlines())] @@ -102,9 +98,9 @@ def request_metric_names(experiment_id): # type: ignore def request_train_metric_batches(experiment_id): # type: ignore - response = api.get( - conf.make_master_url(), - "api/v1/experiments/{}/metrics-stream/batches".format(experiment_id), + sess = api_utils.user_session() + response = sess.get( + f"api/v1/experiments/{experiment_id}/metrics-stream/batches", params={"metric_name": "loss", "metric_type": "METRIC_TYPE_TRAINING", "period_seconds": 1}, ) results = [message["result"] for message in map(json.loads, response.text.splitlines())] @@ -126,9 +122,9 @@ def request_train_metric_batches(experiment_id): # type: ignore def request_valid_metric_batches(experiment_id): # type: ignore - response = api.get( - conf.make_master_url(), - "api/v1/experiments/{}/metrics-stream/batches".format(experiment_id), + sess = api_utils.user_session() + response = sess.get( + f"api/v1/experiments/{experiment_id}/metrics-stream/batches", params={ "metric_name": "accuracy", "metric_type": "METRIC_TYPE_VALIDATION", @@ -164,9 +160,9 @@ def validate_hparam_types(hparams: dict) -> Union[None, str]: def request_train_trials_snapshot(experiment_id): # type: ignore - response = api.get( - conf.make_master_url(), - "api/v1/experiments/{}/metrics-stream/trials-snapshot".format(experiment_id), + sess = api_utils.user_session() + response = sess.get( + f"api/v1/experiments/{experiment_id}/metrics-stream/trials-snapshot", params={ "metric_name": "loss", "metric_type": "METRIC_TYPE_TRAINING", @@ -194,9 +190,9 @@ def request_train_trials_snapshot(experiment_id): # type: ignore def request_valid_trials_snapshot(experiment_id): # type: ignore - response = api.get( - conf.make_master_url(), - "api/v1/experiments/{}/metrics-stream/trials-snapshot".format(experiment_id), + sess = api_utils.user_session() + response = sess.get( + f"api/v1/experiments/{experiment_id}/metrics-stream/trials-snapshot", params={ "metric_name": "accuracy", "metric_type": "METRIC_TYPE_VALIDATION", @@ -261,9 +257,9 @@ def check_trials_sample_result(results: list) -> Union[None, tuple]: def request_train_trials_sample(experiment_id): # type: ignore - response = api.get( - conf.make_master_url(), - "api/v1/experiments/{}/metrics-stream/trials-sample".format(experiment_id), + sess = api_utils.user_session() + response = sess.get( + f"api/v1/experiments/{experiment_id}/metrics-stream/trials-sample", params={ "metric_name": "loss", "metric_type": "METRIC_TYPE_TRAINING", @@ -275,9 +271,9 @@ def request_train_trials_sample(experiment_id): # type: ignore def request_valid_trials_sample(experiment_id): # type: ignore - response = api.get( - conf.make_master_url(), - "api/v1/experiments/{}/metrics-stream/trials-sample".format(experiment_id), + sess = api_utils.user_session() + response = sess.get( + f"api/v1/experiments/{experiment_id}/metrics-stream/trials-sample", params={ "metric_name": "accuracy", "metric_type": "METRIC_TYPE_VALIDATION", @@ -291,14 +287,15 @@ def request_valid_trials_sample(experiment_id): # type: ignore @pytest.mark.e2e_cpu @pytest.mark.parametrize("group", ["validation", "training", "abc"]) def test_trial_time_series(group: str) -> None: + sess = api_utils.user_session() exp_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-one-short-step.yaml"), conf.fixtures_path("no_op"), ["--project_id", str(1), ("--paused")], ) - trials = exp.experiment_trials(exp_id) + trials = exp.experiment_trials(sess, exp_id) trial_id = trials[0].trial.id - sess = api_utils.determined_test_session(admin=False) metric_names = ["lossx"] trial_metrics = bindings.v1TrialMetrics( @@ -325,16 +322,15 @@ def test_trial_time_series(group: str) -> None: @pytest.mark.e2e_cpu def test_trial_describe_metrics() -> None: + sess = api_utils.user_session() exp_id = exp.run_basic_test( - conf.fixtures_path("no_op/single-one-short-step.yaml"), conf.fixtures_path("no_op"), 1 + sess, conf.fixtures_path("no_op/single-one-short-step.yaml"), conf.fixtures_path("no_op"), 1 ) - trials = exp.experiment_trials(exp_id) + trials = exp.experiment_trials(sess, exp_id) trial_id = trials[0].trial.id cmd = [ "det", - "-m", - conf.make_master_url(), "trial", "describe", "--json", @@ -342,7 +338,7 @@ def test_trial_describe_metrics() -> None: str(trial_id), ] - output = json.loads(subprocess.check_output(cmd)) + output = detproc.check_json(sess, cmd) workloads = output["workloads"] assert len(workloads) == 102 @@ -354,7 +350,6 @@ def test_trial_describe_metrics() -> None: assert len(losses) == 100 # assert summary metrics in trial - sess = api_utils.determined_test_session(admin=True) resp = bindings.get_GetTrial(session=sess, trialId=trial_id) summaryMetrics = resp.trial.summaryMetrics assert summaryMetrics is not None diff --git a/e2e_tests/tests/experiment/test_noop.py b/e2e_tests/tests/experiment/test_noop.py index cdb985ad015..74801fa82d7 100644 --- a/e2e_tests/tests/experiment/test_noop.py +++ b/e2e_tests/tests/experiment/test_noop.py @@ -8,7 +8,8 @@ from determined.common import check, util from determined.common.api import bindings -from determined.experimental import Determined +from determined.experimental import client +from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -18,29 +19,31 @@ def test_noop_pause() -> None: """ Walk through starting, pausing, and resuming a single no-op experiment. """ + sess = api_utils.user_session() experiment_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-medium-train-step.yaml"), conf.fixtures_path("no_op"), None, ) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.RUNNING) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.RUNNING) # Wait for the only trial to get scheduled. - exp.wait_for_experiment_active_workload(experiment_id) + exp.wait_for_experiment_active_workload(sess, experiment_id) # Wait for the only trial to show progress, indicating the image is built and running. - exp.wait_for_experiment_workload_progress(experiment_id) + exp.wait_for_experiment_workload_progress(sess, experiment_id) # Pause the experiment. Note that Determined does not currently differentiate # between a "stopping paused" and a "paused" state, so we follow this check # up by ensuring the experiment cleared all scheduled workloads. - exp.pause_experiment(experiment_id) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.PAUSED) + exp.pause_experiment(sess, experiment_id) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.PAUSED) # Wait at most 20 seconds for the experiment to clear all workloads (each # train step should take 5 seconds). for _ in range(20): - workload_active = exp.experiment_has_active_workload(experiment_id) + workload_active = exp.experiment_has_active_workload(sess, experiment_id) if not workload_active: break else: @@ -51,8 +54,8 @@ def test_noop_pause() -> None: ) # Resume the experiment and wait for completion. - exp.activate_experiment(experiment_id) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.COMPLETED) + exp.activate_experiment(sess, experiment_id) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.COMPLETED) @pytest.mark.e2e_cpu @@ -60,12 +63,14 @@ def test_noop_nan_validations() -> None: """ Ensure that NaN validation metric values don't prevent an experiment from completing. """ + sess = api_utils.user_session() experiment_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-nan-validations.yaml"), conf.fixtures_path("no_op"), None, ) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.COMPLETED) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.COMPLETED) @pytest.mark.e2e_cpu @@ -73,11 +78,14 @@ def test_noop_load() -> None: """ Load a checkpoint """ + sess = api_utils.user_session() experiment_id = exp.run_basic_test( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 + ) + trials = exp.experiment_trials(sess, experiment_id) + checkpoint = ( + client.Determined._from_session(sess).get_trial(trials[0].trial.id).top_checkpoint() ) - trials = exp.experiment_trials(experiment_id) - checkpoint = Determined(conf.make_master_url()).get_trial(trials[0].trial.id).top_checkpoint() assert checkpoint.task_id == trials[0].trial.taskId @@ -87,6 +95,7 @@ def test_noop_pause_of_experiment_without_trials() -> None: Walk through starting, pausing, and resuming a single no-op experiment which will never schedule a trial. """ + sess = api_utils.user_session() config_obj = conf.load_config(conf.fixtures_path("no_op/single-one-short-step.yaml")) impossibly_large = 100 config_obj["max_restarts"] = 0 @@ -94,18 +103,18 @@ def test_noop_pause_of_experiment_without_trials() -> None: with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: util.yaml_safe_dump(config_obj, f) - experiment_id = exp.create_experiment(tf.name, conf.fixtures_path("no_op"), None) - exp.pause_experiment(experiment_id) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.PAUSED) + experiment_id = exp.create_experiment(sess, tf.name, conf.fixtures_path("no_op"), None) + exp.pause_experiment(sess, experiment_id) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.PAUSED) - exp.activate_experiment(experiment_id) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.QUEUED) + exp.activate_experiment(sess, experiment_id) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.QUEUED) for _ in range(5): - assert exp.experiment_state(experiment_id) == bindings.experimentv1State.QUEUED + assert exp.experiment_state(sess, experiment_id) == bindings.experimentv1State.QUEUED time.sleep(1) - exp.kill_single(experiment_id) + exp.kill_single(sess, experiment_id) @pytest.mark.e2e_cpu @@ -114,6 +123,7 @@ def test_noop_pause_with_multiexperiment() -> None: Start, pause, and resume a single no-op experiment using the bulk action endpoints and ExperimentIds param. """ + sess = api_utils.user_session() config_obj = conf.load_config(conf.fixtures_path("no_op/single-one-short-step.yaml")) impossibly_large = 100 config_obj["max_restarts"] = 0 @@ -121,13 +131,13 @@ def test_noop_pause_with_multiexperiment() -> None: with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: util.yaml_safe_dump(config_obj, f) - experiment_id = exp.create_experiment(tf.name, conf.fixtures_path("no_op"), None) - exp.pause_experiments([experiment_id]) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.PAUSED) + experiment_id = exp.create_experiment(sess, tf.name, conf.fixtures_path("no_op"), None) + exp.pause_experiments(sess, [experiment_id]) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.PAUSED) - exp.activate_experiments([experiment_id]) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.QUEUED) - exp.kill_experiments([experiment_id]) + exp.activate_experiments(sess, [experiment_id]) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.QUEUED) + exp.kill_experiments(sess, [experiment_id]) @pytest.mark.e2e_cpu @@ -136,6 +146,7 @@ def test_noop_pause_with_multiexperiment_filter() -> None: Pause a single no-op experiment using the bulk action endpoint and Filters param. """ + sess = api_utils.user_session() config_obj = conf.load_config(conf.fixtures_path("no_op/single-one-short-step.yaml")) impossibly_large = 100 config_obj["max_restarts"] = 0 @@ -144,23 +155,24 @@ def test_noop_pause_with_multiexperiment_filter() -> None: config_obj["name"] = tf.name with open(tf.name, "w") as f: util.yaml_safe_dump(config_obj, f) - experiment_id = exp.create_experiment(tf.name, conf.fixtures_path("no_op"), None) - exp.pause_experiments([], name=tf.name) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.PAUSED) + experiment_id = exp.create_experiment(sess, tf.name, conf.fixtures_path("no_op"), None) + exp.pause_experiments(sess, [], name=tf.name) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.PAUSED) # test state=nonTerminalExperimentStates() filter in cancel/kill - exp.kill_experiments([], name=tf.name) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.CANCELED) + exp.kill_experiments(sess, [], name=tf.name) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.CANCELED) # test state=terminalExperimentStates() filter in archive - exp.archive_experiments([], name=tf.name) + exp.archive_experiments(sess, [], name=tf.name) @pytest.mark.e2e_cpu def test_noop_single_warm_start() -> None: + sess = api_utils.user_session() experiment_id1 = exp.run_basic_test( - conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 + sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1 ) - trials = exp.experiment_trials(experiment_id1) + trials = exp.experiment_trials(sess, experiment_id1) assert len(trials) == 1 first_trial = trials[0].trial @@ -182,9 +194,11 @@ def test_noop_single_warm_start() -> None: # Add a source trial ID to warm start from. config_obj["searcher"]["source_trial_id"] = first_trial_id - experiment_id2 = exp.run_basic_test_with_temp_config(config_obj, conf.fixtures_path("no_op"), 1) + experiment_id2 = exp.run_basic_test_with_temp_config( + sess, config_obj, conf.fixtures_path("no_op"), 1 + ) - trials = exp.experiment_trials(experiment_id2) + trials = exp.experiment_trials(sess, experiment_id2) assert len(trials) == 1 second_trial = trials[0] @@ -205,9 +219,9 @@ def test_noop_single_warm_start() -> None: with open(tf.name, "w") as f: util.yaml_safe_dump(config_obj, f) - experiment_id3 = exp.run_basic_test(tf.name, conf.fixtures_path("no_op"), 1) + experiment_id3 = exp.run_basic_test(sess, tf.name, conf.fixtures_path("no_op"), 1) - trials = exp.experiment_trials(experiment_id3) + trials = exp.experiment_trials(sess, experiment_id3) assert len(trials) == 1 third_trial = trials[0] @@ -220,62 +234,72 @@ def test_noop_single_warm_start() -> None: @pytest.mark.e2e_cpu def test_cancel_one_experiment() -> None: + sess = api_utils.user_session() experiment_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-many-long-steps.yaml"), conf.fixtures_path("no_op"), ) - exp.cancel_single(experiment_id) + exp.cancel_single(sess, experiment_id) @pytest.mark.e2e_cpu def test_cancel_one_active_experiment_unready() -> None: + sess = api_utils.user_session() experiment_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-many-long-steps.yaml"), conf.fixtures_path("no_op"), ) for _ in range(15): - if exp.experiment_has_active_workload(experiment_id): + if exp.experiment_has_active_workload(sess, experiment_id): break time.sleep(1) else: raise AssertionError("no workload active after 15 seconds") - exp.cancel_single(experiment_id, should_have_trial=True) + exp.cancel_single(sess, experiment_id, should_have_trial=True) @pytest.mark.e2e_cpu @pytest.mark.timeout(3 * 60) def test_cancel_one_active_experiment_ready() -> None: + sess = api_utils.user_session() experiment_id = exp.create_experiment( + sess, conf.tutorials_path("mnist_pytorch/const.yaml"), conf.tutorials_path("mnist_pytorch"), ) while 1: - if exp.experiment_has_completed_workload(experiment_id): + if exp.experiment_has_completed_workload(sess, experiment_id): break time.sleep(1) - exp.cancel_single(experiment_id, should_have_trial=True) - exp.assert_performed_final_checkpoint(experiment_id) + exp.cancel_single(sess, experiment_id, should_have_trial=True) + exp.assert_performed_final_checkpoint(sess, experiment_id) @pytest.mark.e2e_cpu def test_cancel_one_paused_experiment() -> None: + sess = api_utils.user_session() experiment_id = exp.create_experiment( + sess, conf.fixtures_path("no_op/single-many-long-steps.yaml"), conf.fixtures_path("no_op"), ["--paused"], ) - exp.cancel_single(experiment_id) + exp.cancel_single(sess, experiment_id) @pytest.mark.e2e_cpu def test_cancel_ten_experiments() -> None: + sess = api_utils.user_session() experiment_ids = [ exp.create_experiment( + sess, conf.fixtures_path("no_op/single-many-long-steps.yaml"), conf.fixtures_path("no_op"), ) @@ -283,13 +307,15 @@ def test_cancel_ten_experiments() -> None: ] for experiment_id in experiment_ids: - exp.cancel_single(experiment_id) + exp.cancel_single(sess, experiment_id) @pytest.mark.e2e_cpu def test_cancel_ten_paused_experiments() -> None: + sess = api_utils.user_session() experiment_ids = [ exp.create_experiment( + sess, conf.fixtures_path("no_op/single-many-long-steps.yaml"), conf.fixtures_path("no_op"), ["--paused"], @@ -298,12 +324,14 @@ def test_cancel_ten_paused_experiments() -> None: ] for experiment_id in experiment_ids: - exp.cancel_single(experiment_id) + exp.cancel_single(sess, experiment_id) @pytest.mark.e2e_cpu def test_startup_hook() -> None: + sess = api_utils.user_session() exp.run_basic_test( + sess, conf.fixtures_path("no_op/startup-hook.yaml"), conf.fixtures_path("no_op"), 1, @@ -312,26 +340,29 @@ def test_startup_hook() -> None: @pytest.mark.e2e_cpu def test_large_model_def_experiment() -> None: + sess = api_utils.user_session() with tempfile.TemporaryDirectory() as td: shutil.copy(conf.fixtures_path("no_op/model_def.py"), td) # Write a 94MB file into the directory. Use random data because it is not compressible. with open(os.path.join(td, "junk.txt"), "wb") as f: f.write(os.urandom(94 * 1024 * 1024)) - exp.run_basic_test(conf.fixtures_path("no_op/single-one-short-step.yaml"), td, 1) + exp.run_basic_test(sess, conf.fixtures_path("no_op/single-one-short-step.yaml"), td, 1) @pytest.mark.e2e_cpu def test_noop_experiment_config_override() -> None: + sess = api_utils.user_session() config_obj = conf.load_config(conf.fixtures_path("no_op/single-one-short-step.yaml")) with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: util.yaml_safe_dump(config_obj, f) experiment_id = exp.create_experiment( + sess, tf.name, conf.fixtures_path("no_op"), ["--config", "reproducibility.experiment_seed=8200"], ) - exp_config = exp.experiment_config_json(experiment_id) + exp_config = exp.experiment_config_json(sess, experiment_id) assert exp_config["reproducibility"]["experiment_seed"] == 8200 - exp.kill_single(experiment_id) + exp.kill_single(sess, experiment_id) diff --git a/e2e_tests/tests/experiment/test_noop_hpc.py b/e2e_tests/tests/experiment/test_noop_hpc.py index 34d8cd60761..e8581deef3c 100644 --- a/e2e_tests/tests/experiment/test_noop_hpc.py +++ b/e2e_tests/tests/experiment/test_noop_hpc.py @@ -4,8 +4,9 @@ import pytest -from determined.common import check, yaml +from determined.common import yaml from determined.common.api import bindings +from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -14,6 +15,8 @@ @pytest.mark.e2e_pbs @pytest.mark.timeout(20 * 60) def test_noop_pause_hpc() -> None: + sess = api_utils.user_session() + # The original configuration file, which we will need to modify for HPC # clusters. We choose a configuration file that will create an experiment # that runs long enough to allow us to pause it after the first check @@ -23,14 +26,14 @@ def test_noop_pause_hpc() -> None: config_file = conf.fixtures_path("no_op/single-hpc.yaml") # Walk through starting, pausing, and resuming a single no-op experiment. - experiment_id = exp.create_experiment(config_file, conf.fixtures_path("no_op"), None) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.RUNNING) + experiment_id = exp.create_experiment(sess, config_file, conf.fixtures_path("no_op"), None) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.RUNNING) # Wait for the only trial to get scheduled. - exp.wait_for_experiment_active_workload(experiment_id) + exp.wait_for_experiment_active_workload(sess, experiment_id) # Wait for the only trial to show progress, indicating the image is built and running. - exp.wait_for_experiment_workload_progress(experiment_id) + exp.wait_for_experiment_workload_progress(sess, experiment_id) # If we pause the experiment before it gets to write at least one checkpoint, # then we're really not testing whether the experiment can pick up from where @@ -38,30 +41,28 @@ def test_noop_pause_hpc() -> None: # starts from the beginning upon finding that are no checkpoints to start # from. Therefore, wait a while to give the experiment a chance to write at # least one checkpoint. - exp.wait_for_at_least_one_checkpoint(experiment_id) + exp.wait_for_at_least_one_checkpoint(sess, experiment_id) # Pause the experiment. Note that Determined does not currently differentiate # between a "stopping paused" and a "paused" state, so we follow this check # up by ensuring the experiment cleared all scheduled workloads. - exp.pause_experiment(experiment_id) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.PAUSED) + exp.pause_experiment(sess, experiment_id) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.PAUSED) # Wait at most 420 seconds for the experiment to clear all workloads (each # train step should take 5 seconds). for _ in range(420): - workload_active = exp.experiment_has_active_workload(experiment_id) + workload_active = exp.experiment_has_active_workload(sess, experiment_id) if not workload_active: break else: time.sleep(1) - check.true( - not workload_active, - "The experiment cannot be paused within 420 seconds.", - ) + else: + raise ValueError("The experiment cannot be paused within 420 seconds.") # Resume the experiment and wait for completion. - exp.activate_experiment(experiment_id) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.COMPLETED) + exp.activate_experiment(sess, experiment_id) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.COMPLETED) def remove_item_from_yaml_file(filename: str, item_name: str) -> str: diff --git a/e2e_tests/tests/experiment/test_pending_hpc.py b/e2e_tests/tests/experiment/test_pending_hpc.py index 6d407371c8d..13c2aa36b20 100644 --- a/e2e_tests/tests/experiment/test_pending_hpc.py +++ b/e2e_tests/tests/experiment/test_pending_hpc.py @@ -5,7 +5,6 @@ from determined.common import util from determined.common.api import bindings -from determined.common.api.bindings import experimentv1State from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -15,7 +14,7 @@ # Queries the determined master for resource pool information to determine if # resource pool is suitable for this test. def skip_if_not_suitable_resource_pool() -> None: - session = api_utils.determined_test_session() + session = api_utils.user_session() rps = bindings.get_GetResourcePools(session) assert rps.resourcePools and len(rps.resourcePools) > 0, "missing resource pool" if ( @@ -32,6 +31,7 @@ def skip_if_not_suitable_resource_pool() -> None: @api_utils.skipif_not_hpc() def test_hpc_job_pending_reason() -> None: skip_if_not_suitable_resource_pool() + sess = api_utils.user_session() config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml")) config = conf.set_max_length(config, {"batches": 200}) @@ -51,41 +51,47 @@ def test_hpc_job_pending_reason() -> None: with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: util.yaml_safe_dump(config, f) - running_exp_id = exp.create_experiment(tf.name, conf.tutorials_path("mnist_pytorch"), None) + running_exp_id = exp.create_experiment( + sess, tf.name, conf.tutorials_path("mnist_pytorch"), None + ) print(f"Created experiment {running_exp_id}") - exp.wait_for_experiment_state(running_exp_id, experimentv1State.RUNNING) + exp.wait_for_experiment_state(sess, running_exp_id, bindings.experimentv1State.RUNNING) # Launch another experiment requesting 6 CPUs with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: util.yaml_safe_dump(config, f) - pending_exp_id = exp.create_experiment(tf.name, conf.tutorials_path("mnist_pytorch"), None) + pending_exp_id = exp.create_experiment( + sess, tf.name, conf.tutorials_path("mnist_pytorch"), None + ) print(f"Created experiment {pending_exp_id}") - exp.wait_for_experiment_state(pending_exp_id, experimentv1State.QUEUED) + exp.wait_for_experiment_state(sess, pending_exp_id, bindings.experimentv1State.QUEUED) print(f"Experiment {pending_exp_id} pending") # Kill the first experiment to shorten the test run. First wait for 60 seconds # for the pending job to have a chance to refresh the state and write out the # state reason in experiment logs time.sleep(60) - exp.kill_experiments([running_exp_id]) + exp.kill_experiments(sess, [running_exp_id]) # Make sure the second experiment will start running after the first experinemt # releases the CPUs - exp.wait_for_experiment_state(pending_exp_id, experimentv1State.RUNNING) + exp.wait_for_experiment_state(sess, pending_exp_id, bindings.experimentv1State.RUNNING) print(f"Experiment {pending_exp_id} running") # Now kill the second experiment to shorten the test run - exp.kill_experiments([pending_exp_id]) + exp.kill_experiments(sess, [pending_exp_id]) - trials = exp.experiment_trials(pending_exp_id) + trials = exp.experiment_trials(sess, pending_exp_id) print(f"Check logs for exp {pending_exp_id}") slurm_result = exp.check_if_string_present_in_trial_logs( + sess, trials[0].trial.id, "HPC job waiting to be scheduled: The job is waiting for resources to become available.", ) pbs_result = exp.check_if_string_present_in_trial_logs( + sess, trials[0].trial.id, "HPC job waiting to be scheduled: Not Running: Insufficient amount of resource: ncpus ", ) diff --git a/e2e_tests/tests/experiment/test_port_registry.py b/e2e_tests/tests/experiment/test_port_registry.py index cbfca089c95..a6164bce344 100644 --- a/e2e_tests/tests/experiment/test_port_registry.py +++ b/e2e_tests/tests/experiment/test_port_registry.py @@ -1,19 +1,20 @@ import pytest -from determined.common.api.bindings import experimentv1State +from determined.common.api import bindings +from tests import api_utils from tests import config as conf from tests import experiment as exp -from tests.cluster.test_users import logged_in_user @pytest.mark.port_registry def test_multi_trial_exp_port_registry() -> None: - logged_in_user(conf.ADMIN_CREDENTIALS) + sess = api_utils.user_session() experiment_id = exp.create_experiment( + sess, conf.tutorials_path("mnist_pytorch/dist_random.yaml"), conf.tutorials_path("mnist_pytorch"), ) exp.wait_for_experiment_state( - experiment_id=experiment_id, target_state=experimentv1State.COMPLETED + sess, experiment_id=experiment_id, target_state=bindings.experimentv1State.COMPLETED ) diff --git a/e2e_tests/tests/experiment/test_profiling.py b/e2e_tests/tests/experiment/test_profiling.py index 65db875d3bd..00a6152f800 100644 --- a/e2e_tests/tests/experiment/test_profiling.py +++ b/e2e_tests/tests/experiment/test_profiling.py @@ -1,12 +1,12 @@ import json import tempfile -from typing import Any, Dict, Optional, Sequence -from urllib.parse import urlencode +from typing import Any, Dict, Sequence import pytest from determined.common import api, util -from determined.common.api import authentication, bindings, certs +from determined.common.api import bindings +from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -21,33 +21,34 @@ ], ) def test_streaming_observability_metrics_apis(model_def: str, timings_enabled: bool) -> None: - # TODO: refactor tests to not use cli singleton auth. - certs.cli_cert = certs.default_load(conf.make_master_url()) - authentication.cli_auth = authentication.Authentication(conf.make_master_url()) - + sess = api_utils.user_session() config_path = conf.fixtures_path("mnist_pytorch/const-profiling.yaml") config_obj = conf.load_config(config_path) with tempfile.NamedTemporaryFile() as tf: with open(tf.name, "w") as f: util.yaml_safe_dump(config_obj, f) - experiment_id = exp.create_experiment(tf.name, model_def) + experiment_id = exp.create_experiment(sess, tf.name, model_def) - exp.wait_for_experiment_state(experiment_id, bindings.experimentv1State.COMPLETED) - trials = exp.experiment_trials(experiment_id) + exp.wait_for_experiment_state(sess, experiment_id, bindings.experimentv1State.COMPLETED) + trials = exp.experiment_trials(sess, experiment_id) trial_id = trials[0].trial.id gpu_enabled = conf.GPU_ENABLED - request_profiling_metric_labels(trial_id, timings_enabled, gpu_enabled) + request_profiling_metric_labels(sess, trial_id, timings_enabled, gpu_enabled) if gpu_enabled: - request_profiling_system_metrics(trial_id, "gpu_util") + request_profiling_system_metrics(sess, trial_id, "gpu_util") if timings_enabled: - request_profiling_pytorch_timing_metrics(trial_id, "train_batch") - request_profiling_pytorch_timing_metrics(trial_id, "train_batch.backward", accumulated=True) + request_profiling_pytorch_timing_metrics(sess, trial_id, "train_batch") + request_profiling_pytorch_timing_metrics( + sess, trial_id, "train_batch.backward", accumulated=True + ) -def request_profiling_metric_labels(trial_id: int, timing_enabled: bool, gpu_enabled: bool) -> None: +def request_profiling_metric_labels( + sess: api.Session, trial_id: int, timing_enabled: bool, gpu_enabled: bool +) -> None: def validate_labels(labels: Sequence[Dict[str, Any]]) -> None: # Check some labels against the expected labels. Return the missing labels. expected = { @@ -86,9 +87,8 @@ def validate_labels(labels: Sequence[Dict[str, Any]]) -> None: f"expected completed experiment to have all labels but some are missing: {expected}" ) - with api.get( - conf.make_master_url(), - "api/v1/trials/{}/profiler/available_series".format(trial_id), + with sess.get( + f"api/v1/trials/{trial_id}/profiler/available_series", stream=True, ) as r: for line in r.iter_lines(): @@ -98,7 +98,7 @@ def validate_labels(labels: Sequence[Dict[str, Any]]) -> None: return -def request_profiling_system_metrics(trial_id: int, metric_name: str) -> None: +def request_profiling_system_metrics(sess: api.Session, trial_id: int, metric_name: str) -> None: def validate_gpu_metric_batch(batch: Dict[str, Any]) -> None: num_values = len(batch["values"]) num_batch_indexes = len(batch["batches"]) @@ -111,12 +111,12 @@ def validate_gpu_metric_batch(batch: Dict[str, Any]) -> None: if num_values == 0: pytest.fail(f"received batch of size 0, something went wrong: {batch}") - with api.get( - conf.make_master_url(), - "api/v1/trials/{}/profiler/metrics?{}".format( - trial_id, - to_query_params(PROFILER_METRIC_TYPE_SYSTEM, metric_name), - ), + with sess.get( + f"api/v1/trials/{trial_id}/profiler/metrics", + params={ + "labels.name": metric_name, + "labels.metricType": PROFILER_METRIC_TYPE_SYSTEM, + }, stream=True, ) as r: have_batch = False @@ -129,7 +129,7 @@ def validate_gpu_metric_batch(batch: Dict[str, Any]) -> None: def request_profiling_pytorch_timing_metrics( - trial_id: int, metric_name: str, accumulated: bool = False + sess: api.Session, trial_id: int, metric_name: str, accumulated: bool = False ) -> None: def validate_timing_batch(batch: Dict[str, Any], batch_idx: int) -> int: values = batch["values"] @@ -162,12 +162,12 @@ def validate_timing_batch(batch: Dict[str, Any], batch_idx: int) -> int: return int(batches[-1]) + 1 - with api.get( - conf.make_master_url(), - "api/v1/trials/{}/profiler/metrics?{}".format( - trial_id, - to_query_params(PROFILER_METRIC_TYPE_TIMING, metric_name), - ), + with sess.get( + f"api/v1/trials/{trial_id}/profiler/metrics", + params={ + "labels.name": metric_name, + "labels.metricType": PROFILER_METRIC_TYPE_TIMING, + }, stream=True, ) as r: batch_idx = 0 @@ -182,12 +182,3 @@ def validate_timing_batch(batch: Dict[str, Any], batch_idx: int) -> int: PROFILER_METRIC_TYPE_SYSTEM = "PROFILER_METRIC_TYPE_SYSTEM" PROFILER_METRIC_TYPE_TIMING = "PROFILER_METRIC_TYPE_TIMING" - - -def to_query_params(metric_type: str, metric_name: Optional[str] = None) -> str: - return urlencode( - { - "labels.name": metric_name, - "labels.metricType": metric_type, - } - ) diff --git a/e2e_tests/tests/experiment/test_pytorch.py b/e2e_tests/tests/experiment/test_pytorch.py index 6f065f481b1..9daa0fbd673 100644 --- a/e2e_tests/tests/experiment/test_pytorch.py +++ b/e2e_tests/tests/experiment/test_pytorch.py @@ -2,6 +2,7 @@ import pytest +from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -9,10 +10,13 @@ @pytest.mark.parallel @pytest.mark.e2e_slurm_gpu def test_pytorch_gradient_aggregation() -> None: + sess = api_utils.user_session() config = conf.load_config(conf.fixtures_path("pytorch_identity/distributed.yaml")) - exp_id = exp.run_basic_test_with_temp_config(config, conf.fixtures_path("pytorch_identity"), 1) - trials = exp.experiment_trials(exp_id) + exp_id = exp.run_basic_test_with_temp_config( + sess, config, conf.fixtures_path("pytorch_identity"), 1 + ) + trials = exp.experiment_trials(sess, exp_id) assert len(trials) == 1 workloads = exp.workloads_with_validation(trials[0].workloads) actual_weights = [] diff --git a/e2e_tests/tests/experiment/test_tf_keras.py b/e2e_tests/tests/experiment/test_tf_keras.py index f0229d8b906..76fa60adaeb 100644 --- a/e2e_tests/tests/experiment/test_tf_keras.py +++ b/e2e_tests/tests/experiment/test_tf_keras.py @@ -4,25 +4,28 @@ import pytest from determined import keras +from determined.common import api from determined.experimental import client +from tests import api_utils from tests import config as conf from tests import experiment as exp -def _export_and_load_model(experiment_id: int, master_url: str) -> None: +def _export_and_load_model(sess: api.Session, experiment_id: int, master_url: str) -> None: # Normally verifying that we can load a model would be a good unit test, but making this an e2e # test ensures that our model saving and loading works with all the versions of tf that we test. - ckpt = client.Determined(master_url).get_experiment(experiment_id).top_checkpoint() + ckpt = client.Determined._from_session(sess).get_experiment(experiment_id).top_checkpoint() _ = keras.load_model_from_checkpoint_path(ckpt.download()) -def export_and_load_model(experiment_id: int) -> None: +def export_and_load_model(sess: api.Session, experiment_id: int) -> None: # We run this in a subprocess to avoid module name collisions # when performing checkpoint export of different models. ctx = multiprocessing.get_context("spawn") p = ctx.Process( target=_export_and_load_model, args=( + sess, experiment_id, conf.make_master_url(), ), @@ -37,6 +40,7 @@ def export_and_load_model(experiment_id: int) -> None: def test_tf_keras_parallel( aggregation_frequency: int, collect_trial_profiles: Callable[[int], None] ) -> None: + sess = api_utils.user_session() config = conf.load_config(conf.cv_examples_path("iris_tf_keras/const.yaml")) config = conf.set_slots_per_trial(config, 8) config = conf.set_max_length(config, {"batches": 200}) @@ -45,13 +49,13 @@ def test_tf_keras_parallel( config = conf.set_profiling_enabled(config) experiment_id = exp.run_basic_test_with_temp_config( - config, conf.cv_examples_path("iris_tf_keras"), 1 + sess, config, conf.cv_examples_path("iris_tf_keras"), 1 ) - trials = exp.experiment_trials(experiment_id) + trials = exp.experiment_trials(sess, experiment_id) assert len(trials) == 1 # Test exporting a checkpoint. - export_and_load_model(experiment_id) + export_and_load_model(sess, experiment_id) collect_trial_profiles(trials[0].trial.id) # Check on record/batch counts we emitted in logs. @@ -68,4 +72,4 @@ def test_tf_keras_parallel( f"trained: {scheduling_unit * global_batch_size} records.*in {scheduling_unit} batches", f"validated: {validation_size} records.*in {exp_val_batches} batches", ] - exp.assert_patterns_in_trial_logs(trials[0].trial.id, patterns) + exp.assert_patterns_in_trial_logs(sess, trials[0].trial.id, patterns) diff --git a/e2e_tests/tests/experiment/test_unmanaged.py b/e2e_tests/tests/experiment/test_unmanaged.py index c3a2da0b14d..188d4a812e9 100644 --- a/e2e_tests/tests/experiment/test_unmanaged.py +++ b/e2e_tests/tests/experiment/test_unmanaged.py @@ -54,7 +54,7 @@ def test_unmanaged_checkpoints() -> None: {"DET_TEST_EXTERNAL_EXP_ID": external_id}, ) - sess = api_utils.determined_test_session() + sess = api_utils.user_session() exps = bindings.get_GetExperiments(sess, limit=-1).experiments exps = [exp for exp in exps if exp.externalExperimentId == external_id] assert len(exps) == 1 diff --git a/e2e_tests/tests/filetree.py b/e2e_tests/tests/filetree.py index ef812710f92..bc6263d63de 100644 --- a/e2e_tests/tests/filetree.py +++ b/e2e_tests/tests/filetree.py @@ -1,27 +1,27 @@ +import pathlib import shutil import tempfile -from pathlib import Path -from types import TracebackType +from types import TracebackType # noqa:I2041 from typing import ContextManager, Dict, Optional, Type, Union -class FileTree(ContextManager[Path]): +class FileTree(ContextManager[pathlib.Path]): """ FileTree creates a set of files with their contents in their subdirectories and cleans them up later. """ - def __init__(self, tmp_path: Path, files: Dict[Union[Path, str], str]) -> None: + def __init__(self, tmp_path: pathlib.Path, files: Dict[Union[pathlib.Path, str], str]) -> None: """ Creates a file tree in tempdir with the given filenames and contents. """ self.tmp_path = tmp_path - self.files = {Path(k): v for k, v in files.items()} - self.dir = None # type: Optional[Path] + self.files = {pathlib.Path(k): v for k, v in files.items()} + self.dir = None # type: Optional[pathlib.Path] - def __enter__(self) -> Path: + def __enter__(self) -> pathlib.Path: """Creates FileTree and returns the root directory of the FileTree.""" - self.dir = Path(tempfile.mkdtemp(dir=str(self.tmp_path))) + self.dir = pathlib.Path(tempfile.mkdtemp(dir=str(self.tmp_path))) try: for name, contents in self.files.items(): p = self.dir.joinpath(name) diff --git a/e2e_tests/tests/fixtures/mnist_pytorch/profiling.py b/e2e_tests/tests/fixtures/mnist_pytorch/profiling.py index 6d71cd0daea..4d017a787e1 100644 --- a/e2e_tests/tests/fixtures/mnist_pytorch/profiling.py +++ b/e2e_tests/tests/fixtures/mnist_pytorch/profiling.py @@ -4,7 +4,6 @@ import determined as det from determined import pytorch -from determined.common.api import certs def run(): @@ -13,9 +12,6 @@ def run(): info = det.get_cluster_info() assert info, "Test must be run on cluster." - # TODO: refactor profiling to to not use the cli_cert. - certs.cli_cert = certs.default_load(info.master_url) - with pytorch.init() as train_context: trial = mnist_pytorch.MNistTrial(train_context, hparams=info.trial.hparams) trainer = pytorch.Trainer(trial, train_context) diff --git a/e2e_tests/tests/job/test_rbac.py b/e2e_tests/tests/job/test_rbac.py index 914474e81fa..986e80e0f97 100644 --- a/e2e_tests/tests/job/test_rbac.py +++ b/e2e_tests/tests/job/test_rbac.py @@ -1,86 +1,95 @@ -from typing import Dict +from typing import Callable, Dict, List import pytest import tests.config as conf from determined.common import api -from determined.common.api import NTSC_Kind, bindings, errors -from tests import api_utils +from determined.common.api import bindings, errors +from tests import api_utils, detproc from tests import experiment as exp -from tests.cluster import test_rbac as rbac -from tests.cluster.test_rbac import ( - create_users_with_gloabl_roles, - create_workspaces_with_users, - rbac_disabled, -) -from tests.cluster.test_users import det_run, logged_in_user +from tests.cluster import test_rbac def seed_workspace(ws: bindings.v1Workspace) -> None: """set up each workspace with project, exp, and one of each ntsc""" - admin_session = api_utils.determined_test_session(admin=True) + admin = api_utils.admin_session() pid = bindings.post_PostProject( - admin_session, + admin, body=bindings.v1PostProjectRequest(name="test", workspaceId=ws.id), workspaceId=ws.id, ).project.id - with logged_in_user(conf.ADMIN_CREDENTIALS): - print("creating experiment") - experiment_id = exp.create_experiment( - conf.fixtures_path("no_op/single-very-many-long-steps.yaml"), - conf.fixtures_path("no_op"), - ["--project_id", str(pid)], - ) - print(f"created experiment {experiment_id}") + print("creating experiment") + experiment_id = exp.create_experiment( + admin, + conf.fixtures_path("no_op/single-very-many-long-steps.yaml"), + conf.fixtures_path("no_op"), + ["--project_id", str(pid)], + ) + print(f"created experiment {experiment_id}") + for kind in conf.ALL_NTSC: print(f"creating {kind}") - ntsc = api_utils.launch_ntsc( - admin_session, workspace_id=ws.id, typ=kind, exp_id=experiment_id - ) + ntsc = api_utils.launch_ntsc(admin, workspace_id=ws.id, typ=kind, exp_id=experiment_id) print(f"created {kind} {ntsc.id}") @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_job_global_perm() -> None: - with logged_in_user(conf.ADMIN_CREDENTIALS): - experiment_id = exp.create_experiment( - conf.fixtures_path("no_op/single.yaml"), - conf.fixtures_path("no_op"), - ["--project_id", str(1)], - ) - output = det_run(["job", "ls"]) - assert str(experiment_id) in str(output) + admin = api_utils.admin_session() + experiment_id = exp.create_experiment( + admin, + conf.fixtures_path("no_op/single.yaml"), + conf.fixtures_path("no_op"), + ["--project_id", str(1)], + ) + output = detproc.check_output(admin, ["det", "job", "ls"]) + assert str(experiment_id) in output + + +def run_permission_tests( + action: Callable[[api.Session], None], cases: List[test_rbac.PermCase] +) -> None: + for cred, raises in cases: + if raises is None: + action(cred) + else: + with pytest.raises(raises): + action(cred) @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif( - rbac.strict_q_control_disabled(), - reason="ee, rbac, " + "and strict q control are required for this test", -) +@api_utils.skipif_strict_q_control_not_enabled() def test_job_strict_q_control() -> None: - [cadmin] = create_users_with_gloabl_roles([["ClusterAdmin"]]) - - with create_workspaces_with_users( + admin = api_utils.admin_session() + cadmin, _ = api_utils.create_test_user() + api_utils.assign_user_role( + session=admin, + user=cadmin.username, + role="ClusterAdmin", + workspace=None, + ) + + with test_rbac.create_workspaces_with_users( [ [ (0, ["Editor"]), ], ] ) as (workspaces, creds): - session = api_utils.determined_test_session(creds[0]) - r = api_utils.launch_ntsc(session, typ=NTSC_Kind.command, workspace_id=workspaces[0].id) + r = api_utils.launch_ntsc( + creds[0], typ=api.NTSC_Kind.command, workspace_id=workspaces[0].id + ) cases = [ - rbac.PermCase(creds[0], errors.ForbiddenException), - rbac.PermCase(cadmin, None), + test_rbac.PermCase(creds[0], errors.ForbiddenException), + test_rbac.PermCase(cadmin, None), ] - def action(cred: api.authentication.Credentials) -> None: - session = api_utils.determined_test_session(cred) + def action(sess: api.Session) -> None: bindings.post_UpdateJobQueue( - session, + sess, body=bindings.v1UpdateJobQueueRequest( updates=[ bindings.v1QueueControl(jobId=r.jobId, priority=3), @@ -88,13 +97,13 @@ def action(cred: api.authentication.Credentials) -> None: ), ) - rbac.run_permission_tests(action, cases) + run_permission_tests(action, cases) @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_job_filtering() -> None: - with create_workspaces_with_users( + with test_rbac.create_workspaces_with_users( [ [ (0, ["Viewer", "Editor"]), @@ -114,8 +123,9 @@ def test_job_filtering() -> None: jobs_per_ws = 5 max_jobs = jobs_per_ws * len(workspaces) - expectations: Dict[api.authentication.Credentials, int] = { - conf.ADMIN_CREDENTIALS: max_jobs, + admin = api_utils.admin_session() + expectations: Dict[api.Session, int] = { + admin: max_jobs, creds[0]: max_jobs, creds[1]: jobs_per_ws, creds[2]: jobs_per_ws, @@ -124,14 +134,14 @@ def test_job_filtering() -> None: workspace_ids = {ws.id for ws in workspaces} for cred, visible_count in expectations.items(): - v1_jobs = bindings.get_GetJobs(api_utils.determined_test_session(cred)).jobs + v1_jobs = bindings.get_GetJobs(cred).jobs # filterout jobs from other workspaces as the cluster is shared between tests v1_jobs = [j for j in v1_jobs if j.workspaceId in workspace_ids] assert ( len(v1_jobs) == visible_count ), f"expected {visible_count} jobs for {cred}. {v1_jobs}" - jobs = bindings.get_GetJobsV2(api_utils.determined_test_session(cred)).jobs + jobs = bindings.get_GetJobsV2(cred).jobs full_jobs = [ j for j in jobs if j.full is not None and j.full.workspaceId in workspace_ids ] diff --git a/e2e_tests/tests/model_hub/test_mmdetection.py b/e2e_tests/tests/model_hub/test_mmdetection.py index ec047c31ca8..8c14a226155 100644 --- a/e2e_tests/tests/model_hub/test_mmdetection.py +++ b/e2e_tests/tests/model_hub/test_mmdetection.py @@ -4,6 +4,7 @@ import pytest +from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -27,7 +28,7 @@ def test_maskrcnn_distributed_fake() -> None: config = conf.set_max_length(config, {"batches": 200}) config = set_docker_image(config) - exp.run_basic_test_with_temp_config(config, example_path, 1) + exp.run_basic_test_with_temp_config(api_utils.user_session(), config, example_path, 1) @pytest.mark.model_hub_mmdetection @@ -41,7 +42,7 @@ def test_fasterrcnn_distributed_fake() -> None: config, "config_file", "/mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py" ) - exp.run_basic_test_with_temp_config(config, example_path, 1) + exp.run_basic_test_with_temp_config(api_utils.user_session(), config, example_path, 1) @pytest.mark.model_hub_mmdetection @@ -55,7 +56,7 @@ def test_retinanet_distributed_fake() -> None: config, "config_file", "/mmdetection/configs/retinanet/retinanet_r50_fpn_1x_coco.py" ) - exp.run_basic_test_with_temp_config(config, example_path, 1) + exp.run_basic_test_with_temp_config(api_utils.user_session(), config, example_path, 1) @pytest.mark.model_hub_mmdetection @@ -69,7 +70,7 @@ def test_gfl_distributed_fake() -> None: config, "config_file", "/mmdetection/configs/gfl/gfl_r50_fpn_1x_coco.py" ) - exp.run_basic_test_with_temp_config(config, example_path, 1) + exp.run_basic_test_with_temp_config(api_utils.user_session(), config, example_path, 1) @pytest.mark.model_hub_mmdetection @@ -83,7 +84,7 @@ def test_yolo_distributed_fake() -> None: config, "config_file", "/mmdetection/configs/yolo/yolov3_d53_320_273e_coco.py" ) - exp.run_basic_test_with_temp_config(config, example_path, 1) + exp.run_basic_test_with_temp_config(api_utils.user_session(), config, example_path, 1) @pytest.mark.model_hub_mmdetection @@ -96,4 +97,4 @@ def test_detr_distributed_fake() -> None: config, "config_file", "/mmdetection/configs/detr/detr_r50_8x2_150e_coco.py" ) - exp.run_basic_test_with_temp_config(config, example_path, 1) + exp.run_basic_test_with_temp_config(api_utils.user_session(), config, example_path, 1) diff --git a/e2e_tests/tests/nightly/compute_stats.py b/e2e_tests/tests/nightly/compute_stats.py index d6143dbb228..5d6d5994411 100644 --- a/e2e_tests/tests/nightly/compute_stats.py +++ b/e2e_tests/tests/nightly/compute_stats.py @@ -1,15 +1,13 @@ +import datetime import re -import subprocess import traceback -from datetime import datetime, timedelta, timezone from typing import Any, Dict, Tuple from dateutil import parser from determined.common import api from determined.common.api import bindings -from tests import api_utils -from tests import config as conf +from tests import api_utils, detproc ADD_KEY = "adding" REMOVE_KEY = "removing" @@ -66,12 +64,12 @@ def parse_log_for_gpu_stats(log_path: str) -> Tuple[int, str, str]: agent_uptime_hours = (end - start) / 3600 print(f"{agent_id}: {agent_uptime_hours} hours") - global_start = datetime.fromtimestamp(min_ts, tz=timezone(timedelta(hours=0))).strftime( - "%Y-%m-%dT%H:%M:%S.000Z" - ) - global_end = datetime.fromtimestamp(max_ts, tz=timezone(timedelta(hours=0))).strftime( - "%Y-%m-%dT%H:%M:%S.000Z" - ) + global_start = datetime.datetime.fromtimestamp( + min_ts, tz=datetime.timezone(datetime.timedelta(hours=0)) + ).strftime("%Y-%m-%dT%H:%M:%S.000Z") + global_end = datetime.datetime.fromtimestamp( + max_ts, tz=datetime.timezone(datetime.timedelta(hours=0)) + ).strftime("%Y-%m-%dT%H:%M:%S.000Z") print(f"\nMaster log time period: {global_start} to {global_end} \n") print(f"Total agent up seconds: {total_agent_uptime_sec} ") return total_agent_uptime_sec, global_start, global_end @@ -80,26 +78,26 @@ def parse_log_for_gpu_stats(log_path: str) -> Tuple[int, str, str]: log_path = "/tmp/det-master.log" -def fetch_master_log() -> bool: - command = ["det", "-m", conf.make_master_url(), "master", "logs"] +def fetch_master_log(sess: api.Session) -> bool: + command = ["det", "master", "logs"] try: - output = subprocess.check_output(command) + output = detproc.check_output(sess, command) except Exception: traceback.print_exc() return False - with open(log_path, "wb") as log: + with open(log_path, "w") as log: log.write(output) return True -def compare_stats() -> None: - if not fetch_master_log(): +def compare_stats(sess: api.Session) -> None: + if not fetch_master_log(sess): print("Skip compare stats because error at fetch master") return gpu_from_log, global_start, global_end = parse_log_for_gpu_stats(log_path) try: res = bindings.get_ResourceAllocationRaw( - api_utils.determined_test_session(), + sess, timestampAfter=global_start, timestampBefore=global_end, ) diff --git a/e2e_tests/tests/nightly/test_convergence.py b/e2e_tests/tests/nightly/test_convergence.py index b0e5bd13a94..b0a9359f596 100644 --- a/e2e_tests/tests/nightly/test_convergence.py +++ b/e2e_tests/tests/nightly/test_convergence.py @@ -2,47 +2,52 @@ import pytest -from determined.experimental import client as _client +from determined.experimental import client +from tests import api_utils from tests import config as conf from tests import experiment as exp -def _get_validation_metrics(client: _client.Determined, trial_id: int) -> List[Dict[str, Any]]: - return [m.metrics for m in client.stream_trials_validation_metrics([trial_id])] +def _get_validation_metrics(detobj: client.Determined, trial_id: int) -> List[Dict[str, Any]]: + return [m.metrics for m in detobj.stream_trials_validation_metrics([trial_id])] @pytest.mark.nightly -def test_mnist_pytorch_accuracy(client: _client.Determined) -> None: +def test_mnist_pytorch_accuracy() -> None: + sess = api_utils.user_session() config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml")) experiment_id = exp.run_basic_test_with_temp_config( - config, conf.tutorials_path("mnist_pytorch"), 1 + sess, config, conf.tutorials_path("mnist_pytorch"), 1 ) - trials = exp.experiment_trials(experiment_id) - validations = _get_validation_metrics(client, trials[0].trial.id) + trials = exp.experiment_trials(sess, experiment_id) + detobj = client.Determined._from_session(sess) + validations = _get_validation_metrics(detobj, trials[0].trial.id) validation_accuracies = [v["accuracy"] for v in validations] target_accuracy = 0.97 assert max(validation_accuracies) > target_accuracy, ( - "mnist_pytorch did not reach minimum target accuracy {}." - " full validation accuracy history: {}".format(target_accuracy, validation_accuracies) + f"mnist_pytorch did not reach minimum target accuracy {target_accuracy}." + f" full validation accuracy history: {validation_accuracies}" ) @pytest.mark.nightly -def test_hf_trainer_api_accuracy(client: _client.Determined) -> None: +def test_hf_trainer_api_accuracy() -> None: + sess = api_utils.user_session() test_dir = "hf_image_classification" config = conf.load_config(conf.hf_trainer_examples_path(f"{test_dir}/const.yaml")) experiment_id = exp.run_basic_test_with_temp_config( - config, conf.hf_trainer_examples_path(test_dir), 1 + sess, config, conf.hf_trainer_examples_path(test_dir), 1 ) - trials = exp.experiment_trials(experiment_id) - validations = _get_validation_metrics(client, trials[0].trial.id) + trials = exp.experiment_trials(sess, experiment_id) + detobj = client.Determined._from_session(sess) + validations = _get_validation_metrics(detobj, trials[0].trial.id) validation_accuracies = [v["eval_accuracy"] for v in validations] target_accuracy = 0.82 assert max(validation_accuracies) > target_accuracy, ( - "hf_trainer_api did not reach minimum target accuracy {}." - " full validation accuracy history: {}".format(target_accuracy, validation_accuracies) + f"hf_trainer_api did not reach minimum target accuracy {target_accuracy}." + f" full validation accuracy history: {validation_accuracies}" ) diff --git a/e2e_tests/tests/nightly/test_distributed.py b/e2e_tests/tests/nightly/test_distributed.py index 3e8a9080378..abf3e320ed1 100644 --- a/e2e_tests/tests/nightly/test_distributed.py +++ b/e2e_tests/tests/nightly/test_distributed.py @@ -6,6 +6,7 @@ import pytest from determined.common import util +from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -13,6 +14,7 @@ @pytest.mark.distributed @pytest.mark.parametrize("image_type", ["PT", "TF2", "PT2"]) def test_mnist_pytorch_distributed(image_type: str) -> None: + sess = api_utils.user_session() config = conf.load_config(conf.tutorials_path("mnist_pytorch/distributed.yaml")) config = conf.set_max_length(config, {"batches": 200}) @@ -25,26 +27,29 @@ def test_mnist_pytorch_distributed(image_type: str) -> None: else: warnings.warn("Using default images", stacklevel=2) - exp.run_basic_test_with_temp_config(config, conf.tutorials_path("mnist_pytorch"), 1) + exp.run_basic_test_with_temp_config(sess, config, conf.tutorials_path("mnist_pytorch"), 1) @pytest.mark.distributed def test_mnist_pytorch_set_stop_requested_distributed() -> None: + sess = api_utils.user_session() config = conf.load_config(conf.fixtures_path("mnist_pytorch/distributed-stop-requested.yaml")) - exp.run_basic_test_with_temp_config(config, conf.fixtures_path("mnist_pytorch"), 1) + exp.run_basic_test_with_temp_config(sess, config, conf.fixtures_path("mnist_pytorch"), 1) @pytest.mark.distributed @pytest.mark.gpu_required def test_hf_trainer_api_integration() -> None: + sess = api_utils.user_session() test_dir = "hf_image_classification" config = conf.load_config(conf.hf_trainer_examples_path(f"{test_dir}/distributed.yaml")) - exp.run_basic_test_with_temp_config(config, conf.hf_trainer_examples_path(test_dir), 1) + exp.run_basic_test_with_temp_config(sess, config, conf.hf_trainer_examples_path(test_dir), 1) @pytest.mark.deepspeed @pytest.mark.gpu_required def test_gpt_neox_zero1() -> None: + sess = api_utils.user_session() config = conf.load_config(conf.deepspeed_examples_path("gpt_neox/zero1.yaml")) config = conf.set_max_length(config, {"batches": 100}) config = conf.set_min_validation_period(config, {"batches": 100}) @@ -53,7 +58,7 @@ def test_gpt_neox_zero1() -> None: config["hyperparameters"]["conf_file"] = ["350M.yml", "determined_cluster.yml"] config["hyperparameters"]["overwrite_values"]["train_batch_size"] = 32 - exp.run_basic_test_with_temp_config(config, conf.deepspeed_examples_path("gpt_neox"), 1) + exp.run_basic_test_with_temp_config(sess, config, conf.deepspeed_examples_path("gpt_neox"), 1) HUGGINGFACE_CONTEXT_ERR_MSG = """ @@ -80,6 +85,7 @@ def test_textual_inversion_stable_diffusion_finetune() -> None: The Hugging Face account details can be found at github.com/determined-ai/secrets/blob/master/ci/hugging_face.txt """ + sess = api_utils.user_session() config = conf.load_config( conf.diffusion_examples_path( "textual_inversion_stable_diffusion/finetune_const_advanced.yaml" @@ -91,7 +97,7 @@ def test_textual_inversion_stable_diffusion_finetune() -> None: config, [f'HF_AUTH_TOKEN={os.environ["HF_READ_ONLY_TOKEN"]}'] ) exp.run_basic_test_with_temp_config( - config, conf.diffusion_examples_path("textual_inversion_stable_diffusion"), 1 + sess, config, conf.diffusion_examples_path("textual_inversion_stable_diffusion"), 1 ) except KeyError as k: if str(k) == "'HF_READ_ONLY_TOKEN'": @@ -112,6 +118,7 @@ def test_textual_inversion_stable_diffusion_generate() -> None: The Hugging Face account details can be found at github.com/determined-ai/secrets/blob/master/ci/hugging_face.txt """ + sess = api_utils.user_session() config = conf.load_config( conf.diffusion_examples_path("textual_inversion_stable_diffusion/generate_grid.yaml") ) @@ -127,7 +134,7 @@ def test_textual_inversion_stable_diffusion_generate() -> None: config, [f'HF_AUTH_TOKEN={os.environ["HF_READ_ONLY_TOKEN"]}'] ) exp.run_basic_test_with_temp_config( - config, conf.diffusion_examples_path("textual_inversion_stable_diffusion"), 2 + sess, config, conf.diffusion_examples_path("textual_inversion_stable_diffusion"), 2 ) except KeyError as k: if str(k) == "'HF_READ_ONLY_TOKEN'": @@ -139,6 +146,7 @@ def test_textual_inversion_stable_diffusion_generate() -> None: @pytest.mark.distributed @pytest.mark.gpu_required def test_hf_trainer_image_classification_deepspeed_autotuning() -> None: + sess = api_utils.user_session() test_dir = "hf_image_classification" config_path = conf.hf_trainer_examples_path(f"{test_dir}/deepspeed.yaml") config = conf.load_config(config_path) @@ -148,6 +156,7 @@ def test_hf_trainer_image_classification_deepspeed_autotuning() -> None: # expected_trials=1 in run_basic_autotuning_test because the search runner only generates # a single trial (which in turn generates a second, possibly multi-trial experiment). _ = exp.run_basic_autotuning_test( + sess, tf.name, conf.hf_trainer_examples_path(test_dir), 1, @@ -158,6 +167,7 @@ def test_hf_trainer_image_classification_deepspeed_autotuning() -> None: @pytest.mark.distributed @pytest.mark.gpu_required def test_hf_trainer_language_modeling_deepspeed_autotuning() -> None: + sess = api_utils.user_session() test_dir = "hf_language_modeling" config_path = conf.hf_trainer_examples_path(f"{test_dir}/deepspeed.yaml") config = conf.load_config(config_path) @@ -167,6 +177,7 @@ def test_hf_trainer_language_modeling_deepspeed_autotuning() -> None: # expected_trials=1 in run_basic_autotuning_test because the search runner only generates # a single trial (which in turn generates a second, possibly multi-trial experiment). _ = exp.run_basic_autotuning_test( + sess, tf.name, conf.hf_trainer_examples_path(test_dir), 1, @@ -177,6 +188,7 @@ def test_hf_trainer_language_modeling_deepspeed_autotuning() -> None: @pytest.mark.distributed @pytest.mark.gpu_required def test_torchvision_core_api_deepspeed_autotuning() -> None: + sess = api_utils.user_session() test_dir = "torchvision/core_api" config_path = conf.deepspeed_autotune_examples_path(f"{test_dir}/deepspeed.yaml") config = conf.load_config(config_path) @@ -186,6 +198,7 @@ def test_torchvision_core_api_deepspeed_autotuning() -> None: # expected_trials=1 in run_basic_autotuning_test because the search runner only generates # a single trial (which in turn generates a second, possibly multi-trial experiment). _ = exp.run_basic_autotuning_test( + sess, tf.name, conf.deepspeed_autotune_examples_path(test_dir), 1, @@ -196,6 +209,7 @@ def test_torchvision_core_api_deepspeed_autotuning() -> None: @pytest.mark.distributed @pytest.mark.gpu_required def test_torchvision_deepspeed_trial_deepspeed_autotuning() -> None: + sess = api_utils.user_session() test_dir = "torchvision/deepspeed_trial" config_path = conf.deepspeed_autotune_examples_path(f"{test_dir}/deepspeed.yaml") config = conf.load_config(config_path) @@ -205,6 +219,7 @@ def test_torchvision_deepspeed_trial_deepspeed_autotuning() -> None: # expected_trials=1 in run_basic_autotuning_test because the search runner only generates # a single trial (which in turn generates a second, possibly multi-trial experiment). _ = exp.run_basic_autotuning_test( + sess, tf.name, conf.deepspeed_autotune_examples_path(test_dir), 1, @@ -215,6 +230,7 @@ def test_torchvision_deepspeed_trial_deepspeed_autotuning() -> None: @pytest.mark.distributed @pytest.mark.gpu_required def test_torch_batch_process_generate_embedding() -> None: + sess = api_utils.user_session() config = conf.load_config( conf.features_examples_path("torch_batch_process_embeddings/distributed.yaml") ) @@ -225,4 +241,4 @@ def test_torch_batch_process_generate_embedding() -> None: conf.features_examples_path("torch_batch_process_embeddings"), copy_destination, ) - exp.run_basic_test_with_temp_config(config, copy_destination, 1) + exp.run_basic_test_with_temp_config(sess, config, copy_destination, 1) diff --git a/e2e_tests/tests/nightly/test_pytorch2.py b/e2e_tests/tests/nightly/test_pytorch2.py index a18da90f364..919ebf47688 100644 --- a/e2e_tests/tests/nightly/test_pytorch2.py +++ b/e2e_tests/tests/nightly/test_pytorch2.py @@ -1,5 +1,6 @@ import pytest +from tests import api_utils from tests import config as conf from tests import experiment as exp @@ -8,6 +9,7 @@ @pytest.mark.gpu_required @pytest.mark.e2e_slurm_gpu def test_pytorch2_hf_language_modeling_distributed() -> None: + sess = api_utils.user_session() test_dir = "hf_language_modeling" config = conf.load_config(conf.hf_trainer_examples_path(f"{test_dir}/distributed.yaml")) @@ -23,4 +25,4 @@ def test_pytorch2_hf_language_modeling_distributed() -> None: .replace("--per_device_eval_batch_size 8", "--per_device_eval_batch_size 2"), ) - exp.run_basic_test_with_temp_config(config, conf.hf_trainer_examples_path(test_dir), 1) + exp.run_basic_test_with_temp_config(sess, config, conf.hf_trainer_examples_path(test_dir), 1) diff --git a/e2e_tests/tests/requirements.txt b/e2e_tests/tests/requirements.txt index b0b3749d88b..5a6ed3754ed 100644 --- a/e2e_tests/tests/requirements.txt +++ b/e2e_tests/tests/requirements.txt @@ -2,7 +2,6 @@ appdirs # pytest 6.0 has linter-breaking changes pytest>=6.0.1 pytest-timeout -pexpect torch==1.11.0 torchvision==0.12.0 tensorflow==2.12.0; sys_platform != 'darwin' or platform_machine != 'arm64' diff --git a/e2e_tests/tests/task/test_generic_tasks.py b/e2e_tests/tests/task/test_generic_tasks.py index 87ac81eb299..3a39a65e817 100644 --- a/e2e_tests/tests/task/test_generic_tasks.py +++ b/e2e_tests/tests/task/test_generic_tasks.py @@ -1,5 +1,3 @@ -import subprocess - import pytest from determined.cli import ntsc @@ -7,6 +5,7 @@ from determined.common.api import bindings from tests import api_utils from tests import config as conf +from tests import detproc from tests.task import task @@ -15,6 +14,7 @@ def test_create_generic_task() -> None: """ Start a simple task with a context directory called from the task CLI """ + sess = api_utils.user_session() command = [ "det", "-m", @@ -26,13 +26,12 @@ def test_create_generic_task() -> None: conf.fixtures_path("generic_task"), ] - res = subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE, check=True) + output = detproc.check_output(sess, command) - id_index = res.stdout.find("Created task ") - task_id = res.stdout[id_index + len("Created task ") :].strip() + id_index = output.find("Created task ") + task_id = output[id_index + len("Created task ") :].strip() - test_session = api_utils.determined_test_session() - task.wait_for_task_state(test_session, task_id, bindings.v1GenericTaskState.COMPLETED) + task.wait_for_task_state(sess, task_id, bindings.v1GenericTaskState.COMPLETED) @pytest.mark.e2e_cpu @@ -40,7 +39,7 @@ def test_generic_task_completion() -> None: """ Start a simple task and check for task completion """ - test_session = api_utils.determined_test_session() + sess = api_utils.user_session() with open(conf.fixtures_path("generic_task/test_config.yaml"), "r") as config_file: # Create task @@ -55,10 +54,10 @@ def test_generic_task_completion() -> None: inheritContext=False, noPause=False, ) - task_resp = bindings.post_CreateGenericTask(test_session, body=req) + task_resp = bindings.post_CreateGenericTask(sess, body=req) # Check for complete state - task.wait_for_task_state(test_session, task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) + task.wait_for_task_state(sess, task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) @pytest.mark.e2e_cpu @@ -66,7 +65,7 @@ def test_create_generic_task_error() -> None: """ Start a simple task that fails and check for error task state """ - test_session = api_utils.determined_test_session() + sess = api_utils.user_session() with open(conf.fixtures_path("generic_task/test_config_error.yaml"), "r") as config_file: # Create task @@ -81,10 +80,10 @@ def test_create_generic_task_error() -> None: inheritContext=False, noPause=False, ) - task_resp = bindings.post_CreateGenericTask(test_session, body=req) + task_resp = bindings.post_CreateGenericTask(sess, body=req) # Check for error state - task.wait_for_task_state(test_session, task_resp.taskId, bindings.v1GenericTaskState.ERROR) + task.wait_for_task_state(sess, task_resp.taskId, bindings.v1GenericTaskState.ERROR) @pytest.mark.e2e_cpu @@ -92,7 +91,7 @@ def test_generic_task_config() -> None: """ Start a simple task without a context directory and grab its config """ - test_session = api_utils.determined_test_session() + sess = api_utils.user_session() with open(conf.fixtures_path("generic_task/test_config.yaml"), "r") as config_file: # Create task @@ -107,18 +106,18 @@ def test_generic_task_config() -> None: inheritContext=False, noPause=False, ) - task_resp = bindings.post_CreateGenericTask(test_session, body=req) + task_resp = bindings.post_CreateGenericTask(sess, body=req) # Get config command = ["det", "-m", conf.make_master_url(), "task", "config", task_resp.taskId] - res = subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE, check=True) + output = detproc.check_output(sess, command) - result_config = util.yaml_safe_load(res.stdout) + result_config = util.yaml_safe_load(output) expected_config = {"entrypoint": ["echo", "task ran"]} assert result_config == expected_config - task.wait_for_task_state(test_session, task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) + task.wait_for_task_state(sess, task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) @pytest.mark.e2e_cpu @@ -126,7 +125,7 @@ def test_generic_task_create_with_fork() -> None: """ Start a simple task without a context directory and grab its config """ - test_session = api_utils.determined_test_session() + sess = api_utils.user_session() with open(conf.fixtures_path("generic_task/test_config.yaml"), "r") as config_file: # Create initial task @@ -142,7 +141,7 @@ def test_generic_task_create_with_fork() -> None: inheritContext=False, noPause=False, ) - task_resp = bindings.post_CreateGenericTask(test_session, body=req) + task_resp = bindings.post_CreateGenericTask(sess, body=req) # Create fork task with open(conf.fixtures_path("generic_task/test_config_fork.yaml"), "r") as fork_config_file: @@ -158,20 +157,18 @@ def test_generic_task_create_with_fork() -> None: inheritContext=False, noPause=False, ) - fork_task_resp = bindings.post_CreateGenericTask(test_session, body=req) + fork_task_resp = bindings.post_CreateGenericTask(sess, body=req) # Get fork task Config command = ["det", "-m", conf.make_master_url(), "task", "config", fork_task_resp.taskId] - res = subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE, check=True) - result_config = util.yaml_safe_load(res.stdout) + output = detproc.check_output(sess, command) + result_config = util.yaml_safe_load(output) expected_config = {"entrypoint": ["echo", "forked"]} assert result_config == expected_config - task.wait_for_task_state(test_session, task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) - task.wait_for_task_state( - test_session, fork_task_resp.taskId, bindings.v1GenericTaskState.COMPLETED - ) + task.wait_for_task_state(sess, task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) + task.wait_for_task_state(sess, fork_task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) @pytest.mark.e2e_cpu @@ -179,7 +176,7 @@ def test_kill_generic_task() -> None: """ Start a simple task without a context directory and grab its config """ - test_session = api_utils.determined_test_session() + sess = api_utils.user_session() with open(conf.fixtures_path("generic_task/test_config.yaml"), "r") as config_file: # Create task @@ -195,15 +192,15 @@ def test_kill_generic_task() -> None: inheritContext=False, noPause=False, ) - task_resp = bindings.post_CreateGenericTask(test_session, body=req) + task_resp = bindings.post_CreateGenericTask(sess, body=req) # Kill task command = ["det", "-m", conf.make_master_url(), "task", "kill", task_resp.taskId] - subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE, check=True) + detproc.check_call(sess, command) - bindings.get_GetTask(test_session, taskId=task_resp.taskId) - task.wait_for_task_state(test_session, task_resp.taskId, bindings.v1GenericTaskState.CANCELED) + bindings.get_GetTask(sess, taskId=task_resp.taskId) + task.wait_for_task_state(sess, task_resp.taskId, bindings.v1GenericTaskState.CANCELED) @pytest.mark.e2e_cpu @@ -211,7 +208,7 @@ def test_pause_and_unpause_generic_task() -> None: """ Start a simple task without a context directory and grab its config """ - test_session = api_utils.determined_test_session() + sess = api_utils.user_session() with open(conf.fixtures_path("generic_task/test_config_pause.yaml"), "r") as config_file: # Create task @@ -227,22 +224,22 @@ def test_pause_and_unpause_generic_task() -> None: inheritContext=False, noPause=False, ) - task_resp = bindings.post_CreateGenericTask(test_session, body=req) + task_resp = bindings.post_CreateGenericTask(sess, body=req) # Pause task command = ["det", "-m", conf.make_master_url(), "task", "pause", task_resp.taskId] - subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE, check=True) + detproc.check_call(sess, command) - pause_resp = bindings.get_GetTask(test_session, taskId=task_resp.taskId) + pause_resp = bindings.get_GetTask(sess, taskId=task_resp.taskId) assert pause_resp.task.taskState == bindings.v1GenericTaskState.PAUSED # Unpause task command = ["det", "-m", conf.make_master_url(), "task", "unpause", task_resp.taskId] - subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE, check=True) + detproc.check_call(sess, command) - unpause_resp = bindings.get_GetTask(test_session, taskId=task_resp.taskId) + unpause_resp = bindings.get_GetTask(sess, taskId=task_resp.taskId) assert unpause_resp.task.taskState == bindings.v1GenericTaskState.ACTIVE - task.wait_for_task_state(test_session, task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) + task.wait_for_task_state(sess, task_resp.taskId, bindings.v1GenericTaskState.COMPLETED) diff --git a/e2e_tests/tests/template/template.py b/e2e_tests/tests/template/template.py index 9062baaf9b9..84e1764efe5 100644 --- a/e2e_tests/tests/template/template.py +++ b/e2e_tests/tests/template/template.py @@ -2,36 +2,37 @@ import re import subprocess -from tests import config as conf +from determined.common import api +from tests import detproc -def set_template(template_name: str, template_file: str) -> str: - completed_process = maybe_set_template(template_name, template_file) +def set_template(sess: api.Session, template_name: str, template_file: str) -> str: + completed_process = maybe_set_template(sess, template_name, template_file) assert completed_process.returncode == 0 m = re.search(r"Set template (\w+)", str(completed_process.stdout)) assert m is not None return str(m.group(1)) -def maybe_set_template(template_name: str, template_file: str) -> subprocess.CompletedProcess: +def maybe_set_template( + sess: api.Session, template_name: str, template_file: str +) -> subprocess.CompletedProcess: command = [ "det", - "-m", - conf.make_master_url(), "template", "set", template_name, os.path.join(os.path.dirname(__file__), template_file), ] - return subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE) + return detproc.run(sess, command, universal_newlines=True, stdout=subprocess.PIPE) -def describe_template(template_name: str) -> str: - completed_process = maybe_describe_template(template_name) +def describe_template(sess: api.Session, template_name: str) -> str: + completed_process = maybe_describe_template(sess, template_name) assert completed_process.returncode == 0 return str(completed_process.stdout) -def maybe_describe_template(template_name: str) -> subprocess.CompletedProcess: - command = ["det", "-m", conf.make_master_url(), "template", "describe", template_name] - return subprocess.run(command, universal_newlines=True, stdout=subprocess.PIPE) +def maybe_describe_template(sess: api.Session, template_name: str) -> subprocess.CompletedProcess: + command = ["det", "template", "describe", template_name] + return detproc.run(sess, command, universal_newlines=True, stdout=subprocess.PIPE) diff --git a/e2e_tests/tests/template/test_template.py b/e2e_tests/tests/template/test_template.py index 9008d317681..3c6f1f1f18a 100644 --- a/e2e_tests/tests/template/test_template.py +++ b/e2e_tests/tests/template/test_template.py @@ -1,24 +1,24 @@ -from typing import Optional, Tuple +from typing import Optional import pytest -from determined.common import util -from determined.common.api import NTSC_Kind, Session, bindings, errors +from determined.common import api, util +from determined.common.api import bindings, errors from tests import api_utils from tests import command as cmd from tests import config as conf from tests import experiment as exp from tests import template as tpl -from tests.cluster import test_rbac as rbac -from tests.cluster import test_users as user +from tests.cluster import test_rbac @pytest.mark.e2e_cpu def test_set_template() -> None: + sess = api_utils.user_session() template_name = "test_set_template" template_path = conf.fixtures_path("templates/template.yaml") - tpl.set_template(template_name, template_path) - config = util.yaml_safe_load(tpl.describe_template(template_name)) + tpl.set_template(sess, template_name, template_path) + config = util.yaml_safe_load(tpl.describe_template(sess, template_name)) assert config == conf.load_config(template_path) @@ -26,39 +26,42 @@ def test_set_template() -> None: @pytest.mark.e2e_cpu @pytest.mark.e2e_cpu_cross_version def test_start_notebook_with_template() -> None: + sess = api_utils.user_session() template_name = "test_start_notebook_with_template" - tpl.set_template(template_name, conf.fixtures_path("templates/ntsc.yaml")) + tpl.set_template(sess, template_name, conf.fixtures_path("templates/ntsc.yaml")) with cmd.interactive_command( - "notebook", "start", "--template", template_name, "--detach" + sess, ["notebook", "start", "--template", template_name, "--detach"] ) as nb: - assert "SHOULDBE=SET" in cmd.get_command_config("notebook", str(nb.task_id)) + assert "SHOULDBE=SET" in cmd.get_command_config(sess, "notebook", str(nb.task_id)) @pytest.mark.slow @pytest.mark.e2e_cpu @pytest.mark.e2e_cpu_cross_version def test_start_command_with_template() -> None: + sess = api_utils.user_session() template_name = "test_start_command_with_template" - tpl.set_template(template_name, conf.fixtures_path("templates/ntsc.yaml")) + tpl.set_template(sess, template_name, conf.fixtures_path("templates/ntsc.yaml")) with cmd.interactive_command( - "command", "run", "--template", template_name, "--detach", "sleep infinity" + sess, ["command", "run", "--template", template_name, "--detach", "sleep infinity"] ) as command: - assert "SHOULDBE=SET" in cmd.get_command_config("command", str(command.task_id)) + assert "SHOULDBE=SET" in cmd.get_command_config(sess, "command", str(command.task_id)) @pytest.mark.slow @pytest.mark.e2e_cpu @pytest.mark.e2e_cpu_cross_version def test_start_shell_with_template() -> None: + sess = api_utils.user_session() template_name = "test_start_shell_with_template" - tpl.set_template(template_name, conf.fixtures_path("templates/ntsc.yaml")) + tpl.set_template(sess, template_name, conf.fixtures_path("templates/ntsc.yaml")) with cmd.interactive_command( - "shell", "start", "--template", template_name, "--detach" + sess, ["shell", "start", "--template", template_name, "--detach"] ) as shell: - assert "SHOULDBE=SET" in cmd.get_command_config("shell", str(shell.task_id)) + assert "SHOULDBE=SET" in cmd.get_command_config(sess, "shell", str(shell.task_id)) def assert_templates_equal(t1: bindings.v1Template, t2: bindings.v1Template) -> None: @@ -68,11 +71,10 @@ def assert_templates_equal(t1: bindings.v1Template, t2: bindings.v1Template) -> def setup_template_test( - session: Optional[Session] = None, + sess: api.Session, workspace_id: Optional[int] = None, name: str = "template", -) -> Tuple[Session, bindings.v1Template]: - session = api_utils.determined_test_session() if session is None else session +) -> bindings.v1Template: tpl = bindings.v1Template( name=api_utils.get_random_string(), config=conf.load_config(conf.fixtures_path(f"templates/{name}.yaml")), @@ -80,50 +82,54 @@ def setup_template_test( ) # create - resp = bindings.post_PostTemplate(session, body=tpl, template_name=tpl.name) + resp = bindings.post_PostTemplate(sess, body=tpl, template_name=tpl.name) assert_templates_equal(tpl, resp.template) - return (session, tpl) + return tpl @pytest.mark.e2e_cpu def test_create_template() -> None: - setup_template_test() + sess = api_utils.user_session() + setup_template_test(sess) @pytest.mark.e2e_cpu def test_read_template() -> None: - session, tpl = setup_template_test() + sess = api_utils.user_session() + tpl = setup_template_test(sess) # read - resp = bindings.get_GetTemplate(session, templateName=tpl.name) + resp = bindings.get_GetTemplate(sess, templateName=tpl.name) assert_templates_equal(tpl, resp.template) @pytest.mark.e2e_cpu def test_update_template() -> None: - session, tpl = setup_template_test() + sess = api_utils.user_session() + tpl = setup_template_test(sess) # update tpl.config["description"] = "updated description" - resp = bindings.patch_PatchTemplateConfig(session, body=tpl.config, templateName=tpl.name) + resp = bindings.patch_PatchTemplateConfig(sess, body=tpl.config, templateName=tpl.name) assert_templates_equal(tpl, resp.template) @pytest.mark.e2e_cpu def test_delete_template() -> None: - session, tpl = setup_template_test() + sess = api_utils.user_session() + tpl = setup_template_test(sess) # delete - bindings.delete_DeleteTemplate(session, templateName=tpl.name) + bindings.delete_DeleteTemplate(sess, templateName=tpl.name) with pytest.raises(errors.NotFoundException): - bindings.get_GetTemplate(session, templateName=tpl.name) + bindings.get_GetTemplate(sess, templateName=tpl.name) pytest.fail("template should have been deleted") @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac.rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_rbac_template_create() -> None: - with rbac.create_workspaces_with_users( + with test_rbac.create_workspaces_with_users( [ [ # can create (0, ["Editor"]), @@ -133,18 +139,18 @@ def test_rbac_template_create() -> None: (0, ["Viewer"]), ], ] - ) as (workspaces, creds): - for uid in creds: - setup_template_test(api_utils.determined_test_session(creds[uid]), workspaces[0].id) + ) as (workspaces, sessions): + for sess in sessions.values(): + setup_template_test(sess, workspaces[0].id) with pytest.raises(errors.ForbiddenException): - setup_template_test(api_utils.determined_test_session(creds[uid]), workspaces[1].id) + setup_template_test(sess, workspaces[1].id) @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac.rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_rbac_template_delete() -> None: - admin_session = api_utils.determined_test_session(conf.ADMIN_CREDENTIALS) - with rbac.create_workspaces_with_users( + admin = api_utils.admin_session() + with test_rbac.create_workspaces_with_users( [ [ # can delete (0, ["Editor"]), @@ -155,29 +161,23 @@ def test_rbac_template_delete() -> None: (1, []), ], ] - ) as (workspaces, creds): - for uid in creds: - _, tpl = setup_template_test(admin_session, workspaces[0].id) - bindings.delete_DeleteTemplate( - api_utils.determined_test_session(creds[uid]), templateName=tpl.name - ) + ) as (workspaces, sessions): + for sess in sessions.values(): + tpl = setup_template_test(admin, workspaces[0].id) + bindings.delete_DeleteTemplate(sess, templateName=tpl.name) - _, tpl = setup_template_test(admin_session, workspaces[1].id) + tpl = setup_template_test(admin, workspaces[1].id) with pytest.raises(errors.ForbiddenException): - bindings.delete_DeleteTemplate( - api_utils.determined_test_session(creds[0]), templateName=tpl.name - ) + bindings.delete_DeleteTemplate(sessions[0], templateName=tpl.name) with pytest.raises(errors.NotFoundException): - bindings.delete_DeleteTemplate( - api_utils.determined_test_session(creds[1]), templateName=tpl.name - ) + bindings.delete_DeleteTemplate(sessions[1], templateName=tpl.name) @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac.rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_rbac_template_view() -> None: - admin_session = api_utils.determined_test_session(conf.ADMIN_CREDENTIALS) - with rbac.create_workspaces_with_users( + admin = api_utils.admin_session() + with test_rbac.create_workspaces_with_users( [ [ # can view (0, ["Editor"]), @@ -185,21 +185,20 @@ def test_rbac_template_view() -> None: ], [], # none can view ] - ) as (workspaces, creds): - _, tpl0 = setup_template_test(admin_session, workspaces[0].id) - _, tpl1 = setup_template_test(admin_session, workspaces[1].id) - for uid in creds: - usession = api_utils.determined_test_session(creds[uid]) - bindings.get_GetTemplate(usession, templateName=tpl0.name) + ) as (workspaces, sessions): + tpl0 = setup_template_test(admin, workspaces[0].id) + tpl1 = setup_template_test(admin, workspaces[1].id) + for sess in sessions.values(): + bindings.get_GetTemplate(sess, templateName=tpl0.name) with pytest.raises(errors.NotFoundException): - bindings.get_GetTemplate(usession, templateName=tpl1.name) + bindings.get_GetTemplate(sess, templateName=tpl1.name) @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac.rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_rbac_template_patch_config() -> None: - admin_session = api_utils.determined_test_session(conf.ADMIN_CREDENTIALS) - with rbac.create_workspaces_with_users( + admin = api_utils.admin_session() + with test_rbac.create_workspaces_with_users( [ [ # can update (0, ["Editor"]), @@ -210,30 +209,30 @@ def test_rbac_template_patch_config() -> None: (1, ["Viewer"]), ], ] - ) as (workspaces, creds): - _, tpl0 = setup_template_test(admin_session, workspaces[0].id) - _, tpl1 = setup_template_test(admin_session, workspaces[1].id) - for uid in creds: + ) as (workspaces, sessions): + tpl0 = setup_template_test(admin, workspaces[0].id) + tpl1 = setup_template_test(admin, workspaces[1].id) + for sess in sessions.values(): tpl0.config["description"] = "updated description" bindings.patch_PatchTemplateConfig( - api_utils.determined_test_session(creds[uid]), + sess, body=tpl0.config, templateName=tpl0.name, ) with pytest.raises(errors.ForbiddenException): bindings.patch_PatchTemplateConfig( - api_utils.determined_test_session(creds[uid]), + sess, body=tpl1.config, templateName=tpl1.name, ) @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac.rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() @pytest.mark.parametrize("kind", conf.ALL_NTSC) -def test_rbac_template_ntsc_create(kind: NTSC_Kind) -> None: - admin_session = api_utils.determined_test_session(conf.ADMIN_CREDENTIALS) - with rbac.create_workspaces_with_users( +def test_rbac_template_ntsc_create(kind: api.NTSC_Kind) -> None: + admin = api_utils.admin_session() + with test_rbac.create_workspaces_with_users( [ [ (0, ["Editor"]), @@ -241,41 +240,40 @@ def test_rbac_template_ntsc_create(kind: NTSC_Kind) -> None: ], [], ] - ) as (workspaces, creds): - _, tpl0 = setup_template_test(admin_session, workspaces[0].id, name="ntsc") - _, tpl1 = setup_template_test(admin_session, workspaces[1].id, name="ntsc") + ) as (workspaces, sessions): + tpl0 = setup_template_test(admin, workspaces[0].id, name="ntsc") + tpl1 = setup_template_test(admin, workspaces[1].id, name="ntsc") experiment_id = None pid = bindings.post_PostProject( - admin_session, + admin, body=bindings.v1PostProjectRequest(name="test", workspaceId=workspaces[0].id), workspaceId=workspaces[0].id, ).project.id - with user.logged_in_user(conf.ADMIN_CREDENTIALS): - experiment_id = exp.create_experiment( - conf.fixtures_path("no_op/single.yaml"), - conf.fixtures_path("no_op"), - ["--project_id", str(pid)], - ) - - for uid in creds: - usession = api_utils.determined_test_session(creds[uid]) + experiment_id = exp.create_experiment( + admin, + conf.fixtures_path("no_op/single.yaml"), + conf.fixtures_path("no_op"), + ["--project_id", str(pid)], + ) + + for sess in sessions.values(): api_utils.launch_ntsc( - usession, workspaces[0].id, kind, exp_id=experiment_id, template=tpl0.name + sess, workspaces[0].id, kind, exp_id=experiment_id, template=tpl0.name ) e = None with pytest.raises(errors.APIException) as e: api_utils.launch_ntsc( - usession, workspaces[0].id, kind, exp_id=experiment_id, template=tpl1.name + sess, workspaces[0].id, kind, exp_id=experiment_id, template=tpl1.name ) assert e.value.status_code == 404, e.value.message @pytest.mark.e2e_cpu_rbac -@pytest.mark.skipif(rbac.rbac_disabled(), reason="ee rbac is required for this test") +@api_utils.skipif_rbac_not_enabled() def test_rbac_template_exp_create() -> None: - admin_session = api_utils.determined_test_session(conf.ADMIN_CREDENTIALS) - with rbac.create_workspaces_with_users( + admin = api_utils.admin_session() + with test_rbac.create_workspaces_with_users( [ [ (0, ["Editor"]), @@ -283,27 +281,28 @@ def test_rbac_template_exp_create() -> None: ], [], ] - ) as (workspaces, creds): - _, tpl0 = setup_template_test(admin_session, workspaces[0].id) - _, tpl1 = setup_template_test(admin_session, workspaces[1].id) + ) as (workspaces, sessions): + tpl0 = setup_template_test(admin, workspaces[0].id) + tpl1 = setup_template_test(admin, workspaces[1].id) pid = bindings.post_PostProject( - admin_session, + admin, body=bindings.v1PostProjectRequest(name="test", workspaceId=workspaces[0].id), workspaceId=workspaces[0].id, ).project.id - for uid in creds: - with user.logged_in_user(creds[uid]): - exp.create_experiment( - conf.fixtures_path("no_op/single.yaml"), - conf.fixtures_path("no_op"), - ["--project_id", str(pid), "--template", tpl0.name], - ) - proc = exp.maybe_create_experiment( - conf.fixtures_path("no_op/single.yaml"), - conf.fixtures_path("no_op"), - ["--project_id", str(pid), "--template", tpl1.name], - ) - assert proc.returncode == 1 - assert "not found" in proc.stderr + for sess in sessions.values(): + exp.create_experiment( + sess, + conf.fixtures_path("no_op/single.yaml"), + conf.fixtures_path("no_op"), + ["--project_id", str(pid), "--template", tpl0.name], + ) + proc = exp.maybe_create_experiment( + sess, + conf.fixtures_path("no_op/single.yaml"), + conf.fixtures_path("no_op"), + ["--project_id", str(pid), "--template", tpl1.name], + ) + assert proc.returncode == 1 + assert "not found" in proc.stderr diff --git a/e2e_tests/tests/test_sdk.py b/e2e_tests/tests/test_sdk.py index 082f9829386..581448fa840 100644 --- a/e2e_tests/tests/test_sdk.py +++ b/e2e_tests/tests/test_sdk.py @@ -8,14 +8,16 @@ from determined.common import util from determined.common.api import bindings, errors -from determined.common.experimental import resource_pool -from determined.common.experimental.metrics import TrialMetrics -from determined.experimental import client as _client +from determined.experimental import client +from tests import api_utils from tests import config as conf @pytest.mark.e2e_cpu -def test_completed_experiment_and_checkpoint_apis(client: _client.Determined) -> None: +def test_completed_experiment_and_checkpoint_apis() -> None: + sess = api_utils.user_session() + detobj = client.Determined._from_session(sess) + with open(conf.fixtures_path("no_op/single-one-short-step.yaml")) as f: config = util.yaml_safe_load(f) config["hyperparameters"]["num_validation_metrics"] = 2 @@ -23,10 +25,10 @@ def test_completed_experiment_and_checkpoint_apis(client: _client.Determined) -> emptydir = tempfile.mkdtemp() try: model_def = conf.fixtures_path("no_op/model_def.py") - exp = client.create_experiment(config, emptydir, includes=[model_def]) + exp = detobj.create_experiment(config, emptydir, includes=[model_def]) finally: os.rmdir(emptydir) - exp = client.create_experiment(config, conf.fixtures_path("no_op")) + exp = detobj.create_experiment(config, conf.fixtures_path("no_op")) # Await first trial is safe to call before a trial has started. trial = exp.await_first_trial() @@ -34,7 +36,7 @@ def test_completed_experiment_and_checkpoint_apis(client: _client.Determined) -> # .logs(follow=True) block until the trial completes. all_logs = list(trial.logs(follow=True)) - assert exp.wait() == _client.ExperimentState.COMPLETED + assert exp.wait() == client.ExperimentState.COMPLETED assert all_logs == list(trial.logs()) assert list(trial.logs(head=10)) == all_logs[:10] @@ -42,7 +44,7 @@ def test_completed_experiment_and_checkpoint_apis(client: _client.Determined) -> trials = exp.get_trials() assert len(trials) == 1, trials - assert client.get_trial(trial.id).id == trial.id + assert detobj.get_trial(trial.id).id == trial.id ckpt = trial.top_checkpoint() @@ -66,7 +68,7 @@ def test_completed_experiment_and_checkpoint_apis(client: _client.Determined) -> assert exp.top_checkpoint().uuid == ckpt.uuid assert ckpt.uuid in (c.uuid for c in exp.top_n_checkpoints(100)) - assert client.get_checkpoint(ckpt.uuid).uuid == ckpt.uuid + assert detobj.get_checkpoint(ckpt.uuid).uuid == ckpt.uuid # Adding checkpoint metadata. ckpt.add_metadata({"newkey": "newvalue"}) @@ -74,25 +76,27 @@ def test_completed_experiment_and_checkpoint_apis(client: _client.Determined) -> assert ckpt.metadata assert ckpt.metadata["newkey"] == "newvalue" # Database should be updated. - ckpt = client.get_checkpoint(ckpt.uuid) + ckpt = detobj.get_checkpoint(ckpt.uuid) assert ckpt.metadata assert ckpt.metadata["newkey"] == "newvalue" # Removing checkpoint metadata ckpt.remove_metadata(["newkey"]) assert "newkey" not in ckpt.metadata - ckpt = client.get_checkpoint(ckpt.uuid) + ckpt = detobj.get_checkpoint(ckpt.uuid) assert ckpt.metadata assert "newkey" not in ckpt.metadata # Test creation of experiment without a model definition with open(conf.fixtures_path("no_op/empty_model_dir.yaml")) as f: config = util.yaml_safe_load(f) - exp = client.create_experiment(config) + exp = detobj.create_experiment(config) @pytest.mark.e2e_cpu -def test_checkpoint_apis(client: _client.Determined) -> None: +def test_checkpoint_apis() -> None: + sess = api_utils.user_session() + detobj = client.Determined._from_session(sess) with open(conf.fixtures_path("no_op/single-default-ckpt.yaml")) as f: config = util.yaml_safe_load(f) @@ -102,12 +106,12 @@ def test_checkpoint_apis(client: _client.Determined) -> None: config["checkpoint_storage"] = {} config["checkpoint_storage"]["save_trial_best"] = 10 - exp = client.create_experiment(config, conf.fixtures_path("no_op")) + exp = detobj.create_experiment(config, conf.fixtures_path("no_op")) # Await first trial is safe to call before a trial has started. trial = exp.await_first_trial() - assert exp.wait() == _client.ExperimentState.COMPLETED + assert exp.wait() == client.ExperimentState.COMPLETED trials = exp.get_trials() assert len(trials) == 1, trials @@ -116,14 +120,14 @@ def test_checkpoint_apis(client: _client.Determined) -> None: # Validate end (report) time sorting. checkpoints = trial.get_checkpoints( - sort_by=_client.CheckpointSortBy.END_TIME, order_by=_client.OrderBy.DESC + sort_by=client.CheckpointSortBy.END_TIME, order_by=client.OrderBy.DESC ) end_times = [checkpoint.report_time for checkpoint in checkpoints] assert all(x >= y for x, y in zip(end_times, end_times[1:])) # type: ignore # Validate state sorting. checkpoints = trial.get_checkpoints( - sort_by=_client.CheckpointSortBy.STATE, order_by=_client.OrderBy.ASC + sort_by=client.CheckpointSortBy.STATE, order_by=client.OrderBy.ASC ) states = [] for checkpoint in checkpoints: @@ -133,14 +137,14 @@ def test_checkpoint_apis(client: _client.Determined) -> None: # Validate UUID sorting. checkpoints = trial.get_checkpoints( - sort_by=_client.CheckpointSortBy.UUID, order_by=_client.OrderBy.ASC + sort_by=client.CheckpointSortBy.UUID, order_by=client.OrderBy.ASC ) uuids = [checkpoint.uuid for checkpoint in checkpoints] assert all(x <= y for x, y in zip(uuids, uuids[1:])) # Validate batch number sorting. checkpoints = trial.get_checkpoints( - sort_by=_client.CheckpointSortBy.BATCH_NUMBER, order_by=_client.OrderBy.DESC + sort_by=client.CheckpointSortBy.BATCH_NUMBER, order_by=client.OrderBy.DESC ) batch_numbers = [] for checkpoint in checkpoints: @@ -149,7 +153,7 @@ def test_checkpoint_apis(client: _client.Determined) -> None: assert all(x >= y for x, y in zip(batch_numbers, batch_numbers[1:])) # Validate metric sorting. - checkpoints = trial.get_checkpoints(sort_by="validation_error", order_by=_client.OrderBy.ASC) + checkpoints = trial.get_checkpoints(sort_by="validation_error", order_by=client.OrderBy.ASC) validation_metrics = [ checkpoint.training.validation_metrics["avgMetrics"]["validation_error"] # type: ignore for checkpoint in checkpoints @@ -160,7 +164,7 @@ def test_checkpoint_apis(client: _client.Determined) -> None: checkpoints = [ checkpoint for checkpoint in checkpoints - if checkpoint.state == _client.CheckpointState.COMPLETED + if checkpoint.state == client.CheckpointState.COMPLETED ] assert len(checkpoints) == 10 @@ -176,7 +180,7 @@ def test_checkpoint_apis(client: _client.Determined) -> None: deleted_checkpoints = [ checkpoint for checkpoint in checkpoints - if checkpoint.state == _client.CheckpointState.DELETED + if checkpoint.state == client.CheckpointState.DELETED ] if deleted_checkpoints: break @@ -199,7 +203,7 @@ def test_checkpoint_apis(client: _client.Determined) -> None: partially_deleted_checkpoints = [ checkpoint for checkpoint in checkpoints - if checkpoint.state == _client.CheckpointState.PARTIALLY_DELETED + if checkpoint.state == client.CheckpointState.PARTIALLY_DELETED ] if partially_deleted_checkpoints: break @@ -231,7 +235,7 @@ def test_checkpoint_apis(client: _client.Determined) -> None: deleted_checkpoints = [ checkpoint for checkpoint in checkpoints - if checkpoint.state == _client.CheckpointState.DELETED + if checkpoint.state == client.CheckpointState.DELETED and checkpoint.uuid == partially_deleted_checkpoint.uuid ] if deleted_checkpoints: @@ -240,12 +244,12 @@ def test_checkpoint_apis(client: _client.Determined) -> None: time.sleep(0.1) -def _make_live_experiment(client: _client.Determined) -> _client.Experiment: +def _make_live_experiment(detobj: client.Determined) -> client.Experiment: # Create an experiment that takes a long time to run with open(conf.fixtures_path("no_op/single-very-many-long-steps.yaml")) as f: config = util.yaml_safe_load(f) - exp = client.create_experiment(config, conf.fixtures_path("no_op")) + exp = detobj.create_experiment(config, conf.fixtures_path("no_op")) # Wait for a trial to actually start. start = time.time() deadline = start + 90 @@ -260,8 +264,10 @@ def _make_live_experiment(client: _client.Determined) -> _client.Experiment: @pytest.mark.e2e_cpu -def test_experiment_manipulation(client: _client.Determined) -> None: - exp = _make_live_experiment(client) +def test_experiment_manipulation() -> None: + sess = api_utils.user_session() + detobj = client.Determined._from_session(sess) + exp = _make_live_experiment(detobj) exp.pause() with pytest.raises(ValueError, match="Make sure the experiment is active"): @@ -271,7 +277,7 @@ def test_experiment_manipulation(client: _client.Determined) -> None: exp.activate() exp.cancel() - assert exp.wait() == _client.ExperimentState.CANCELED + assert exp.wait() == client.ExperimentState.CANCELED assert isinstance(exp.config, dict) @@ -280,37 +286,39 @@ def test_experiment_manipulation(client: _client.Determined) -> None: deleting_exp = exp # Create another experiment and kill it. - exp = _make_live_experiment(client) + exp = _make_live_experiment(detobj) exp.kill() - assert exp.wait() == _client.ExperimentState.CANCELED + assert exp.wait() == client.ExperimentState.CANCELED # Test remaining methods exp.archive() - assert bindings.get_GetExperiment(client._session, experimentId=exp.id).experiment.archived + assert bindings.get_GetExperiment(sess, experimentId=exp.id).experiment.archived exp.unarchive() - assert not bindings.get_GetExperiment(client._session, experimentId=exp.id).experiment.archived + assert not bindings.get_GetExperiment(sess, experimentId=exp.id).experiment.archived # Create another experiment and kill its trial. - exp = _make_live_experiment(client) + exp = _make_live_experiment(detobj) trial = exp.get_trials()[0] trial.kill() - assert exp.wait() == _client.ExperimentState.CANCELED + assert exp.wait() == client.ExperimentState.CANCELED # Make sure that the experiment we deleted earlier does actually delete. with pytest.raises(errors.APIException): for _ in range(300): - client.get_experiment(deleting_exp.id).get_trials() + detobj.get_experiment(deleting_exp.id).get_trials() time.sleep(0.1) @pytest.mark.e2e_cpu -def test_models(client: _client.Determined) -> None: +def test_models() -> None: + sess = api_utils.user_session() + detobj = client.Determined._from_session(sess) suffix = [random.sample("abcdefghijklmnopqrstuvwxyz", 1) for _ in range(10)] model_name = f"test-model-{suffix}" - model = client.create_model(model_name) + model = detobj.create_model(model_name) try: - assert model_name in (m.name for m in client.get_models()) + assert model_name in (m.name for m in detobj.get_models()) model.archive() model.unarchive() @@ -321,12 +329,12 @@ def test_models(client: _client.Determined) -> None: model.set_description("modeldescr") # Check cached values - assert set(client.get_model_labels()) == set(labels) + assert set(detobj.get_model_labels()) == set(labels) assert model.metadata == {"a": 1, "b": 2, "c": 3}, model.metadata assert model.description == "modeldescr", model.description # avoid false-positives due to caching on the model object itself - model = client.get_model(model_name) + model = detobj.get_model(model_name) assert model.labels assert set(model.labels) == set(labels) assert model.metadata == {"a": 1, "b": 2, "c": 3}, model.metadata @@ -337,7 +345,7 @@ def test_models(client: _client.Determined) -> None: # break the cache again, testing get_model_by_id. assert model.model_id is not None, "model_id was populated by create_model" - model = client.get_model_by_id(model.model_id) + model = detobj.get_model_by_id(model.model_id) assert model.labels == [] assert model.metadata == {"c": 3}, model.metadata @@ -345,16 +353,18 @@ def test_models(client: _client.Determined) -> None: model.delete() with pytest.raises(errors.APIException): - client.get_model(model_name) + detobj.get_model(model_name) @pytest.mark.e2e_cpu -def test_stream_metrics(client: _client.Determined) -> None: +def test_stream_metrics() -> None: + sess = api_utils.user_session() + detobj = client.Determined._from_session(sess) with open(conf.fixtures_path("no_op/single-one-short-step.yaml")) as f: config = util.yaml_safe_load(f) config["hyperparameters"]["num_validation_metrics"] = 2 - exp = client.create_experiment(config, conf.fixtures_path("no_op")) - assert exp.wait() == _client.ExperimentState.COMPLETED + exp = detobj.create_experiment(config, conf.fixtures_path("no_op")) + assert exp.wait() == client.ExperimentState.COMPLETED trials = exp.get_trials() assert len(trials) == 1 @@ -362,11 +372,11 @@ def test_stream_metrics(client: _client.Determined) -> None: for metrics in [ list(trial.stream_metrics("training")), - list(client.stream_trials_metrics([trial.id], "training")), + list(detobj.stream_trials_metrics([trial.id], "training")), ]: assert len(metrics) == config["searcher"]["max_length"]["batches"] for i, actual in enumerate(metrics): - assert actual == TrialMetrics( + assert actual == client.TrialMetrics( trial_id=trial.id, trial_run_id=1, steps_completed=i + 1, @@ -378,10 +388,10 @@ def test_stream_metrics(client: _client.Determined) -> None: for val_metrics in [ list(trial.stream_metrics("validation")), - list(client.stream_trials_metrics([trial.id], "validation")), + list(detobj.stream_trials_metrics([trial.id], "validation")), ]: assert len(val_metrics) == 1 - assert val_metrics[0] == TrialMetrics( + assert val_metrics[0] == client.TrialMetrics( trial_id=trial.id, trial_run_id=1, steps_completed=100, @@ -395,16 +405,18 @@ def test_stream_metrics(client: _client.Determined) -> None: @pytest.mark.e2e_cpu -def test_model_versions(client: _client.Determined) -> None: +def test_model_versions() -> None: + sess = api_utils.user_session() + detobj = client.Determined._from_session(sess) with open(conf.fixtures_path("no_op/single-one-short-step.yaml")) as f: config = util.yaml_safe_load(f) - exp = client.create_experiment(config, conf.fixtures_path("no_op")) - assert exp.wait() == _client.ExperimentState.COMPLETED + exp = detobj.create_experiment(config, conf.fixtures_path("no_op")) + assert exp.wait() == client.ExperimentState.COMPLETED ckpt = exp.top_checkpoint() suffix = [random.sample("abcdefghijklmnopqrstuvwxyz", 1) for _ in range(10)] model_name = f"test-model-{suffix}" - model = client.create_model(model_name) + model = detobj.create_model(model_name) try: ver = model.register_version(ckpt.uuid) @@ -441,7 +453,8 @@ def test_model_versions(client: _client.Determined) -> None: @pytest.mark.e2e_cpu -def test_rp_workspace_mapping(client: _client.Determined) -> None: +def test_rp_workspace_mapping() -> None: + sess = api_utils.user_session() workspace_names = ["Workspace A", "Workspace B"] overwrite_workspace_names = ["Workspace C", "Workspace D"] rp_names = ["default"] # TODO: not sure how to add more rp @@ -449,17 +462,15 @@ def test_rp_workspace_mapping(client: _client.Determined) -> None: for wn in workspace_names + overwrite_workspace_names: req = bindings.v1PostWorkspaceRequest(name=wn) - workspace_ids.append( - bindings.post_PostWorkspace(session=client._session, body=req).workspace.id - ) + workspace_ids.append(bindings.post_PostWorkspace(sess, body=req).workspace.id) try: with pytest.raises( errors.APIException, match="default resource pool default cannot be bound to any workspace", ): - rp = resource_pool.ResourcePool(client._session, rp_names[0]) + rp = client.ResourcePool(sess, rp_names[0]) rp.add_bindings(workspace_names) finally: for wid in workspace_ids: - bindings.delete_DeleteWorkspace(session=client._session, id=wid) + bindings.delete_DeleteWorkspace(session=sess, id=wid) diff --git a/harness/determined/_trial_controller.py b/harness/determined/_trial_controller.py index 377ae21d6a7..6ad17b947dd 100644 --- a/harness/determined/_trial_controller.py +++ b/harness/determined/_trial_controller.py @@ -5,6 +5,7 @@ import determined as det from determined import profiler, tensorboard, workload +from determined.common import api class _DistributedBackend: @@ -27,6 +28,31 @@ def use_deepspeed(self) -> bool: return bool(os.environ.get(self.DEEPSPEED, None)) +def _profiler_agent_from_env( + session: api.Session, env: det.EnvContext, global_rank: int, local_rank: int +) -> profiler.ProfilerAgent: + """ + This used to be ProfilerAgent.from_env(), but it was demoted to being a helper function here. + + The purpose of demoting it is isolating the EnvContext object to the smallest footprint + possible. As EnvContext was part of the legacy Trial-centric harness architecture, and as this + functionality was only required in this legacy file, this is a good home for it. + """ + + begin_on_batch, end_after_batch = env.experiment_config.profiling_interval() + return profiler.ProfilerAgent( + session=session, + trial_id=env.det_trial_id, + agent_id=env.det_agent_id, + profiling_is_enabled=env.experiment_config.profiling_enabled(), + global_rank=global_rank, + local_rank=local_rank, + begin_on_batch=begin_on_batch, + end_after_batch=end_after_batch, + sync_timings=env.experiment_config.profiling_sync_timings(), + ) + + class TrialController(metaclass=abc.ABCMeta): """ TrialController is the legacy class that represented the Determined-owned logic to interact with @@ -44,11 +70,14 @@ def __init__( # The only time that workloads should be non-None here is unit tests or test mode. self.workloads = workloads - self.prof = profiler.ProfilerAgent.from_env( - env, - global_rank=context.distributed.rank, - local_rank=context.distributed.local_rank, - ) + if hasattr(context._core.train, "_session"): + # XXX: stealing this session feels _horrible_ + sess = context._core.train._session + self.prof = _profiler_agent_from_env( + sess, env, context.distributed.rank, context.distributed.local_rank + ) + else: + self.prof = profiler.DummyProfilerAgent() distributed_backend = _DistributedBackend() self.use_horovod = distributed_backend.use_horovod() diff --git a/harness/determined/cli/__init__.py b/harness/determined/cli/__init__.py index 836160db3bb..5235bf1c6e3 100644 --- a/harness/determined/cli/__init__.py +++ b/harness/determined/cli/__init__.py @@ -2,9 +2,9 @@ output_format_args, make_pagination_args, default_pagination_args, + unauth_session, setup_session, require_feature_flag, - login_sdk_client, print_launch_warnings, wait_ntsc_ready, warn, @@ -30,3 +30,10 @@ user, workspace, ) + +from determined.common.api import certs as _certs +from typing import Optional as _Optional + +# cert is a singleton that we configure very early in the cli's main() function, before any cli +# subcommand handlers are invoked. +cert: _Optional[_certs.Cert] = None diff --git a/harness/determined/cli/_util.py b/harness/determined/cli/_util.py index a5b5a4d768a..577e82b0369 100644 --- a/harness/determined/cli/_util.py +++ b/harness/determined/cli/_util.py @@ -1,14 +1,13 @@ import argparse -import functools import sys from typing import Any, Callable, Dict, List, Sequence import termcolor +from determined import cli from determined.cli import errors, render from determined.common import api, declarative_argparse, util -from determined.common.api import authentication, bindings, certs -from determined.experimental import client +from determined.common.api import authentication, bindings output_format_args: Dict[str, declarative_argparse.Arg] = { "json": declarative_argparse.Arg( @@ -85,28 +84,27 @@ def make_pagination_args( default_pagination_args = make_pagination_args() -def login_sdk_client(func: Callable[[argparse.Namespace], Any]) -> Callable[..., Any]: - @functools.wraps(func) - def f(namespace: argparse.Namespace) -> Any: - client.login(master=namespace.master, user=namespace.user) - return func(namespace) - - return f +def unauth_session(args: argparse.Namespace) -> api.UnauthSession: + master_url = args.master or util.get_default_master_address() + return api.UnauthSession(master=master_url, cert=cli.cert) def setup_session(args: argparse.Namespace) -> api.Session: master_url = args.master or util.get_default_master_address() - cert = certs.default_load(master_url) - retry = api.default_retry() - - return api.Session(master_url, args.user, authentication.cli_auth, cert, retry) + utp = authentication.login_with_cache( + master_address=master_url, + requested_user=args.user, + password=None, + cert=cli.cert, + ) + return api.Session(master_url, utp, cli.cert, api.default_retry()) def require_feature_flag(feature_flag: str, error_message: str) -> Callable[..., Any]: def decorator(function: Callable[..., Any]) -> Callable[..., Any]: def wrapper(args: argparse.Namespace) -> None: - resp = bindings.get_GetMaster(setup_session(args)) - if not resp.to_json().get("rbacEnabled"): + resp = bindings.get_GetMaster(unauth_session(args)) + if not resp.rbacEnabled: raise errors.FeatureFlagDisabled(error_message) function(args) @@ -126,7 +124,7 @@ def wait_ntsc_ready(session: api.Session, ntsc_type: api.NTSC_Kind, eid: str) -> """ name = ntsc_type.value loading_animator = render.Animator(f"Waiting for {name} to become ready") - err_msg = api.task_is_ready( + err_msg = api.wait_for_task_ready( session=session, task_id=eid, progress_report=loading_animator.next, diff --git a/harness/determined/cli/agent.py b/harness/determined/cli/agent.py index aa186f08b51..cfd18ad8a53 100644 --- a/harness/determined/cli/agent.py +++ b/harness/determined/cli/agent.py @@ -10,8 +10,8 @@ from determined import cli from determined.cli import errors, render from determined.cli import task as cli_task -from determined.common import api, check -from determined.common.api import authentication, bindings +from determined.common import check +from determined.common.api import bindings from determined.common.declarative_argparse import Arg, Cmd, Group NO_PERMISSIONS = "NO PERMISSIONS" @@ -21,9 +21,9 @@ def local_id(address: str) -> str: return os.path.basename(address) -@authentication.required def list_agents(args: argparse.Namespace) -> None: - resp = bindings.get_GetAgents(cli.setup_session(args)) + sess = cli.setup_session(args) + resp = bindings.get_GetAgents(sess) agents = [ collections.OrderedDict( @@ -64,10 +64,10 @@ def list_agents(args: argparse.Namespace) -> None: render.tabulate_or_csv(headers, values, args.csv) -@authentication.required def list_slots(args: argparse.Namespace) -> None: - task_res = bindings.get_GetTasks(cli.setup_session(args)) - resp = bindings.get_GetAgents(cli.setup_session(args)) + sess = cli.setup_session(args) + task_res = bindings.get_GetTasks(sess) + resp = bindings.get_GetAgents(sess) allocations = task_res.allocationIdToSummary @@ -162,8 +162,8 @@ def get_task_name(containers: Dict[str, Any], slot: bindings.v1Slot) -> str: def patch_agent(enabled: bool) -> Callable[[argparse.Namespace], None]: - @authentication.required def patch(args: argparse.Namespace) -> None: + sess = cli.setup_session(args) check.check_false(args.all and args.agent_id) action = "enable" if enabled else "disable" @@ -176,7 +176,7 @@ def patch(args: argparse.Namespace) -> None: if args.agent_id: agent_ids = [args.agent_id] else: - resp = bindings.get_GetAgents(cli.setup_session(args)) + resp = bindings.get_GetAgents(sess) agent_ids = sorted(local_id(a.id) for a in resp.agents or []) drain_mode = None if enabled else args.drain @@ -190,14 +190,14 @@ def patch(args: argparse.Namespace) -> None: "drain": drain_mode, } - api.post(args.master, path, payload) + sess.post(path, json=payload) status = "Disabled" if not enabled else "Enabled" print(f"{status} agent {agent_id}.", file=sys.stderr) # When draining, check if there're any tasks currently running on # these slots, and list them. if drain_mode: - rsp = bindings.get_GetTasks(cli.setup_session(args)) + rsp = bindings.get_GetTasks(sess) tasks_data = { k: t for (k, t) in ( @@ -220,15 +220,13 @@ def patch(args: argparse.Namespace) -> None: def patch_slot(enabled: bool) -> Callable[[argparse.Namespace], None]: - @authentication.required def patch(args: argparse.Namespace) -> None: + sess = cli.setup_session(args) if enabled: - bindings.post_EnableSlot( - cli.setup_session(args), agentId=args.agent_id, slotId=args.slot_id - ) + bindings.post_EnableSlot(sess, agentId=args.agent_id, slotId=args.slot_id) else: bindings.post_DisableSlot( - cli.setup_session(args), + sess, agentId=args.agent_id, slotId=args.slot_id, body=bindings.v1DisableSlotRequest(), diff --git a/harness/determined/cli/checkpoint.py b/harness/determined/cli/checkpoint.py index 8e304f97289..041bfd670dd 100644 --- a/harness/determined/cli/checkpoint.py +++ b/harness/determined/cli/checkpoint.py @@ -4,9 +4,9 @@ from determined import cli, errors, experimental from determined.cli import render -from determined.common.api import authentication, bindings +from determined.common.api import bindings from determined.common.declarative_argparse import Arg, Cmd -from determined.experimental.client import DownloadMode +from determined.experimental import client def render_checkpoint(checkpoint: experimental.Checkpoint, path: Optional[str] = None) -> None: @@ -38,14 +38,14 @@ def render_checkpoint(checkpoint: experimental.Checkpoint, path: Optional[str] = render.tabulate_or_csv(headers, [values], False) -@authentication.required def list_checkpoints(args: argparse.Namespace) -> None: + sess = cli.setup_session(args) if args.best: sorter = bindings.checkpointv1SortBy.SEARCHER_METRIC else: sorter = bindings.checkpointv1SortBy.END_TIME r = bindings.get_GetExperimentCheckpoints( - cli.setup_session(args), + sess, id=args.experiment_id, limit=args.best, sortByAttr=sorter, @@ -92,7 +92,9 @@ def get_validation_metric(c: bindings.v1Checkpoint, metric: str) -> str: def download(args: argparse.Namespace) -> None: - checkpoint = experimental.Determined(args.master, args.user).get_checkpoint(args.uuid) + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + checkpoint = d.get_checkpoint(args.uuid) try: path = checkpoint.download(path=args.output_dir, mode=args.mode) @@ -106,26 +108,28 @@ def download(args: argparse.Namespace) -> None: def describe(args: argparse.Namespace) -> None: - checkpoint = experimental.Determined(args.master, args.user).get_checkpoint(args.uuid) + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + checkpoint = d.get_checkpoint(args.uuid) render_checkpoint(checkpoint) -@authentication.required def delete_checkpoints(args: argparse.Namespace) -> None: + sess = cli.setup_session(args) if args.yes or render.yes_or_no( "Deleting checkpoints will result in deletion of all data associated\n" "with each checkpoint in the checkpoint storage. Do you still want to proceed?" ): c_uuids = args.checkpoints_uuids.split(",") delete_body = bindings.v1DeleteCheckpointsRequest(checkpointUuids=c_uuids) - bindings.delete_DeleteCheckpoints(cli.setup_session(args), body=delete_body) + bindings.delete_DeleteCheckpoints(sess, body=delete_body) print("Deletion of checkpoints {} is in progress".format(args.checkpoints_uuids)) else: print("Stopping deletion of checkpoints.") -@authentication.required def checkpoints_file_rm(args: argparse.Namespace) -> None: + sess = cli.setup_session(args) if ( args.yes or len(args.glob) == 0 @@ -139,7 +143,7 @@ def checkpoints_file_rm(args: argparse.Namespace) -> None: checkpointGlobs=args.glob, checkpointUuids=c_uuids, ) - bindings.post_CheckpointsRemoveFiles(cli.setup_session(args), body=remove_body) + bindings.post_CheckpointsRemoveFiles(sess, body=remove_body) if len(args.glob) == 0: print( @@ -177,15 +181,15 @@ def checkpoints_file_rm(args: argparse.Namespace) -> None: ), Arg( "--mode", - choices=list(DownloadMode), - default=DownloadMode.AUTO, - type=DownloadMode, + choices=list(client.DownloadMode), + default=client.DownloadMode.AUTO, + type=client.DownloadMode, help=( "Select different download modes: " - f"'{DownloadMode.DIRECT}' to directly download from checkpoint storage; " - f"'{DownloadMode.MASTER}' to download via the master; " - f"'{DownloadMode.AUTO}' to first attempt a direct download and fall " - f"back to '{DownloadMode.MASTER}'." + f"'{client.DownloadMode.DIRECT}' to directly download from checkpoint " + f" storage; '{client.DownloadMode.MASTER}' to download via the master; " + f"'{client.DownloadMode.AUTO}' to first attempt a direct download and fall " + f"back to '{client.DownloadMode.MASTER}'." ), ), ], diff --git a/harness/determined/cli/cli.py b/harness/determined/cli/cli.py index d1d8a587f12..7dad57be246 100644 --- a/harness/determined/cli/cli.py +++ b/harness/determined/cli/cli.py @@ -19,8 +19,8 @@ from OpenSSL import SSL, crypto from termcolor import colored -import determined -import determined.cli +import determined as det +from determined import cli from determined.cli import render from determined.cli.agent import args_description as agent_args_description from determined.cli.checkpoint import args_description as checkpoint_args_description @@ -49,7 +49,7 @@ from determined.cli.version import check_version from determined.cli.workspace import args_description as workspace_args_description from determined.common import api, yaml -from determined.common.api import authentication, bindings, certs +from determined.common.api import bindings, certs from determined.common.check import check_not_none from determined.common.declarative_argparse import ( Arg, @@ -69,15 +69,15 @@ from .errors import CliError, FeatureFlagDisabled -@authentication.required def preview_search(args: Namespace) -> None: + sess = cli.setup_session(args) experiment_config = safe_load_yaml_with_exceptions(args.config_file) args.config_file.close() if "searcher" not in experiment_config: print("Experiment configuration must have 'searcher' section") sys.exit(1) - r = api.post(args.master, "searcher/preview", json=experiment_config) + r = sess.post("searcher/preview", json=experiment_config) j = r.json() def to_full_name(kind: str) -> str: @@ -141,7 +141,7 @@ def render_sequence(sequence: List[str]) -> str: "--version", action="version", help="print CLI version and exit", - version="%(prog)s {}".format(determined.__version__), + version="%(prog)s {}".format(det.__version__), ), Cmd( "preview-search", @@ -233,10 +233,10 @@ def main( return # Configure the CLI's Cert singleton. - certs.cli_cert = certs.default_load(parsed_args.master) + cli.cert = certs.default_load(parsed_args.master) try: - check_version(parsed_args) + check_version(cli.unauth_session(parsed_args), parsed_args) except requests.exceptions.SSLError: # An SSLError usually means that we queried a master over HTTPS and got an untrusted # cert, so allow the user to store and trust the current cert. (It could also mean @@ -282,10 +282,10 @@ def main( joined_certs = "".join(cert_pem_data) certs.CertStore(certs.default_store()).set_cert(parsed_args.master, joined_certs) # Reconfigure the CLI's Cert singleton, but preserve the certificate name. - old_cert_name = certs.cli_cert.name - certs.cli_cert = certs.Cert(cert_pem=joined_certs, name=old_cert_name) + old_cert_name = cli.cert.name + cli.cert = certs.Cert(cert_pem=joined_certs, name=old_cert_name) - check_version(parsed_args) + check_version(cli.unauth_session(parsed_args), parsed_args) parsed_args.func(parsed_args) except KeyboardInterrupt as e: diff --git a/harness/determined/cli/command.py b/harness/determined/cli/command.py index ebb79ead0a4..86da4413572 100644 --- a/harness/determined/cli/command.py +++ b/harness/determined/cli/command.py @@ -7,16 +7,15 @@ from determined import cli from determined.cli import ntsc, render, task, workspace from determined.common import api -from determined.common.api import authentication from determined.common.declarative_argparse import Arg, ArgsDescription, Cmd, Group -@authentication.required def run_command(args: Namespace) -> None: + sess = cli.setup_session(args) config = ntsc.parse_config(args.config_file, args.entrypoint, args.config, args.volume) workspace_id = workspace.get_workspace_id_from_args(args) resp = ntsc.launch_command( - args.master, + sess, "api/v1/commands", config, args.template, @@ -32,7 +31,7 @@ def run_command(args: Namespace) -> None: render.report_job_launched("command", resp["id"]) try: - logs = api.task_logs(cli.setup_session(args), resp["id"], follow=True) + logs = api.task_logs(sess, resp["id"], follow=True) api.pprint_logs(logs) finally: print( diff --git a/harness/determined/cli/dev.py b/harness/determined/cli/dev.py index 2ec0f71020e..16dfd8ae1e7 100644 --- a/harness/determined/cli/dev.py +++ b/harness/determined/cli/dev.py @@ -14,24 +14,21 @@ from termcolor import colored -import determined.cli.render from determined import cli -from determined.cli import errors -from determined.common.api import authentication, bindings +from determined.cli import errors, render +from determined.common.api import bindings from determined.common.api import errors as api_errors from determined.common.api import request from determined.common.declarative_argparse import Arg, Cmd -@authentication.required -def token(_: Namespace) -> None: - token = authentication.must_cli_auth().get_session_token() - print(token) +def token(args: Namespace) -> None: + sess = cli.setup_session(args) + print(sess.token) -@authentication.required def curl(args: Namespace) -> None: - assert authentication.cli_auth is not None + sess = cli.setup_session(args) if shutil.which("curl") is None: print(colored("curl is not installed on this machine", "red")) sys.exit(1) @@ -47,7 +44,7 @@ def curl(args: Namespace) -> None: "curl", request.make_url_new(args.master, args.path), "-H", - f"Authorization: Bearer {authentication.cli_auth.get_session_token()}", + f"Authorization: Bearer {sess.token}", "-s", ] if args.curl_args: @@ -66,7 +63,7 @@ def curl(args: Namespace) -> None: output = subprocess.run(cmd, stdout=subprocess.PIPE) try: out = output.stdout.decode("utf8") - determined.cli.render.print_json(out) + render.print_json(out) except UnicodeDecodeError: print( "Failed to decode response as utf8. Redirect output to capture it.", @@ -292,7 +289,6 @@ def auto_complete_binding(available_calls: List[str], fn_name: str) -> str: return fn_name -@authentication.required def call_bindings(args: Namespace) -> None: """ support calling some bindings with primitive arguments via the cli diff --git a/harness/determined/cli/experiment.py b/harness/determined/cli/experiment.py index cbe87a734dd..8c2b61076e1 100644 --- a/harness/determined/cli/experiment.py +++ b/harness/determined/cli/experiment.py @@ -20,7 +20,7 @@ from determined.cli import checkpoint, render from determined.cli.ntsc import CONFIG_DESC, parse_config_overrides from determined.common import api, context, set_logger, util -from determined.common.api import authentication, bindings, logs +from determined.common.api import bindings, logs from determined.common.declarative_argparse import Arg, Cmd, Group from determined.experimental import client @@ -35,19 +35,16 @@ ZERO_OR_ONE = "?" -@authentication.required def activate(args: Namespace) -> None: bindings.post_ActivateExperiment(cli.setup_session(args), id=args.experiment_id) print(f"Activated experiment {args.experiment_id}") -@authentication.required def archive(args: Namespace) -> None: bindings.post_ArchiveExperiment(cli.setup_session(args), id=args.experiment_id) print(f"Archived experiment {args.experiment_id}") -@authentication.required def cancel(args: Namespace) -> None: bindings.post_CancelExperiment(cli.setup_session(args), id=args.experiment_id) print(f"Canceled experiment {args.experiment_id}") @@ -175,8 +172,8 @@ def print_progress(active_stage: int, ended: bool) -> None: time.sleep(0.2) -@authentication.required def submit_experiment(args: Namespace) -> None: + sess = cli.setup_session(args) config_text = args.config_file.read() args.config_file.close() experiment_config = _parse_config_text_or_exit(config_text, args.config_file.name, args.config) @@ -190,8 +187,6 @@ def submit_experiment(args: Namespace) -> None: assert yaml_dump is not None config_text = yaml_dump - sess = cli.setup_session(args) - req = bindings.v1CreateExperimentRequest( activate=not args.paused, config=config_text, @@ -235,8 +230,8 @@ def submit_experiment(args: Namespace) -> None: _follow_experiment_logs(sess, resp.experiment.id) -@authentication.required def continue_experiment(args: Namespace) -> None: + sess = cli.setup_session(args) if args.config_file: config_text = args.config_file.read() args.config_file.close() @@ -248,7 +243,6 @@ def continue_experiment(args: Namespace) -> None: config_text = util.yaml_safe_dump(experiment_config) - sess = cli.setup_session(args) req = bindings.v1ContinueExperimentRequest( id=args.experiment_id, overrideConfig=config_text, @@ -296,7 +290,6 @@ def create(args: Namespace) -> None: submit_experiment(args) -@authentication.required def delete_experiment(args: Namespace) -> None: if args.yes or render.yes_or_no( "Deleting an experiment will result in the unrecoverable \n" @@ -311,12 +304,11 @@ def delete_experiment(args: Namespace) -> None: print("Aborting experiment deletion.") -@authentication.required def describe(args: Namespace) -> None: - session = cli.setup_session(args) + sess = cli.setup_session(args) responses: List[bindings.v1GetExperimentResponse] = [] for experiment_id in args.experiment_ids.split(","): - r = bindings.get_GetExperiment(session, experimentId=experiment_id) + r = bindings.get_GetExperiment(sess, experimentId=experiment_id) responses.append(r) if args.json: @@ -364,7 +356,7 @@ def describe(args: Namespace) -> None: def get_all_trials(exp_id: int) -> List[bindings.trialv1Trial]: def get_with_offset(offset: int) -> bindings.v1GetExperimentTrialsResponse: return bindings.get_GetExperimentTrials( - session, + sess, offset=offset, experimentId=exp_id, ) @@ -399,7 +391,7 @@ def get_with_offset(offset: int) -> bindings.v1GetExperimentTrialsResponse: def get_all_workloads(trial_id: int) -> List[bindings.v1WorkloadContainer]: def get_with_offset(offset: int) -> bindings.v1GetTrialWorkloadsResponse: return bindings.get_GetTrialWorkloads( - session, + sess, offset=offset, trialId=trial_id, limit=500, @@ -549,7 +541,6 @@ def get_with_offset(offset: int) -> bindings.v1GetTrialWorkloadsResponse: render.tabulate_or_csv(headers, values, args.csv, outfile) -@authentication.required def experiment_logs(args: Namespace) -> None: sess = cli.setup_session(args) trials = bindings.get_GetExperimentTrials(sess, experimentId=args.experiment_id).trials @@ -558,7 +549,7 @@ def experiment_logs(args: Namespace) -> None: first_trial_id = sorted(t_id.id for t_id in trials)[0] try: logs = api.trial_logs( - cli.setup_session(args), + sess, first_trial_id, head=args.head, tail=args.tail, @@ -587,7 +578,6 @@ def experiment_logs(args: Namespace) -> None: ) -@authentication.required def config(args: Namespace) -> None: result = bindings.get_GetExperiment( cli.setup_session(args), experimentId=args.experiment_id @@ -595,7 +585,6 @@ def config(args: Namespace) -> None: util.yaml_safe_dump(result, stream=sys.stdout, default_flow_style=False) -@authentication.required def download_model_def(args: Namespace) -> None: resp = bindings.get_GetModelDef(cli.setup_session(args), experimentId=args.experiment_id) dst = f"experiment_{args.experiment_id}_model_def.tgz" @@ -603,7 +592,6 @@ def download_model_def(args: Namespace) -> None: f.write(base64.b64decode(resp.b64Tgz)) -@authentication.required def download(args: Namespace) -> None: sess = cli.setup_session(args) exp = client.Experiment(args.experiment_id, sess) @@ -632,13 +620,11 @@ def download(args: Namespace) -> None: print() -@authentication.required def kill_experiment(args: Namespace) -> None: bindings.post_KillExperiment(cli.setup_session(args), id=args.experiment_id) print(f"Killed experiment {args.experiment_id}") -@authentication.required def wait(args: Namespace) -> None: sess = cli.setup_session(args) exp = client.Experiment(args.experiment_id, sess) @@ -647,17 +633,16 @@ def wait(args: Namespace) -> None: sys.exit(1) -@authentication.required def list_experiments(args: Namespace) -> None: - session = cli.setup_session(args) + sess = cli.setup_session(args) def get_with_offset(offset: int) -> bindings.v1GetExperimentsResponse: return bindings.get_GetExperiments( - session, + sess, offset=offset, archived=None if args.all else False, limit=args.limit, - users=None if args.all else [authentication.must_cli_auth().get_session_user()], + users=None if args.all else [sess.username], ) resps = api.read_paginated(get_with_offset, offset=args.offset, pages=args.pages) @@ -739,13 +724,12 @@ def scalar_validation_metrics_names( return set() -@authentication.required def list_trials(args: Namespace) -> None: - session = cli.setup_session(args) + sess = cli.setup_session(args) def get_with_offset(offset: int) -> bindings.v1GetExperimentTrialsResponse: return bindings.get_GetExperimentTrials( - session, + sess, offset=offset, experimentId=args.experiment_id, limit=args.limit, @@ -770,98 +754,89 @@ def get_with_offset(offset: int) -> bindings.v1GetExperimentTrialsResponse: render.tabulate_or_csv(headers, values, args.csv) -@authentication.required def pause(args: Namespace) -> None: bindings.post_PauseExperiment(cli.setup_session(args), id=args.experiment_id) print(f"Paused experiment {args.experiment_id}") -@authentication.required def set_description(args: Namespace) -> None: - session = cli.setup_session(args) - exp = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment + sess = cli.setup_session(args) + exp = bindings.get_GetExperiment(sess, experimentId=args.experiment_id).experiment exp_patch = bindings.v1PatchExperiment.from_json(exp.to_json()) exp_patch.description = args.description - bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id) + bindings.patch_PatchExperiment(sess, body=exp_patch, experiment_id=args.experiment_id) print(f"Set description of experiment {args.experiment_id} to '{args.description}'") -@authentication.required def set_name(args: Namespace) -> None: - session = cli.setup_session(args) - exp = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment + sess = cli.setup_session(args) + exp = bindings.get_GetExperiment(sess, experimentId=args.experiment_id).experiment exp_patch = bindings.v1PatchExperiment.from_json(exp.to_json()) exp_patch.name = args.name - bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id) + bindings.patch_PatchExperiment(sess, body=exp_patch, experiment_id=args.experiment_id) print(f"Set name of experiment {args.experiment_id} to '{args.name}'") -@authentication.required def add_label(args: Namespace) -> None: - session = cli.setup_session(args) - exp = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment + sess = cli.setup_session(args) + exp = bindings.get_GetExperiment(sess, experimentId=args.experiment_id).experiment exp_patch = bindings.v1PatchExperiment.from_json(exp.to_json()) if exp_patch.labels is None: exp_patch.labels = [] if args.label not in exp_patch.labels: exp_patch.labels = list(exp_patch.labels) + [args.label] - bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id) + bindings.patch_PatchExperiment(sess, body=exp_patch, experiment_id=args.experiment_id) print(f"Added label '{args.label}' to experiment {args.experiment_id}") -@authentication.required def remove_label(args: Namespace) -> None: - session = cli.setup_session(args) - exp = bindings.get_GetExperiment(session, experimentId=args.experiment_id).experiment + sess = cli.setup_session(args) + exp = bindings.get_GetExperiment(sess, experimentId=args.experiment_id).experiment exp_patch = bindings.v1PatchExperiment.from_json(exp.to_json()) if (exp_patch.labels) and (args.label in exp_patch.labels): exp_patch.labels = [label for label in exp_patch.labels if label != args.label] - bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id) + bindings.patch_PatchExperiment(sess, body=exp_patch, experiment_id=args.experiment_id) print(f"Removed label '{args.label}' from experiment {args.experiment_id}") -@authentication.required def set_max_slots(args: Namespace) -> None: - session = cli.setup_session(args) + sess = cli.setup_session(args) exp_patch = bindings.v1PatchExperiment( id=args.experiment_id, resources=bindings.PatchExperimentPatchResources(maxSlots=args.max_slots), ) - bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id) + bindings.patch_PatchExperiment(sess, body=exp_patch, experiment_id=args.experiment_id) print(f"Set `max_slots` of experiment {args.experiment_id} to {args.max_slots}") -@authentication.required def set_weight(args: Namespace) -> None: - session = cli.setup_session(args) + sess = cli.setup_session(args) exp_patch = bindings.v1PatchExperiment( id=args.experiment_id, resources=bindings.PatchExperimentPatchResources(weight=args.weight) ) - bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id) + bindings.patch_PatchExperiment(sess, body=exp_patch, experiment_id=args.experiment_id) print(f"Set `weight` of experiment {args.experiment_id} to {args.weight}") -@authentication.required def set_priority(args: Namespace) -> None: - session = cli.setup_session(args) + sess = cli.setup_session(args) exp_patch = bindings.v1PatchExperiment( id=args.experiment_id, resources=bindings.PatchExperimentPatchResources(priority=args.priority), ) - bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id) + bindings.patch_PatchExperiment(sess, body=exp_patch, experiment_id=args.experiment_id) print(f"Set `priority` of experiment {args.experiment_id} to {args.priority}") -@authentication.required def set_gc_policy(args: Namespace) -> None: + sess = cli.setup_session(args) + policy = { + "save_experiment_best": args.save_experiment_best, + "save_trial_best": args.save_trial_best, + "save_trial_latest": args.save_trial_latest, + } if not args.yes: - policy = { - "save_experiment_best": args.save_experiment_best, - "save_trial_best": args.save_trial_best, - "save_trial_latest": args.save_trial_latest, - } - - r = api.get(args.master, f"experiments/{args.experiment_id}/preview_gc", params=policy) + r = sess.get(f"experiments/{args.experiment_id}/preview_gc", params=policy) response = r.json() checkpoints = response["checkpoints"] metric_name = response["metric_name"] @@ -906,7 +881,6 @@ def set_gc_policy(args: Namespace) -> None: "in the unrecoverable deletion of checkpoints. Do you wish to " "proceed?" ): - session = cli.setup_session(args) exp_patch = bindings.v1PatchExperiment( id=args.experiment_id, checkpointStorage=bindings.PatchExperimentPatchCheckpointStorage( @@ -915,19 +889,17 @@ def set_gc_policy(args: Namespace) -> None: saveTrialLatest=args.save_trial_latest, ), ) - bindings.patch_PatchExperiment(session, body=exp_patch, experiment_id=args.experiment_id) + bindings.patch_PatchExperiment(sess, body=exp_patch, experiment_id=args.experiment_id) print(f"Set GC policy of experiment {args.experiment_id} to\n{pformat(policy)}") else: print("Aborting operations.") -@authentication.required def unarchive(args: Namespace) -> None: bindings.post_UnarchiveExperiment(cli.setup_session(args), id=args.experiment_id) print(f"Unarchived experiment {args.experiment_id}") -@authentication.required def move_experiment(args: Namespace) -> None: sess = cli.setup_session(args) (w, p) = project_by_name(sess, args.workspace_name, args.project_name) @@ -939,9 +911,10 @@ def move_experiment(args: Namespace) -> None: print(f'Moved experiment {args.experiment_id} to project "{args.project_name}"') -@cli.login_sdk_client def delete_tensorboard_files(args: Namespace) -> None: - exp = client.get_experiment(args.experiment_id) + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + exp = d.get_experiment(args.experiment_id) exp.delete_tensorboard_files() diff --git a/harness/determined/cli/job.py b/harness/determined/cli/job.py index 9fa90e3c5e5..e027226f2db 100644 --- a/harness/determined/cli/job.py +++ b/harness/determined/cli/job.py @@ -5,7 +5,7 @@ from determined import cli from determined.cli import render from determined.common import api, util -from determined.common.api import authentication, bindings +from determined.common.api import bindings from determined.common.declarative_argparse import Arg, Cmd, Group from determined.common.util import parse_protobuf_timestamp @@ -21,17 +21,16 @@ def parse_jobv2_resp( return jobs -@authentication.required def ls(args: Namespace) -> None: - session = cli.setup_session(args) - pools = bindings.get_GetResourcePools(cli.setup_session(args)) + sess = cli.setup_session(args) + pools = bindings.get_GetResourcePools(sess) is_priority = check_is_priority(pools, args.resource_pool) order_by = bindings.v1OrderBy.ASC if not args.reverse else bindings.v1OrderBy.DESC def get_with_offset(offset: int) -> bindings.v1GetJobsV2Response: return bindings.get_GetJobsV2( - session, + sess, resourcePool=args.resource_pool, offset=offset, limit=args.limit, @@ -88,8 +87,8 @@ def computed_job_name(job: bindings.v1Job) -> str: render.tabulate_or_csv(headers, values, as_csv=args.csv) -@authentication.required def update(args: Namespace) -> None: + sess = cli.setup_session(args) update = bindings.v1QueueControl( jobId=args.job_id, priority=args.priority, @@ -98,22 +97,19 @@ def update(args: Namespace) -> None: behindOf=args.behind_of, aheadOf=args.ahead_of, ) - bindings.post_UpdateJobQueue( - cli.setup_session(args), body=bindings.v1UpdateJobQueueRequest(updates=[update]) - ) + bindings.post_UpdateJobQueue(sess, body=bindings.v1UpdateJobQueueRequest(updates=[update])) -@authentication.required def process_updates(args: Namespace) -> None: - session = cli.setup_session(args) + sess = cli.setup_session(args) for arg in args.operation: inputs = validate_operation_args(arg) - _single_update(session=session, **inputs) + _single_update(sess, **inputs) def _single_update( - job_id: str, session: api.Session, + job_id: str, priority: str = "", weight: str = "", resource_pool: str = "", diff --git a/harness/determined/cli/master.py b/harness/determined/cli/master.py index 8e249b33847..f77abebe71b 100644 --- a/harness/determined/cli/master.py +++ b/harness/determined/cli/master.py @@ -5,21 +5,21 @@ from determined import cli from determined.cli import render from determined.common import util -from determined.common.api import authentication, bindings +from determined.common.api import bindings from determined.common.declarative_argparse import Arg, Cmd, Group -@authentication.required def show_config(args: Namespace) -> None: - resp = bindings.get_GetMasterConfig(cli.setup_session(args)).config + sess = cli.setup_session(args) + resp = bindings.get_GetMasterConfig(sess).config if args.json: render.print_json(resp) else: print(util.yaml_safe_dump(resp, default_flow_style=False)) -@authentication.required def set_master_config(args: Namespace) -> None: + sess = cli.setup_session(args) log_config = bindings.v1LogConfig() field_masks = [] if "log_color" in args: @@ -39,7 +39,7 @@ def set_master_config(args: Namespace) -> None: req = bindings.v1PatchMasterConfigRequest( config=master_config, fieldMask=bindings.protobufFieldMask(paths=field_masks) ) - bindings.patch_PatchMasterConfig(cli.setup_session(args), body=req) + bindings.patch_PatchMasterConfig(sess, body=req) cli.warn( "This will only make ephermeral changes to the master config, " + "that will be lost if the user restarts the cluster." @@ -48,7 +48,8 @@ def set_master_config(args: Namespace) -> None: def get_master(args: Namespace) -> None: - resp = bindings.get_GetMaster(cli.setup_session(args)) + sess = cli.setup_session(args) + resp = bindings.get_GetMaster(sess) if args.json: render.print_json(resp.to_json()) else: @@ -61,12 +62,12 @@ def format_log_entry(log: bindings.v1LogEntry) -> str: return f"{log.timestamp} [{log_level}]: {log.message}" -@authentication.required def logs(args: Namespace) -> None: + sess = cli.setup_session(args) offset: Optional[int] = None if args.tail: offset = -args.tail - responses = bindings.get_MasterLogs(cli.setup_session(args), follow=args.follow, offset=offset) + responses = bindings.get_MasterLogs(sess, follow=args.follow, offset=offset) for response in responses: print(format_log_entry(response.logEntry)) diff --git a/harness/determined/cli/model.py b/harness/determined/cli/model.py index 9129061cc12..a370a759823 100644 --- a/harness/determined/cli/model.py +++ b/harness/determined/cli/model.py @@ -6,12 +6,11 @@ from determined import cli from determined.cli import render from determined.common import api -from determined.common.api import authentication from determined.common.declarative_argparse import Arg, Cmd -from determined.experimental import Determined, Model, ModelSortBy, ModelVersion, OrderBy +from determined.experimental import client -def render_model(model: Model) -> None: +def render_model(model: client.Model) -> None: table = [ ["ID", model.model_id], ["Name", model.name], @@ -27,7 +26,7 @@ def render_model(model: Model) -> None: render.tabulate_or_csv(headers, [values], False) -def _render_model_versions(model_versions: List[ModelVersion]) -> None: +def _render_model_versions(model_versions: List[client.ModelVersion]) -> None: headers = [ "Version #", "Trial ID", @@ -60,12 +59,14 @@ def _render_model_versions(model_versions: List[ModelVersion]) -> None: def list_models(args: Namespace) -> None: + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) workspace_names = None if args.workspace_names is not None: workspace_names = args.workspace_names.split(",") - models = Determined(args.master, args.user).list_models( - sort_by=ModelSortBy[args.sort_by.upper()], - order_by=OrderBy[args.order_by.upper()], + models = d.list_models( + sort_by=client.ModelSortBy[args.sort_by.upper()], + order_by=client.OrderBy[args.order_by.upper()], workspace_names=workspace_names, ) if args.json: @@ -88,15 +89,16 @@ def list_models(args: Namespace) -> None: render.tabulate_or_csv(headers, values, False) -def model_by_name(args: Namespace) -> Model: - return Determined(args.master, args.user).get_model(identifier=args.name) +def model_by_name(sess: api.Session, name: str) -> client.Model: + d = client.Determined._from_session(sess) + return d.get_model(identifier=name) -@authentication.required def list_versions(args: Namespace) -> None: - model = model_by_name(args) + sess = cli.setup_session(args) + model = model_by_name(sess, args.name) if args.json: - r = api.get(args.master, "api/v1/models/{}/versions".format(model.model_id)) + r = sess.get(f"api/v1/models/{model.model_id}/versions") data = r.json() determined.cli.render.print_json(data) @@ -107,10 +109,9 @@ def list_versions(args: Namespace) -> None: def create(args: Namespace) -> None: - model = Determined(args.master, args.user).create_model( - args.name, args.description, workspace_name=args.workspace_name - ) - + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + model = d.create_model(args.name, args.description, workspace_name=args.workspace_name) if args.json: determined.cli.render.print_json(render.model_to_json(model)) else: @@ -118,12 +119,14 @@ def create(args: Namespace) -> None: def move(args: Namespace) -> None: - model = model_by_name(args) + sess = cli.setup_session(args) + model = model_by_name(sess, args.name) model.move_to_workspace(args.workspace_name) def describe(args: Namespace) -> None: - model = model_by_name(args) + sess = cli.setup_session(args) + model = model_by_name(sess, args.name) model_version = model.get_version(args.version) if args.json: @@ -135,13 +138,12 @@ def describe(args: Namespace) -> None: _render_model_versions([model_version]) -@authentication.required def register_version(args: Namespace) -> None: - model = model_by_name(args) + sess = cli.setup_session(args) + model = model_by_name(sess, args.name) if args.json: - resp = api.post( - args.master, - "/api/v1/models/{}/versions".format(model.model_id), + resp = sess.post( + f"/api/v1/models/{model.model_id}/versions", json={"checkpointUuid": args.uuid}, ) diff --git a/harness/determined/cli/notebook.py b/harness/determined/cli/notebook.py index 1b41fd8ef70..06002b7d6df 100644 --- a/harness/determined/cli/notebook.py +++ b/harness/determined/cli/notebook.py @@ -8,13 +8,13 @@ from determined import cli from determined.cli import ntsc, render, task from determined.common import api, context -from determined.common.api import authentication, bindings, request +from determined.common.api import bindings, request from determined.common.check import check_none from determined.common.declarative_argparse import Arg, ArgsDescription, Cmd, Group -@authentication.required def start_notebook(args: Namespace) -> None: + sess = cli.setup_session(args) config = ntsc.parse_config(args.config_file, None, args.config, args.volume) files = context.read_v1_context(args.context, args.include) @@ -28,7 +28,7 @@ def start_notebook(args: Namespace) -> None: templateName=args.template, workspaceId=workspace_id, ) - resp = bindings.post_LaunchNotebook(cli.setup_session(args), body=body) + resp = bindings.post_LaunchNotebook(sess, body=body) if args.preview: print(render.format_object_as_yaml(resp.config)) @@ -48,7 +48,7 @@ def start_notebook(args: Namespace) -> None: bindings.v1LaunchWarning.CURRENT_SLOTS_EXCEEDED in resp.warnings ) - cli.wait_ntsc_ready(cli.setup_session(args), api.NTSC_Kind.notebook, nb.id) + cli.wait_ntsc_ready(sess, api.NTSC_Kind.notebook, nb.id) assert nb.serviceAddress is not None, "missing tensorboard serviceAddress" nb_path = request.make_interactive_task_url( @@ -65,11 +65,10 @@ def start_notebook(args: Namespace) -> None: print(colored("Jupyter Notebook is running at: {}".format(url), "green")) -@authentication.required def open_notebook(args: Namespace) -> None: - notebook_id = cast(str, ntsc.expand_uuid_prefixes(args)) - sess = cli.setup_session(args) + notebook_id = cast(str, ntsc.expand_uuid_prefixes(sess, args)) + task = bindings.get_GetTask(sess, taskId=notebook_id).task check_none(task.endTime, "Notebook has ended") diff --git a/harness/determined/cli/ntsc.py b/harness/determined/cli/ntsc.py index 9fe3e25bd73..d235c66cdd3 100644 --- a/harness/determined/cli/ntsc.py +++ b/harness/determined/cli/ntsc.py @@ -16,7 +16,6 @@ from determined import cli from determined.cli import render from determined.common import api, context, declarative_argparse, util -from determined.common.api import authentication from determined.util import merge_dicts CONFIG_DESC = """ @@ -152,7 +151,7 @@ def expand_uuid_prefixes( - args: Namespace, prefixes: Optional[Union[str, List[str]]] = None + sess: api.Session, args: Namespace, prefixes: Optional[Union[str, List[str]]] = None ) -> Union[str, List[str]]: if prefixes is None: prefixes = RemoteTaskGetIDsFunc[args._command](args) # type: ignore @@ -170,7 +169,7 @@ def expand_uuid_prefixes( ) api_path = RemoteTaskNewAPIs[args._command] api_full_path = "api/v1/{}".format(api_path) - res = api.get(args.master, api_full_path).json()[api_path] + res = sess.get(api_full_path).json()[api_path] all_ids: List[str] = [x["id"] for x in res] def expand(prefix: str) -> str: @@ -193,8 +192,8 @@ def expand(prefix: str) -> str: return prefixes -@authentication.required def list_tasks(args: Namespace) -> None: + sess = cli.setup_session(args) api_path = RemoteTaskNewAPIs[args._command] api_full_path = "api/v1/{}".format(api_path) table_header = RemoteTaskListTableHeaders[args._command] @@ -202,14 +201,14 @@ def list_tasks(args: Namespace) -> None: params: Dict[str, Any] = {} if "workspace_name" in args and args.workspace_name is not None: - workspace = api.workspace_by_name(cli.setup_session(args), args.workspace_name) + workspace = api.workspace_by_name(sess, args.workspace_name) params["workspaceId"] = workspace.id if not args.all: - params["users"] = [authentication.must_cli_auth().get_session_user()] + params["users"] = [sess.username] - res = api.get(args.master, api_full_path, params=params).json()[api_path] + res = sess.get(api_full_path, params=params).json()[api_path] if args.quiet: for command in res: @@ -217,7 +216,7 @@ def list_tasks(args: Namespace) -> None: return # swap workspace_id for workspace name. - w_names = cli.workspace.get_workspace_names(cli.setup_session(args)) + w_names = cli.workspace.get_workspace_names(sess) for item in res: if item["state"].startswith("STATE_"): @@ -239,14 +238,14 @@ def list_tasks(args: Namespace) -> None: render.tabulate_or_csv(table_header, values, getattr(args, "csv", False)) -@authentication.required def kill(args: Namespace) -> None: - task_ids = expand_uuid_prefixes(args) + sess = cli.setup_session(args) + task_ids = expand_uuid_prefixes(sess, args) name = RemoteTaskName[args._command] for i, task_id in enumerate(task_ids): try: - _kill(args.master, args._command, task_id) + _kill(sess, args._command, task_id) print(colored("Killed {} {}".format(name, task_id), "green")) except api.errors.APIException as e: if not args.force: @@ -256,31 +255,27 @@ def kill(args: Namespace) -> None: print(colored("Skipping: {} ({})".format(e, type(e).__name__), "red")) -def _kill(master_url: str, taskType: str, taskID: str) -> None: - api_full_path = "api/v1/{}/{}/kill".format(RemoteTaskNewAPIs[taskType], taskID) - api.post(master_url, api_full_path) +def _kill(sess: api.Session, task_type: str, task_id: str) -> None: + sess.post(f"api/v1/{RemoteTaskNewAPIs[task_type]}/{task_id}/kill") -@authentication.required def set_priority(args: Namespace) -> None: - task_id = expand_uuid_prefixes(args) + sess = cli.setup_session(args) + task_id = expand_uuid_prefixes(sess, args) name = RemoteTaskName[args._command] try: - api_full_path = "api/v1/{}/{}/set_priority".format( - RemoteTaskNewAPIs[args._command], task_id - ) - api.post(args.master, api_full_path, {"priority": args.priority}) - print(colored("Set priority of {} {} to {}".format(name, task_id, args.priority), "green")) + api_full_path = f"api/v1/{RemoteTaskNewAPIs[args._command]}/{task_id}/set_priority" + sess.post(api_full_path, json={"priority": args.priority}) + print(colored(f"Set priority of {name} {task_id} to {args.priority}", "green")) except api.errors.APIException as e: - print(colored("Skipping: {} ({})".format(e, type(e).__name__), "red")) + print(colored(f"Skipping: {e} ({type(e).__name__})", "red")) -@authentication.required def config(args: Namespace) -> None: - task_id = expand_uuid_prefixes(args) - api_full_path = "api/v1/{}/{}".format(RemoteTaskNewAPIs[args._command], task_id) - res_json = api.get(args.master, api_full_path).json() + sess = cli.setup_session(args) + task_id = expand_uuid_prefixes(sess, args) + res_json = sess.get(f"api/v1/{RemoteTaskNewAPIs[args._command]}/{task_id}").json() print(render.format_object_as_yaml(res_json["config"])) @@ -403,7 +398,7 @@ def parse_config( def launch_command( - master: str, + sess: api.Session, endpoint: str, config: Dict[str, Any], template: str, @@ -439,8 +434,4 @@ def launch_command( if workspace_id is not None: body["workspaceId"] = workspace_id - return api.post( - master, - endpoint, - body, - ).json() + return sess.post(endpoint, json=body).json() diff --git a/harness/determined/cli/oauth.py b/harness/determined/cli/oauth.py index 98b5ae8ece3..6d63c9c0486 100644 --- a/harness/determined/cli/oauth.py +++ b/harness/determined/cli/oauth.py @@ -1,14 +1,16 @@ -from argparse import Namespace +import argparse from typing import Any, List -from determined.cli import login_sdk_client, render +from determined import cli +from determined.cli import render from determined.common.declarative_argparse import Arg, Cmd from determined.experimental import client -@login_sdk_client -def list_clients(parsed_args: Namespace) -> None: - oauth_clients = client.list_oauth_clients() +def list_clients(args: argparse.Namespace) -> None: + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + oauth_clients = d.list_oauth_clients() headers = ["Name", "Client ID", "Domain"] keys = ["name", "id", "domain"] oauth_clients_dict = [ @@ -20,17 +22,19 @@ def list_clients(parsed_args: Namespace) -> None: ) -@login_sdk_client -def add_client(parsed_args: Namespace) -> None: - oauth_client = client.add_oauth_client(domain=parsed_args.domain, name=parsed_args.name) +def add_client(args: argparse.Namespace) -> None: + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + oauth_client = d.add_oauth_client(domain=args.domain, name=args.name) - print("Client ID: {}".format(oauth_client.id)) - print("Client secret: {}".format(oauth_client.secret)) + print(f"Client ID: {oauth_client.id}") + print(f"Client secret: {oauth_client.secret}") -@login_sdk_client -def remove_client(parsed_args: Namespace) -> None: - client.remove_oauth_client(client_id=parsed_args.client_id) +def remove_client(args: argparse.Namespace) -> None: + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + d.remove_oauth_client(client_id=args.client_id) # fmt: off diff --git a/harness/determined/cli/project.py b/harness/determined/cli/project.py index 795ec268a6a..1be71a78693 100644 --- a/harness/determined/cli/project.py +++ b/harness/determined/cli/project.py @@ -6,7 +6,7 @@ from determined import cli from determined.cli import render from determined.common import api -from determined.common.api import authentication, bindings, errors +from determined.common.api import bindings, errors from determined.common.declarative_argparse import Arg, Cmd from .workspace import list_workspace_projects, pagination_args @@ -68,7 +68,6 @@ def project_by_name( return (w, p[0]) -@authentication.required def list_project_experiments(args: Namespace) -> None: sess = cli.setup_session(args) (w, p) = project_by_name(sess, args.workspace_name, args.project_name) @@ -78,7 +77,7 @@ def list_project_experiments(args: Namespace) -> None: "sortBy": bindings.v1GetExperimentsRequestSortBy[args.sort_by.upper()], } if not args.all: - kwargs["users"] = [authentication.must_cli_auth().get_session_user()] + kwargs["users"] = [sess.username] kwargs["archived"] = "false" all_experiments: List[bindings.v1Experiment] = [] @@ -99,7 +98,6 @@ def list_project_experiments(args: Namespace) -> None: render_experiments(args, all_experiments) -@authentication.required def create_project(args: Namespace) -> None: sess = cli.setup_session(args) w = api.workspace_by_name(sess, args.workspace_name) @@ -113,7 +111,6 @@ def create_project(args: Namespace) -> None: render_project(p) -@authentication.required def describe_project(args: Namespace) -> None: sess = cli.setup_session(args) (w, p) = project_by_name(sess, args.workspace_name, args.project_name) @@ -127,7 +124,6 @@ def describe_project(args: Namespace) -> None: list_project_experiments(args) -@authentication.required def delete_project(args: Namespace) -> None: sess = cli.setup_session(args) (w, p) = project_by_name(sess, args.workspace_name, args.project_name) @@ -157,7 +153,6 @@ def delete_project(args: Namespace) -> None: print("Aborting project deletion.") -@authentication.required def edit_project(args: Namespace) -> None: sess = cli.setup_session(args) (w, p) = project_by_name(sess, args.workspace_name, args.project_name) @@ -170,7 +165,6 @@ def edit_project(args: Namespace) -> None: render_project(new_p) -@authentication.required def archive_project(args: Namespace) -> None: sess = cli.setup_session(args) (w, p) = project_by_name(sess, args.workspace_name, args.project_name) @@ -178,7 +172,6 @@ def archive_project(args: Namespace) -> None: print(f"Successfully archived project {args.project_name}.") -@authentication.required def unarchive_project(args: Namespace) -> None: sess = cli.setup_session(args) (w, p) = project_by_name(sess, args.workspace_name, args.project_name) diff --git a/harness/determined/cli/proxy.py b/harness/determined/cli/proxy.py index 4540526aad1..400d6648a5f 100644 --- a/harness/determined/cli/proxy.py +++ b/harness/determined/cli/proxy.py @@ -11,11 +11,12 @@ import time import urllib.request from dataclasses import dataclass -from typing import Iterator, List, Optional, Union +from typing import Iterator, List, Optional import lomond -from determined.common.api import Session, authentication, bindings, request +from determined.common import api +from determined.common.api import bindings, certs, request @dataclass @@ -31,20 +32,21 @@ class CustomSSLWebsocketSession(lomond.session.WebsocketSession): # type: ignor configured. """ - def __init__( - self, socket: lomond.WebSocket, cert_file: Union[str, bool, None], cert_name: Optional[str] - ) -> None: + def __init__(self, socket: lomond.WebSocket, cert: Optional[certs.Cert]) -> None: super().__init__(socket) self.ctx = ssl.create_default_context() - self.cert_name = cert_name - if cert_file is False: + + self.cert_name = cert.name if cert else None + + bundle = cert.bundle if cert else None + if bundle is False: self.ctx.check_hostname = False self.ctx.verify_mode = ssl.CERT_NONE return - if cert_file is not None: - assert isinstance(cert_file, str) - self.ctx.load_verify_locations(cafile=cert_file) + if bundle is not None: + assert isinstance(bundle, str) + self.ctx.load_verify_locations(cafile=bundle) def _wrap_socket(self, sock: socket.SocketType, host: str) -> socket.SocketType: return self.ctx.wrap_socket(sock, server_hostname=self.cert_name or host) @@ -86,13 +88,12 @@ def copy_from_websocket( f: io.RawIOBase, ws: lomond.WebSocket, ready_sem: threading.Semaphore, - cert_file: Union[str, bool, None], - cert_name: Optional[str], + cert: Optional[certs.Cert], ) -> None: try: for event in ws.connect( ping_rate=0, - session_class=lambda socket: CustomSSLWebsocketSession(socket, cert_file, cert_name), + session_class=lambda socket: CustomSSLWebsocketSession(socket, cert), ): if isinstance(event, lomond.events.Binary): f.write(event.data) @@ -113,13 +114,12 @@ def copy_from_websocket2( f: socket.socket, ws: lomond.WebSocket, ready_sem: threading.Semaphore, - cert_file: Union[str, bool, None], - cert_name: Optional[str], + cert: Optional[certs.Cert], ) -> None: try: for event in ws.connect( ping_rate=0, - session_class=lambda socket: CustomSSLWebsocketSession(socket, cert_file, cert_name), + session_class=lambda socket: CustomSSLWebsocketSession(socket, cert), ): if isinstance(event, lomond.events.Binary): f.send(event.data) @@ -138,15 +138,9 @@ def copy_from_websocket2( f.close() -def http_connect_tunnel( - master: str, - service: str, - cert_file: Union[str, bool, None], - cert_name: Optional[str], - authorization_token: Optional[str] = None, -) -> None: - parsed_master = request.parse_master_address(master) - assert parsed_master.hostname is not None, "Failed to parse master address: {}".format(master) +def http_connect_tunnel(sess: api.BaseSession, service: str) -> None: + parsed_master = request.parse_master_address(sess.master) + assert parsed_master.hostname is not None, f"Failed to parse master address: {sess.master}" # The "lomond.WebSocket()" function does not honor the "no_proxy" or # "NO_PROXY" environment variables. To work around that, we check if @@ -160,10 +154,10 @@ def http_connect_tunnel( # specified, the default value is "None". proxies = {} if urllib.request.proxy_bypass(parsed_master.hostname) else None # type: ignore - url = request.make_url(master, "proxy/{}/".format(service)) + url = request.make_url(sess.master, f"proxy/{service}/") ws = lomond.WebSocket(request.maybe_upgrade_ws_scheme(url), proxies=proxies) - if authorization_token is not None: - ws.add_header(b"Authorization", f"Bearer {authorization_token}".encode()) + if isinstance(sess, api.Session): + ws.add_header(b"Authorization", f"Bearer {sess.token}".encode()) # We can't send data to the WebSocket before the connection becomes ready, which takes a bit of # time; this semaphore lets the sending thread wait for that to happen. @@ -176,7 +170,7 @@ def http_connect_tunnel( c1 = threading.Thread(target=copy_to_websocket, args=(ws, unbuffered_stdin, ready_sem)) c2 = threading.Thread( - target=copy_from_websocket, args=(unbuffered_stdout, ws, ready_sem, cert_file, cert_name) + target=copy_from_websocket, args=(unbuffered_stdout, ws, ready_sem, sess.cert) ) c1.start() c2.start() @@ -189,18 +183,13 @@ class ReuseAddrServer(socketserver.ThreadingTCPServer): def _http_tunnel_listener( - master_addr: str, + sess: api.BaseSession, tunnel: ListenerConfig, - cert_file: Union[str, bool, None], - cert_name: Optional[str], - authorization_token: Optional[str] = None, ) -> socketserver.ThreadingTCPServer: - parsed_master = request.parse_master_address(master_addr) - assert parsed_master.hostname is not None, "Failed to parse master address: {}".format( - master_addr - ) + parsed_master = request.parse_master_address(sess.master) + assert parsed_master.hostname is not None, f"Failed to parse master address: {sess.master}" - url = request.make_url(master_addr, "proxy/{}/".format(tunnel.service_id)) + url = request.make_url(sess.master, f"proxy/{tunnel.service_id}/") class TunnelHandler(socketserver.BaseRequestHandler): def handle(self) -> None: @@ -209,8 +198,8 @@ def handle(self) -> None: ) ws = lomond.WebSocket(request.maybe_upgrade_ws_scheme(url), proxies=proxies) - if authorization_token is not None: - ws.add_header(b"Authorization", f"Bearer {authorization_token}".encode()) + if isinstance(sess, api.Session): + ws.add_header(b"Authorization", f"Bearer {sess.token}".encode()) # We can't send data to the WebSocket before the connection becomes ready, # which takes a bit of time; this semaphore lets the sending thread # wait for that to happen. @@ -219,7 +208,7 @@ def handle(self) -> None: c1 = threading.Thread(target=copy_to_websocket2, args=(ws, self.request, ready_sem)) c2 = threading.Thread( target=copy_from_websocket2, - args=(self.request, ws, ready_sem, cert_file, cert_name), + args=(self.request, ws, ready_sem, sess.cert), ) c1.start() c2.start() @@ -237,16 +226,10 @@ def handle(self) -> None: @contextlib.contextmanager def http_tunnel_listener( - master: str, + sess: api.BaseSession, tunnels: List[ListenerConfig], - cert_file: Union[str, bool, None], - cert_name: Optional[str], - authorization_token: Optional[str] = None, ) -> Iterator[None]: - servers = [ - _http_tunnel_listener(master, tunnel, cert_file, cert_name, authorization_token) - for tunnel in tunnels - ] + servers = [_http_tunnel_listener(sess, tunnel) for tunnel in tunnels] threads = [threading.Thread(target=lambda s: s.serve_forever(), args=(s,)) for s in servers] @@ -264,30 +247,22 @@ def http_tunnel_listener( @contextlib.contextmanager -def _tunnel_task(sess: Session, task_id: str, port_map: dict[int, int]) -> Iterator[None]: +def _tunnel_task(sess: api.Session, task_id: str, port_map: dict[int, int]) -> Iterator[None]: # Args: # port_map: dict of local port => task port. # task_id: tunneled task_id. - master_addr = sess._master listeners = [ ListenerConfig(service_id=f"{task_id}:{task_port}", local_port=local_port) for local_port, task_port in port_map.items() ] - cert = sess._cert - cert_file, cert_name = None, None - if cert is not None: - cert_file = cert.bundle - cert_name = cert.name - - token = authentication.must_cli_auth().get_session_token() - with http_tunnel_listener(master_addr, listeners, cert_file, cert_name, token): + with http_tunnel_listener(sess, listeners): yield @contextlib.contextmanager -def _tunnel_trial(sess: Session, trial_id: int, port_map: dict[int, int]) -> Iterator[None]: +def _tunnel_trial(sess: api.Session, trial_id: int, port_map: dict[int, int]) -> Iterator[None]: # TODO(DET-9000): perhaps the tunnel should be able to probe master for service status, # instead of us explicitly polling for task/trial status. while True: @@ -314,7 +289,7 @@ def _tunnel_trial(sess: Session, trial_id: int, port_map: dict[int, int]) -> Ite @contextlib.contextmanager def tunnel_experiment( - sess: Session, experiment_id: int, port_map: dict[int, int] + sess: api.Session, experiment_id: int, port_map: dict[int, int] ) -> Iterator[None]: while True: trials = bindings.get_GetExperimentTrials(sess, experimentId=experiment_id).trials diff --git a/harness/determined/cli/rbac.py b/harness/determined/cli/rbac.py index 668d97261b2..bdd15f00ffc 100644 --- a/harness/determined/cli/rbac.py +++ b/harness/determined/cli/rbac.py @@ -2,10 +2,10 @@ from collections import namedtuple from typing import Any, Dict, List, Set, Tuple -import determined.cli.render -from determined.cli import default_pagination_args, render, require_feature_flag, setup_session +from determined import cli +from determined.cli import render from determined.common import api -from determined.common.api import authentication, bindings +from determined.common.api import bindings from determined.common.declarative_argparse import Arg, Cmd rbac_flag_disabled_message = ( @@ -49,13 +49,12 @@ ) -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def my_permissions(args: Namespace) -> None: - session = setup_session(args) - resp = bindings.get_GetPermissionsSummary(session) + sess = cli.setup_session(args) + resp = bindings.get_GetPermissionsSummary(sess) if args.json: - determined.cli.render.print_json(resp.to_json()) + render.print_json(resp.to_json()) return role_id_to_permissions: Dict[int, Set[bindings.v1Permission]] = {} @@ -87,7 +86,7 @@ def my_permissions(args: Namespace) -> None: if wid == 0: print("global permissions assigned") else: - workspace_name = bindings.get_GetWorkspace(session, id=wid).workspace.name + workspace_name = bindings.get_GetWorkspace(sess, id=wid).workspace.name print(f"permissions assigned over workspace '{workspace_name}' with ID '{wid}'") render.render_objects( @@ -96,17 +95,17 @@ def my_permissions(args: Namespace) -> None: print() -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def list_roles(args: Namespace) -> None: + sess = cli.setup_session(args) req = bindings.v1SearchRolesAssignableToScopeRequest( limit=args.limit, offset=args.offset, workspaceId=1 if args.exclude_global_roles else None, ) - resp = bindings.post_SearchRolesAssignableToScope(setup_session(args), body=req) + resp = bindings.post_SearchRolesAssignableToScope(sess, body=req) if args.json: - determined.cli.render.print_json(resp.to_json()) + render.print_json(resp.to_json()) return if resp.roles is None or len(resp.roles) == 0: @@ -150,14 +149,13 @@ def role_with_assignment_to_dict( } -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def list_users_roles(args: Namespace) -> None: - session = setup_session(args) - user_id = api.usernames_to_user_ids(session, [args.username])[0] - resp = bindings.get_GetRolesAssignedToUser(session, userId=user_id) + sess = cli.setup_session(args) + user_id = api.usernames_to_user_ids(sess, [args.username])[0] + resp = bindings.get_GetRolesAssignedToUser(sess, userId=user_id) if args.json: - determined.cli.render.print_json(resp.to_json()) + render.print_json(resp.to_json()) return if resp.roles is None or len(resp.roles) == 0: @@ -168,16 +166,14 @@ def list_users_roles(args: Namespace) -> None: for r in resp.roles: if r.userRoleAssignments is not None: for u in r.userRoleAssignments: - o = role_with_assignment_to_dict(session, r, u.roleAssignment) + o = role_with_assignment_to_dict(sess, r, u.roleAssignment) o["assignedDirectlyToUser"] = True output.append(o) if r.groupRoleAssignments is not None: for g in r.groupRoleAssignments: - o = role_with_assignment_to_dict(session, r, g.roleAssignment) + o = role_with_assignment_to_dict(sess, r, g.roleAssignment) o["assignedToGroupID"] = g.groupId - o["assignedToGroupName"] = bindings.get_GetGroup( - session, groupId=g.groupId - ).group.name + o["assignedToGroupName"] = bindings.get_GetGroup(sess, groupId=g.groupId).group.name output.append(o) render.render_objects( @@ -186,14 +182,13 @@ def list_users_roles(args: Namespace) -> None: ) -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def list_groups_roles(args: Namespace) -> None: - session = setup_session(args) - group_id = api.group_name_to_group_id(session, args.group_name) - resp = bindings.get_GetRolesAssignedToGroup(session, groupId=group_id) + sess = cli.setup_session(args) + group_id = api.group_name_to_group_id(sess, args.group_name) + resp = bindings.get_GetRolesAssignedToGroup(sess, groupId=group_id) if args.json: - determined.cli.render.print_json(resp.to_json()) + render.print_json(resp.to_json()) return if resp.roles is None or len(resp.roles) == 0: @@ -209,7 +204,7 @@ def list_groups_roles(args: Namespace) -> None: workspace_ids = resp.assignments[i].scopeWorkspaceIds or [] for wid in workspace_ids: - workspace_name = bindings.get_GetWorkspace(session, id=wid).workspace.name + workspace_name = bindings.get_GetWorkspace(sess, id=wid).workspace.name workspaces.append( { "workspaceID": wid, @@ -226,15 +221,14 @@ def list_groups_roles(args: Namespace) -> None: print() -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def describe_role(args: Namespace) -> None: - session = setup_session(args) - role_id = api.role_name_to_role_id(session, args.role_name) + sess = cli.setup_session(args) + role_id = api.role_name_to_role_id(sess, args.role_name) req = bindings.v1GetRolesByIDRequest(roleIds=[role_id]) - resp = bindings.post_GetRolesByID(session, body=req) + resp = bindings.post_GetRolesByID(sess, body=req) if args.json: - determined.cli.render.print_json(resp.roles[0].to_json() if resp.roles else None) + render.print_json(resp.roles[0].to_json() if resp.roles else None) return if resp.roles is None or len(resp.roles) != 1: @@ -263,9 +257,9 @@ def describe_role(args: Namespace) -> None: for group_assignment in group_assignments: workspace_id = group_assignment.roleAssignment.scopeWorkspaceId workspace_name = None - group_name = bindings.get_GetGroup(session, groupId=group_assignment.groupId).group.name + group_name = bindings.get_GetGroup(sess, groupId=group_assignment.groupId).group.name if workspace_id is not None: - workspace_name = bindings.get_GetWorkspace(session, id=workspace_id).workspace.name + workspace_name = bindings.get_GetWorkspace(sess, id=workspace_id).workspace.name output.append( { @@ -290,9 +284,9 @@ def describe_role(args: Namespace) -> None: for user_assignment in user_assignments: workspace_id = user_assignment.roleAssignment.scopeWorkspaceId workspace_name = None - username = bindings.get_GetUser(session, userId=user_assignment.userId).user.username + username = bindings.get_GetUser(sess, userId=user_assignment.userId).user.username if workspace_id is not None: - workspace_name = bindings.get_GetWorkspace(session, id=workspace_id).workspace.name + workspace_name = bindings.get_GetWorkspace(sess, id=workspace_id).workspace.name output.append( { @@ -340,8 +334,7 @@ def make_assign_req( return user_assign, group_assign -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def assign_role(args: Namespace) -> None: # Valid CLI usage is enforced before even creating a session. if (args.username_to_assign is None) == (args.group_name_to_assign is None): @@ -349,12 +342,12 @@ def assign_role(args: Namespace) -> None: "must provide exactly one of --username-to-assign or --group-name-to-assign" ) - session = setup_session(args) - user_assign, group_assign = make_assign_req(session, args) + sess = cli.setup_session(args) + user_assign, group_assign = make_assign_req(sess, args) req = bindings.v1AssignRolesRequest( userRoleAssignments=user_assign, groupRoleAssignments=group_assign ) - bindings.post_AssignRoles(session, body=req) + bindings.post_AssignRoles(sess, body=req) scope = " globally" if args.workspace_name: @@ -373,8 +366,7 @@ def assign_role(args: Namespace) -> None: ) -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def unassign_role(args: Namespace) -> None: # Valid CLI usage is enforced before even creating a session. if (args.username_to_assign is None) == (args.group_name_to_assign is None): @@ -382,12 +374,12 @@ def unassign_role(args: Namespace) -> None: "must provide exactly one of --username-to-assign or --group-name-to-assign" ) - session = setup_session(args) - user_assign, group_assign = make_assign_req(session, args) + sess = cli.setup_session(args) + user_assign, group_assign = make_assign_req(sess, args) req = bindings.v1RemoveAssignmentsRequest( userRoleAssignments=user_assign, groupRoleAssignments=group_assign ) - bindings.post_RemoveAssignments(session, body=req) + bindings.post_RemoveAssignments(sess, body=req) scope = " globally" if args.workspace_name: @@ -429,7 +421,7 @@ def unassign_role(args: Namespace) -> None: help="Ignore roles with global permissions", ), Arg("--json", action="store_true", help="print as JSON"), - *default_pagination_args, + *cli.make_pagination_args(), ], is_default=True, ), diff --git a/harness/determined/cli/resource_pool.py b/harness/determined/cli/resource_pool.py index bf3b4b85464..d9067fe2360 100644 --- a/harness/determined/cli/resource_pool.py +++ b/harness/determined/cli/resource_pool.py @@ -1,17 +1,18 @@ from argparse import ONE_OR_MORE, Namespace from typing import Any, List -from determined.cli import render, setup_session -from determined.common.api import authentication, bindings +from determined import cli +from determined.cli import render +from determined.common.api import bindings from determined.common.declarative_argparse import Arg, Cmd -@authentication.required def add_binding(args: Namespace) -> None: + sess = cli.setup_session(args) body = bindings.v1BindRPToWorkspaceRequest( resourcePoolName=args.pool_name, workspaceNames=args.workspace_names ) - bindings.post_BindRPToWorkspace(setup_session(args), body=body, resourcePoolName=args.pool_name) + bindings.post_BindRPToWorkspace(sess, body=body, resourcePoolName=args.pool_name) print( f'added bindings between the resource pool "{args.pool_name}" ' @@ -20,15 +21,13 @@ def add_binding(args: Namespace) -> None: return -@authentication.required def remove_binding(args: Namespace) -> None: + sess = cli.setup_session(args) body = bindings.v1UnbindRPFromWorkspaceRequest( resourcePoolName=args.pool_name, workspaceNames=args.workspace_names, ) - bindings.delete_UnbindRPFromWorkspace( - setup_session(args), body=body, resourcePoolName=args.pool_name - ) + bindings.delete_UnbindRPFromWorkspace(sess, body=body, resourcePoolName=args.pool_name) print( f'removed bindings between the resource pool "{args.pool_name}" ' @@ -37,15 +36,13 @@ def remove_binding(args: Namespace) -> None: return -@authentication.required def replace_bindings(args: Namespace) -> None: + sess = cli.setup_session(args) body = bindings.v1OverwriteRPWorkspaceBindingsRequest( resourcePoolName=args.pool_name, workspaceNames=args.workspace_names, ) - bindings.put_OverwriteRPWorkspaceBindings( - setup_session(args), body=body, resourcePoolName=args.pool_name - ) + bindings.put_OverwriteRPWorkspaceBindings(sess, body=body, resourcePoolName=args.pool_name) print( f'replaced bindings of the resource pool "{args.pool_name}" ' @@ -54,17 +51,16 @@ def replace_bindings(args: Namespace) -> None: return -@authentication.required def list_workspaces(args: Namespace) -> None: - session = setup_session(args) - resp = bindings.get_ListWorkspacesBoundToRP(session, resourcePoolName=args.pool_name) + sess = cli.setup_session(args) + resp = bindings.get_ListWorkspacesBoundToRP(sess, resourcePoolName=args.pool_name) workspace_names = "" if resp.workspaceIds: workspace_names = ", ".join( [ workspace.name - for workspace in bindings.get_GetWorkspaces(session).workspaces + for workspace in bindings.get_GetWorkspaces(sess).workspaces if workspace.id in set(resp.workspaceIds) ] ) diff --git a/harness/determined/cli/resources.py b/harness/determined/cli/resources.py index 2f10cc117dc..05ed124d8c7 100644 --- a/harness/determined/cli/resources.py +++ b/harness/determined/cli/resources.py @@ -3,8 +3,7 @@ import requests -from determined.common import api -from determined.common.api import authentication +from determined import cli from determined.common.declarative_argparse import Arg, ArgsDescription, Cmd @@ -14,15 +13,15 @@ def print_response(r: requests.Response) -> None: sys.stdout.buffer.write(chunk) -@authentication.required def raw(args: Namespace) -> None: + sess = cli.setup_session(args) params = {"timestamp_after": args.timestamp_after, "timestamp_before": args.timestamp_before} path = "api/v1/resources/allocation/raw" if args.json else "resources/allocation/raw" - print_response(api.get(args.master, path, params=params)) + print_response(sess.get(path, params=params)) -@authentication.required def aggregated(args: Namespace) -> None: + sess = cli.setup_session(args) params = { "start_date": args.start_date, "end_date": args.end_date, @@ -33,7 +32,7 @@ def aggregated(args: Namespace) -> None: path = ( "api/v1/resources/allocation/aggregated" if args.json else "resources/allocation/aggregated" ) - print_response(api.get(args.master, path, params=params)) + print_response(sess.get(path, params=params)) args_description: ArgsDescription = [ diff --git a/harness/determined/cli/shell.py b/harness/determined/cli/shell.py index 75eabd818f4..e186be68435 100644 --- a/harness/determined/cli/shell.py +++ b/harness/determined/cli/shell.py @@ -17,12 +17,12 @@ from determined import cli from determined.cli import ntsc, render, task from determined.common import api -from determined.common.api import authentication, bindings, certs +from determined.common.api import bindings from determined.common.declarative_argparse import Arg, ArgsDescription, Cmd, Group -@authentication.required def start_shell(args: argparse.Namespace) -> None: + sess = cli.setup_session(args) data = {} if args.passphrase: data["passphrase"] = getpass.getpass("Enter new passphrase: ") @@ -30,7 +30,7 @@ def start_shell(args: argparse.Namespace) -> None: workspace_id = cli.workspace.get_workspace_id_from_args(args) resp = ntsc.launch_command( - args.master, + sess, "api/v1/shells", config, args.template, @@ -48,12 +48,9 @@ def start_shell(args: argparse.Namespace) -> None: render.report_job_launched("shell", sid) - session = cli.setup_session(args) - - shell = bindings.get_GetShell(session, shellId=sid).shell + shell = bindings.get_GetShell(sess, shellId=sid).shell _open_shell( - session, - args.master, + sess, shell.to_json(), args.ssh_opts, retain_keys_and_print=args.show_ssh_command, @@ -61,14 +58,12 @@ def start_shell(args: argparse.Namespace) -> None: ) -@authentication.required def open_shell(args: argparse.Namespace) -> None: - shell_id = cast(str, ntsc.expand_uuid_prefixes(args)) - - shell = api.get(args.master, f"api/v1/shells/{shell_id}").json()["shell"] + sess = cli.setup_session(args) + shell_id = cast(str, ntsc.expand_uuid_prefixes(sess, args)) + shell = sess.get(f"api/v1/shells/{shell_id}").json()["shell"] _open_shell( - cli.setup_session(args), - args.master, + sess, shell, args.ssh_opts, retain_keys_and_print=args.show_ssh_command, @@ -76,7 +71,6 @@ def open_shell(args: argparse.Namespace) -> None: ) -@authentication.required def show_ssh_command(args: argparse.Namespace) -> None: if platform.system() == "Linux" and "WSL" in os.uname().release: cli.warn( @@ -85,11 +79,11 @@ def show_ssh_command(args: argparse.Namespace) -> None: "command in a Windows shell. For PyCharm users, configure the Pycharm " "ssh command to target the WSL ssh command." ) - shell_id = ntsc.expand_uuid_prefixes(args) - shell = api.get(args.master, f"api/v1/shells/{shell_id}").json()["shell"] + sess = cli.setup_session(args) + shell_id = ntsc.expand_uuid_prefixes(sess, args) + shell = sess.get(f"api/v1/shells/{shell_id}").json()["shell"] _open_shell( - cli.setup_session(args), - args.master, + sess, shell, args.ssh_opts, retain_keys_and_print=True, @@ -141,8 +135,8 @@ def file_closer() -> Iterator[IO]: def _prepare_cert_bundle(retention_dir: Union[Path, None]) -> Union[str, bool, None]: - cert = certs.cli_cert - assert cert is not None, "cli_cert was not configured" + cert = cli.cert + assert cert is not None, "cli.cert was not configured" if retention_dir and isinstance(cert.bundle, str): retained_cert_bundle_path = retention_dir / "cert_bundle" shutil.copy2(str(cert.bundle), retained_cert_bundle_path) @@ -152,7 +146,6 @@ def _prepare_cert_bundle(retention_dir: Union[Path, None]) -> Union[str, bool, N def _open_shell( sess: api.Session, - master: str, shell: Dict[str, Any], additional_opts: List[str], retain_keys_and_print: bool, @@ -173,7 +166,7 @@ def _open_shell( # Use determined.cli.tunnel as a portable script for using the HTTP CONNECT mechanism, # similar to `nc -X CONNECT -x ...` but without any dependency on external binaries. - proxy_cmd = f"{sys.executable} -m determined.cli.tunnel {master} %h" + proxy_cmd = f"{sys.executable} -m determined.cli.tunnel {sess.master} %h" cert_bundle_path = _prepare_cert_bundle(cache_dir) if cert_bundle_path is False: @@ -186,8 +179,8 @@ def _open_shell( f"of type ({type(cert_bundle_path).__name__})" ) - cert = certs.cli_cert - assert cert is not None, "cli_cert was not configured" + cert = cli.cert + assert cert is not None, "cli.cert was not configured" if cert.name: proxy_cmd += f' --cert-name "{cert.name}"' diff --git a/harness/determined/cli/sso.py b/harness/determined/cli/sso.py index 026b6b502e9..740313091f2 100644 --- a/harness/determined/cli/sso.py +++ b/harness/determined/cli/sso.py @@ -6,6 +6,7 @@ from typing import Any, Callable, List from urllib.parse import parse_qs, urlparse +from determined import cli from determined.common import api from determined.common.api import authentication from determined.common.declarative_argparse import Arg, Cmd @@ -14,9 +15,9 @@ CLI_REDIRECT_PORT = 49176 -def handle_token(master_url: str, token: str) -> None: +def handle_token(sess: api.BaseSession, master_url: str, token: str) -> None: tmp_auth = {"Cookie": "auth={token}".format(token=token)} - me = api.get(master_url, "/users/me", headers=tmp_auth, authenticated=False).json() + me = sess.get("/users/me", headers=tmp_auth).json() token_store = authentication.TokenStore(master_url) token_store.set_token(me["username"], token) @@ -25,13 +26,13 @@ def handle_token(master_url: str, token: str) -> None: print("Authenticated as {}.".format(me["username"])) -def make_handler(master_url: str, close_cb: Callable[[int], None]) -> Any: +def make_handler(sess: api.BaseSession, master_url: str, close_cb: Callable[[int], None]) -> Any: class TokenAcceptHandler(BaseHTTPRequestHandler): def do_GET(self) -> None: try: """Serve a GET request.""" token = parse_qs(urlparse(self.path).query)["token"][0] - handle_token(master_url, token) + handle_token(sess, master_url, token) self.send_response(200) self.send_header("Content-type", "text/html") @@ -49,8 +50,9 @@ def log_message(self, format: Any, *args: List[Any]) -> None: # noqa: A002 return TokenAcceptHandler -def sso(parsed_args: Namespace) -> None: - master_info = api.get(parsed_args.master, "info", authenticated=False).json() +def sso(args: Namespace) -> None: + sess = cli.unauth_session(args) + master_info = sess.get("info").json() try: sso_providers = master_info["sso_providers"] except KeyError: @@ -58,27 +60,27 @@ def sso(parsed_args: Namespace) -> None: if not sso_providers: print("No SSO providers found.") return - elif not parsed_args.provider: + elif not args.provider: if len(sso_providers) > 1: print("Provider must be specified when multiple are available.") return matched_provider = sso_providers[0] else: matching_providers = [ - p for p in sso_providers if p["name"].lower() == parsed_args.provider.lower() + p for p in sso_providers if p["name"].lower() == args.provider.lower() ] if not matching_providers: ps = ", ".join(p["name"].lower() for p in sso_providers) - print("Provider {} unsupported. (Providers found: {})".format(parsed_args.provider, ps)) + print("Provider {} unsupported. (Providers found: {})".format(args.provider, ps)) return elif len(matching_providers) > 1: - print("Multiple SSO providers found with name {}.".format(parsed_args.provider)) + print("Multiple SSO providers found with name {}.".format(args.provider)) return matched_provider = matching_providers[0] sso_url = matched_provider["sso_url"] + "?relayState=cli" - if not parsed_args.headless: + if not args.headless: if webbrowser.open(sso_url): print( "Your browser should open and prompt you to sign on;" @@ -87,7 +89,7 @@ def sso(parsed_args: Namespace) -> None: print("Killing this process before signing on will cancel authentication.") with HTTPServer( ("localhost", CLI_REDIRECT_PORT), - make_handler(parsed_args.master, lambda code: sys.exit(code)), + make_handler(sess, args.master, lambda code: sys.exit(code)), ) as httpd: return httpd.serve_forever() @@ -105,13 +107,14 @@ def sso(parsed_args: Namespace) -> None: user_input_url = getpass(prompt="\n(hidden) localhost URL? ") try: token = parse_qs(urlparse(user_input_url).query)["token"][0] - handle_token(parsed_args.master, token) + handle_token(sess, args.master, token) except (KeyError, IndexError): print(f"Could not extract token from localhost URL. {example_url}") -def list_providers(parsed_args: Namespace) -> None: - master_info = api.get(parsed_args.master, "info", authenticated=False).json() +def list_providers(args: Namespace) -> None: + sess = cli.unauth_session(args) + master_info = sess.get("info").json() try: sso_providers = master_info["sso_providers"] diff --git a/harness/determined/cli/task.py b/harness/determined/cli/task.py index 76895175f0e..7f420b927a9 100644 --- a/harness/determined/cli/task.py +++ b/harness/determined/cli/task.py @@ -9,7 +9,7 @@ from determined import cli from determined.cli import ntsc, render from determined.common import api, context, util -from determined.common.api import authentication, bindings +from determined.common.api import bindings from determined.common.api.bindings import v1AllocationSummary, v1CreateGenericTaskResponse from determined.common.declarative_argparse import Arg, Cmd, Group @@ -71,19 +71,19 @@ def agent_info(t: v1AllocationSummary) -> Union[str, List[str]]: render.tabulate_or_csv(headers, values, args.csv) -@authentication.required def list_tasks(args: Namespace) -> None: - r = bindings.get_GetTasks(cli.setup_session(args)) + sess = cli.setup_session(args) + r = bindings.get_GetTasks(sess) tasks = r.allocationIdToSummary or {} render_tasks(args, tasks) -@authentication.required def logs(args: Namespace) -> None: - task_id = cast(str, ntsc.expand_uuid_prefixes(args, args.task_id)) + sess = cli.setup_session(args) + task_id = cast(str, ntsc.expand_uuid_prefixes(sess, args, args.task_id)) try: logs = api.task_logs( - cli.setup_session(args), + sess, task_id, head=args.head, tail=args.tail, @@ -112,7 +112,6 @@ def logs(args: Namespace) -> None: ) -@authentication.required def kill(args: Namespace) -> None: sess = cli.setup_session(args) req = bindings.v1KillGenericTaskRequest(taskId=args.task_id, killFromRoot=args.root) @@ -142,13 +141,12 @@ def task_creation_output( ) -@authentication.required def create(args: Namespace) -> None: + sess = cli.setup_session(args) config = ntsc.parse_config(args.config_file, None, args.config, []) config_text = util.yaml_safe_dump(config) context_directory = context.read_v1_context(args.context, args.include) - sess = cli.setup_session(args) req = bindings.v1CreateGenericTaskRequest( config=config_text, contextDirectory=context_directory, @@ -162,7 +160,6 @@ def create(args: Namespace) -> None: task_creation_output(session=sess, task_resp=task_resp, follow=args.follow) -@authentication.required def config(args: Namespace) -> None: sess = cli.setup_session(args) config_resp = bindings.get_GetGenericTaskConfig(sess, taskId=args.task_id) @@ -173,7 +170,6 @@ def config(args: Namespace) -> None: print(util.yaml_safe_dump(yaml_dict, default_flow_style=False)) -@authentication.required def fork(args: Namespace) -> None: sess = cli.setup_session(args) req = bindings.v1CreateGenericTaskRequest( @@ -187,14 +183,12 @@ def fork(args: Namespace) -> None: task_creation_output(session=sess, task_resp=task_resp, follow=args.follow) -@authentication.required def pause(args: Namespace) -> None: sess = cli.setup_session(args) bindings.post_PauseGenericTask(sess, taskId=args.task_id) print(f"Paused task: {args.task_id}") -@authentication.required def unpause(args: Namespace) -> None: sess = cli.setup_session(args) bindings.post_UnpauseGenericTask(sess, taskId=args.task_id) diff --git a/harness/determined/cli/template.py b/harness/determined/cli/template.py index 4a8d7c52c2b..ccccb165ba3 100644 --- a/harness/determined/cli/template.py +++ b/harness/determined/cli/template.py @@ -8,7 +8,7 @@ from determined.cli import render from determined.cli.workspace import get_workspace_id_from_args, workspace_arg from determined.common import api, util -from determined.common.api import authentication, bindings +from determined.common.api import bindings from determined.common.declarative_argparse import Arg, Cmd TemplateClean = namedtuple("TemplateClean", ["name", "workspace"]) @@ -20,12 +20,12 @@ def _parse_config(data: Dict[str, Any]) -> Any: return util.yaml_safe_dump(data, default_flow_style=False) -@authentication.required def list_template(args: Namespace) -> None: + sess = cli.setup_session(args) templates: List[TemplateAll] = [] - w_names = cli.workspace.get_workspace_names(cli.setup_session(args)) + w_names = cli.workspace.get_workspace_names(sess) - for tpl in bindings.get_GetTemplates(cli.setup_session(args)).templates: + for tpl in bindings.get_GetTemplates(sess).templates: w_name = w_names.get(tpl.workspaceId, "missing workspace") templates.append(TemplateAll(tpl.name, w_name, _parse_config(tpl.config))) if args.details: @@ -34,61 +34,54 @@ def list_template(args: Namespace) -> None: render.render_objects(TemplateClean, templates) -@authentication.required def describe_template(args: Namespace) -> None: - tpl = bindings.get_GetTemplate( - cli.setup_session(args), templateName=args.template_name - ).template + sess = cli.setup_session(args) + tpl = bindings.get_GetTemplate(sess, templateName=args.template_name).template print(_parse_config(tpl.config)) -@authentication.required def set_template(args: Namespace) -> None: with args.template_file: """ WARN: this downgrades the atomic behavior of upsert but it's an acceptable tradeoff for now until we can remove this command. """ - session = cli.setup_session(args) + sess = cli.setup_session(args) body = util.safe_load_yaml_with_exceptions(args.template_file) try: - bindings.get_GetTemplate(session, templateName=args.template_name).template - bindings.patch_PatchTemplateConfig(session, templateName=args.template_name, body=body) + bindings.get_GetTemplate(sess, templateName=args.template_name).template + bindings.patch_PatchTemplateConfig(sess, templateName=args.template_name, body=body) except api.errors.NotFoundException: v1_template = bindings.v1Template(name=args.template_name, config=body, workspaceId=0) - bindings.post_PostTemplate(session, template_name=args.template_name, body=v1_template) + bindings.post_PostTemplate(sess, template_name=args.template_name, body=v1_template) print(colored("Set template {}".format(args.template_name), "green")) -@authentication.required def create_template(args: Namespace) -> None: if not args.template_file: raise ArgumentError(None, "template_file is required for set command") + sess = cli.setup_session(args) body = util.safe_load_yaml_with_exceptions(args.template_file) workspace_id = get_workspace_id_from_args(args) or 0 v1_template = bindings.v1Template( name=args.template_name, config=body, workspaceId=workspace_id ) - bindings.post_PostTemplate( - cli.setup_session(args), template_name=args.template_name, body=v1_template - ) + bindings.post_PostTemplate(sess, template_name=args.template_name, body=v1_template) print(colored("Created template {}".format(args.template_name), "green")) -@authentication.required def patch_template_config(args: Namespace) -> None: if not args.template_file: raise ArgumentError(None, "template_file is required for set command") + sess = cli.setup_session(args) body = util.safe_load_yaml_with_exceptions(args.template_file) - bindings.patch_PatchTemplateConfig( - cli.setup_session(args), templateName=args.template_name, body=body - ) + bindings.patch_PatchTemplateConfig(sess, templateName=args.template_name, body=body) print(colored("Updated template {}".format(args.template_name), "green")) -@authentication.required def remove_templates(args: Namespace) -> None: - bindings.delete_DeleteTemplate(cli.setup_session(args), templateName=args.template_name) + sess = cli.setup_session(args) + bindings.delete_DeleteTemplate(sess, templateName=args.template_name) print(colored("Removed template {}".format(args.template_name), "green")) diff --git a/harness/determined/cli/tensorboard.py b/harness/determined/cli/tensorboard.py index 460f4d67198..ac1ba0dfb32 100644 --- a/harness/determined/cli/tensorboard.py +++ b/harness/determined/cli/tensorboard.py @@ -8,13 +8,13 @@ from determined import cli from determined.cli import ntsc, render, task from determined.common import api, context -from determined.common.api import authentication, bindings, request +from determined.common.api import bindings, request from determined.common.check import check_none from determined.common.declarative_argparse import Arg, ArgsDescription, Cmd, Group -@authentication.required def start_tensorboard(args: Namespace) -> None: + sess = cli.setup_session(args) if not (args.trial_ids or args.experiment_ids): raise ArgumentError(None, "Either experiment_ids or trial_ids must be specified.") @@ -30,7 +30,7 @@ def start_tensorboard(args: Namespace) -> None: workspaceId=workspace_id, ) - resp = bindings.post_LaunchTensorboard(cli.setup_session(args), body=body) + resp = bindings.post_LaunchTensorboard(sess, body=body) tsb = resp.tensorboard if args.detach: @@ -44,7 +44,7 @@ def start_tensorboard(args: Namespace) -> None: currentSlotsExceeded = (resp.warnings is not None) and ( bindings.v1LaunchWarning.CURRENT_SLOTS_EXCEEDED in resp.warnings ) - cli.wait_ntsc_ready(cli.setup_session(args), api.NTSC_Kind.tensorboard, tsb.id) + cli.wait_ntsc_ready(sess, api.NTSC_Kind.tensorboard, tsb.id) assert tsb.serviceAddress is not None, "missing tensorboard serviceAddress" nb_path = request.make_interactive_task_url( @@ -61,17 +61,15 @@ def start_tensorboard(args: Namespace) -> None: print(colored("Tensorboard is running at: {}".format(url), "green")) -@authentication.required def open_tensorboard(args: Namespace) -> None: - tensorboard_id = cast(str, ntsc.expand_uuid_prefixes(args)) - sess = cli.setup_session(args) + tensorboard_id = cast(str, ntsc.expand_uuid_prefixes(sess, args)) + task = bindings.get_GetTask(sess, taskId=tensorboard_id).task check_none(task.endTime, "Tensorboard has ended") tsb = bindings.get_GetTensorboard(sess, tensorboardId=tensorboard_id).tensorboard assert tsb.serviceAddress is not None, "missing tensorboard serviceAddress" - api.browser_open( args.master, request.make_interactive_task_url( diff --git a/harness/determined/cli/trial.py b/harness/determined/cli/trial.py index c428d24538d..bc2fba0e736 100644 --- a/harness/determined/cli/trial.py +++ b/harness/determined/cli/trial.py @@ -12,7 +12,7 @@ from determined.cli import errors, render from determined.cli.master import format_log_entry from determined.common import api -from determined.common.api import authentication, bindings +from determined.common.api import bindings from determined.common.declarative_argparse import Arg, ArgsDescription, Cmd, Group, string_to_bool from determined.experimental import client @@ -91,7 +91,6 @@ def _workloads_tabulate( return headers, values -@authentication.required def describe_trial(args: Namespace) -> None: session = cli.setup_session(args) @@ -148,7 +147,8 @@ def get_with_offset(offset: int) -> bindings.v1GetTrialWorkloadsResponse: def download(args: Namespace) -> None: - det = client.Determined(args.master, args.user) + sess = cli.setup_session(args) + det = client.Determined._from_session(sess) if [args.latest, args.best, args.uuid].count(True) != 1: raise ValueError("exactly one of --latest, --best, or --uuid must be set") @@ -198,17 +198,17 @@ def download(args: Namespace) -> None: render_checkpoint(checkpoint, path) -@authentication.required def kill_trial(args: Namespace) -> None: - api.post(args.master, "/api/v1/trials/{}/kill".format(args.trial_id)) - print("Killed trial {}".format(args.trial_id)) + sess = cli.setup_session(args) + sess.post(f"/api/v1/trials/{args.trial_id}/kill") + print("Killed trial", args.trial_id) -@authentication.required def trial_logs(args: Namespace) -> None: + sess = cli.setup_session(args) try: logs = api.trial_logs( - cli.setup_session(args), + sess, args.trial_id, head=args.head, tail=args.tail, @@ -237,8 +237,8 @@ def trial_logs(args: Namespace) -> None: ) -@authentication.required def generate_support_bundle(args: Namespace) -> None: + sess = cli.setup_session(args) try: output_dir = args.output_dir if output_dir is None: @@ -249,9 +249,9 @@ def generate_support_bundle(args: Namespace) -> None: fullpath = os.path.join(output_dir, f"{bundle_name}.tar.gz") with tempfile.TemporaryDirectory() as temp_dir, tarfile.open(fullpath, "w:gz") as bundle: - trial_logs_filepath = write_trial_logs(args, temp_dir) - master_logs_filepath = write_master_logs(args, temp_dir) - api_experiment_filepath, api_trail_filepath = write_api_call(args, temp_dir) + trial_logs_filepath = write_trial_logs(sess, args, temp_dir) + master_logs_filepath = write_master_logs(sess, args, temp_dir) + api_experiment_filepath, api_trail_filepath = write_api_call(sess, args, temp_dir) bundle.add( trial_logs_filepath, @@ -276,9 +276,8 @@ def generate_support_bundle(args: Namespace) -> None: print("Could not create the bundle because the output_dir provived was not found.") -def write_trial_logs(args: Namespace, temp_dir: str) -> str: - session = cli.setup_session(args) - trial_logs = api.trial_logs(session, args.trial_id) +def write_trial_logs(sess: api.Session, args: Namespace, temp_dir: str) -> str: + trial_logs = api.trial_logs(sess, args.trial_id) file_path = os.path.join(temp_dir, "trial_logs.txt") with open(file_path, "w") as f: for log in trial_logs: @@ -287,8 +286,8 @@ def write_trial_logs(args: Namespace, temp_dir: str) -> str: return file_path -def write_master_logs(args: Namespace, temp_dir: str) -> str: - responses = bindings.get_MasterLogs(cli.setup_session(args)) +def write_master_logs(sess: api.Session, args: Namespace, temp_dir: str) -> str: + responses = bindings.get_MasterLogs(sess) file_path = os.path.join(temp_dir, "master_logs.txt") with open(file_path, "w") as f: for response in responses: @@ -296,13 +295,13 @@ def write_master_logs(args: Namespace, temp_dir: str) -> str: return file_path -def write_api_call(args: Namespace, temp_dir: str) -> Tuple[str, str]: +def write_api_call(sess: api.Session, args: Namespace, temp_dir: str) -> Tuple[str, str]: api_experiment_filepath = os.path.join(temp_dir, "api_experiment_call.json") api_trial_filepath = os.path.join(temp_dir, "api_trial_call.json") - trial_obj = bindings.get_GetTrial(cli.setup_session(args), trialId=args.trial_id).trial + trial_obj = bindings.get_GetTrial(sess, trialId=args.trial_id).trial experiment_id = trial_obj.experimentId - exp_obj = bindings.get_GetExperiment(cli.setup_session(args), experimentId=experiment_id) + exp_obj = bindings.get_GetExperiment(sess, experimentId=experiment_id) create_json_file_in_dir(exp_obj.to_json(), api_experiment_filepath) create_json_file_in_dir(trial_obj.to_json(), api_trial_filepath) diff --git a/harness/determined/cli/tunnel.py b/harness/determined/cli/tunnel.py index 52808f0e049..c1ca57857e5 100644 --- a/harness/determined/cli/tunnel.py +++ b/harness/determined/cli/tunnel.py @@ -6,7 +6,8 @@ import argparse import time -from determined.common.api import authentication +from determined.common import api +from determined.common.api import authentication, certs from .proxy import ListenerConfig, http_connect_tunnel, http_tunnel_listener @@ -21,25 +22,27 @@ parser.add_argument("--auth", action="store_true") args = parser.parse_args() - authorization_token = None - if args.auth: - auth = authentication.Authentication(args.master_addr, args.user) - authorization_token = auth.get_session_token(must=True) + if args.cert_file == "noverify": + # The special string "noverify" means to not even check the TLS cert. + cert_file = None + noverify = True + else: + cert_file = args.cert_file + noverify = False + + cert = certs.default_load(args.master_addr, cert_file, args.cert_name, noverify) - # The special string "noverify" is passed to our certs.Cert object as a boolean False. - cert_file = False if args.cert_file == "noverify" else args.cert_file + if args.auth: + utp = authentication.login_with_cache(args.master_addr, args.user, cert=cert) + sess: api.BaseSession = api.Session(args.master_addr, utp, cert) + else: + sess = api.UnauthSession(args.master_addr, cert) if args.listener: with http_tunnel_listener( - args.master_addr, - [ListenerConfig(service_id=args.service_uuid, local_port=args.listener)], - cert_file, - args.cert_name, - authorization_token, + sess, [ListenerConfig(service_id=args.service_uuid, local_port=args.listener)] ): while True: time.sleep(1) else: - http_connect_tunnel( - args.master_addr, args.service_uuid, cert_file, args.cert_name, authorization_token - ) + http_connect_tunnel(sess, args.service_uuid) diff --git a/harness/determined/cli/user.py b/harness/determined/cli/user.py index edc93d6c4dc..267b033ce5f 100644 --- a/harness/determined/cli/user.py +++ b/harness/determined/cli/user.py @@ -3,12 +3,12 @@ from collections import namedtuple from typing import Any, List -from determined.cli import errors, login_sdk_client, render, setup_session +from determined import cli +from determined.cli import errors, render from determined.common import api -from determined.common.api import authentication, bindings, certs +from determined.common.api import authentication, bindings from determined.common.declarative_argparse import Arg, Cmd, string_to_bool from determined.experimental import client -from determined.experimental.client import Determined FullUser = namedtuple( "FullUser", @@ -41,65 +41,74 @@ ) -@login_sdk_client def list_users(args: Namespace) -> None: - resp = bindings.get_GetMaster(setup_session(args)) - users_list = client.list_users(active=None if args.all else True) + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + resp = bindings.get_GetMaster(sess) + users_list = d.list_users(active=None if args.all else True) renderer = FullUser # type: Any if resp.to_json().get("rbacEnabled"): renderer = FullUserNoAdmin render.render_objects(renderer, users_list) -@login_sdk_client -def activate_user(parsed_args: Namespace) -> None: - user_obj = client.get_user_by_name(parsed_args.username) +def activate_user(args: Namespace) -> None: + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + user_obj = d.get_user_by_name(args.username) user_obj.activate() -@login_sdk_client -def deactivate_user(parsed_args: Namespace) -> None: - user_obj = client.get_user_by_name(parsed_args.username) +def deactivate_user(args: Namespace) -> None: + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + user_obj = d.get_user_by_name(args.username) user_obj.deactivate() -def log_in_user(parsed_args: Namespace) -> None: - if parsed_args.username is None: +def log_in_user(args: Namespace) -> None: + if args.username is None: username = input("Username: ") else: - username = parsed_args.username + username = args.username message = "Password for user '{}': ".format(username) password = getpass.getpass(message) - token_store = authentication.TokenStore(parsed_args.master) - token = authentication.do_login(parsed_args.master, username, password, certs.cli_cert) - token_store.set_token(username, token) - token_store.set_active(username) + token_store = authentication.TokenStore(args.master) + utp = authentication.login(args.master, username, password, cli.cert) + token_store.set_token(utp.username, utp.token) + token_store.set_active(utp.username) -def log_out_user(parsed_args: Namespace) -> None: - if parsed_args.all: - authentication.logout_all(parsed_args.master, certs.cli_cert) +def log_out_user(args: Namespace) -> None: + token_store = authentication.TokenStore(args.master) + if args.all: + authentication.logout_all(args.master, cli.cert) + token_store.clear_active() else: # Log out of the user specified by the command line, or the active user. - authentication.logout(parsed_args.master, parsed_args.user, certs.cli_cert) + logged_out_user = authentication.logout(args.master, args.user, cli.cert) + if logged_out_user and token_store.get_active_user() == logged_out_user: + token_store.clear_active() -@login_sdk_client -def rename(parsed_args: Namespace) -> None: - user_obj = client.get_user_by_name(parsed_args.target_user) - user_obj.rename(new_username=parsed_args.new_username) +def rename(args: Namespace) -> None: + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + user_obj = d.get_user_by_name(args.target_user) + user_obj.rename(new_username=args.new_username) -@login_sdk_client -def change_password(parsed_args: Namespace) -> None: - if parsed_args.target_user: - username = parsed_args.target_user - elif parsed_args.user: - username = parsed_args.user +def change_password(args: Namespace) -> None: + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + if args.target_user: + username = args.target_user + elif args.user: + username = args.user else: - username = client.get_session_username() + username = d.get_session_username() if not username: # The default user should have been set by now by autologin. @@ -111,81 +120,83 @@ def change_password(parsed_args: Namespace) -> None: if password != check_password: raise errors.CliError("Passwords do not match") - user_obj = client.get_user_by_name(username) + user_obj = d.get_user_by_name(username) user_obj.change_password(new_password=password) # If the target user's password isn't being changed by another user, reauthenticate after # password change so that the user doesn't have to do so manually. - if parsed_args.target_user is None: - token_store = authentication.TokenStore(parsed_args.master) - token = authentication.do_login(parsed_args.master, username, password, certs.cli_cert) - token_store.set_token(username, token) - token_store.set_active(username) + if args.target_user is None: + token_store = authentication.TokenStore(args.master) + utp = authentication.login(args.master, username, password, cli.cert) + token_store.set_token(utp.username, utp.token) + token_store.set_active(utp.username) -@login_sdk_client -def link_with_agent_user(parsed_args: Namespace) -> None: - if parsed_args.agent_uid is None: +def link_with_agent_user(args: Namespace) -> None: + if args.agent_uid is None: raise api.errors.BadRequestException("agent-uid argument required") - elif parsed_args.agent_user is None: + elif args.agent_user is None: raise api.errors.BadRequestException("agent-user argument required") - elif parsed_args.agent_gid is None: + elif args.agent_gid is None: raise api.errors.BadRequestException("agent-gid argument required") - elif parsed_args.agent_group is None: + elif args.agent_group is None: raise api.errors.BadRequestException("agent-group argument required") - user_obj = client.get_user_by_name(parsed_args.det_username) + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + user_obj = d.get_user_by_name(args.det_username) user_obj.link_with_agent( - agent_gid=parsed_args.agent_gid, - agent_group=parsed_args.agent_group, - agent_uid=parsed_args.agent_uid, - agent_user=parsed_args.agent_user, + agent_gid=args.agent_gid, + agent_group=args.agent_group, + agent_uid=args.agent_uid, + agent_user=args.agent_user, ) -@login_sdk_client -def create_user(parsed_args: Namespace) -> None: - username = parsed_args.username - admin = bool(parsed_args.admin) - remote = bool(parsed_args.remote) - client.create_user(username=username, admin=admin, remote=remote) +def create_user(args: Namespace) -> None: + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + username = args.username + admin = bool(args.admin) + remote = bool(args.remote) + d.create_user(username=username, admin=admin, remote=remote) -@login_sdk_client -def whoami(parsed_args: Namespace) -> None: - user = client.whoami() +def whoami(args: Namespace) -> None: + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + user = d.whoami() print("You are logged in as user '{}'".format(user.username)) -@authentication.required -def edit(parsed_args: Namespace) -> None: - session = setup_session(parsed_args) - det = Determined._from_session(session) - user_obj = det.get_user_by_name(parsed_args.target_user) +def edit(args: Namespace) -> None: + sess = cli.setup_session(args) + d = client.Determined._from_session(sess) + user_obj = d.get_user_by_name(args.target_user) changes = [] patch_user = bindings.v1PatchUser() - if parsed_args.display_name is not None: - patch_user.displayName = parsed_args.display_name + if args.display_name is not None: + patch_user.displayName = args.display_name changes.append("Display Name") - if parsed_args.remote is not None: - patch_user.remote = parsed_args.remote + if args.remote is not None: + patch_user.remote = args.remote changes.append("Remote") - if parsed_args.activate is not None: - patch_user.active = parsed_args.activate + if args.activate is not None: + patch_user.active = args.activate changes.append("Active") - if parsed_args.username is not None: - patch_user.username = parsed_args.username + if args.username is not None: + patch_user.username = args.username changes.append("Username") - if parsed_args.admin is not None: - patch_user.admin = parsed_args.admin + if args.admin is not None: + patch_user.admin = args.admin changes.append("Admin") if len(changes) > 0: - bindings.patch_PatchUser(session=session, body=patch_user, userId=user_obj.user_id) + bindings.patch_PatchUser(sess, body=patch_user, userId=user_obj.user_id) print("Changes made to the following fields: " + ", ".join(changes)) else: raise errors.CliError("No field provided. Use 'det user edit -h' for usage.") diff --git a/harness/determined/cli/user_groups.py b/harness/determined/cli/user_groups.py index d6e2c07b55c..66101ed8910 100644 --- a/harness/determined/cli/user_groups.py +++ b/harness/determined/cli/user_groups.py @@ -2,10 +2,10 @@ from collections import namedtuple from typing import Any, List -import determined.cli.render -from determined.cli import default_pagination_args, render, require_feature_flag, setup_session +from determined import cli +from determined.cli import render from determined.common import api -from determined.common.api import authentication, bindings +from determined.common.api import bindings from determined.common.declarative_argparse import Arg, Cmd v1UserHeaders = namedtuple( @@ -24,13 +24,12 @@ ) -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def create_group(args: Namespace) -> None: - session = setup_session(args) - add_users = api.usernames_to_user_ids(session, args.add_user) + sess = cli.setup_session(args) + add_users = api.usernames_to_user_ids(sess, args.add_user) body = bindings.v1CreateGroupRequest(name=args.group_name, addUsers=add_users) - resp = bindings.post_CreateGroup(session, body=body) + resp = bindings.post_CreateGroup(sess, body=body) group = resp.group print(f"user group with name {group.name} and ID {group.groupId} created") @@ -38,10 +37,9 @@ def create_group(args: Namespace) -> None: print(f"{', '.join([g.username for g in group.users])} was added to the group") -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def list_groups(args: Namespace) -> None: - sess = setup_session(args) + sess = cli.setup_session(args) user_id = None if args.groups_user_belongs_to: user_id = api.usernames_to_user_ids(sess, [args.groups_user_belongs_to])[0] @@ -49,7 +47,7 @@ def list_groups(args: Namespace) -> None: body = bindings.v1GetGroupsRequest(offset=args.offset, limit=args.limit, userId=user_id) resp = bindings.post_GetGroups(sess, body=body) if args.json: - determined.cli.render.print_json(resp.to_json()) + render.print_json(resp.to_json()) else: if resp.groups is None: resp.groups = [] @@ -64,16 +62,15 @@ def list_groups(args: Namespace) -> None: ) -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def describe_group(args: Namespace) -> None: - session = setup_session(args) - group_id = api.group_name_to_group_id(session, args.group_name) - resp = bindings.get_GetGroup(session, groupId=group_id) + sess = cli.setup_session(args) + group_id = api.group_name_to_group_id(sess, args.group_name) + resp = bindings.get_GetGroup(sess, groupId=group_id) group_details = resp.group if args.json: - determined.cli.render.print_json(group_details.to_json()) + render.print_json(group_details.to_json()) else: print(f"group ID {group_details.groupId} group name {group_details.name} with users added") if group_details.users is None: @@ -84,61 +81,57 @@ def describe_group(args: Namespace) -> None: ) -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def add_user_to_group(args: Namespace) -> None: - session = setup_session(args) + sess = cli.setup_session(args) usernames = args.usernames.split(",") - group_id = api.group_name_to_group_id(session, args.group_name) - user_ids = api.usernames_to_user_ids(session, usernames) + group_id = api.group_name_to_group_id(sess, args.group_name) + user_ids = api.usernames_to_user_ids(sess, usernames) body = bindings.v1UpdateGroupRequest(groupId=group_id, addUsers=user_ids) - resp = bindings.put_UpdateGroup(session, groupId=group_id, body=body) + resp = bindings.put_UpdateGroup(sess, groupId=group_id, body=body) print(f"user group with ID {resp.group.groupId} name {resp.group.name}") for user_id, username in zip(user_ids, usernames): print(f"user added to group with username {username} and ID {user_id}") -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def remove_user_from_group(args: Namespace) -> None: - session = setup_session(args) + sess = cli.setup_session(args) usernames = args.usernames.split(",") - group_id = api.group_name_to_group_id(session, args.group_name) - user_ids = api.usernames_to_user_ids(session, usernames) + group_id = api.group_name_to_group_id(sess, args.group_name) + user_ids = api.usernames_to_user_ids(sess, usernames) body = bindings.v1UpdateGroupRequest(groupId=group_id, removeUsers=user_ids) - resp = bindings.put_UpdateGroup(setup_session(args), groupId=group_id, body=body) + resp = bindings.put_UpdateGroup(sess, groupId=group_id, body=body) print(f"user group with ID {resp.group.groupId} name {resp.group.name}") for user_id, username in zip(user_ids, usernames): print(f"user removed from the group with username {username} and ID {user_id}") -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def change_group_name(args: Namespace) -> None: - session = setup_session(args) - group_id = api.group_name_to_group_id(session, args.old_group_name) + sess = cli.setup_session(args) + group_id = api.group_name_to_group_id(sess, args.old_group_name) body = bindings.v1UpdateGroupRequest(groupId=group_id, name=args.new_group_name) - resp = bindings.put_UpdateGroup(session, groupId=group_id, body=body) + resp = bindings.put_UpdateGroup(sess, groupId=group_id, body=body) g = resp.group print(f"user group with ID {g.groupId} name changed from {args.old_group_name} to {g.name}") -@authentication.required -@require_feature_flag("rbacEnabled", rbac_flag_disabled_message) +@cli.require_feature_flag("rbacEnabled", rbac_flag_disabled_message) def delete_group(args: Namespace) -> None: if args.yes or render.yes_or_no( "Deleting a group will result in an unrecoverable \n" "deletion of the group along with all the membership \n" "information of the group. Do you still wish to proceed? \n" ): - session = setup_session(args) - group_id = api.group_name_to_group_id(session, args.group_name) - bindings.delete_DeleteGroup(session, groupId=group_id) + sess = cli.setup_session(args) + group_id = api.group_name_to_group_id(sess, args.group_name) + bindings.delete_DeleteGroup(sess, groupId=group_id) print(f"user group with name {args.group_name} and ID {group_id} deleted") else: print("Skipping group deletion.") @@ -179,7 +172,7 @@ def delete_group(args: Namespace) -> None: list_groups, "list user groups", [ - *default_pagination_args, + *cli.make_pagination_args(), Arg("--groups-user-belongs-to", help="list groups that the username is in"), Arg("--json", action="store_true", help="print as JSON"), ], diff --git a/harness/determined/cli/version.py b/harness/determined/cli/version.py index d7006686550..21291159b02 100644 --- a/harness/determined/cli/version.py +++ b/harness/determined/cli/version.py @@ -6,20 +6,20 @@ import termcolor from packaging import version -import determined -import determined.cli +import determined as det +from determined import cli from determined.cli import render from determined.common import api from determined.common.declarative_argparse import Cmd -def get_version(host: str) -> Dict[str, Any]: - client_info = {"version": determined.__version__} +def get_version(sess: api.BaseSession) -> Dict[str, Any]: + client_info = {"version": det.__version__} master_info = {"cluster_id": "", "master_id": "", "version": ""} try: - master_info = api.get(host, "info", authenticated=False).json() + master_info = sess.get("info").json() # Most connection errors mean that the master is unreachable, which this function handles. # An SSL error, however, means it was reachable but something went wrong, so let that error # propagate out. @@ -28,11 +28,11 @@ def get_version(host: str) -> Dict[str, Any]: except api.errors.MasterNotFoundException: pass - return {"client": client_info, "master": master_info, "master_address": host} + return {"client": client_info, "master": master_info, "master_address": sess.master} -def check_version(parsed_args: argparse.Namespace) -> None: - info = get_version(parsed_args.master) +def check_version(sess: api.BaseSession, args: argparse.Namespace) -> None: + info = get_version(sess) master_version = info["master"]["version"] client_version = info["client"]["version"] @@ -42,7 +42,7 @@ def check_version(parsed_args: argparse.Namespace) -> None: "Master not found at {}. " "Hint: Remember to set the DET_MASTER environment variable " "to the correct Determined master IP and port or use the '-m' flag.".format( - parsed_args.master + args.master ), "yellow", ), @@ -68,8 +68,9 @@ def check_version(parsed_args: argparse.Namespace) -> None: ) -def describe_version(parsed_args: argparse.Namespace) -> None: - info = get_version(parsed_args.master) +def describe_version(args: argparse.Namespace) -> None: + sess = cli.unauth_session(args) + info = get_version(sess) print(render.format_object_as_yaml(info)) diff --git a/harness/determined/cli/workspace.py b/harness/determined/cli/workspace.py index 9bc1b656b5a..c640138886c 100644 --- a/harness/determined/cli/workspace.py +++ b/harness/determined/cli/workspace.py @@ -8,7 +8,7 @@ from determined.cli import render from determined.cli.user import AGENT_USER_GROUP_ARGS from determined.common import api, util -from determined.common.api import authentication, bindings, errors +from determined.common.api import bindings, errors from determined.common.declarative_argparse import Arg, Cmd from determined.common.experimental import workspace @@ -29,9 +29,10 @@ def get_workspace_id_from_args(args: Namespace) -> Optional[int]: + sess = cli.setup_session(args) workspace_id = None if args.workspace_name: - workspace = api.workspace_by_name(cli.setup_session(args), args.workspace_name) + workspace = api.workspace_by_name(sess, args.workspace_name) if workspace.archived: raise ArgumentError(None, f'Workspace "{args.workspace_name}" is archived.') workspace_id = workspace.id @@ -74,7 +75,6 @@ def render_workspaces( render.tabulate_or_csv(headers, values, False) -@authentication.required def list_workspaces(args: Namespace) -> None: sess = cli.setup_session(args) orderArg = bindings.v1OrderBy[args.order_by.upper()] @@ -100,7 +100,6 @@ def list_workspaces(args: Namespace) -> None: render_workspaces(all_workspaces, from_list_api=True) -@authentication.required def list_workspace_projects(args: Namespace) -> None: sess = cli.setup_session(args) all_projects = workspace.Workspace( @@ -135,11 +134,10 @@ def list_workspace_projects(args: Namespace) -> None: render.tabulate_or_csv(PROJECT_HEADERS, values, False) -@authentication.required def list_pools(args: Namespace) -> None: - session = cli.setup_session(args) - w = api.workspace_by_name(session, args.workspace_name) - resp = bindings.get_ListRPsBoundToWorkspace(session, workspaceId=w.id) + sess = cli.setup_session(args) + w = api.workspace_by_name(sess, args.workspace_name) + resp = bindings.get_ListRPsBoundToWorkspace(sess, workspaceId=w.id) pools_str = "" if resp.resourcePools: pools_str = ", ".join(resp.resourcePools) @@ -175,11 +173,11 @@ def _parse_checkpoint_storage_args(args: Namespace) -> Any: return checkpoint_storage -@authentication.required def create_workspace(args: Namespace) -> None: agent_user_group = _parse_agent_user_group_args(args) checkpoint_storage = _parse_checkpoint_storage_args(args) + sess = cli.setup_session(args) content = bindings.v1PostWorkspaceRequest( name=args.name, agentUserGroup=agent_user_group, @@ -187,7 +185,7 @@ def create_workspace(args: Namespace) -> None: defaultComputePool=args.default_compute_pool, defaultAuxPool=args.default_aux_pool, ) - w = bindings.post_PostWorkspace(cli.setup_session(args), body=content).workspace + w = bindings.post_PostWorkspace(sess, body=content).workspace if args.json: determined.cli.render.print_json(w.to_json()) @@ -195,7 +193,6 @@ def create_workspace(args: Namespace) -> None: render_workspaces([w]) -@authentication.required def describe_workspace(args: Namespace) -> None: sess = cli.setup_session(args) w = api.workspace_by_name(sess, args.workspace_name) @@ -205,7 +202,6 @@ def describe_workspace(args: Namespace) -> None: render_workspaces([w]) -@authentication.required def delete_workspace(args: Namespace) -> None: sess = cli.setup_session(args) w = api.workspace_by_name(sess, args.workspace_name) @@ -236,7 +232,6 @@ def delete_workspace(args: Namespace) -> None: print("Aborting workspace deletion.") -@authentication.required def archive_workspace(args: Namespace) -> None: sess = cli.setup_session(args) current = api.workspace_by_name(sess, args.workspace_name) @@ -244,7 +239,6 @@ def archive_workspace(args: Namespace) -> None: print(f"Successfully archived workspace {args.workspace_name}.") -@authentication.required def unarchive_workspace(args: Namespace) -> None: sess = cli.setup_session(args) current = api.workspace_by_name(sess, args.workspace_name) @@ -252,7 +246,6 @@ def unarchive_workspace(args: Namespace) -> None: print(f"Successfully un-archived workspace {args.workspace_name}.") -@authentication.required def edit_workspace(args: Namespace) -> None: checkpoint_storage = _parse_checkpoint_storage_args(args) diff --git a/harness/determined/common/api/__init__.py b/harness/determined/common/api/__init__.py index b4e38ba91c6..6aa96e3ee70 100644 --- a/harness/determined/common/api/__init__.py +++ b/harness/determined/common/api/__init__.py @@ -1,5 +1,5 @@ from determined.common.api import authentication, errors, metric, request, bindings -from determined.common.api._session import Session +from determined.common.api._session import BaseSession, UnauthSession, Session from determined.common.api._util import ( PageOpts, default_retry, @@ -7,7 +7,7 @@ read_paginated, WARNING_MESSAGE_MAP, wait_for_ntsc_state, - task_is_ready, + wait_for_task_ready, NTSC_Kind, AnyNTSC, ) @@ -20,27 +20,14 @@ workspace_by_name, not_found_errs, ) -from determined.common.api.authentication import Authentication, salt_and_hash +from determined.common.api.authentication import UsernameTokenPair, salt_and_hash from determined.common.api.logs import ( pprint_logs, trial_logs, task_logs, ) from determined.common.api.request import ( - WebSocket, - delete, - do_request, - get, make_url, browser_open, parse_master_address, - patch, - post, - put, - ws, -) -from determined.common.api.profiler import ( - post_trial_profiler_metrics_batches, - TrialProfilerMetricsBatch, - get_trial_profiler_available_series, ) diff --git a/harness/determined/common/api/_session.py b/harness/determined/common/api/_session.py index 11e179eea54..0cc753fd235 100644 --- a/harness/determined/common/api/_session.py +++ b/harness/determined/common/api/_session.py @@ -1,27 +1,93 @@ -from typing import Any, Dict, Optional +import abc +import json as _json +from typing import Any, Dict, Optional, Tuple, Union import requests import urllib3 +import determined as det +from determined.common import requests as det_requests from determined.common import util -from determined.common.api import authentication, certs, request +from determined.common.api import authentication, certs, errors, request -class Session: - def __init__( - self, - master: Optional[str], - user: Optional[str], - auth: Optional[authentication.Authentication], - cert: Optional[certs.Cert], - max_retries: Optional[urllib3.util.retry.Retry] = None, - ) -> None: - self._master = master or util.get_default_master_address() - self._user = user - self._auth = auth - self._cert = cert - self._max_retries = max_retries +def _do_request( + method: str, + host: str, + path: str, + params: Optional[Dict[str, Any]] = None, + json: Any = None, + data: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + cert: Optional[certs.Cert] = None, + timeout: Optional[Union[Tuple, float]] = None, + stream: bool = False, + max_retries: Optional[urllib3.util.retry.Retry] = None, +) -> requests.Response: + # Allow the json to come pre-encoded, if we need custom encoding. + if json is not None and data is not None: + raise ValueError("json and data must not be provided together") + + if json: + data = det.util.json_encode(json) + + try: + r = det_requests.request( + method, + request.make_url(host, path), + params=params, + data=data, + headers=headers, + verify=cert.bundle if cert else None, + stream=stream, + timeout=timeout, + server_hostname=cert.name if cert else None, + max_retries=max_retries, + ) + except requests.exceptions.SSLError: + raise + except requests.exceptions.ConnectionError as e: + raise errors.MasterNotFoundException(str(e)) + except requests.exceptions.RequestException as e: + raise errors.BadRequestException(str(e)) + + def _get_error_str(r: requests.models.Response) -> str: + try: + json_resp = _json.loads(r.text) + mes = json_resp.get("message") + if mes is not None: + return str(mes) + # Try getting GRPC error description if message does not exist. + return str(json_resp.get("error").get("error")) + except Exception: + return "" + + if r.status_code == 403: + raise errors.ForbiddenException(message=_get_error_str(r)) + if r.status_code == 401: + raise errors.UnauthenticatedException() + elif r.status_code == 404: + raise errors.NotFoundException(_get_error_str(r)) + elif r.status_code >= 300: + raise errors.APIException(r) + return r + + +class BaseSession(metaclass=abc.ABCMeta): + """ + BaseSession is a requests-like interface that hides master url, master cert, and authz info. + + There are very few cases where BaseSession is the right type; you probably want a Session. In + a few cases, you might be ok with an UnauthSession. BaseSession is really only to express that + you don't know what kind of session you need. For example, the generated bindings take a + BaseSession because the protos aren't annotated with which endpoints are authenticated. + """ + + master: str + cert: Optional[certs.Cert] + + @abc.abstractmethod def _do_request( self, method: str, @@ -33,24 +99,12 @@ def _do_request( timeout: Optional[int], stream: bool, ) -> requests.Response: - return request.do_request( - method, - self._master, - path, - params=params, - json=json, - data=data, - auth=self._auth, - cert=self._cert, - headers=headers, - timeout=timeout, - stream=stream, - max_retries=self._max_retries, - ) + pass def get( self, path: str, + *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, Any]] = None, timeout: Optional[int] = None, @@ -61,6 +115,7 @@ def get( def delete( self, path: str, + *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, Any]] = None, timeout: Optional[int] = None, @@ -70,6 +125,7 @@ def delete( def post( self, path: str, + *, params: Optional[Dict[str, Any]] = None, json: Any = None, data: Optional[str] = None, @@ -81,6 +137,7 @@ def post( def patch( self, path: str, + *, params: Optional[Dict[str, Any]] = None, json: Any = None, data: Optional[str] = None, @@ -92,6 +149,7 @@ def patch( def put( self, path: str, + *, params: Optional[Dict[str, Any]] = None, json: Any = None, data: Optional[str] = None, @@ -100,13 +158,92 @@ def put( ) -> requests.Response: return self._do_request("PUT", path, params, json, data, headers, timeout, False) - def with_retry(self, retry: Optional[urllib3.util.retry.Retry]) -> "Session": - """Return a copy of this session with different max_retries.""" - return type(self)( - master=self._master, - user=self._user, - auth=self._auth, - cert=self._cert, - max_retries=retry, +class UnauthSession(BaseSession): + """ + UnauthSession is mostly only useful to log in or unathenticated endpoints like /info. + """ + + def __init__( + self, + master: Optional[str], + cert: Optional[certs.Cert], + max_retries: Optional[urllib3.util.retry.Retry] = None, + ) -> None: + self.master = master or util.get_default_master_address() + self.cert = cert + self._max_retries = max_retries + + def _do_request( + self, + method: str, + path: str, + params: Optional[Dict[str, Any]], + json: Any, + data: Optional[str], + headers: Optional[Dict[str, Any]], + timeout: Optional[int], + stream: bool, + ) -> requests.Response: + return _do_request( + method=method, + host=self.master, + path=path, + params=params, + json=json, + data=data, + headers=headers, + cert=self.cert, + timeout=timeout, + stream=stream, + max_retries=self._max_retries, + ) + + +class Session(BaseSession): + """ + Session authenticates every request it makes. + + By far, most BaseSessions in the codebase will be this Session subclass. + """ + + def __init__( + self, + master: Optional[str], + utp: authentication.UsernameTokenPair, + cert: Optional[certs.Cert], + max_retries: Optional[urllib3.util.retry.Retry] = None, + ) -> None: + self.master = master or util.get_default_master_address() + self.username = utp.username + self.token = utp.token + self.cert = cert + self._max_retries = max_retries + + def _do_request( + self, + method: str, + path: str, + params: Optional[Dict[str, Any]], + json: Any, + data: Optional[str], + headers: Optional[Dict[str, Any]], + timeout: Optional[int], + stream: bool, + ) -> requests.Response: + # Add authentication. + headers = dict(headers) if headers is not None else {} + headers["Authorization"] = f"Bearer {self.token}" + return _do_request( + method, + self.master, + path, + params=params, + json=json, + data=data, + cert=self.cert, + headers=headers, + timeout=timeout, + stream=stream, + max_retries=self._max_retries, ) diff --git a/harness/determined/common/api/_util.py b/harness/determined/common/api/_util.py index 536ac02b4fb..857b7222ff9 100644 --- a/harness/determined/common/api/_util.py +++ b/harness/determined/common/api/_util.py @@ -104,7 +104,7 @@ def get_state() -> Tuple[bool, bindings.taskv1State]: return util.wait_for(get_state, timeout) -def task_is_ready( +def wait_for_task_ready( session: api.Session, task_id: str, progress_report: Optional[Callable] = None, diff --git a/harness/determined/common/api/authentication.py b/harness/determined/common/api/authentication.py index eff03425e15..6c7fca4ed61 100644 --- a/harness/determined/common/api/authentication.py +++ b/harness/determined/common/api/authentication.py @@ -1,30 +1,18 @@ -import argparse import contextlib -import functools import getpass import hashlib import json import pathlib -from typing import Any, Callable, Dict, Iterator, List, NamedTuple, Optional, Tuple, cast +from typing import Any, Dict, Iterator, List, Optional, Tuple import filelock -import determined as det from determined.common import api, constants, util from determined.common.api import bindings, certs -Credentials = NamedTuple("Credentials", [("username", str), ("password", str)]) - PASSWORD_SALT = "GubPEmmotfiK9TMD6Zdw" -def get_allocation_token() -> str: - info = det.get_cluster_info() - if info is None: - return "" - return info.session_token - - def salt_and_hash(password: str) -> str: if password: return hashlib.sha512((PASSWORD_SALT + password).encode()).hexdigest() @@ -38,204 +26,184 @@ def __init__(self, username: str, token: str): self.token = token +def login( + master_address: str, + username: str, + password: str, + cert: Optional[certs.Cert] = None, +) -> UsernameTokenPair: + """ + Log in without considering or affecting the TokenStore on the file system. + + Used as part of login_with_cache, and also useful in tests where you wish to not affect the + TokenStore. + """ + password = api.salt_and_hash(password) + unauth_session = api.UnauthSession(master=master_address, cert=cert) + login = bindings.v1LoginRequest(username=username, password=password, isHashed=True) + r = bindings.post_Login(session=unauth_session, body=login) + return UsernameTokenPair(username, r.token) + + def default_load_user_password( requested_user: Optional[str], password: Optional[str], token_store: "TokenStore", -) -> Tuple[Optional[str], Optional[str]]: +) -> Tuple[str, Optional[str], bool]: + """ + Decide on a username and password for a login attempt. + + When values are explicitly provided, they should be honored. But when they are not provided, + check environment variables and the token store before falling back to the system default. + + Args: + requested_user: a username explicitly provided by the end user + password: a password explicitly provided by the end user + + Returns: + A tuple of (username, Optional[password], was_fallback], where was_fallback indicates that + we are returning the system default username and password. + """ # Always prefer an explicitly provided user/password. if requested_user: - return requested_user, password + return requested_user, password, False # Next highest priority is user/password from environment. # Watch out! We have to check for DET_USER and DET_PASS, because containers will have DET_USER # set, but that doesn't overrule the active user in the TokenStore, because if the TokenStore in # the container has an active user, that means the user has explicitly ran `det user login` # inside the container. - if ( - util.get_det_username_from_env() is not None - and util.get_det_password_from_env() is not None - ): - return util.get_det_username_from_env(), util.get_det_password_from_env() - - # Last priority is the active user in the token store. - return token_store.get_active_user(), password - - -class Authentication: - def __init__( - self, - master_address: Optional[str] = None, - requested_user: Optional[str] = None, - password: Optional[str] = None, - cert: Optional[certs.Cert] = None, - ) -> None: - self.master_address = master_address or util.get_default_master_address() - self.token_store = TokenStore(self.master_address) - - self.session = self._init_session(requested_user, password, cert) - - def _init_session( - self, - requested_user: Optional[str], - password: Optional[str], - cert: Optional[certs.Cert], - ) -> UsernameTokenPair: - # Get session_user and password given the following priority: - # 1. User passed in with flag (requested_user) - # 2. User from environment if DET_PASS is set. - # 3. Active user from the token store. - session_user, password = default_load_user_password( - requested_user, password, self.token_store - ) - - # For login, we allow falling back to the default username. - if not session_user: - session_user = constants.DEFAULT_DETERMINED_USER - assert session_user is not None - - # Check the token store if this session_user has a cached token. If so, check with the - # master to verify it has not expired. Otherwise, let the token be None. - token = self.token_store.get_token(session_user) - if token is not None and not _is_token_valid(self.master_address, token, cert): - self.token_store.drop_user(session_user) - token = None - - # Special case: use token provided from the container environment if: - # - No token was obtained from the token store already, - # - There is a token available from the container environment, and - # - No user was explicitly requested, or the requested user matches the token available - # in the container environment. - if ( - token is None - and util.get_det_username_from_env() is not None - and util.get_det_user_token_from_env() is not None - and requested_user in (None, util.get_det_username_from_env()) - ): - session_user = util.get_det_username_from_env() - assert session_user - token = util.get_det_user_token_from_env() - - if token is not None: - return UsernameTokenPair(session_user, token) - - # We'll need to create a new token, so we'll need a password. If there was no requested - # user and we ended up falling back to the default username `determined`, then we can fall - # back to the default login as well. Otherwise, ask the user for their password. - fallback_to_default = password is None and session_user == constants.DEFAULT_DETERMINED_USER - if fallback_to_default: - password = constants.DEFAULT_DETERMINED_PASSWORD - - if password is None: - password = getpass.getpass("Password for user '{}': ".format(session_user)) - - try: - token = do_login(self.master_address, session_user, password, cert) - except api.errors.ForbiddenException: - if fallback_to_default: - raise api.errors.UnauthenticatedException(username=session_user) - raise - - self.token_store.set_token(session_user, token) + env_user = util.get_det_username_from_env() + env_pass = util.get_det_password_from_env() + if env_user is not None and env_pass is not None: + return env_user, env_pass, False + + # Next priority is the active user in the token store. + active_user = token_store.get_active_user() + if active_user is not None: + return active_user, password, False + + # Last priority is the default username and password. + return ( + constants.DEFAULT_DETERMINED_USER, + password or constants.DEFAULT_DETERMINED_PASSWORD, + True, + ) + + +def login_with_cache( + master_address: Optional[str] = None, + requested_user: Optional[str] = None, + password: Optional[str] = None, + cert: Optional[certs.Cert] = None, +) -> UsernameTokenPair: + """ + Log in, preferring cached credentials in the TokenStore, if possible. - return UsernameTokenPair(session_user, token) + This is the login path for nearly all user-facing cases. - def is_user_active(self, username: str) -> bool: - return self.token_store.get_active_user() == username + There is also a special case for checking if the DET_USER_TOKEN is set in the environment (by + the determined-master). That must happen in this function because it is only used when no other + login tokens are active, but it must be considered before asking the user for a password. - def get_session_user(self) -> str: - """ - Returns the session user for the current session. If there is no active - session, then an UnauthenticatedException will be raised. - """ - return self.session.username + As a somewhat surprising side-effect re-using an existing token from the cache, it is actually + possible in cache hit scenarios for an invalid password here to result in a valid login since + the password is only used in a cache miss. - def get_session_token(self, must: bool = True) -> str: - """ - Returns the authentication token for the session user. If there is no - active session, then an UnauthenticatedException will be raised. - """ - if self.session is None: - if must: - raise api.errors.UnauthenticatedException(username="") - else: - return "" - return self.session.token + Returns: + The username and token of the logged in user. + """ + master_address = master_address or util.get_default_master_address() + token_store = TokenStore(master_address) -def do_login( - master_address: str, - username: str, - password: str, - cert: Optional[certs.Cert] = None, -) -> str: - password = api.salt_and_hash(password) - unauth_session = api.Session(user=username, master=master_address, auth=None, cert=cert) - login = bindings.v1LoginRequest(username=username, password=password, isHashed=True) - r = bindings.post_Login(session=unauth_session, body=login) - token = r.token + user, password, was_fallback = default_load_user_password(requested_user, password, token_store) - return token + # Check the token store if this session_user has a cached token. If so, check with the + # master to verify it has not expired. Otherwise, let the token be None. + token = token_store.get_token(user) + if token is not None and not _is_token_valid(master_address, token, cert): + token_store.drop_user(user) + token = None + if token is not None: + return UsernameTokenPair(user, token) -class LogoutAuthentication(Authentication): - """ - An api-compatible Authentication object that is basically exactly a UserTokenPair. + # Special case: use token provided from the container environment if: + # - No token was obtained from the token store already, + # - There is a token available from the container environment, and + # - No user was explicitly requested, or the requested user matches the token available in the + # container environment. + if ( + util.get_det_username_from_env() is not None + and util.get_det_user_token_from_env() is not None + and requested_user in (None, util.get_det_username_from_env()) + ): + env_user = util.get_det_username_from_env() + assert env_user + env_token = util.get_det_user_token_from_env() + assert env_token + return UsernameTokenPair(env_user, env_token) - TODO(MLG-215): delete Authentication class and write a function that returns a UsernameTokenPair - in its place, and let do_request() take UsernameTokenPair as input. - """ + if password is None: + password = getpass.getpass(f"Password for user '{user}': ") - def __init__(self, session_user: str, session_token: str) -> None: - self.session_user = session_user - self.session_token = session_token + try: + utp = login(master_address, user, password, cert) + user, token = utp.username, utp.token + except api.errors.ForbiddenException: + if was_fallback: + raise api.errors.UnauthenticatedException() + raise - def get_session_user(self) -> str: - return self.session_user + token_store.set_token(user, token) - def get_session_token(self, must: bool = True) -> str: - return self.session_token + return UsernameTokenPair(user, token) def logout( master_address: Optional[str], requested_user: Optional[str], cert: Optional[certs.Cert], -) -> None: +) -> Optional[str]: """ Logout if there is an active session for this master/username pair, otherwise do nothing. - Additionally, if the user happens to be the active user, drop the active user from the - TokenStore. + A session is active when a valid token for it can be found. In that case, the token + is sent to the master for invalidation and dropped from the token store. + + If requested_user is None, logout attempts to log out the token store's active_user. + + Logout does not affect the "active_user" entry itself in the token store, + since the whole concept of an "active user" mostly belongs to the CLI and is + handled explicitly by the CLI. + + Returns: + The name of the user who was logged out, if any. """ master_address = master_address or util.get_default_master_address() token_store = TokenStore(master_address) - session_user, _ = default_load_user_password(requested_user, None, token_store) - # Don't log out of DEFAULT_DETERMINED_USER when it's not specified and not the active user. - - if session_user is None: - return + user, _, __ = default_load_user_password(requested_user, None, token_store) - if session_user == token_store.get_active_user(): - token_store.clear_active() + token = token_store.get_token(user) - session_token = token_store.get_token(session_user) + if token is None: + return None - if session_token is None: - return + token_store.drop_user(user) - token_store.drop_user(session_user) - - auth = LogoutAuthentication(session_user, session_token) - sess = api.Session(user=session_user, master=master_address, auth=auth, cert=cert) + utp = UsernameTokenPair(user, token) + sess = api.Session(master=master_address, utp=utp, cert=cert) try: bindings.post_Logout(sess) except (api.errors.UnauthenticatedException, api.errors.APIException): # This session may have expired, but we don't care. pass + return user + def logout_all(master_address: Optional[str], cert: Optional[certs.Cert]) -> None: master_address = master_address or util.get_default_master_address() @@ -246,17 +214,16 @@ def logout_all(master_address: Optional[str], cert: Optional[certs.Cert]) -> Non for user in users: logout(master_address, user, cert) - token_store.clear_active() - def _is_token_valid(master_address: str, token: str, cert: Optional[certs.Cert]) -> bool: """ Find out whether the given token is valid by attempting to use it on the "api/v1/me" endpoint. """ - headers = {"Authorization": "Bearer {}".format(token)} + utp = UsernameTokenPair("username-doesnt-matter", token) + sess = api.Session(master_address, utp, cert) try: - r = api.get(master_address, "api/v1/me", headers=headers, authenticated=False, cert=cert) + r = sess.get("api/v1/me") except (api.errors.UnauthenticatedException, api.errors.APIException): return False @@ -288,8 +255,14 @@ def __init__(self, master_address: str, path: Optional[pathlib.Path] = None) -> def _reconfigure_from_store(self, store: dict) -> None: substore = store.get("masters", {}).get(self.master_address, {}) - self._active_user = cast(str, substore.get("active_user")) - self._tokens = cast(Dict[str, str], substore.get("tokens", {})) + + active_user = substore.get("active_user") + assert isinstance(active_user, (str, type(None))), active_user + self._active_user = active_user + + tokens = substore.get("tokens", {}) + assert isinstance(tokens, dict), tokens + self._tokens = tokens def get_active_user(self) -> Optional[str]: return self._active_user @@ -323,7 +296,7 @@ def set_active(self, username: str) -> None: with self._persistent_store() as substore: tokens = substore.setdefault("tokens", {}) if username not in tokens: - raise api.errors.UnauthenticatedException(username=username) + raise api.errors.UnauthenticatedException() substore["active_user"] = username def clear_active(self) -> None: @@ -379,7 +352,8 @@ def _load_store_file(self) -> Dict[str, Any]: validate_token_store_v1(store) - return cast(dict, store) + assert isinstance(store, dict), store + return store except api.errors.CorruptTokenCacheException: # Delete invalid caches before exiting. @@ -481,27 +455,3 @@ def validate_token_store_v1(store: Any) -> bool: validate_token_store_v0(val) return True - - -# cli_auth is the process-wide authentication used for api calls originating from the cli. -cli_auth = None # type: Optional[Authentication] - - -def required(func: Callable[[argparse.Namespace], Any]) -> Callable[..., Any]: - """ - A decorator for cli functions. - """ - - @functools.wraps(func) - def f(namespace: argparse.Namespace) -> Any: - global cli_auth - cli_auth = Authentication(namespace.master, namespace.user) - return func(namespace) - - return f - - -def must_cli_auth() -> Authentication: - if not cli_auth: - raise api.errors.UnauthenticatedException(username="") - return cli_auth diff --git a/harness/determined/common/api/bindings.py b/harness/determined/common/api/bindings.py index a3da2087930..b6e4b5bb7fa 100644 --- a/harness/determined/common/api/bindings.py +++ b/harness/determined/common/api/bindings.py @@ -15269,7 +15269,7 @@ class v1WorkspaceState(DetEnum): DELETED = "WORKSPACE_STATE_DELETED" def post_AckAllocationPreemptionSignal( - session: "api.Session", + session: "api.BaseSession", *, allocationId: str, body: "v1AckAllocationPreemptionSignalRequest", @@ -15299,7 +15299,7 @@ def post_AckAllocationPreemptionSignal( raise APIHttpError("post_AckAllocationPreemptionSignal", _resp) def post_ActivateExperiment( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -15323,7 +15323,7 @@ def post_ActivateExperiment( raise APIHttpError("post_ActivateExperiment", _resp) def post_ActivateExperiments( - session: "api.Session", + session: "api.BaseSession", *, body: "v1ActivateExperimentsRequest", ) -> "v1ActivateExperimentsResponse": @@ -15344,7 +15344,7 @@ def post_ActivateExperiments( raise APIHttpError("post_ActivateExperiments", _resp) def post_AddProjectNote( - session: "api.Session", + session: "api.BaseSession", *, body: "v1Note", projectId: int, @@ -15370,7 +15370,7 @@ def post_AddProjectNote( raise APIHttpError("post_AddProjectNote", _resp) def post_AllocationAllGather( - session: "api.Session", + session: "api.BaseSession", *, allocationId: str, body: "v1AllocationAllGatherRequest", @@ -15398,7 +15398,7 @@ def post_AllocationAllGather( raise APIHttpError("post_AllocationAllGather", _resp) def post_AllocationPendingPreemptionSignal( - session: "api.Session", + session: "api.BaseSession", *, allocationId: str, body: "v1AllocationPendingPreemptionSignalRequest", @@ -15428,7 +15428,7 @@ def post_AllocationPendingPreemptionSignal( raise APIHttpError("post_AllocationPendingPreemptionSignal", _resp) def get_AllocationPreemptionSignal( - session: "api.Session", + session: "api.BaseSession", *, allocationId: str, timeoutSeconds: "typing.Optional[int]" = None, @@ -15461,7 +15461,7 @@ def get_AllocationPreemptionSignal( raise APIHttpError("get_AllocationPreemptionSignal", _resp) def post_AllocationReady( - session: "api.Session", + session: "api.BaseSession", *, allocationId: str, body: "v1AllocationReadyRequest", @@ -15488,7 +15488,7 @@ def post_AllocationReady( raise APIHttpError("post_AllocationReady", _resp) def get_AllocationRendezvousInfo( - session: "api.Session", + session: "api.BaseSession", *, allocationId: str, resourcesId: str, @@ -15520,7 +15520,7 @@ def get_AllocationRendezvousInfo( raise APIHttpError("get_AllocationRendezvousInfo", _resp) def post_AllocationWaiting( - session: "api.Session", + session: "api.BaseSession", *, allocationId: str, body: "v1AllocationWaitingRequest", @@ -15547,7 +15547,7 @@ def post_AllocationWaiting( raise APIHttpError("post_AllocationWaiting", _resp) def post_ArchiveExperiment( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -15571,7 +15571,7 @@ def post_ArchiveExperiment( raise APIHttpError("post_ArchiveExperiment", _resp) def post_ArchiveExperiments( - session: "api.Session", + session: "api.BaseSession", *, body: "v1ArchiveExperimentsRequest", ) -> "v1ArchiveExperimentsResponse": @@ -15592,7 +15592,7 @@ def post_ArchiveExperiments( raise APIHttpError("post_ArchiveExperiments", _resp) def post_ArchiveModel( - session: "api.Session", + session: "api.BaseSession", *, modelName: str, ) -> None: @@ -15618,7 +15618,7 @@ def post_ArchiveModel( raise APIHttpError("post_ArchiveModel", _resp) def post_ArchiveProject( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -15642,7 +15642,7 @@ def post_ArchiveProject( raise APIHttpError("post_ArchiveProject", _resp) def post_ArchiveWorkspace( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -15666,7 +15666,7 @@ def post_ArchiveWorkspace( raise APIHttpError("post_ArchiveWorkspace", _resp) def patch_AssignMultipleGroups( - session: "api.Session", + session: "api.BaseSession", *, body: "v1AssignMultipleGroupsRequest", ) -> None: @@ -15687,7 +15687,7 @@ def patch_AssignMultipleGroups( raise APIHttpError("patch_AssignMultipleGroups", _resp) def post_AssignRoles( - session: "api.Session", + session: "api.BaseSession", *, body: "v1AssignRolesRequest", ) -> None: @@ -15708,7 +15708,7 @@ def post_AssignRoles( raise APIHttpError("post_AssignRoles", _resp) def post_BindRPToWorkspace( - session: "api.Session", + session: "api.BaseSession", *, body: "v1BindRPToWorkspaceRequest", resourcePoolName: str, @@ -15735,7 +15735,7 @@ def post_BindRPToWorkspace( raise APIHttpError("post_BindRPToWorkspace", _resp) def post_CancelExperiment( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -15759,7 +15759,7 @@ def post_CancelExperiment( raise APIHttpError("post_CancelExperiment", _resp) def post_CancelExperiments( - session: "api.Session", + session: "api.BaseSession", *, body: "v1CancelExperimentsRequest", ) -> "v1CancelExperimentsResponse": @@ -15780,7 +15780,7 @@ def post_CancelExperiments( raise APIHttpError("post_CancelExperiments", _resp) def post_CheckpointsRemoveFiles( - session: "api.Session", + session: "api.BaseSession", *, body: "v1CheckpointsRemoveFilesRequest", ) -> None: @@ -15801,7 +15801,7 @@ def post_CheckpointsRemoveFiles( raise APIHttpError("post_CheckpointsRemoveFiles", _resp) def get_CompareTrials( - session: "api.Session", + session: "api.BaseSession", *, endBatches: "typing.Optional[int]" = None, group: "typing.Optional[str]" = None, @@ -15898,7 +15898,7 @@ def get_CompareTrials( raise APIHttpError("get_CompareTrials", _resp) def post_CompleteTrialSearcherValidation( - session: "api.Session", + session: "api.BaseSession", *, body: "v1CompleteValidateAfterOperation", trialId: int, @@ -15925,7 +15925,7 @@ def post_CompleteTrialSearcherValidation( raise APIHttpError("post_CompleteTrialSearcherValidation", _resp) def post_ContinueExperiment( - session: "api.Session", + session: "api.BaseSession", *, body: "v1ContinueExperimentRequest", ) -> "v1ContinueExperimentResponse": @@ -15948,7 +15948,7 @@ def post_ContinueExperiment( raise APIHttpError("post_ContinueExperiment", _resp) def post_CreateExperiment( - session: "api.Session", + session: "api.BaseSession", *, body: "v1CreateExperimentRequest", ) -> "v1CreateExperimentResponse": @@ -15969,7 +15969,7 @@ def post_CreateExperiment( raise APIHttpError("post_CreateExperiment", _resp) def post_CreateGenericTask( - session: "api.Session", + session: "api.BaseSession", *, body: "v1CreateGenericTaskRequest", ) -> "v1CreateGenericTaskResponse": @@ -15990,7 +15990,7 @@ def post_CreateGenericTask( raise APIHttpError("post_CreateGenericTask", _resp) def post_CreateGroup( - session: "api.Session", + session: "api.BaseSession", *, body: "v1CreateGroupRequest", ) -> "v1CreateGroupResponse": @@ -16011,7 +16011,7 @@ def post_CreateGroup( raise APIHttpError("post_CreateGroup", _resp) def post_CreateTrial( - session: "api.Session", + session: "api.BaseSession", *, body: "v1CreateTrialRequest", ) -> "v1CreateTrialResponse": @@ -16032,7 +16032,7 @@ def post_CreateTrial( raise APIHttpError("post_CreateTrial", _resp) def get_CurrentUser( - session: "api.Session", + session: "api.BaseSession", ) -> "v1CurrentUserResponse": """Get the current user.""" _params = None @@ -16051,7 +16051,7 @@ def get_CurrentUser( raise APIHttpError("get_CurrentUser", _resp) def delete_DeleteCheckpoints( - session: "api.Session", + session: "api.BaseSession", *, body: "v1DeleteCheckpointsRequest", ) -> None: @@ -16072,7 +16072,7 @@ def delete_DeleteCheckpoints( raise APIHttpError("delete_DeleteCheckpoints", _resp) def delete_DeleteExperiment( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, ) -> None: @@ -16096,7 +16096,7 @@ def delete_DeleteExperiment( raise APIHttpError("delete_DeleteExperiment", _resp) def delete_DeleteExperimentLabel( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, label: str, @@ -16124,7 +16124,7 @@ def delete_DeleteExperimentLabel( raise APIHttpError("delete_DeleteExperimentLabel", _resp) def delete_DeleteExperiments( - session: "api.Session", + session: "api.BaseSession", *, body: "v1DeleteExperimentsRequest", ) -> "v1DeleteExperimentsResponse": @@ -16145,7 +16145,7 @@ def delete_DeleteExperiments( raise APIHttpError("delete_DeleteExperiments", _resp) def delete_DeleteGroup( - session: "api.Session", + session: "api.BaseSession", *, groupId: int, ) -> None: @@ -16169,7 +16169,7 @@ def delete_DeleteGroup( raise APIHttpError("delete_DeleteGroup", _resp) def delete_DeleteModel( - session: "api.Session", + session: "api.BaseSession", *, modelName: str, ) -> None: @@ -16195,7 +16195,7 @@ def delete_DeleteModel( raise APIHttpError("delete_DeleteModel", _resp) def delete_DeleteModelVersion( - session: "api.Session", + session: "api.BaseSession", *, modelName: str, modelVersionNum: int, @@ -16223,7 +16223,7 @@ def delete_DeleteModelVersion( raise APIHttpError("delete_DeleteModelVersion", _resp) def delete_DeleteProject( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> "v1DeleteProjectResponse": @@ -16247,7 +16247,7 @@ def delete_DeleteProject( raise APIHttpError("delete_DeleteProject", _resp) def delete_DeleteTemplate( - session: "api.Session", + session: "api.BaseSession", *, templateName: str, ) -> None: @@ -16273,7 +16273,7 @@ def delete_DeleteTemplate( raise APIHttpError("delete_DeleteTemplate", _resp) def delete_DeleteTensorboardFiles( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, ) -> None: @@ -16297,7 +16297,7 @@ def delete_DeleteTensorboardFiles( raise APIHttpError("delete_DeleteTensorboardFiles", _resp) def delete_DeleteWebhook( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -16321,7 +16321,7 @@ def delete_DeleteWebhook( raise APIHttpError("delete_DeleteWebhook", _resp) def delete_DeleteWorkspace( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> "v1DeleteWorkspaceResponse": @@ -16345,7 +16345,7 @@ def delete_DeleteWorkspace( raise APIHttpError("delete_DeleteWorkspace", _resp) def post_DisableAgent( - session: "api.Session", + session: "api.BaseSession", *, agentId: str, body: "v1DisableAgentRequest", @@ -16372,7 +16372,7 @@ def post_DisableAgent( raise APIHttpError("post_DisableAgent", _resp) def post_DisableSlot( - session: "api.Session", + session: "api.BaseSession", *, agentId: str, body: "v1DisableSlotRequest", @@ -16403,7 +16403,7 @@ def post_DisableSlot( raise APIHttpError("post_DisableSlot", _resp) def post_EnableAgent( - session: "api.Session", + session: "api.BaseSession", *, agentId: str, ) -> "v1EnableAgentResponse": @@ -16429,7 +16429,7 @@ def post_EnableAgent( raise APIHttpError("post_EnableAgent", _resp) def post_EnableSlot( - session: "api.Session", + session: "api.BaseSession", *, agentId: str, slotId: str, @@ -16459,7 +16459,7 @@ def post_EnableSlot( raise APIHttpError("post_EnableSlot", _resp) def get_ExpMetricNames( - session: "api.Session", + session: "api.BaseSession", *, ids: "typing.Sequence[int]", periodSeconds: "typing.Optional[int]" = None, @@ -16499,7 +16499,7 @@ def get_ExpMetricNames( raise APIHttpError("get_ExpMetricNames", _resp) def get_GetActiveTasksCount( - session: "api.Session", + session: "api.BaseSession", ) -> "v1GetActiveTasksCountResponse": """Get a count of active tasks.""" _params = None @@ -16518,7 +16518,7 @@ def get_GetActiveTasksCount( raise APIHttpError("get_GetActiveTasksCount", _resp) def get_GetAgent( - session: "api.Session", + session: "api.BaseSession", *, agentId: str, ) -> "v1GetAgentResponse": @@ -16544,7 +16544,7 @@ def get_GetAgent( raise APIHttpError("get_GetAgent", _resp) def get_GetAgents( - session: "api.Session", + session: "api.BaseSession", *, label: "typing.Optional[str]" = None, limit: "typing.Optional[int]" = None, @@ -16591,7 +16591,7 @@ def get_GetAgents( raise APIHttpError("get_GetAgents", _resp) def get_GetAllocation( - session: "api.Session", + session: "api.BaseSession", *, allocationId: str, ) -> "v1GetAllocationResponse": @@ -16617,7 +16617,7 @@ def get_GetAllocation( raise APIHttpError("get_GetAllocation", _resp) def get_GetBestSearcherValidationMetric( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, ) -> "v1GetBestSearcherValidationMetricResponse": @@ -16641,7 +16641,7 @@ def get_GetBestSearcherValidationMetric( raise APIHttpError("get_GetBestSearcherValidationMetric", _resp) def get_GetCheckpoint( - session: "api.Session", + session: "api.BaseSession", *, checkpointUuid: str, ) -> "v1GetCheckpointResponse": @@ -16667,7 +16667,7 @@ def get_GetCheckpoint( raise APIHttpError("get_GetCheckpoint", _resp) def get_GetCommand( - session: "api.Session", + session: "api.BaseSession", *, commandId: str, ) -> "v1GetCommandResponse": @@ -16693,7 +16693,7 @@ def get_GetCommand( raise APIHttpError("get_GetCommand", _resp) def get_GetCommands( - session: "api.Session", + session: "api.BaseSession", *, limit: "typing.Optional[int]" = None, offset: "typing.Optional[int]" = None, @@ -16750,7 +16750,7 @@ def get_GetCommands( raise APIHttpError("get_GetCommands", _resp) def get_GetCurrentTrialSearcherOperation( - session: "api.Session", + session: "api.BaseSession", *, trialId: int, ) -> "v1GetCurrentTrialSearcherOperationResponse": @@ -16774,7 +16774,7 @@ def get_GetCurrentTrialSearcherOperation( raise APIHttpError("get_GetCurrentTrialSearcherOperation", _resp) def get_GetExperiment( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, ) -> "v1GetExperimentResponse": @@ -16798,7 +16798,7 @@ def get_GetExperiment( raise APIHttpError("get_GetExperiment", _resp) def get_GetExperimentCheckpoints( - session: "api.Session", + session: "api.BaseSession", *, id: int, limit: "typing.Optional[int]" = None, @@ -16862,7 +16862,7 @@ def get_GetExperimentCheckpoints( raise APIHttpError("get_GetExperimentCheckpoints", _resp) def get_GetExperimentLabels( - session: "api.Session", + session: "api.BaseSession", *, projectId: "typing.Optional[int]" = None, ) -> "v1GetExperimentLabelsResponse": @@ -16888,7 +16888,7 @@ def get_GetExperimentLabels( raise APIHttpError("get_GetExperimentLabels", _resp) def get_GetExperimentTrials( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, limit: "typing.Optional[int]" = None, @@ -16970,7 +16970,7 @@ def get_GetExperimentTrials( raise APIHttpError("get_GetExperimentTrials", _resp) def get_GetExperimentValidationHistory( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, ) -> "v1GetExperimentValidationHistoryResponse": @@ -16994,7 +16994,7 @@ def get_GetExperimentValidationHistory( raise APIHttpError("get_GetExperimentValidationHistory", _resp) def get_GetExperiments( - session: "api.Session", + session: "api.BaseSession", *, archived: "typing.Optional[bool]" = None, description: "typing.Optional[str]" = None, @@ -17125,7 +17125,7 @@ def get_GetExperiments( raise APIHttpError("get_GetExperiments", _resp) def get_GetGenericTaskConfig( - session: "api.Session", + session: "api.BaseSession", *, taskId: str, ) -> "v1GetGenericTaskConfigResponse": @@ -17151,7 +17151,7 @@ def get_GetGenericTaskConfig( raise APIHttpError("get_GetGenericTaskConfig", _resp) def get_GetGroup( - session: "api.Session", + session: "api.BaseSession", *, groupId: int, ) -> "v1GetGroupResponse": @@ -17175,7 +17175,7 @@ def get_GetGroup( raise APIHttpError("get_GetGroup", _resp) def post_GetGroups( - session: "api.Session", + session: "api.BaseSession", *, body: "v1GetGroupsRequest", ) -> "v1GetGroupsResponse": @@ -17196,7 +17196,7 @@ def post_GetGroups( raise APIHttpError("post_GetGroups", _resp) def get_GetGroupsAndUsersAssignedToWorkspace( - session: "api.Session", + session: "api.BaseSession", *, workspaceId: int, name: "typing.Optional[str]" = None, @@ -17227,7 +17227,7 @@ def get_GetGroupsAndUsersAssignedToWorkspace( raise APIHttpError("get_GetGroupsAndUsersAssignedToWorkspace", _resp) def get_GetJobQueueStats( - session: "api.Session", + session: "api.BaseSession", *, resourcePools: "typing.Optional[typing.Sequence[str]]" = None, ) -> "v1GetJobQueueStatsResponse": @@ -17253,7 +17253,7 @@ def get_GetJobQueueStats( raise APIHttpError("get_GetJobQueueStats", _resp) def get_GetJobs( - session: "api.Session", + session: "api.BaseSession", *, limit: "typing.Optional[int]" = None, offset: "typing.Optional[int]" = None, @@ -17301,7 +17301,7 @@ def get_GetJobs( raise APIHttpError("get_GetJobs", _resp) def get_GetJobsV2( - session: "api.Session", + session: "api.BaseSession", *, limit: "typing.Optional[int]" = None, offset: "typing.Optional[int]" = None, @@ -17349,7 +17349,7 @@ def get_GetJobsV2( raise APIHttpError("get_GetJobsV2", _resp) def get_GetMaster( - session: "api.Session", + session: "api.BaseSession", ) -> "v1GetMasterResponse": """Get master information.""" _params = None @@ -17368,7 +17368,7 @@ def get_GetMaster( raise APIHttpError("get_GetMaster", _resp) def get_GetMasterConfig( - session: "api.Session", + session: "api.BaseSession", ) -> "v1GetMasterConfigResponse": """Get master config.""" _params = None @@ -17387,7 +17387,7 @@ def get_GetMasterConfig( raise APIHttpError("get_GetMasterConfig", _resp) def get_GetMe( - session: "api.Session", + session: "api.BaseSession", ) -> "v1GetMeResponse": """Get the current user.""" _params = None @@ -17406,7 +17406,7 @@ def get_GetMe( raise APIHttpError("get_GetMe", _resp) def get_GetMetrics( - session: "api.Session", + session: "api.BaseSession", *, group: str, trialIds: "typing.Sequence[int]", @@ -17446,7 +17446,7 @@ def get_GetMetrics( raise APIHttpError("get_GetMetrics", _resp) def get_GetModel( - session: "api.Session", + session: "api.BaseSession", *, modelName: str, ) -> "v1GetModelResponse": @@ -17472,7 +17472,7 @@ def get_GetModel( raise APIHttpError("get_GetModel", _resp) def get_GetModelDef( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, ) -> "v1GetModelDefResponse": @@ -17496,7 +17496,7 @@ def get_GetModelDef( raise APIHttpError("get_GetModelDef", _resp) def post_GetModelDefFile( - session: "api.Session", + session: "api.BaseSession", *, body: "v1GetModelDefFileRequest", experimentId: int, @@ -17521,7 +17521,7 @@ def post_GetModelDefFile( raise APIHttpError("post_GetModelDefFile", _resp) def get_GetModelDefTree( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, ) -> "v1GetModelDefTreeResponse": @@ -17545,7 +17545,7 @@ def get_GetModelDefTree( raise APIHttpError("get_GetModelDefTree", _resp) def get_GetModelLabels( - session: "api.Session", + session: "api.BaseSession", *, workspaceId: "typing.Optional[int]" = None, ) -> "v1GetModelLabelsResponse": @@ -17571,7 +17571,7 @@ def get_GetModelLabels( raise APIHttpError("get_GetModelLabels", _resp) def get_GetModelVersion( - session: "api.Session", + session: "api.BaseSession", *, modelName: str, modelVersionNum: int, @@ -17599,7 +17599,7 @@ def get_GetModelVersion( raise APIHttpError("get_GetModelVersion", _resp) def get_GetModelVersions( - session: "api.Session", + session: "api.BaseSession", *, modelName: str, limit: "typing.Optional[int]" = None, @@ -17647,7 +17647,7 @@ def get_GetModelVersions( raise APIHttpError("get_GetModelVersions", _resp) def get_GetModels( - session: "api.Session", + session: "api.BaseSession", *, archived: "typing.Optional[bool]" = None, description: "typing.Optional[str]" = None, @@ -17722,7 +17722,7 @@ def get_GetModels( raise APIHttpError("get_GetModels", _resp) def get_GetNotebook( - session: "api.Session", + session: "api.BaseSession", *, notebookId: str, ) -> "v1GetNotebookResponse": @@ -17748,7 +17748,7 @@ def get_GetNotebook( raise APIHttpError("get_GetNotebook", _resp) def get_GetNotebooks( - session: "api.Session", + session: "api.BaseSession", *, limit: "typing.Optional[int]" = None, offset: "typing.Optional[int]" = None, @@ -17806,7 +17806,7 @@ def get_GetNotebooks( raise APIHttpError("get_GetNotebooks", _resp) def get_GetPermissionsSummary( - session: "api.Session", + session: "api.BaseSession", ) -> "v1GetPermissionsSummaryResponse": """List all permissions for the logged in user in all scopes.""" _params = None @@ -17825,7 +17825,7 @@ def get_GetPermissionsSummary( raise APIHttpError("get_GetPermissionsSummary", _resp) def get_GetProject( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> "v1GetProjectResponse": @@ -17849,7 +17849,7 @@ def get_GetProject( raise APIHttpError("get_GetProject", _resp) def get_GetProjectColumns( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> "v1GetProjectColumnsResponse": @@ -17873,7 +17873,7 @@ def get_GetProjectColumns( raise APIHttpError("get_GetProjectColumns", _resp) def get_GetProjectNumericMetricsRange( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> "v1GetProjectNumericMetricsRangeResponse": @@ -17897,7 +17897,7 @@ def get_GetProjectNumericMetricsRange( raise APIHttpError("get_GetProjectNumericMetricsRange", _resp) def get_GetProjectsByUserActivity( - session: "api.Session", + session: "api.BaseSession", *, limit: "typing.Optional[int]" = None, ) -> "v1GetProjectsByUserActivityResponse": @@ -17923,7 +17923,7 @@ def get_GetProjectsByUserActivity( raise APIHttpError("get_GetProjectsByUserActivity", _resp) def get_GetResourcePools( - session: "api.Session", + session: "api.BaseSession", *, limit: "typing.Optional[int]" = None, offset: "typing.Optional[int]" = None, @@ -17957,7 +17957,7 @@ def get_GetResourcePools( raise APIHttpError("get_GetResourcePools", _resp) def get_GetRolesAssignedToGroup( - session: "api.Session", + session: "api.BaseSession", *, groupId: int, ) -> "v1GetRolesAssignedToGroupResponse": @@ -17981,7 +17981,7 @@ def get_GetRolesAssignedToGroup( raise APIHttpError("get_GetRolesAssignedToGroup", _resp) def get_GetRolesAssignedToUser( - session: "api.Session", + session: "api.BaseSession", *, userId: int, ) -> "v1GetRolesAssignedToUserResponse": @@ -18005,7 +18005,7 @@ def get_GetRolesAssignedToUser( raise APIHttpError("get_GetRolesAssignedToUser", _resp) def post_GetRolesByID( - session: "api.Session", + session: "api.BaseSession", *, body: "v1GetRolesByIDRequest", ) -> "v1GetRolesByIDResponse": @@ -18026,7 +18026,7 @@ def post_GetRolesByID( raise APIHttpError("post_GetRolesByID", _resp) def get_GetSearcherEvents( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, ) -> "v1GetSearcherEventsResponse": @@ -18050,7 +18050,7 @@ def get_GetSearcherEvents( raise APIHttpError("get_GetSearcherEvents", _resp) def get_GetShell( - session: "api.Session", + session: "api.BaseSession", *, shellId: str, ) -> "v1GetShellResponse": @@ -18076,7 +18076,7 @@ def get_GetShell( raise APIHttpError("get_GetShell", _resp) def get_GetShells( - session: "api.Session", + session: "api.BaseSession", *, limit: "typing.Optional[int]" = None, offset: "typing.Optional[int]" = None, @@ -18132,7 +18132,7 @@ def get_GetShells( raise APIHttpError("get_GetShells", _resp) def get_GetSlot( - session: "api.Session", + session: "api.BaseSession", *, agentId: str, slotId: str, @@ -18162,7 +18162,7 @@ def get_GetSlot( raise APIHttpError("get_GetSlot", _resp) def get_GetSlots( - session: "api.Session", + session: "api.BaseSession", *, agentId: str, ) -> "v1GetSlotsResponse": @@ -18188,7 +18188,7 @@ def get_GetSlots( raise APIHttpError("get_GetSlots", _resp) def get_GetTask( - session: "api.Session", + session: "api.BaseSession", *, taskId: str, ) -> "v1GetTaskResponse": @@ -18214,7 +18214,7 @@ def get_GetTask( raise APIHttpError("get_GetTask", _resp) def get_GetTaskAcceleratorData( - session: "api.Session", + session: "api.BaseSession", *, taskId: str, ) -> "v1GetTaskAcceleratorDataResponse": @@ -18241,7 +18241,7 @@ def get_GetTaskAcceleratorData( raise APIHttpError("get_GetTaskAcceleratorData", _resp) def get_GetTaskContextDirectory( - session: "api.Session", + session: "api.BaseSession", *, taskId: str, ) -> "v1GetTaskContextDirectoryResponse": @@ -18267,7 +18267,7 @@ def get_GetTaskContextDirectory( raise APIHttpError("get_GetTaskContextDirectory", _resp) def get_GetTasks( - session: "api.Session", + session: "api.BaseSession", ) -> "v1GetTasksResponse": """Get all tasks.""" _params = None @@ -18286,7 +18286,7 @@ def get_GetTasks( raise APIHttpError("get_GetTasks", _resp) def get_GetTelemetry( - session: "api.Session", + session: "api.BaseSession", ) -> "v1GetTelemetryResponse": """Get telemetry information.""" _params = None @@ -18305,7 +18305,7 @@ def get_GetTelemetry( raise APIHttpError("get_GetTelemetry", _resp) def get_GetTemplate( - session: "api.Session", + session: "api.BaseSession", *, templateName: str, ) -> "v1GetTemplateResponse": @@ -18331,7 +18331,7 @@ def get_GetTemplate( raise APIHttpError("get_GetTemplate", _resp) def get_GetTemplates( - session: "api.Session", + session: "api.BaseSession", *, limit: "typing.Optional[int]" = None, name: "typing.Optional[str]" = None, @@ -18377,7 +18377,7 @@ def get_GetTemplates( raise APIHttpError("get_GetTemplates", _resp) def get_GetTensorboard( - session: "api.Session", + session: "api.BaseSession", *, tensorboardId: str, ) -> "v1GetTensorboardResponse": @@ -18403,7 +18403,7 @@ def get_GetTensorboard( raise APIHttpError("get_GetTensorboard", _resp) def get_GetTensorboards( - session: "api.Session", + session: "api.BaseSession", *, limit: "typing.Optional[int]" = None, offset: "typing.Optional[int]" = None, @@ -18462,7 +18462,7 @@ def get_GetTensorboards( raise APIHttpError("get_GetTensorboards", _resp) def get_GetTrainingMetrics( - session: "api.Session", + session: "api.BaseSession", *, trialIds: "typing.Optional[typing.Sequence[int]]" = None, ) -> "typing.Iterable[v1GetTrainingMetricsResponse]": @@ -18499,7 +18499,7 @@ def get_GetTrainingMetrics( raise APIHttpError("get_GetTrainingMetrics", _resp) def get_GetTrial( - session: "api.Session", + session: "api.BaseSession", *, trialId: int, ) -> "v1GetTrialResponse": @@ -18523,7 +18523,7 @@ def get_GetTrial( raise APIHttpError("get_GetTrial", _resp) def get_GetTrialByExternalID( - session: "api.Session", + session: "api.BaseSession", *, externalExperimentId: str, externalTrialId: str, @@ -18553,7 +18553,7 @@ def get_GetTrialByExternalID( raise APIHttpError("get_GetTrialByExternalID", _resp) def get_GetTrialCheckpoints( - session: "api.Session", + session: "api.BaseSession", *, id: int, limit: "typing.Optional[int]" = None, @@ -18617,7 +18617,7 @@ def get_GetTrialCheckpoints( raise APIHttpError("get_GetTrialCheckpoints", _resp) def get_GetTrialMetricsByCheckpoint( - session: "api.Session", + session: "api.BaseSession", *, checkpointUuid: str, metricGroup: "typing.Optional[str]" = None, @@ -18655,7 +18655,7 @@ def get_GetTrialMetricsByCheckpoint( raise APIHttpError("get_GetTrialMetricsByCheckpoint", _resp) def get_GetTrialMetricsByModelVersion( - session: "api.Session", + session: "api.BaseSession", *, modelName: str, modelVersionNum: int, @@ -18695,7 +18695,7 @@ def get_GetTrialMetricsByModelVersion( raise APIHttpError("get_GetTrialMetricsByModelVersion", _resp) def get_GetTrialProfilerAvailableSeries( - session: "api.Session", + session: "api.BaseSession", *, trialId: int, follow: "typing.Optional[bool]" = None, @@ -18734,7 +18734,7 @@ def get_GetTrialProfilerAvailableSeries( raise APIHttpError("get_GetTrialProfilerAvailableSeries", _resp) def get_GetTrialProfilerMetrics( - session: "api.Session", + session: "api.BaseSession", *, labels_trialId: int, follow: "typing.Optional[bool]" = None, @@ -18791,7 +18791,7 @@ def get_GetTrialProfilerMetrics( raise APIHttpError("get_GetTrialProfilerMetrics", _resp) def get_GetTrialWorkloads( - session: "api.Session", + session: "api.BaseSession", *, trialId: int, filter: "typing.Optional[GetTrialWorkloadsRequestFilterOption]" = None, @@ -18858,7 +18858,7 @@ def get_GetTrialWorkloads( raise APIHttpError("get_GetTrialWorkloads", _resp) def get_GetUser( - session: "api.Session", + session: "api.BaseSession", *, userId: int, ) -> "v1GetUserResponse": @@ -18882,7 +18882,7 @@ def get_GetUser( raise APIHttpError("get_GetUser", _resp) def get_GetUserByUsername( - session: "api.Session", + session: "api.BaseSession", *, username: str, ) -> "v1GetUserByUsernameResponse": @@ -18908,7 +18908,7 @@ def get_GetUserByUsername( raise APIHttpError("get_GetUserByUsername", _resp) def get_GetUserSetting( - session: "api.Session", + session: "api.BaseSession", ) -> "v1GetUserSettingResponse": """Get a user's settings for website""" _params = None @@ -18927,7 +18927,7 @@ def get_GetUserSetting( raise APIHttpError("get_GetUserSetting", _resp) def get_GetUsers( - session: "api.Session", + session: "api.BaseSession", *, active: "typing.Optional[bool]" = None, admin: "typing.Optional[bool]" = None, @@ -18989,7 +18989,7 @@ def get_GetUsers( raise APIHttpError("get_GetUsers", _resp) def get_GetValidationMetrics( - session: "api.Session", + session: "api.BaseSession", *, trialIds: "typing.Optional[typing.Sequence[int]]" = None, ) -> "typing.Iterable[v1GetValidationMetricsResponse]": @@ -19026,7 +19026,7 @@ def get_GetValidationMetrics( raise APIHttpError("get_GetValidationMetrics", _resp) def get_GetWebhooks( - session: "api.Session", + session: "api.BaseSession", ) -> "v1GetWebhooksResponse": """Get a list of webhooks.""" _params = None @@ -19045,7 +19045,7 @@ def get_GetWebhooks( raise APIHttpError("get_GetWebhooks", _resp) def get_GetWorkspace( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> "v1GetWorkspaceResponse": @@ -19069,7 +19069,7 @@ def get_GetWorkspace( raise APIHttpError("get_GetWorkspace", _resp) def get_GetWorkspaceProjects( - session: "api.Session", + session: "api.BaseSession", *, id: int, archived: "typing.Optional[bool]" = None, @@ -19131,7 +19131,7 @@ def get_GetWorkspaceProjects( raise APIHttpError("get_GetWorkspaceProjects", _resp) def get_GetWorkspaces( - session: "api.Session", + session: "api.BaseSession", *, archived: "typing.Optional[bool]" = None, limit: "typing.Optional[int]" = None, @@ -19193,7 +19193,7 @@ def get_GetWorkspaces( raise APIHttpError("get_GetWorkspaces", _resp) def put_IdleNotebook( - session: "api.Session", + session: "api.BaseSession", *, body: "v1IdleNotebookRequest", notebookId: str, @@ -19220,7 +19220,7 @@ def put_IdleNotebook( raise APIHttpError("put_IdleNotebook", _resp) def post_KillCommand( - session: "api.Session", + session: "api.BaseSession", *, commandId: str, ) -> "v1KillCommandResponse": @@ -19246,7 +19246,7 @@ def post_KillCommand( raise APIHttpError("post_KillCommand", _resp) def post_KillExperiment( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -19270,7 +19270,7 @@ def post_KillExperiment( raise APIHttpError("post_KillExperiment", _resp) def post_KillExperiments( - session: "api.Session", + session: "api.BaseSession", *, body: "v1KillExperimentsRequest", ) -> "v1KillExperimentsResponse": @@ -19291,7 +19291,7 @@ def post_KillExperiments( raise APIHttpError("post_KillExperiments", _resp) def post_KillGenericTask( - session: "api.Session", + session: "api.BaseSession", *, body: "v1KillGenericTaskRequest", taskId: str, @@ -19318,7 +19318,7 @@ def post_KillGenericTask( raise APIHttpError("post_KillGenericTask", _resp) def post_KillNotebook( - session: "api.Session", + session: "api.BaseSession", *, notebookId: str, ) -> "v1KillNotebookResponse": @@ -19344,7 +19344,7 @@ def post_KillNotebook( raise APIHttpError("post_KillNotebook", _resp) def post_KillShell( - session: "api.Session", + session: "api.BaseSession", *, shellId: str, ) -> "v1KillShellResponse": @@ -19370,7 +19370,7 @@ def post_KillShell( raise APIHttpError("post_KillShell", _resp) def post_KillTensorboard( - session: "api.Session", + session: "api.BaseSession", *, tensorboardId: str, ) -> "v1KillTensorboardResponse": @@ -19396,7 +19396,7 @@ def post_KillTensorboard( raise APIHttpError("post_KillTensorboard", _resp) def post_KillTrial( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -19420,7 +19420,7 @@ def post_KillTrial( raise APIHttpError("post_KillTrial", _resp) def post_LaunchCommand( - session: "api.Session", + session: "api.BaseSession", *, body: "v1LaunchCommandRequest", ) -> "v1LaunchCommandResponse": @@ -19441,7 +19441,7 @@ def post_LaunchCommand( raise APIHttpError("post_LaunchCommand", _resp) def post_LaunchNotebook( - session: "api.Session", + session: "api.BaseSession", *, body: "v1LaunchNotebookRequest", ) -> "v1LaunchNotebookResponse": @@ -19462,7 +19462,7 @@ def post_LaunchNotebook( raise APIHttpError("post_LaunchNotebook", _resp) def post_LaunchShell( - session: "api.Session", + session: "api.BaseSession", *, body: "v1LaunchShellRequest", ) -> "v1LaunchShellResponse": @@ -19483,7 +19483,7 @@ def post_LaunchShell( raise APIHttpError("post_LaunchShell", _resp) def post_LaunchTensorboard( - session: "api.Session", + session: "api.BaseSession", *, body: "v1LaunchTensorboardRequest", ) -> "v1LaunchTensorboardResponse": @@ -19504,7 +19504,7 @@ def post_LaunchTensorboard( raise APIHttpError("post_LaunchTensorboard", _resp) def get_ListRPsBoundToWorkspace( - session: "api.Session", + session: "api.BaseSession", *, workspaceId: int, limit: "typing.Optional[int]" = None, @@ -19536,7 +19536,7 @@ def get_ListRPsBoundToWorkspace( raise APIHttpError("get_ListRPsBoundToWorkspace", _resp) def post_ListRoles( - session: "api.Session", + session: "api.BaseSession", *, body: "v1ListRolesRequest", ) -> "v1ListRolesResponse": @@ -19557,7 +19557,7 @@ def post_ListRoles( raise APIHttpError("post_ListRoles", _resp) def get_ListWorkspacesBoundToRP( - session: "api.Session", + session: "api.BaseSession", *, resourcePoolName: str, limit: "typing.Optional[int]" = None, @@ -19590,7 +19590,7 @@ def get_ListWorkspacesBoundToRP( raise APIHttpError("get_ListWorkspacesBoundToRP", _resp) def post_Login( - session: "api.Session", + session: "api.BaseSession", *, body: "v1LoginRequest", ) -> "v1LoginResponse": @@ -19611,7 +19611,7 @@ def post_Login( raise APIHttpError("post_Login", _resp) def post_Logout( - session: "api.Session", + session: "api.BaseSession", ) -> None: """Logout the user.""" _params = None @@ -19630,7 +19630,7 @@ def post_Logout( raise APIHttpError("post_Logout", _resp) def post_MarkAllocationResourcesDaemon( - session: "api.Session", + session: "api.BaseSession", *, allocationId: str, body: "v1MarkAllocationResourcesDaemonRequest", @@ -19665,7 +19665,7 @@ def post_MarkAllocationResourcesDaemon( raise APIHttpError("post_MarkAllocationResourcesDaemon", _resp) def get_MasterLogs( - session: "api.Session", + session: "api.BaseSession", *, follow: "typing.Optional[bool]" = None, limit: "typing.Optional[int]" = None, @@ -19709,7 +19709,7 @@ def get_MasterLogs( raise APIHttpError("get_MasterLogs", _resp) def get_MetricBatches( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, metricName: str, @@ -19762,7 +19762,7 @@ def get_MetricBatches( raise APIHttpError("get_MetricBatches", _resp) def post_MoveExperiment( - session: "api.Session", + session: "api.BaseSession", *, body: "v1MoveExperimentRequest", experimentId: int, @@ -19787,7 +19787,7 @@ def post_MoveExperiment( raise APIHttpError("post_MoveExperiment", _resp) def post_MoveExperiments( - session: "api.Session", + session: "api.BaseSession", *, body: "v1MoveExperimentsRequest", ) -> "v1MoveExperimentsResponse": @@ -19808,7 +19808,7 @@ def post_MoveExperiments( raise APIHttpError("post_MoveExperiments", _resp) def post_MoveModel( - session: "api.Session", + session: "api.BaseSession", *, body: "v1MoveModelRequest", modelName: str, @@ -19835,7 +19835,7 @@ def post_MoveModel( raise APIHttpError("post_MoveModel", _resp) def post_MoveProject( - session: "api.Session", + session: "api.BaseSession", *, body: "v1MoveProjectRequest", projectId: int, @@ -19860,7 +19860,7 @@ def post_MoveProject( raise APIHttpError("post_MoveProject", _resp) def post_NotifyContainerRunning( - session: "api.Session", + session: "api.BaseSession", *, allocationId: str, body: "v1NotifyContainerRunningRequest", @@ -19892,7 +19892,7 @@ def post_NotifyContainerRunning( raise APIHttpError("post_NotifyContainerRunning", _resp) def put_OverwriteRPWorkspaceBindings( - session: "api.Session", + session: "api.BaseSession", *, body: "v1OverwriteRPWorkspaceBindingsRequest", resourcePoolName: str, @@ -19919,7 +19919,7 @@ def put_OverwriteRPWorkspaceBindings( raise APIHttpError("put_OverwriteRPWorkspaceBindings", _resp) def patch_PatchCheckpoints( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PatchCheckpointsRequest", ) -> None: @@ -19940,7 +19940,7 @@ def patch_PatchCheckpoints( raise APIHttpError("patch_PatchCheckpoints", _resp) def patch_PatchExperiment( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PatchExperiment", experiment_id: int, @@ -19966,7 +19966,7 @@ def patch_PatchExperiment( raise APIHttpError("patch_PatchExperiment", _resp) def patch_PatchMasterConfig( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PatchMasterConfigRequest", ) -> None: @@ -19987,7 +19987,7 @@ def patch_PatchMasterConfig( raise APIHttpError("patch_PatchMasterConfig", _resp) def patch_PatchModel( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PatchModel", modelName: str, @@ -20015,7 +20015,7 @@ def patch_PatchModel( raise APIHttpError("patch_PatchModel", _resp) def patch_PatchModelVersion( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PatchModelVersion", modelName: str, @@ -20045,7 +20045,7 @@ def patch_PatchModelVersion( raise APIHttpError("patch_PatchModelVersion", _resp) def patch_PatchProject( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PatchProject", id: int, @@ -20071,7 +20071,7 @@ def patch_PatchProject( raise APIHttpError("patch_PatchProject", _resp) def patch_PatchTemplateConfig( - session: "api.Session", + session: "api.BaseSession", *, body: "typing.Dict[str, typing.Any]", templateName: str, @@ -20099,7 +20099,7 @@ def patch_PatchTemplateConfig( raise APIHttpError("patch_PatchTemplateConfig", _resp) def patch_PatchTrial( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PatchTrialRequest", trialId: int, @@ -20124,7 +20124,7 @@ def patch_PatchTrial( raise APIHttpError("patch_PatchTrial", _resp) def patch_PatchUser( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PatchUser", userId: int, @@ -20150,7 +20150,7 @@ def patch_PatchUser( raise APIHttpError("patch_PatchUser", _resp) def patch_PatchUsers( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PatchUsersRequest", ) -> "v1PatchUsersResponse": @@ -20171,7 +20171,7 @@ def patch_PatchUsers( raise APIHttpError("patch_PatchUsers", _resp) def patch_PatchWorkspace( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PatchWorkspace", id: int, @@ -20197,7 +20197,7 @@ def patch_PatchWorkspace( raise APIHttpError("patch_PatchWorkspace", _resp) def post_PauseExperiment( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -20221,7 +20221,7 @@ def post_PauseExperiment( raise APIHttpError("post_PauseExperiment", _resp) def post_PauseExperiments( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PauseExperimentsRequest", ) -> "v1PauseExperimentsResponse": @@ -20242,7 +20242,7 @@ def post_PauseExperiments( raise APIHttpError("post_PauseExperiments", _resp) def post_PauseGenericTask( - session: "api.Session", + session: "api.BaseSession", *, taskId: str, ) -> None: @@ -20268,7 +20268,7 @@ def post_PauseGenericTask( raise APIHttpError("post_PauseGenericTask", _resp) def post_PinWorkspace( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -20292,7 +20292,7 @@ def post_PinWorkspace( raise APIHttpError("post_PinWorkspace", _resp) def post_PostAllocationAcceleratorData( - session: "api.Session", + session: "api.BaseSession", *, allocationId: str, body: "v1PostAllocationAcceleratorDataRequest", @@ -20319,7 +20319,7 @@ def post_PostAllocationAcceleratorData( raise APIHttpError("post_PostAllocationAcceleratorData", _resp) def post_PostAllocationProxyAddress( - session: "api.Session", + session: "api.BaseSession", *, allocationId: str, body: "v1PostAllocationProxyAddressRequest", @@ -20348,7 +20348,7 @@ def post_PostAllocationProxyAddress( raise APIHttpError("post_PostAllocationProxyAddress", _resp) def post_PostCheckpointMetadata( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PostCheckpointMetadataRequest", checkpoint_uuid: str, @@ -20375,7 +20375,7 @@ def post_PostCheckpointMetadata( raise APIHttpError("post_PostCheckpointMetadata", _resp) def post_PostModel( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PostModelRequest", ) -> "v1PostModelResponse": @@ -20396,7 +20396,7 @@ def post_PostModel( raise APIHttpError("post_PostModel", _resp) def post_PostModelVersion( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PostModelVersionRequest", modelName: str, @@ -20423,7 +20423,7 @@ def post_PostModelVersion( raise APIHttpError("post_PostModelVersion", _resp) def post_PostProject( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PostProjectRequest", workspaceId: int, @@ -20448,7 +20448,7 @@ def post_PostProject( raise APIHttpError("post_PostProject", _resp) def post_PostSearcherOperations( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PostSearcherOperationsRequest", experimentId: int, @@ -20473,7 +20473,7 @@ def post_PostSearcherOperations( raise APIHttpError("post_PostSearcherOperations", _resp) def post_PostTaskLogs( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PostTaskLogsRequest", ) -> None: @@ -20494,7 +20494,7 @@ def post_PostTaskLogs( raise APIHttpError("post_PostTaskLogs", _resp) def post_PostTemplate( - session: "api.Session", + session: "api.BaseSession", *, body: "v1Template", template_name: str, @@ -20522,7 +20522,7 @@ def post_PostTemplate( raise APIHttpError("post_PostTemplate", _resp) def post_PostTrialProfilerMetricsBatch( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PostTrialProfilerMetricsBatchRequest", ) -> None: @@ -20543,7 +20543,7 @@ def post_PostTrialProfilerMetricsBatch( raise APIHttpError("post_PostTrialProfilerMetricsBatch", _resp) def post_PostTrialRunnerMetadata( - session: "api.Session", + session: "api.BaseSession", *, body: "v1TrialRunnerMetadata", trialId: int, @@ -20569,7 +20569,7 @@ def post_PostTrialRunnerMetadata( raise APIHttpError("post_PostTrialRunnerMetadata", _resp) def post_PostUser( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PostUserRequest", ) -> "v1PostUserResponse": @@ -20590,7 +20590,7 @@ def post_PostUser( raise APIHttpError("post_PostUser", _resp) def post_PostUserActivity( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PostUserActivityRequest", ) -> None: @@ -20611,7 +20611,7 @@ def post_PostUserActivity( raise APIHttpError("post_PostUserActivity", _resp) def post_PostUserSetting( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PostUserSettingRequest", ) -> None: @@ -20632,7 +20632,7 @@ def post_PostUserSetting( raise APIHttpError("post_PostUserSetting", _resp) def post_PostWebhook( - session: "api.Session", + session: "api.BaseSession", *, body: "v1Webhook", ) -> "v1PostWebhookResponse": @@ -20657,7 +20657,7 @@ def post_PostWebhook( raise APIHttpError("post_PostWebhook", _resp) def post_PostWorkspace( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PostWorkspaceRequest", ) -> "v1PostWorkspaceResponse": @@ -20678,7 +20678,7 @@ def post_PostWorkspace( raise APIHttpError("post_PostWorkspace", _resp) def post_PreviewHPSearch( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PreviewHPSearchRequest", ) -> "v1PreviewHPSearchResponse": @@ -20699,7 +20699,7 @@ def post_PreviewHPSearch( raise APIHttpError("post_PreviewHPSearch", _resp) def put_PutExperiment( - session: "api.Session", + session: "api.BaseSession", *, body: "v1CreateExperimentRequest", externalExperimentId: str, @@ -20727,7 +20727,7 @@ def put_PutExperiment( raise APIHttpError("put_PutExperiment", _resp) def put_PutExperimentLabel( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, label: str, @@ -20755,7 +20755,7 @@ def put_PutExperimentLabel( raise APIHttpError("put_PutExperimentLabel", _resp) def put_PutProjectNotes( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PutProjectNotesRequest", projectId: int, @@ -20780,7 +20780,7 @@ def put_PutProjectNotes( raise APIHttpError("put_PutProjectNotes", _resp) def put_PutTemplate( - session: "api.Session", + session: "api.BaseSession", *, body: "v1Template", template_name: str, @@ -20808,7 +20808,7 @@ def put_PutTemplate( raise APIHttpError("put_PutTemplate", _resp) def put_PutTrial( - session: "api.Session", + session: "api.BaseSession", *, body: "v1PutTrialRequest", ) -> "v1PutTrialResponse": @@ -20829,7 +20829,7 @@ def put_PutTrial( raise APIHttpError("put_PutTrial", _resp) def post_RemoveAssignments( - session: "api.Session", + session: "api.BaseSession", *, body: "v1RemoveAssignmentsRequest", ) -> None: @@ -20850,7 +20850,7 @@ def post_RemoveAssignments( raise APIHttpError("post_RemoveAssignments", _resp) def post_ReportCheckpoint( - session: "api.Session", + session: "api.BaseSession", *, body: "v1Checkpoint", ) -> None: @@ -20874,7 +20874,7 @@ def post_ReportCheckpoint( raise APIHttpError("post_ReportCheckpoint", _resp) def post_ReportTrialMetrics( - session: "api.Session", + session: "api.BaseSession", *, body: "v1ReportTrialMetricsRequest", metrics_trialId: int, @@ -20899,7 +20899,7 @@ def post_ReportTrialMetrics( raise APIHttpError("post_ReportTrialMetrics", _resp) def post_ReportTrialProgress( - session: "api.Session", + session: "api.BaseSession", *, body: float, trialId: int, @@ -20927,7 +20927,7 @@ def post_ReportTrialProgress( raise APIHttpError("post_ReportTrialProgress", _resp) def post_ReportTrialSearcherEarlyExit( - session: "api.Session", + session: "api.BaseSession", *, body: "v1TrialEarlyExit", trialId: int, @@ -20955,7 +20955,7 @@ def post_ReportTrialSearcherEarlyExit( raise APIHttpError("post_ReportTrialSearcherEarlyExit", _resp) def post_ReportTrialSourceInfo( - session: "api.Session", + session: "api.BaseSession", *, body: "v1ReportTrialSourceInfoRequest", ) -> "v1ReportTrialSourceInfoResponse": @@ -20976,7 +20976,7 @@ def post_ReportTrialSourceInfo( raise APIHttpError("post_ReportTrialSourceInfo", _resp) def post_ReportTrialTrainingMetrics( - session: "api.Session", + session: "api.BaseSession", *, body: "v1TrialMetrics", trainingMetrics_trialId: int, @@ -21002,7 +21002,7 @@ def post_ReportTrialTrainingMetrics( raise APIHttpError("post_ReportTrialTrainingMetrics", _resp) def post_ReportTrialValidationMetrics( - session: "api.Session", + session: "api.BaseSession", *, body: "v1TrialMetrics", validationMetrics_trialId: int, @@ -21028,7 +21028,7 @@ def post_ReportTrialValidationMetrics( raise APIHttpError("post_ReportTrialValidationMetrics", _resp) def post_ResetUserSetting( - session: "api.Session", + session: "api.BaseSession", ) -> None: """Reset a user's settings for website""" _params = None @@ -21047,7 +21047,7 @@ def post_ResetUserSetting( raise APIHttpError("post_ResetUserSetting", _resp) def get_ResourceAllocationAggregated( - session: "api.Session", + session: "api.BaseSession", *, endDate: str, period: "v1ResourceAllocationAggregationPeriod", @@ -21086,7 +21086,7 @@ def get_ResourceAllocationAggregated( raise APIHttpError("get_ResourceAllocationAggregated", _resp) def get_ResourceAllocationRaw( - session: "api.Session", + session: "api.BaseSession", *, timestampAfter: str, timestampBefore: str, @@ -21115,7 +21115,7 @@ def get_ResourceAllocationRaw( raise APIHttpError("get_ResourceAllocationRaw", _resp) def post_RunPrepareForReporting( - session: "api.Session", + session: "api.BaseSession", *, body: "v1RunPrepareForReportingRequest", ) -> "v1RunPrepareForReportingResponse": @@ -21138,7 +21138,7 @@ def post_RunPrepareForReporting( raise APIHttpError("post_RunPrepareForReporting", _resp) def get_SearchExperiments( - session: "api.Session", + session: "api.BaseSession", *, filter: "typing.Optional[str]" = None, limit: "typing.Optional[int]" = None, @@ -21176,7 +21176,7 @@ def get_SearchExperiments( raise APIHttpError("get_SearchExperiments", _resp) def post_SearchRolesAssignableToScope( - session: "api.Session", + session: "api.BaseSession", *, body: "v1SearchRolesAssignableToScopeRequest", ) -> "v1SearchRolesAssignableToScopeResponse": @@ -21197,7 +21197,7 @@ def post_SearchRolesAssignableToScope( raise APIHttpError("post_SearchRolesAssignableToScope", _resp) def post_SetCommandPriority( - session: "api.Session", + session: "api.BaseSession", *, body: "v1SetCommandPriorityRequest", commandId: str, @@ -21224,7 +21224,7 @@ def post_SetCommandPriority( raise APIHttpError("post_SetCommandPriority", _resp) def post_SetNotebookPriority( - session: "api.Session", + session: "api.BaseSession", *, body: "v1SetNotebookPriorityRequest", notebookId: str, @@ -21251,7 +21251,7 @@ def post_SetNotebookPriority( raise APIHttpError("post_SetNotebookPriority", _resp) def post_SetShellPriority( - session: "api.Session", + session: "api.BaseSession", *, body: "v1SetShellPriorityRequest", shellId: str, @@ -21278,7 +21278,7 @@ def post_SetShellPriority( raise APIHttpError("post_SetShellPriority", _resp) def post_SetTensorboardPriority( - session: "api.Session", + session: "api.BaseSession", *, body: "v1SetTensorboardPriorityRequest", tensorboardId: str, @@ -21305,7 +21305,7 @@ def post_SetTensorboardPriority( raise APIHttpError("post_SetTensorboardPriority", _resp) def post_SetUserPassword( - session: "api.Session", + session: "api.BaseSession", *, body: str, userId: int, @@ -21331,7 +21331,7 @@ def post_SetUserPassword( raise APIHttpError("post_SetUserPassword", _resp) def post_StartTrial( - session: "api.Session", + session: "api.BaseSession", *, body: "v1StartTrialRequest", trialId: int, @@ -21356,7 +21356,7 @@ def post_StartTrial( raise APIHttpError("post_StartTrial", _resp) def get_TaskLogs( - session: "api.Session", + session: "api.BaseSession", *, taskId: str, agentIds: "typing.Optional[typing.Sequence[str]]" = None, @@ -21445,7 +21445,7 @@ def get_TaskLogs( raise APIHttpError("get_TaskLogs", _resp) def get_TaskLogsFields( - session: "api.Session", + session: "api.BaseSession", *, taskId: str, follow: "typing.Optional[bool]" = None, @@ -21486,7 +21486,7 @@ def get_TaskLogsFields( raise APIHttpError("get_TaskLogsFields", _resp) def post_TestWebhook( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> "v1TestWebhookResponse": @@ -21510,7 +21510,7 @@ def post_TestWebhook( raise APIHttpError("post_TestWebhook", _resp) def get_TrialLogs( - session: "api.Session", + session: "api.BaseSession", *, trialId: int, agentIds: "typing.Optional[typing.Sequence[str]]" = None, @@ -21594,7 +21594,7 @@ def get_TrialLogs( raise APIHttpError("get_TrialLogs", _resp) def get_TrialLogsFields( - session: "api.Session", + session: "api.BaseSession", *, trialId: int, follow: "typing.Optional[bool]" = None, @@ -21633,7 +21633,7 @@ def get_TrialLogsFields( raise APIHttpError("get_TrialLogsFields", _resp) def get_TrialsSample( - session: "api.Session", + session: "api.BaseSession", *, experimentId: int, metricName: str, @@ -21697,7 +21697,7 @@ def get_TrialsSample( raise APIHttpError("get_TrialsSample", _resp) def get_TrialsSnapshot( - session: "api.Session", + session: "api.BaseSession", *, batchesProcessed: int, experimentId: int, @@ -21756,7 +21756,7 @@ def get_TrialsSnapshot( raise APIHttpError("get_TrialsSnapshot", _resp) def post_UnarchiveExperiment( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -21780,7 +21780,7 @@ def post_UnarchiveExperiment( raise APIHttpError("post_UnarchiveExperiment", _resp) def post_UnarchiveExperiments( - session: "api.Session", + session: "api.BaseSession", *, body: "v1UnarchiveExperimentsRequest", ) -> "v1UnarchiveExperimentsResponse": @@ -21801,7 +21801,7 @@ def post_UnarchiveExperiments( raise APIHttpError("post_UnarchiveExperiments", _resp) def post_UnarchiveModel( - session: "api.Session", + session: "api.BaseSession", *, modelName: str, ) -> None: @@ -21827,7 +21827,7 @@ def post_UnarchiveModel( raise APIHttpError("post_UnarchiveModel", _resp) def post_UnarchiveProject( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -21851,7 +21851,7 @@ def post_UnarchiveProject( raise APIHttpError("post_UnarchiveProject", _resp) def post_UnarchiveWorkspace( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -21875,7 +21875,7 @@ def post_UnarchiveWorkspace( raise APIHttpError("post_UnarchiveWorkspace", _resp) def delete_UnbindRPFromWorkspace( - session: "api.Session", + session: "api.BaseSession", *, body: "v1UnbindRPFromWorkspaceRequest", resourcePoolName: str, @@ -21902,7 +21902,7 @@ def delete_UnbindRPFromWorkspace( raise APIHttpError("delete_UnbindRPFromWorkspace", _resp) def post_UnpauseGenericTask( - session: "api.Session", + session: "api.BaseSession", *, taskId: str, ) -> None: @@ -21928,7 +21928,7 @@ def post_UnpauseGenericTask( raise APIHttpError("post_UnpauseGenericTask", _resp) def post_UnpinWorkspace( - session: "api.Session", + session: "api.BaseSession", *, id: int, ) -> None: @@ -21952,7 +21952,7 @@ def post_UnpinWorkspace( raise APIHttpError("post_UnpinWorkspace", _resp) def put_UpdateGroup( - session: "api.Session", + session: "api.BaseSession", *, body: "v1UpdateGroupRequest", groupId: int, @@ -21977,7 +21977,7 @@ def put_UpdateGroup( raise APIHttpError("put_UpdateGroup", _resp) def post_UpdateJobQueue( - session: "api.Session", + session: "api.BaseSession", *, body: "v1UpdateJobQueueRequest", ) -> None: diff --git a/harness/determined/common/api/certs.py b/harness/determined/common/api/certs.py index 92adfb0a522..372ee4d6191 100644 --- a/harness/determined/common/api/certs.py +++ b/harness/determined/common/api/certs.py @@ -65,9 +65,6 @@ def name(self) -> Optional[str]: return self._name -cli_cert = None # type: Optional[Cert] - - class CertStore: """ CertStore represents a persistent file-based record of certificates, each associated with a diff --git a/harness/determined/common/api/errors.py b/harness/determined/common/api/errors.py index 798cf059378..12d7f1b0547 100644 --- a/harness/determined/common/api/errors.py +++ b/harness/determined/common/api/errors.py @@ -68,23 +68,21 @@ def __init__(self, error_message: str) -> None: class ForbiddenException(BadRequestException): - def __init__(self, username: str, message: str = ""): + def __init__(self, message: str = ""): err_message = f"Forbidden({message})" if not (message == "invalid credentials" or message == "user not found"): err_message += ": Please contact your administrator in order to access this resource." super().__init__(message=err_message) - self.username = username class UnauthenticatedException(BadRequestException): - def __init__(self, username: str): + def __init__(self) -> None: super().__init__( message="Unauthenticated: Please use 'det user login ' for password login, or" " for Enterprise users logging in with an SSO provider," " use 'det auth login --provider='." ) - self.username = username class CorruptTokenCacheException(Exception): diff --git a/harness/determined/common/api/profiler.py b/harness/determined/common/api/profiler.py deleted file mode 100644 index 23fd3eccdb3..00000000000 --- a/harness/determined/common/api/profiler.py +++ /dev/null @@ -1,112 +0,0 @@ -import time -from typing import Any, Dict, List, Optional - -from requests import exceptions - -from determined.common import api - - -class TrialProfilerMetricsBatch: - """ - TrialProfilerMetricsBatch is the representation of a batch of trial - profiler metrics as accepted by POST /api/v1/trials/:trial_id/profiler/metrics - """ - - def __init__( - self, - values: List[float], - batches: List[int], - timestamps: List[str], - labels: Dict[str, Any], - ): - self.values = values - self.batches = batches - self.timestamps = timestamps - self.labels = labels - - -def post_trial_profiler_metrics_batches( - master_url: str, - batches: List[TrialProfilerMetricsBatch], -) -> None: - """ - Post the given metrics to the master to be persisted. Labels - must contain only a subset of the keys: trial_id, name, - gpu_uuid, agent_id and metric_type, where metric_type is one - of PROFILER_METRIC_TYPE_SYSTEM or PROFILER_METRIC_TYPE_TIMING. - """ - backoff_interval = 1 - max_tries = 2 - tries = 0 - - while tries < max_tries: - try: - api.post( - master_url, - "/api/v1/trials/profiler/metrics", - json={"batches": [b.__dict__ for b in batches]}, - ) - return - except exceptions.RequestException as e: - if e.response is not None and e.response.status_code < 500: - raise e - - tries += 1 - if tries == max_tries: - raise e - time.sleep(backoff_interval) - return - - -class TrialProfilerSeriesLabels: - def __init__(self, trial_id: int, name: str, agent_id: str, gpu_uuid: str, metric_type: str): - self.trial_id = str(trial_id) - self.name = name - self.agent_id = agent_id - self.gpu_uuid = gpu_uuid if gpu_uuid != "" else None # type: Optional[str] - self.metric_type = metric_type - - -def get_trial_profiler_available_series( - master_url: str, - trial_id: str, -) -> List[TrialProfilerSeriesLabels]: - """ - Get available profiler series for a trial. This uses the non-streaming version of the API - """ - follow = False - backoff_interval = 1 - max_tries = 2 - tries = 0 - - response = None - while tries < max_tries: - try: - response = api.get( - host=master_url, - path=f"/api/v1/trials/{trial_id}/profiler/available_series", - params={"follow": follow}, - ) - break - except exceptions.RequestException as e: - if e.response is not None and e.response.status_code < 500: - raise e - - tries += 1 - if tries == max_tries: - raise e - time.sleep(backoff_interval) - - assert response - j = response.json() - labels = [ - TrialProfilerSeriesLabels( - trial_id=ld["trialId"], - name=ld["name"], - agent_id=ld["agentId"], - gpu_uuid=ld["gpuUuid"], - metric_type=ld["metricType"], - ) - for ld in j["result"]["labels"] - ] - return labels diff --git a/harness/determined/common/api/request.py b/harness/determined/common/api/request.py index 56cc18dd075..7a09827695e 100644 --- a/harness/determined/common/api/request.py +++ b/harness/determined/common/api/request.py @@ -1,17 +1,7 @@ -import json as _json import os -import types import webbrowser -from typing import Any, Dict, Iterator, Optional, Tuple, Union from urllib import parse -import requests -import urllib3 - -import determined as det -import determined.common.requests -from determined.common.api import authentication, certs, errors - def parse_master_address(master_address: str) -> parse.ParseResult: if master_address.startswith("https://"): @@ -84,288 +74,7 @@ def make_interactive_task_url( return task_web_url -def do_request( - method: str, - host: str, - path: str, - params: Optional[Dict[str, Any]] = None, - json: Any = None, - data: Optional[str] = None, - headers: Optional[Dict[str, str]] = None, - authenticated: bool = True, - auth: Optional[authentication.Authentication] = None, - cert: Optional[certs.Cert] = None, - stream: bool = False, - timeout: Optional[Union[Tuple, float]] = None, - max_retries: Optional[urllib3.util.retry.Retry] = None, -) -> requests.Response: - if headers is None: - h: Dict[str, str] = {} - else: - h = headers - - if cert is None: - cert = certs.cli_cert - - # set the token and username based on this order: - # - argument `auth` - # - header `Authorization` - # - existing cli_auth - # - allocation_token - - username = "" - if auth is not None: - if authenticated: - h["Authorization"] = "Bearer {}".format(auth.get_session_token()) - username = auth.get_session_user() - elif h.get("Authorization") is not None: - pass - elif authentication.cli_auth is not None: - if authenticated: - h["Authorization"] = "Bearer {}".format(authentication.cli_auth.get_session_token()) - username = authentication.cli_auth.get_session_user() - elif authenticated and h.get("Grpc-Metadata-x-allocation-token") is None: - allocation_token = authentication.get_allocation_token() - if allocation_token: - h["Grpc-Metadata-x-allocation-token"] = "Bearer {}".format(allocation_token) - - if params is None: - params = {} - - # Allow the json json to come pre-encoded, if we need custom encoding. - if json is not None and data is not None: - raise ValueError("json and data must not be provided together") - - if json: - data = det.util.json_encode(json) - - try: - r = determined.common.requests.request( - method, - make_url(host, path), - params=params, - data=data, - headers=h, - verify=cert.bundle if cert else None, - stream=stream, - timeout=timeout, - server_hostname=cert.name if cert else None, - max_retries=max_retries, - ) - except requests.exceptions.SSLError: - raise - except requests.exceptions.ConnectionError as e: - raise errors.MasterNotFoundException(str(e)) - except requests.exceptions.RequestException as e: - raise errors.BadRequestException(str(e)) - - def _get_error_str(r: requests.models.Response) -> str: - try: - json_resp = _json.loads(r.text) - mes = json_resp.get("message") - if mes is not None: - return str(mes) - # Try getting GRPC error description if message does not exist. - return str(json_resp.get("error").get("error")) - except Exception: - return "" - - if r.status_code == 403: - raise errors.ForbiddenException(username=username, message=_get_error_str(r)) - if r.status_code == 401: - raise errors.UnauthenticatedException(username=username) - elif r.status_code == 404: - raise errors.NotFoundException(_get_error_str(r)) - elif r.status_code >= 300: - raise errors.APIException(r) - - return r - - -def get( - host: str, - path: str, - params: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, str]] = None, - authenticated: bool = True, - auth: Optional[authentication.Authentication] = None, - cert: Optional[certs.Cert] = None, - stream: bool = False, - timeout: Optional[Union[Tuple, float]] = None, -) -> requests.Response: - """ - Send a GET request to the remote API. - """ - return do_request( - "GET", - host, - path, - params=params, - headers=headers, - authenticated=authenticated, - auth=auth, - cert=cert, - stream=stream, - ) - - -def delete( - host: str, - path: str, - params: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, str]] = None, - authenticated: bool = True, - auth: Optional[authentication.Authentication] = None, - cert: Optional[certs.Cert] = None, - timeout: Optional[Union[Tuple, float]] = None, -) -> requests.Response: - """ - Send a DELETE request to the remote API. - """ - return do_request( - "DELETE", - host, - path, - params=params, - headers=headers, - authenticated=authenticated, - auth=auth, - cert=cert, - timeout=timeout, - ) - - -def post( - host: str, - path: str, - json: Any = None, - headers: Optional[Dict[str, str]] = None, - authenticated: bool = True, - auth: Optional[authentication.Authentication] = None, - cert: Optional[certs.Cert] = None, - timeout: Optional[Union[Tuple, float]] = None, -) -> requests.Response: - """ - Send a POST request to the remote API. - """ - return do_request( - "POST", - host, - path, - json=json, - headers=headers, - authenticated=authenticated, - auth=auth, - cert=cert, - timeout=timeout, - ) - - -def patch( - host: str, - path: str, - json: Dict[str, Any], - headers: Optional[Dict[str, str]] = None, - authenticated: bool = True, - auth: Optional[authentication.Authentication] = None, - cert: Optional[certs.Cert] = None, - timeout: Optional[Union[Tuple, float]] = None, -) -> requests.Response: - """ - Send a PATCH request to the remote API. - """ - return do_request( - "PATCH", - host, - path, - json=json, - headers=headers, - authenticated=authenticated, - auth=auth, - cert=cert, - timeout=timeout, - ) - - -def put( - host: str, - path: str, - json: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, str]] = None, - authenticated: bool = True, - auth: Optional[authentication.Authentication] = None, - cert: Optional[certs.Cert] = None, - timeout: Optional[Union[Tuple, float]] = None, -) -> requests.Response: - """ - Send a PUT request to the remote API. - """ - return do_request( - "PUT", - host, - path, - json=json, - headers=headers, - authenticated=authenticated, - auth=auth, - cert=cert, - timeout=timeout, - ) - - def browser_open(host: str, path: str) -> str: url = make_url(host, path) webbrowser.open(url) return url - - -class WebSocket: - def __init__(self, socket: Any) -> None: - import lomond - - self.socket = socket # type: lomond.WebSocket - - def __enter__(self) -> "WebSocket": - return self - - def __iter__(self) -> Iterator[Any]: - from lomond import events - - for event in self.socket.connect(ping_rate=0): - if isinstance(event, events.Connected): - # Ignore the initial connection event. - pass - elif isinstance(event, (events.Closing, events.Disconnected)): - # The socket was successfully closed so we just return. - return - elif isinstance( - event, - (events.ConnectFail, events.Rejected, events.ProtocolError), - ): - # Any unexpected failures raise the standard API exception. - raise errors.BadRequestException(message="WebSocket failure: {}".format(event)) - elif isinstance(event, events.Text): - # All web socket connections are expected to be in a JSON - # format. - yield _json.loads(event.text) - - def __exit__( - self, - exc_type: Optional[type], - exc_val: Optional[BaseException], - exc_tb: Optional[types.TracebackType], - ) -> None: - if not self.socket.is_closed: - self.socket.close() - - -def ws(host: str, path: str) -> WebSocket: - """ - Connect to a web socket at the remote API. - """ - import lomond - - websocket = lomond.WebSocket(maybe_upgrade_ws_scheme(make_url(host, path))) - token = authentication.must_cli_auth().get_session_token() - websocket.add_header("Authorization".encode(), "Bearer {}".format(token).encode()) - return WebSocket(websocket) diff --git a/harness/determined/common/experimental/__init__.py b/harness/determined/common/experimental/__init__.py index 388e212eb0c..77cd5ea6935 100644 --- a/harness/determined/common/experimental/__init__.py +++ b/harness/determined/common/experimental/__init__.py @@ -1,7 +1,6 @@ import warnings # TODO: delete all of these when det.experimental.client is removed. -from determined.common.api import Session from determined.common.experimental.checkpoint import Checkpoint from determined.common.experimental.determined import Determined from determined.common.experimental.experiment import Experiment, ExperimentReference diff --git a/harness/determined/common/experimental/determined.py b/harness/determined/common/experimental/determined.py index 7784b1d524a..26f8a2b56cc 100644 --- a/harness/determined/common/experimental/determined.py +++ b/harness/determined/common/experimental/determined.py @@ -25,6 +25,7 @@ class Determined: + # Dev note: Determined is basically a wrapper around Session that calls generated bindings. """ Determined gives access to Determined API objects. @@ -54,14 +55,9 @@ def __init__( explicit_noverify=noverify, ) - auth = authentication.Authentication(self._master, user, password, cert=cert) + utp = authentication.login_with_cache(self._master, user, password, cert=cert) retry = api.default_retry() - self._session = api.Session(self._master, user, auth, cert, retry) - token_user = auth.token_store.get_active_user() - if token_user is not None: - self._token = auth.token_store.get_token(token_user) - else: - self._token = None + self._session = api.Session(self._master, utp, cert, retry) @classmethod def _from_session(cls, session: api.Session) -> "Determined": @@ -69,18 +65,14 @@ def _from_session(cls, session: api.Session) -> "Determined": This constructor exists to help the CLI transition to using SDK methods, most of which are derived from a Determined object at some point in their lifespan. - - WARNING: Determined objects created with this contsructor will not have a token, and so - have no access to the oauth API. """ # mypy gives new_det "Any" type, even if cls is annotated new_det = cls.__new__(cls) # type: Determined new_det._session = session - new_det._token = None return new_det def create_user( - self, username: str, admin: bool, password: Optional[str], remote: bool = False + self, username: str, admin: bool, password: Optional[str] = None, remote: bool = False ) -> user.User: create_user = bindings.v1User(username=username, admin=admin, active=True, remote=remote) hashedPassword = None @@ -105,22 +97,10 @@ def whoami(self) -> user.User: return user.User._from_bindings(resp.user, self._session) def get_session_username(self) -> str: - auth = self._session._auth - assert auth - return auth.get_session_user() + return self._session.username def logout(self) -> None: - auth = self._session._auth - # auth should only be None in the special login Session, which must not be used in a - # Determined object. - assert auth, "Determined.logout() found an unauthorized Session" - - user = auth.get_session_user() - # get_session_user() is allowed to return an empty string, which seems dumb, but in that - # case we do not want to trigger the authentication.logout default username lookup logic. - assert user, "Determined.logout() couldn't find a valid username" - - authentication.logout(self._session._master, user, self._session._cert) + authentication.logout(self._session.master, self._session.username, self._session.cert) def list_users(self, active: Optional[bool] = None) -> List[user.User]: def get_with_offset(offset: int) -> bindings.v1GetUsersResponse: @@ -478,8 +458,7 @@ def get_model_labels(self) -> List[str]: def list_oauth_clients(self) -> Sequence[oauth2_scim_client.Oauth2ScimClient]: try: oauth2_scim_clients: List[oauth2_scim_client.Oauth2ScimClient] = [] - headers = {"Authorization": "Bearer {}".format(self._token)} - clients = api.get(self._master, "oauth2/clients", headers=headers).json() + clients = self._session.get("oauth2/clients").json() for client in clients: osc: oauth2_scim_client.Oauth2ScimClient = oauth2_scim_client.Oauth2ScimClient( name=client["name"], client_id=client["id"], domain=client["domain"] @@ -491,14 +470,9 @@ def list_oauth_clients(self) -> Sequence[oauth2_scim_client.Oauth2ScimClient]: def add_oauth_client(self, domain: str, name: str) -> oauth2_scim_client.Oauth2ScimClient: try: - headers = {"Authorization": "Bearer {}".format(self._token)} - client = api.post( - self._master, - "oauth2/clients", - headers=headers, - json={"domain": domain, "name": name}, + client = self._session.post( + "oauth2/clients", json={"domain": domain, "name": name} ).json() - return oauth2_scim_client.Oauth2ScimClient( client_id=str(client["id"]), secret=str(client["secret"]), domain=domain, name=name ) @@ -508,8 +482,7 @@ def add_oauth_client(self, domain: str, name: str) -> oauth2_scim_client.Oauth2S def remove_oauth_client(self, client_id: str) -> None: try: - headers = {"Authorization": "Bearer {}".format(self._token)} - api.delete(self._master, "oauth2/clients/{}".format(client_id), headers=headers) + self._session.delete(f"oauth2/clients/{client_id}") except api.errors.NotFoundException: raise det.errors.EnterpriseOnlyError("API not found: oauth2/clients") diff --git a/harness/determined/common/experimental/session.py b/harness/determined/common/experimental/session.py deleted file mode 100644 index 3c84d185f47..00000000000 --- a/harness/determined/common/experimental/session.py +++ /dev/null @@ -1,7 +0,0 @@ -# link for backwards compatibility, since this was the original place that Session was exposed. -# At the present time, users should be using `determined.experimental.client.Session` instead, but -# since there's a big breaking change after we remove `client` from `determined.experimental`, we -# should remove this at that time. -# TODO: remove this link when we remove `client` from `determined.experimental`. - -from determined.common.api import Session # noqa: F401 diff --git a/harness/determined/core/_context.py b/harness/determined/core/_context.py index 4e2c79d0ac8..7ea00ad12f4 100644 --- a/harness/determined/core/_context.py +++ b/harness/determined/core/_context.py @@ -12,7 +12,7 @@ import determined as det from determined import core, tensorboard from determined.common import api, constants, storage, util -from determined.common.api import bindings, certs +from determined.common.api import authentication, bindings, certs from determined.common.storage import shared logger = logging.getLogger("determined.core") @@ -226,9 +226,8 @@ def init( # We are on the cluster. cert = certs.default_load(info.master_url) - session = api.Session( - info.master_url, None, None, cert, max_retries=util.get_max_retries_config() - ) + utp = authentication.login_with_cache(info.master_url, cert=cert) + session = api.Session(info.master_url, utp, cert, max_retries=util.get_max_retries_config()) if distributed is None: if len(info.container_addrs) > 1 or len(info.slot_ids) > 1: diff --git a/harness/determined/deploy/healthcheck.py b/harness/determined/deploy/healthcheck.py index 0b337aed872..c18a1a22a4b 100644 --- a/harness/determined/deploy/healthcheck.py +++ b/harness/determined/deploy/healthcheck.py @@ -4,7 +4,7 @@ import requests from determined.common import api -from determined.common.api import certs +from determined.common.api import authentication, certs from .errors import MasterTimeoutExpired @@ -38,7 +38,8 @@ def wait_for_master_url( try: while time.time() - start_time < timeout: try: - r = api.get(master_url, "info", authenticated=False, cert=cert) + sess = api.UnauthSession(master_url, cert=cert) + r = sess.get("info") if r.status_code == requests.codes.ok: return except api.errors.MasterNotFoundException: @@ -63,14 +64,15 @@ def wait_for_genai_url( POLL_INTERVAL = 2 polling = False start_time = time.time() - GENAI_PREFIX = "/genai" - check_path = GENAI_PREFIX + "/api/v1/workspaces" + + # Hopefully we have an active session to this master, or we can make a default one. + utp = authentication.login_with_cache(master_url, cert=cert) + sess = api.Session(master_url, utp, cert) try: while time.time() - start_time < timeout: try: - auth = api.Authentication(master_address=master_url, cert=cert) - r = api.get(master_url, check_path, authenticated=True, cert=cert, auth=auth) + r = sess.get("genai/api/v1/workspaces") if r.status_code == requests.codes.ok: _ = r.json() return diff --git a/harness/determined/exec/gc_checkpoints.py b/harness/determined/exec/gc_checkpoints.py index 440af6d168a..ec4f3122425 100644 --- a/harness/determined/exec/gc_checkpoints.py +++ b/harness/determined/exec/gc_checkpoints.py @@ -12,8 +12,8 @@ import determined as det from determined import errors, tensorboard -from determined.common import api, constants, storage, util -from determined.common.api import bindings, certs +from determined.common import api, constants, storage +from determined.common.api import authentication, bindings, certs logger = logging.getLogger("determined") @@ -25,16 +25,10 @@ def patch_checkpoints(storage_ids_to_resources: Dict[str, Dict[str, int]]) -> No info._to_file() cert = certs.default_load(info.master_url) - sess = api.Session( - info.master_url, - util.get_det_username_from_env(), - None, - cert, - max_retries=urllib3.util.retry.Retry( - total=6, # With backoff retries for 64 seconds - backoff_factor=0.5, - ), - ) + utp = authentication.login_with_cache(info.master_url, cert=cert) + # With backoff retries for 64 seconds + max_retries = urllib3.util.retry.Retry(total=6, backoff_factor=0.5) + sess = api.Session(info.master_url, utp, cert, max_retries) checkpoints = [] for storage_id, resources in storage_ids_to_resources.items(): diff --git a/harness/determined/exec/harness.py b/harness/determined/exec/harness.py index 8ab9a970ac3..74e37992f0f 100644 --- a/harness/determined/exec/harness.py +++ b/harness/determined/exec/harness.py @@ -7,7 +7,7 @@ import determined as det from determined import core, horovod, load -from determined.common.api import analytics, certs +from determined.common.api import analytics logger = logging.getLogger("determined") @@ -28,9 +28,6 @@ def main(train_entrypoint: str) -> int: assert info is not None, "must be run on-cluster" assert info.task_type == "TRIAL", f'must be run with task_type="TRIAL", not "{info.task_type}"' - # TODO: refactor profiling to to not use the cli_cert. - certs.cli_cert = certs.default_load(info.master_url) - trial_class = load.trial_class_from_entrypoint(train_entrypoint) if info.container_rank == 0: diff --git a/harness/determined/exec/launch.py b/harness/determined/exec/launch.py index abfe00770b7..5543a015fee 100644 --- a/harness/determined/exec/launch.py +++ b/harness/determined/exec/launch.py @@ -7,8 +7,8 @@ import types import determined as det -import determined.common from determined.common import api, constants, storage +from determined.common.api import certs from determined.exec import prep_container logger = logging.getLogger("determined") @@ -21,9 +21,9 @@ def trigger_preemption(signum: int, frame: types.FrameType) -> None: # Chief container, requests preemption, others ignore logger.info("SIGTERM: Preemption imminent.") # Notify the master that we need to be preempted - api.post( - info.master_url, f"/api/v1/allocations/{info.allocation_id}/signals/pending_preemption" - ) + cert = certs.default_load(info.master_url) + sess = api.UnauthSession(info.master_url, cert) + sess.post(f"/api/v1/allocations/{info.allocation_id}/signals/pending_preemption") def launch(experiment_config: det.ExperimentConfig) -> int: @@ -75,7 +75,7 @@ def launch(experiment_config: det.ExperimentConfig) -> int: # Hack: read the full config. The experiment config is not a stable API! experiment_config = det.ExperimentConfig(info.trial._config) - determined.common.set_logger(experiment_config.debug_enabled()) + det.common.set_logger(experiment_config.debug_enabled()) logger.info( f"New trial runner in (container {resources_id}) on agent {info.agent_id}: " diff --git a/harness/determined/exec/prep_container.py b/harness/determined/exec/prep_container.py index cc6de67e50f..7d98ef86fdf 100644 --- a/harness/determined/exec/prep_container.py +++ b/harness/determined/exec/prep_container.py @@ -14,10 +14,9 @@ import urllib3 import determined as det -import determined.util from determined import constants, gpu -from determined.common import api, util -from determined.common.api import bindings, certs +from determined.common import api +from determined.common.api import authentication, bindings, certs logger = logging.getLogger("determined") @@ -311,16 +310,10 @@ def do_proxy(sess: api.Session, allocation_id: str) -> None: ) cert = certs.default_load(info.master_url) - sess = api.Session( - info.master_url, - util.get_det_username_from_env(), - None, - cert, - max_retries=urllib3.util.retry.Retry( - total=6, # With backoff retries for 64 seconds - backoff_factor=0.5, - ), - ) + utp = authentication.login_with_cache(info.master_url, cert=cert) + # With backoff retries for 64 seconds + max_retries = urllib3.util.retry.Retry(total=6, backoff_factor=0.5) + sess = api.Session(info.master_url, utp, cert, max_retries) # Notify the Determined Master that the container is running. # This should only be used on HPC clusters. diff --git a/harness/determined/experimental/__init__.py b/harness/determined/experimental/__init__.py index d7a805a95f6..463e6ed559b 100644 --- a/harness/determined/experimental/__init__.py +++ b/harness/determined/experimental/__init__.py @@ -12,7 +12,6 @@ ModelVersion, OrderBy, Project, - Session, Trial, TrialOrderBy, TrialReference, diff --git a/harness/determined/experimental/client.py b/harness/determined/experimental/client.py index e4e75b6f4b4..9ef82fc620f 100644 --- a/harness/determined/experimental/client.py +++ b/harness/determined/experimental/client.py @@ -50,7 +50,7 @@ import warnings from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TypeVar, Union -from determined.common.api import Session # noqa: F401 +from determined.common import api as _api from determined.common.experimental._util import OrderBy from determined.common.experimental.checkpoint import ( # noqa: F401 Checkpoint, @@ -67,7 +67,12 @@ ExperimentState, ) from determined.common.experimental.metrics import TrainingMetrics, TrialMetrics, ValidationMetrics -from determined.common.experimental.model import Model, ModelOrderBy, ModelSortBy # noqa: F401 +from determined.common.experimental.model import ( # noqa: F401 + Model, + ModelOrderBy, + ModelSortBy, + ModelVersion, +) from determined.common.experimental.oauth2_scim_client import Oauth2ScimClient from determined.common.experimental.project import Project # noqa: F401 from determined.common.experimental.resource_pool import ResourcePool # noqa: F401 @@ -631,6 +636,6 @@ def stream_trials_validation_metrics(trial_ids: List[int]) -> Iterable[Validatio @_require_singleton -def _get_singleton_session() -> Session: +def _get_singleton_session() -> _api.Session: assert _determined is not None return _determined._session diff --git a/harness/determined/experimental/core_v2/_core_context_v2.py b/harness/determined/experimental/core_v2/_core_context_v2.py index 66e9cd8e412..75d11141b8b 100644 --- a/harness/determined/experimental/core_v2/_core_context_v2.py +++ b/harness/determined/experimental/core_v2/_core_context_v2.py @@ -7,7 +7,7 @@ import determined as det from determined import core, experimental, tensorboard from determined.common import api, constants, storage, util -from determined.common.api import certs +from determined.common.api import authentication, certs logger = logging.getLogger("determined.core") @@ -43,9 +43,8 @@ def _make_v2_context( # We are on the cluster. cert = certs.default_load(info.master_url) - session = api.Session( - info.master_url, None, None, cert, max_retries=util.get_max_retries_config() - ) + utp = authentication.login_with_cache(info.master_url, cert=cert) + session = api.Session(info.master_url, utp, cert, util.get_max_retries_config()) else: unmanaged = True diff --git a/harness/determined/experimental/core_v2/_unmanaged.py b/harness/determined/experimental/core_v2/_unmanaged.py index 7d7ab08f830..406ee3fec6b 100644 --- a/harness/determined/experimental/core_v2/_unmanaged.py +++ b/harness/determined/experimental/core_v2/_unmanaged.py @@ -198,8 +198,7 @@ def _build_unmanaged_trial_cluster_info( assert sess cluster_id = _get_cluster_id(sess) - assert sess._auth - token = sess._auth.get_session_token(True) + token = sess.token resp = _start_trial(client, trial_id, resume, distributed) diff --git a/harness/determined/launch/deepspeed.py b/harness/determined/launch/deepspeed.py index 7b501c6f0a8..a2309c7f51e 100644 --- a/harness/determined/launch/deepspeed.py +++ b/harness/determined/launch/deepspeed.py @@ -22,7 +22,7 @@ import determined.common from determined import constants, util from determined.common import api -from determined.common.api import certs +from determined.common.api import authentication, certs hostfile_path = None deepspeed_version = version.parse(deepspeed.__version__) @@ -245,10 +245,6 @@ def main(script: List[str]) -> int: resources_id = os.environ.get("DET_RESOURCES_ID") assert resources_id is not None, "Unable to run with DET_RESOURCES_ID unset" - # TODO: refactor websocket and profiling to to not use the cli_cert. - cert = certs.default_load(info.master_url) - certs.cli_cert = cert - # The launch layer should provide the chief_ip to the training code, so that the training code # can function with a different launch layer in a different environment. Inside Determined, the # easiest way to get the chief_ip is with container_addrs. @@ -290,11 +286,10 @@ def main(script: List[str]) -> int: # Mark sshd containers as daemon containers that the master should kill when all non-daemon # containers (deepspeed launcher, in this case) have exited. - api.post( - info.master_url, - path=f"/api/v1/allocations/{info.allocation_id}/resources/{resources_id}/daemon", - cert=cert, - ) + cert = certs.default_load(info.master_url) + utp = authentication.login_with_cache(info.master_url, cert=cert) + sess = api.Session(info.master_url, utp, cert) + sess.post(f"/api/v1/allocations/{info.allocation_id}/resources/{resources_id}/daemon") # Wrap it in a pid_server to ensure that we can't hang if a worker fails. # This is useful for deepspeed which does not have good error handling for remote processes diff --git a/harness/determined/launch/horovod.py b/harness/determined/launch/horovod.py index de77a0f004f..f6476b02fdd 100644 --- a/harness/determined/launch/horovod.py +++ b/harness/determined/launch/horovod.py @@ -15,7 +15,7 @@ import determined as det from determined import horovod, util from determined.common import api -from determined.common.api import certs +from determined.common.api import authentication, certs from determined.constants import DTRAIN_SSH_PORT logger = logging.getLogger("determined.launch.horovod") @@ -115,10 +115,6 @@ def main(hvd_args: List[str], script: List[str], autohorovod: bool) -> int: if debug: logging.getLogger().setLevel(logging.DEBUG) - # TODO: refactor websocket and profiling to to not use the cli_cert. - cert = certs.default_load(info.master_url) - certs.cli_cert = cert - # The launch layer should provide the chief_ip to the training code, so that the training code # can function with a different launch layer in a different environment. Inside Determined, the # easiest way to get the chief_ip is with container_addrs. @@ -132,11 +128,10 @@ def main(hvd_args: List[str], script: List[str], autohorovod: bool) -> int: # Mark sshd containers as daemon resources that the master should kill when all non-daemon # containers (horovodrun, in this case) have exited. - api.post( - info.master_url, - path=f"/api/v1/allocations/{info.allocation_id}/resources/{resources_id}/daemon", - cert=cert, - ) + cert = certs.default_load(info.master_url) + utp = authentication.login_with_cache(info.master_url, cert=cert) + sess = api.Session(info.master_url, utp, cert) + sess.post(f"/api/v1/allocations/{info.allocation_id}/resources/{resources_id}/daemon") pid_server_cmd, run_sshd_command = create_sshd_worker_cmd( info.allocation_id, len(info.slot_ids), debug=debug diff --git a/harness/determined/profiler.py b/harness/determined/profiler.py index 7eb670cd6db..1f3cbbcdfd3 100644 --- a/harness/determined/profiler.py +++ b/harness/determined/profiler.py @@ -9,10 +9,9 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast import psutil +from requests import exceptions -import determined as det from determined.common import api, check -from determined.common.api import TrialProfilerMetricsBatch MAX_COLLECTION_SECONDS = 300 @@ -223,16 +222,120 @@ def pop_until_deadline(q: queue.Queue, deadline: float) -> Iterator[Any]: break -def profiling_metrics_exist(master_url: str, trial_id: str) -> bool: +class TrialProfilerMetricsBatch: + """ + TrialProfilerMetricsBatch is the representation of a batch of trial + profiler metrics as accepted by POST /api/v1/trials/:trial_id/profiler/metrics + """ + + def __init__( + self, + values: List[float], + batches: List[int], + timestamps: List[str], + labels: Dict[str, Any], + ): + self.values = values + self.batches = batches + self.timestamps = timestamps + self.labels = labels + + +def post_trial_profiler_metrics_batches( + sess: api.Session, + batches: List[TrialProfilerMetricsBatch], +) -> None: + """ + Post the given metrics to the master to be persisted. Labels + must contain only a subset of the keys: trial_id, name, + gpu_uuid, agent_id and metric_type, where metric_type is one + of PROFILER_METRIC_TYPE_SYSTEM or PROFILER_METRIC_TYPE_TIMING. + """ + backoff_interval = 1 + max_tries = 2 + tries = 0 + + while tries < max_tries: + try: + sess.post( + "/api/v1/trials/profiler/metrics", + json={"batches": [b.__dict__ for b in batches]}, + ) + return + except exceptions.RequestException as e: + if e.response is not None and e.response.status_code < 500: + raise e + + tries += 1 + if tries == max_tries: + raise e + time.sleep(backoff_interval) + return + + +class TrialProfilerSeriesLabels: + def __init__(self, trial_id: int, name: str, agent_id: str, gpu_uuid: str, metric_type: str): + self.trial_id = str(trial_id) + self.name = name + self.agent_id = agent_id + self.gpu_uuid = gpu_uuid if gpu_uuid != "" else None # type: Optional[str] + self.metric_type = metric_type + + +def get_trial_profiler_available_series( + sess: api.Session, + trial_id: str, +) -> List[TrialProfilerSeriesLabels]: + """ + Get available profiler series for a trial. This uses the non-streaming version of the API + """ + follow = False + backoff_interval = 1 + max_tries = 2 + tries = 0 + + response = None + while tries < max_tries: + try: + response = sess.get( + f"/api/v1/trials/{trial_id}/profiler/available_series", + params={"follow": follow}, + ) + break + except exceptions.RequestException as e: + if e.response is not None and e.response.status_code < 500: + raise e + + tries += 1 + if tries == max_tries: + raise e + time.sleep(backoff_interval) + + assert response + j = response.json() + labels = [ + TrialProfilerSeriesLabels( + trial_id=ld["trialId"], + name=ld["name"], + agent_id=ld["agentId"], + gpu_uuid=ld["gpuUuid"], + metric_type=ld["metricType"], + ) + for ld in j["result"]["labels"] + ] + return labels + + +def profiling_metrics_exist(sess: api.Session, trial_id: str) -> bool: """ Return True if there are already profiling metrics for the trial. """ - series_labels = api.get_trial_profiler_available_series(master_url, trial_id) + series_labels = get_trial_profiler_available_series(sess, trial_id) return len(series_labels) > 0 -SendBatchFnType = Callable[[str, List[TrialProfilerMetricsBatch]], None] -CheckDataExistsFnType = Callable[[str, str], bool] +SendBatchFnType = Callable[[api.Session, List[TrialProfilerMetricsBatch]], None] +CheckDataExistsFnType = Callable[[api.Session, str], bool] class ProfilerAgent: @@ -268,22 +371,22 @@ class ProfilerAgent: def __init__( self, + session: api.Session, trial_id: str, agent_id: str, - master_url: str, profiling_is_enabled: bool, global_rank: int, local_rank: int, begin_on_batch: int, sync_timings: bool, end_after_batch: Optional[int] = None, - send_batch_fn: SendBatchFnType = api.post_trial_profiler_metrics_batches, + send_batch_fn: SendBatchFnType = post_trial_profiler_metrics_batches, check_data_exists_fn: CheckDataExistsFnType = profiling_metrics_exist, ): self.current_batch_idx = 0 + self.session = session self.trial_id = trial_id self.agent_id = agent_id - self.master_url = master_url self.profiling_is_enabled_in_experiment_config = profiling_is_enabled self.global_rank = global_rank self.local_rank = local_rank @@ -308,7 +411,7 @@ def __init__( self.pynvml_wrapper = PynvmlWrapper() self.disabled_due_to_preexisting_metrics = self.check_data_already_exists_fn( - self.master_url, self.trial_id + self.session, self.trial_id ) if self.disabled_due_to_preexisting_metrics and self.global_rank == 0: logger.warning( @@ -342,27 +445,12 @@ def __init__( ) self.sender_thread = ProfilerSenderThread( - self.send_queue, self.master_url, num_producers, self.send_batch_fn + self.send_queue, self.session, num_producers, self.send_batch_fn ) def _set_sync_device(self, sync_device: Callable[[], None]) -> None: self.sync_device = sync_device - @staticmethod - def from_env(env: det.EnvContext, global_rank: int, local_rank: int) -> "ProfilerAgent": - begin_on_batch, end_after_batch = env.experiment_config.profiling_interval() - return ProfilerAgent( - trial_id=env.det_trial_id, - agent_id=env.det_agent_id, - master_url=env.master_url, - profiling_is_enabled=env.experiment_config.profiling_enabled(), - global_rank=global_rank, - local_rank=local_rank, - begin_on_batch=begin_on_batch, - end_after_batch=end_after_batch, - sync_timings=env.experiment_config.profiling_sync_timings(), - ) - # Launch the children threads. This does not mean 'start collecting metrics' def start(self) -> None: if not self.is_enabled: @@ -897,12 +985,12 @@ class ProfilerSenderThread(threading.Thread): def __init__( self, inbound_queue: queue.Queue, - master_url: str, + session: api.Session, num_producers: int, send_batch_fn: SendBatchFnType, ) -> None: - self.master_url = master_url self.inbound_queue = inbound_queue + self.session = session self.num_producers = num_producers self.producers_shutdown = 0 self.send_batch_fn = send_batch_fn @@ -918,7 +1006,7 @@ def run(self) -> None: else: continue self.send_batch_fn( - self.master_url, + self.session, message, ) diff --git a/harness/determined/pytorch/_trainer.py b/harness/determined/pytorch/_trainer.py index 74bc61c1612..b643663a976 100644 --- a/harness/determined/pytorch/_trainer.py +++ b/harness/determined/pytorch/_trainer.py @@ -69,9 +69,10 @@ def configure_profiler( assert self._info, "Determined profiler must be run on cluster." self._det_profiler = profiler.ProfilerAgent( + # XXX: bad hack + session=self._core.train._session, trial_id=str(self._info.trial.trial_id), agent_id=self._info.agent_id, - master_url=self._info.master_url, profiling_is_enabled=enabled, global_rank=self._core.distributed.get_rank(), local_rank=self._core.distributed.get_local_rank(), diff --git a/harness/determined/searcher/_search_runner.py b/harness/determined/searcher/_search_runner.py index d7919c643e1..e03361d74bd 100644 --- a/harness/determined/searcher/_search_runner.py +++ b/harness/determined/searcher/_search_runner.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union from determined import searcher +from determined.common import api from determined.common.api import bindings, errors from determined.experimental import client @@ -118,7 +119,7 @@ def _get_operations(self, event: bindings.v1SearcherEvent) -> List[searcher.Oper def run_experiment( self, experiment_id: int, - session: client.Session, + session: api.Session, prior_operations: Optional[List[searcher.Operation]], sleep_time: float = 1.0, ) -> None: @@ -186,7 +187,7 @@ def run_experiment( def post_operations( self, - session: client.Session, + session: api.Session, experiment_id: int, event: bindings.v1SearcherEvent, operations: List[searcher.Operation], @@ -221,7 +222,7 @@ def post_operations( def get_events( self, - session: client.Session, + session: api.Session, experiment_id: int, ) -> Optional[Sequence[bindings.v1SearcherEvent]]: # API is implemented with long polling. diff --git a/harness/tests/checkpoints/test_checkpoint.py b/harness/tests/checkpoints/test_checkpoint.py index 8f0304d495f..5f7a0ef2475 100644 --- a/harness/tests/checkpoints/test_checkpoint.py +++ b/harness/tests/checkpoints/test_checkpoint.py @@ -6,7 +6,8 @@ from responses import matchers from determined.common import api -from determined.common.experimental import Checkpoint +from determined.common.api import authentication +from determined.experimental import client def get_long_str(approx_len: int) -> str: @@ -78,8 +79,9 @@ def test_checkpoint_download_via_master(tmp_path: Path) -> None: ) checkpoint_path = tmp_path / uuid_tgz - Checkpoint._download_via_master( - api.Session(master="https://dummy-master.none", user=None, auth=None, cert=None), + utp = authentication.UsernameTokenPair("username", "token") + client.Checkpoint._download_via_master( + api.Session("https://dummy-master.none", utp, cert=None), uuid_tgz, checkpoint_path, ) diff --git a/harness/tests/cli/test_auth.py b/harness/tests/cli/test_auth.py index fd68411d7c3..af00ce8170f 100644 --- a/harness/tests/cli/test_auth.py +++ b/harness/tests/cli/test_auth.py @@ -236,11 +236,11 @@ class Login(ScenarioSet): ], ) @mock.patch("determined.common.api.authentication._is_token_valid") -@mock.patch("determined.common.api.authentication.do_login") +@mock.patch("determined.common.api.authentication.login") @mock.patch("getpass.getpass") def test_login_scenarios( mock_getpass: mock.MagicMock, - mock_do_login: mock.MagicMock, + mock_login: mock.MagicMock, mock_is_token_valid: mock.MagicMock, scenario_set: Login, ) -> None: @@ -250,12 +250,12 @@ def getpass(*_: Any) -> str: def _is_token_valid(master_url: str, token: str, cert: Any) -> bool: return token in ["cache", "env_token"] - def do_login(*_: Any) -> str: - return "new" + def login(master_address: str, username: str, *_: Any) -> authentication.UsernameTokenPair: + return authentication.UsernameTokenPair(username, "new") mock_getpass.side_effect = getpass mock_is_token_valid.side_effect = _is_token_valid - mock_do_login.side_effect = do_login + mock_login.side_effect = login for scenario in scenario_set.scenarios(): with contextlib.ExitStack() as es: @@ -280,12 +280,12 @@ def do_login(*_: Any) -> str: # function will never show evidence that it checked the environment. However, for the # remainder of the authentication flow, it never matters again that the user came from # the TokenStore.get_active_user(), so if we modeled that in the table it wouldn't - # actually increase code coverage for the Authentication object, which is what we are - # focusing on. + # actually increase code coverage for the logout_with_cache function, which is what we + # are focusing on. mts.get_active_user(retval=scenario.cache and "user") try: - auth = authentication.Authentication( + utp = authentication.login_with_cache( "master_url", scenario.req_user, scenario.req_pass, None ) @@ -296,12 +296,12 @@ def do_login(*_: Any) -> str: [mock.call("master_url", exp.token, None)] ) elif isinstance(exp, DoLogin): - mock_do_login.assert_has_calls( + mock_login.assert_has_calls( [mock.call("master_url", exp.username, exp.password, None)] ) elif isinstance(exp, Use): - assert auth.session.username == exp.user - assert auth.session.token == exp.token + assert utp.username == exp.user + assert utp.token == exp.token else: raise ValueError(f"unexpected result: {exp}") @@ -309,7 +309,7 @@ def do_login(*_: Any) -> str: if not any(isinstance(exp, Check) for exp in scenario_set.expected): mock_is_token_valid.assert_not_called() if not any(isinstance(exp, DoLogin) for exp in scenario_set.expected): - mock_do_login.assert_not_called() + mock_login.assert_not_called() except Exception as e: raise RuntimeError( @@ -324,7 +324,6 @@ def __init__(self, user: str) -> None: self.user = user def expect(self, rsps: responses.RequestsMock, mts: util.MockTokenStore, scenario: Any) -> None: - mts.get_active_user(retval=None) mts.get_token(self.user, retval="cache_token" if scenario.user_in_cache else None) @@ -341,6 +340,8 @@ def expect(self, rsps: responses.RequestsMock, mts: util.MockTokenStore, scenari status=200, match=[matchers.header_matcher({"Authorization": "Bearer cache_token"})], ) + mts.get_active_user(retval=self.user) + mts.clear_active() class GetActiveUser: @@ -372,7 +373,7 @@ class Logout(ScenarioSet): # Cache active user is logged out. Logout("n", "n", "y", GetActiveUser(), GetToken("cache_user"), DoLogout("cache_user")), # When no user is found, it is a noop. - Logout("n", "n", "n", GetActiveUser()), + Logout("n", "n", "n", GetActiveUser(), GetToken("determined")), ], ) def test_logout(scenario_set: Logout) -> None: @@ -403,3 +404,34 @@ def test_logout(scenario_set: Logout) -> None: else: cmd = ["user", "logout"] cli.main(cmd) + + +def test_logout_all() -> None: + with util.MockTokenStore(strict=True) as mts: + with responses.RequestsMock( + registry=registries.OrderedRegistry, assert_all_requests_are_fired=True + ) as rsps: + # Every active user should get logged out. + mts.get_all_users(retval=["u1", "u2"]) + + util.expect_get_info(rsps) + + mts.get_token("u1", retval="t1") + rsps.post( + f"{MOCK_MASTER_URL}/api/v1/auth/logout", + status=200, + match=[matchers.header_matcher({"Authorization": "Bearer t1"})], + ) + mts.drop_user("u1") + + # Unauthenticated errors are ignored during logout. + mts.get_token("u2", retval="t2") + rsps.post( + f"{MOCK_MASTER_URL}/api/v1/auth/logout", + status=401, + match=[matchers.header_matcher({"Authorization": "Bearer t2"})], + ) + mts.drop_user("u2") + mts.clear_active() + + cli.main(["user", "logout", "--all"]) diff --git a/harness/tests/cli/test_experiment.py b/harness/tests/cli/test_experiment.py index 80067e269b6..2c9fc5b3b9a 100644 --- a/harness/tests/cli/test_experiment.py +++ b/harness/tests/cli/test_experiment.py @@ -6,9 +6,7 @@ import determined import determined.cli -from determined.common import constants -from determined.common.api import bindings, certs -from determined.common.api.authentication import Authentication +from determined.common.api import authentication, bindings from tests.fixtures import api_responses @@ -21,34 +19,12 @@ class CliArgs: polling_interval: float = 0.01 # Short polling interval so we can run tests quickly -def mock_det_auth(user: str = "test", master_url: str = "http://localhost:8888") -> Authentication: - with mock.Mocker() as mocker: - mocker.get(master_url + "/api/v1/me", status_code=200, json={"username": user}) - fake_user = {"username": user, "admin": True, "active": True} - mocker.post( - master_url + "/api/v1/auth/login", - status_code=200, - json={"token": "fake-token", "user": fake_user}, - ) - mocker.get("/info", status_code=200, json={"version": "1.0"}) - mocker.get( - "/users/me", status_code=200, json={"username": constants.DEFAULT_DETERMINED_USER} - ) - auth = Authentication( - master_address=master_url, - requested_user=user, - password="password1", - cert=certs.Cert(noverify=True), - ) - return auth - - -@unittest.mock.patch("determined.common.api.authentication.Authentication") +@unittest.mock.patch("determined.common.api.authentication.login_with_cache") def test_wait_returns_error_code_when_experiment_errors( - auth_mock: unittest.mock.MagicMock, + login_with_cache_mock: unittest.mock.MagicMock, requests_mock: mock.Mocker, ) -> None: - auth_mock.return_value = mock_det_auth() + login_with_cache_mock.return_value = authentication.UsernameTokenPair("username", "token") exp = api_responses.sample_get_experiment(id=1, state=bindings.experimentv1State.COMPLETED) args = CliArgs(master="http://localhost:8888", experiment_id=1) exp.experiment.state = bindings.experimentv1State.ERROR @@ -58,5 +34,5 @@ def test_wait_returns_error_code_when_experiment_errors( json=exp.to_json(), ) with pytest.raises(SystemExit) as e: - determined.cli.experiment.wait(args) + determined.cli.experiment.wait(args) # type: ignore assert e.value.code == 1 diff --git a/harness/tests/common/api/test_authentication.py b/harness/tests/common/api/test_authentication.py index b44615a0d51..f35f33d722f 100644 --- a/harness/tests/common/api/test_authentication.py +++ b/harness/tests/common/api/test_authentication.py @@ -1,43 +1,14 @@ -import contextlib import json import pathlib import shutil -from typing import Optional - -import pytest -import responses -from responses import registries from determined.common.api import authentication from tests import confdir -from tests.cli import util MOCK_MASTER_URL = "http://localhost:8080" AUTH_V0_PATH = pathlib.Path(__file__).parent / "auth_v0.json" -@pytest.mark.parametrize("active_user", ["alice", "bob", None]) -def test_logout_clears_active_user(active_user: Optional[str]) -> None: - with contextlib.ExitStack() as es: - es.enter_context(util.setenv_optional("DET_MASTER", MOCK_MASTER_URL)) - rsps = es.enter_context( - responses.RequestsMock( - registry=registries.OrderedRegistry, - assert_all_requests_are_fired=True, - ) - ) - mts = es.enter_context(util.MockTokenStore(strict=True)) - - mts.get_active_user(retval=active_user) - if active_user == "alice": - mts.clear_active() - mts.get_token("alice", retval="token") - mts.drop_user("alice") - rsps.post(f"{MOCK_MASTER_URL}/api/v1/auth/logout", status=200) - - authentication.logout(MOCK_MASTER_URL, "alice", None) - - def test_auth_json_v0_upgrade() -> None: with confdir.use_test_config_dir() as config_dir: auth_json_path = config_dir / "auth.json" diff --git a/harness/tests/common/test_tls.py b/harness/tests/common/test_tls.py index d6de974eefb..75eb059a94b 100644 --- a/harness/tests/common/test_tls.py +++ b/harness/tests/common/test_tls.py @@ -8,7 +8,8 @@ import pytest import requests -from determined.common.api import certs, request +from determined.common import api +from determined.common.api import certs TRUSTED_DOMAIN = "https://google.com" UNTRUSTED_DIR = os.path.join(os.path.dirname(__file__), "untrusted-root") @@ -58,9 +59,9 @@ def test_custom_tls_certs() -> None: cert = certs.Cert(**kwargs) # Trusted domains should always work. - request.get(TRUSTED_DOMAIN, "", authenticated=False, cert=cert) + api.UnauthSession(TRUSTED_DOMAIN, cert=cert).get("") with contextlib.ExitStack() as ctx: if raises: ctx.enter_context(pytest.raises(requests.exceptions.SSLError)) - request.get(untrusted_url, "", authenticated=False, cert=cert) + api.UnauthSession(untrusted_url, cert=cert).get("") diff --git a/harness/tests/custom_search_mocks.py b/harness/tests/custom_search_mocks.py index 79de600495e..79ba7231ed0 100644 --- a/harness/tests/custom_search_mocks.py +++ b/harness/tests/custom_search_mocks.py @@ -5,8 +5,8 @@ from unittest.mock import Mock from determined import searcher +from determined.common import api from determined.common.api import bindings -from determined.experimental import client class MockMaster(metaclass=abc.ABCMeta): @@ -116,7 +116,7 @@ def __init__( def post_operations( self, - session: client.Session, + session: api.Session, experiment_id: int, event: bindings.v1SearcherEvent, operations: List[searcher.Operation], @@ -126,7 +126,7 @@ def post_operations( def get_events( self, - session: client.Session, + session: api.Session, experiment_id: int, ) -> Optional[Sequence[bindings.v1SearcherEvent]]: logging.info("MockMasterSearchRunner.get_events") @@ -151,7 +151,7 @@ def run( super(MockMasterSearchRunner, self).save_state(exp_id, []) experiment_id = exp_id operations: Optional[List[searcher.Operation]] = None - session: client.Session = Mock() + session: api.Session = Mock() super(MockMasterSearchRunner, self).run_experiment( experiment_id, session, operations, sleep_time=0.0 ) diff --git a/harness/tests/determined/common/experimental/test_checkpoint.py b/harness/tests/determined/common/experimental/test_checkpoint.py index be3b703086f..8cf9b017cbb 100644 --- a/harness/tests/determined/common/experimental/test_checkpoint.py +++ b/harness/tests/determined/common/experimental/test_checkpoint.py @@ -7,6 +7,7 @@ import responses from determined.common import api, storage +from determined.common.api import authentication from determined.common.experimental import checkpoint from tests.fixtures import api_responses @@ -17,7 +18,8 @@ @pytest.fixture def standard_session() -> api.Session: - return api.Session(master=_MASTER, user=None, auth=None, cert=None) + utp = authentication.UsernameTokenPair("username", "token") + return api.Session(_MASTER, utp, cert=None) @pytest.fixture diff --git a/harness/tests/determined/common/experimental/test_determined.py b/harness/tests/determined/common/experimental/test_determined.py index 91945fa8b88..7b44ac5a32a 100644 --- a/harness/tests/determined/common/experimental/test_determined.py +++ b/harness/tests/determined/common/experimental/test_determined.py @@ -1,45 +1,25 @@ import math -from typing import Callable, List +from typing import List from unittest import mock import pytest import responses from determined.common.api import authentication, errors -from determined.common.experimental import determined, experiment +from determined.experimental import client as _client from tests.fixtures import api_responses _MASTER = "http://localhost:8080" -Determined = determined.Determined - -@pytest.fixture -@mock.patch("determined.common.api.authentication.Authentication") -def mock_default_auth(auth_mock: mock.MagicMock) -> None: - responses.get(f"{_MASTER}/api/v1/me", status=200, json={"username": api_responses.USERNAME}) - responses.post( - f"{_MASTER}/api/v1/auth/login", - status=200, - json=api_responses.sample_login(username=api_responses.USERNAME).to_json(), - ) - auth_mock.return_value = authentication.Authentication( - master_address=_MASTER, - requested_user=api_responses.USERNAME, - password=api_responses.PASSWORD, - ) - - -@pytest.fixture -def make_client(mock_default_auth: Callable) -> Callable[[], Determined]: - def _make_client() -> Determined: - return Determined(master=_MASTER) - - return _make_client +def make_client() -> _client.Determined: + with mock.patch("determined.common.api.authentication.login_with_cache") as mock_login: + mock_login.return_value = authentication.UsernameTokenPair("username", "token") + return _client.Determined(_MASTER) @responses.activate -def test_default_retry_retries_transient_failures(make_client: Callable[[], Determined]) -> None: +def test_default_retry_retries_transient_failures() -> None: client = make_client() model_resp = api_responses.sample_get_model() @@ -55,7 +35,7 @@ def test_default_retry_retries_transient_failures(make_client: Callable[[], Dete @responses.activate -def test_default_retry_retries_until_max(make_client: Callable[[], Determined]) -> None: +def test_default_retry_retries_until_max() -> None: client = make_client() model_resp = api_responses.sample_get_model() get_model_url = f"{_MASTER}/api/v1/models/{model_resp.model.name}" @@ -69,7 +49,7 @@ def test_default_retry_retries_until_max(make_client: Callable[[], Determined]) @responses.activate -def test_default_retry_fails_after_max_retries(make_client: Callable[[], Determined]) -> None: +def test_default_retry_fails_after_max_retries() -> None: client = make_client() model_resp = api_responses.sample_get_model() get_model_url = f"{_MASTER}/api/v1/models/{model_resp.model.name}" @@ -80,7 +60,7 @@ def test_default_retry_fails_after_max_retries(make_client: Callable[[], Determi @responses.activate -def test_default_retry_doesnt_retry_post(make_client: Callable[[], Determined]) -> None: +def test_default_retry_doesnt_retry_post() -> None: client = make_client() model_resp = api_responses.sample_get_model() create_model_url = f"{_MASTER}/api/v1/models" @@ -95,10 +75,7 @@ def test_default_retry_doesnt_retry_post(make_client: Callable[[], Determined]) [502, 503, 504], ) @responses.activate -def test_default_retry_retries_status_forcelist( - make_client: Callable[[], Determined], - status: List[int], -) -> None: +def test_default_retry_retries_status_forcelist(status: List[int]) -> None: client = make_client() model_resp = api_responses.sample_get_model() get_model_url = f"{_MASTER}/api/v1/models/{model_resp.model.name}" @@ -114,10 +91,7 @@ def test_default_retry_retries_status_forcelist( [400, 404, 500], ) @responses.activate -def test_default_retry_doesnt_retry_allowed_status( - make_client: Callable[[], Determined], - status: List[int], -) -> None: +def test_default_retry_doesnt_retry_allowed_status(status: List[int]) -> None: client = make_client() model_resp = api_responses.sample_get_model() get_model_url = f"{_MASTER}/api/v1/models/{model_resp.model.name}" @@ -130,9 +104,7 @@ def test_default_retry_doesnt_retry_allowed_status( @pytest.mark.parametrize("attribute", ["summary_metrics", "state"]) @responses.activate -def test_get_trial_populates_attribute( - make_client: Callable[[], Determined], attribute: str -) -> None: +def test_get_trial_populates_attribute(attribute: str) -> None: client = make_client() trial_id = 1 tr_resp = api_responses.sample_get_trial(id=trial_id) @@ -146,20 +118,17 @@ def test_get_trial_populates_attribute( @responses.activate @mock.patch("determined.common.api.bindings.get_GetExperiments") -def test_list_experiments_calls_bindings_with_params( - mock_bindings: mock.MagicMock, - make_client: Callable[[], Determined], -) -> None: +def test_list_experiments_calls_bindings_with_params(mock_bindings: mock.MagicMock) -> None: client = make_client() exps_resp = api_responses.sample_get_experiments() params = { - "sort_by": experiment.ExperimentSortBy.ID, - "order_by": determined.OrderBy.ASCENDING, + "sort_by": _client.ExperimentSortBy.ID, + "order_by": _client.OrderBy.ASCENDING, "experiment_ids": list(range(10)), "labels": ["label1", "label2"], "users": ["user1", "user2"], - "states": [experiment.ExperimentState.COMPLETED, experiment.ExperimentState.ACTIVE], + "states": [_client.ExperimentState.COMPLETED, _client.ExperimentState.ACTIVE], "name": "exp name", "project_id": 1, } @@ -187,10 +156,7 @@ def test_list_experiments_calls_bindings_with_params( @responses.activate @mock.patch("determined.common.api.bindings.get_GetExperiments") -def test_list_experiments_returns_all_response_pages( - mock_bindings: mock.MagicMock, - make_client: Callable[[], Determined], -) -> None: +def test_list_experiments_returns_all_response_pages(mock_bindings: mock.MagicMock) -> None: client = make_client() exps_resp = api_responses.sample_get_experiments() total_exps = len(exps_resp.experiments) diff --git a/harness/tests/determined/common/experimental/test_experiment.py b/harness/tests/determined/common/experimental/test_experiment.py index db93b790b6d..2fda7575624 100644 --- a/harness/tests/determined/common/experimental/test_experiment.py +++ b/harness/tests/determined/common/experimental/test_experiment.py @@ -9,7 +9,7 @@ import responses from determined.common import api -from determined.common.api import bindings +from determined.common.api import authentication, bindings from determined.common.experimental import checkpoint, determined, experiment from tests.fixtures import api_responses @@ -18,7 +18,8 @@ @pytest.fixture def standard_session() -> api.Session: - return api.Session(master=_MASTER, user=None, auth=None, cert=None) + utp = authentication.UsernameTokenPair("username", "token") + return api.Session(_MASTER, utp, cert=None) @pytest.fixture diff --git a/harness/tests/determined/common/experimental/test_model.py b/harness/tests/determined/common/experimental/test_model.py index 3e562aa96e1..f63ad503cac 100644 --- a/harness/tests/determined/common/experimental/test_model.py +++ b/harness/tests/determined/common/experimental/test_model.py @@ -4,6 +4,7 @@ import responses from determined.common import api +from determined.common.api import authentication from determined.common.experimental import model from tests.fixtures import api_responses @@ -12,7 +13,8 @@ @pytest.fixture def standard_session() -> api.Session: - return api.Session(master=_MASTER, user=None, auth=None, cert=None) + utp = authentication.UsernameTokenPair("username", "token") + return api.Session(_MASTER, utp, cert=None) @pytest.fixture diff --git a/harness/tests/determined/common/experimental/test_project.py b/harness/tests/determined/common/experimental/test_project.py index 971840079de..2e9f4093e5b 100644 --- a/harness/tests/determined/common/experimental/test_project.py +++ b/harness/tests/determined/common/experimental/test_project.py @@ -4,7 +4,7 @@ import responses from determined.common import api -from determined.common.api import bindings +from determined.common.api import authentication, bindings from determined.common.experimental import project from tests.fixtures import api_responses @@ -13,7 +13,8 @@ @pytest.fixture def standard_session() -> api.Session: - return api.Session(master=_MASTER, user=None, auth=None, cert=None) + utp = authentication.UsernameTokenPair("username", "token") + return api.Session(_MASTER, utp, cert=None) @pytest.fixture diff --git a/harness/tests/determined/common/experimental/test_resource_pool.py b/harness/tests/determined/common/experimental/test_resource_pool.py index 656bfbbc67f..7812a311ae1 100644 --- a/harness/tests/determined/common/experimental/test_resource_pool.py +++ b/harness/tests/determined/common/experimental/test_resource_pool.py @@ -1,7 +1,7 @@ import pytest from determined.common import api -from determined.common.api import bindings +from determined.common.api import authentication, bindings from tests.fixtures import api_responses _MASTER = "http://localhost:8080" @@ -9,7 +9,8 @@ @pytest.fixture def standard_session() -> api.Session: - return api.Session(master=_MASTER, user=None, auth=None, cert=None) + utp = authentication.UsernameTokenPair("username", "token") + return api.Session(_MASTER, utp, cert=None) @pytest.fixture diff --git a/harness/tests/determined/common/experimental/test_trial.py b/harness/tests/determined/common/experimental/test_trial.py index b9cf4afbe39..2dd2a6ff626 100644 --- a/harness/tests/determined/common/experimental/test_trial.py +++ b/harness/tests/determined/common/experimental/test_trial.py @@ -6,6 +6,7 @@ import responses from determined.common import api +from determined.common.api import authentication from determined.common.experimental import checkpoint, determined, trial from tests.fixtures import api_responses @@ -14,7 +15,8 @@ @pytest.fixture def standard_session() -> api.Session: - return api.Session(master=_MASTER, user=None, auth=None, cert=None) + utp = authentication.UsernameTokenPair("username", "token") + return api.Session(_MASTER, utp, cert=None) @pytest.fixture diff --git a/harness/tests/determined/common/experimental/test_workspace.py b/harness/tests/determined/common/experimental/test_workspace.py index 4f29398f8b0..f8a703f04e3 100644 --- a/harness/tests/determined/common/experimental/test_workspace.py +++ b/harness/tests/determined/common/experimental/test_workspace.py @@ -2,7 +2,7 @@ import responses from determined.common import api -from determined.common.api import bindings, errors +from determined.common.api import authentication, bindings, errors from determined.common.experimental import workspace from tests.fixtures import api_responses @@ -11,7 +11,8 @@ @pytest.fixture def standard_session() -> api.Session: - return api.Session(master=_MASTER, user=None, auth=None, cert=None) + utp = authentication.UsernameTokenPair("username", "token") + return api.Session(_MASTER, utp, cert=None) @pytest.fixture diff --git a/harness/tests/determined/common/storage/__init__.py b/harness/tests/determined/common/storage/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/harness/tests/determined/common/storage/test_s3.py b/harness/tests/determined/common/storage/test_s3.py deleted file mode 100644 index cea000296d1..00000000000 --- a/harness/tests/determined/common/storage/test_s3.py +++ /dev/null @@ -1,51 +0,0 @@ -import os - -import boto3 -import moto -import pytest - -from determined.common import api, storage -from determined.common.experimental import checkpoint -from tests.fixtures import api_responses - -_MASTER = "http://localhost:8080" - - -@pytest.fixture -def standard_session() -> api.Session: - return api.Session(master=_MASTER, user=None, auth=None, cert=None) - - -@pytest.fixture -def sample_checkpoint(standard_session: api.Session) -> checkpoint.Checkpoint: - bindings_checkpoint = api_responses.sample_get_checkpoint().checkpoint - return checkpoint.Checkpoint._from_bindings(bindings_checkpoint, standard_session) - - -@moto.mock_s3 -def test_download_simple_checkpoint( - sample_checkpoint: checkpoint.Checkpoint, tmp_path: os.PathLike -) -> None: - metadata_payload = "{'determined_version': '0.22.2-dev0'}" - if sample_checkpoint.training is None or sample_checkpoint.training.experiment_config is None: - raise ValueError( - "Test depends on an existing experiment_config within the tested checkpoint." - ) - storage_conf = sample_checkpoint.training.experiment_config["checkpoint_storage"] - storage_conf.update({"type": "s3", "secret_key": None, "endpoint_url": None, "prefix": None}) - - s3_client = boto3.client("s3") - s3_client.create_bucket(Bucket=storage_conf["bucket"]) - s3_client.put_object( - Body=bytes(metadata_payload, "utf-8"), - Bucket=storage_conf["bucket"], - Key=f"{sample_checkpoint.uuid}/metadata.json", - ) - - storage_manager = storage.build(storage_conf, container_path=None) - storage_manager.download(sample_checkpoint.uuid, str(tmp_path)) - - downloaded_metadata_path = os.path.join(tmp_path, "metadata.json") - assert os.path.exists(downloaded_metadata_path) - with open(downloaded_metadata_path, "r") as f: - assert f.read() == metadata_payload diff --git a/harness/tests/determined/pytorch/experimental/test_torch_batch_process.py b/harness/tests/determined/pytorch/experimental/test_torch_batch_process.py index 988dd89c5a2..b5ceadcadaf 100644 --- a/harness/tests/determined/pytorch/experimental/test_torch_batch_process.py +++ b/harness/tests/determined/pytorch/experimental/test_torch_batch_process.py @@ -75,6 +75,7 @@ def _get_dist_context( ) -> unittest.mock.MagicMock: mock_distributed_context = unittest.mock.MagicMock() mock_distributed_context.get_rank.return_value = rank + mock_distributed_context.local_rank = rank mock_distributed_context.broadcast.return_value = "mock_checkpoint_uuid" mock_distributed_context.allgather.return_value = all_gather_return_value mock_distributed_context.gather.return_value = gather_return_value diff --git a/harness/tests/launch/test_deepspeed.py b/harness/tests/launch/test_deepspeed.py index 19a2a6b2907..1b117f8aeba 100644 --- a/harness/tests/launch/test_deepspeed.py +++ b/harness/tests/launch/test_deepspeed.py @@ -245,9 +245,13 @@ def test_launch_fail(mock_cluster_info: mock.MagicMock, mock_subprocess: mock.Ma @mock.patch("subprocess.Popen") @mock.patch("determined.get_cluster_info") -@mock.patch("determined.common.api.post") +@mock.patch("determined.common.api.authentication.login_with_cache") +@mock.patch("determined.common.api._session.Session.post") def test_launch_worker( - mock_api: mock.MagicMock, mock_cluster_info: mock.MagicMock, mock_subprocess: mock.MagicMock + mock_post: mock.MagicMock, + mock_login: mock.MagicMock, + mock_cluster_info: mock.MagicMock, + mock_subprocess: mock.MagicMock, ) -> None: cluster_info = test_util.make_mock_cluster_info(["0.0.0.0", "0.0.0.1"], 1, 4) mock_cluster_info.return_value = cluster_info @@ -257,7 +261,8 @@ def test_launch_worker( mock_cluster_info.assert_called_once() assert os.environ["DET_CHIEF_IP"] == cluster_info.container_addrs[0] - mock_api.assert_called_once() + mock_login.assert_called_once() + mock_post.assert_called_once() pid_server_cmd = launch.deepspeed.create_pid_server_cmd( cluster_info.allocation_id, len(cluster_info.slot_ids) diff --git a/harness/tests/launch/test_horovod.py b/harness/tests/launch/test_horovod.py index 8da719297ff..5776702b21c 100644 --- a/harness/tests/launch/test_horovod.py +++ b/harness/tests/launch/test_horovod.py @@ -6,7 +6,6 @@ import determined.launch.horovod # noqa: F401 from determined import constants, horovod, launch -from determined.common.api import certs from tests.launch import test_util @@ -84,6 +83,9 @@ def test_horovod_chief( mock_popen.return_value = mock_proc + os.environ.pop("DET_CHIEF_IP", None) + os.environ.pop("USE_HOROVOD", None) + with test_util.set_resources_id_env_var(): assert launch.horovod.main(hvd_args, script, autohorovod) == 99 @@ -96,6 +98,8 @@ def test_horovod_chief( mock_cluster_info.assert_called_once() assert os.environ["DET_CHIEF_IP"] == info.container_addrs[0] assert os.environ["USE_HOROVOD"] == "1" + del os.environ["DET_CHIEF_IP"] + del os.environ["USE_HOROVOD"] mock_popen.assert_has_calls([mock.call(launch_cmd)]) @@ -112,9 +116,11 @@ def test_horovod_chief( @mock.patch("subprocess.Popen") @mock.patch("determined.get_cluster_info") -@mock.patch("determined.common.api.post") +@mock.patch("determined.common.api.authentication.login_with_cache") +@mock.patch("determined.common.api._session.Session.post") def test_sshd_worker( - mock_api_post: mock.MagicMock, + mock_post: mock.MagicMock, + mock_login: mock.MagicMock, mock_cluster_info: mock.MagicMock, mock_popen: mock.MagicMock, ) -> None: @@ -135,23 +141,20 @@ def test_sshd_worker( mock_popen.return_value = mock_proc + os.environ.pop("DET_CHIEF_IP", None) + with test_util.set_resources_id_env_var(): assert launch.horovod.main(hvd_args, script, True) == 99 mock_cluster_info.assert_called_once() assert os.environ["DET_CHIEF_IP"] == info.container_addrs[0] - assert os.environ["USE_HOROVOD"] == "1" + del os.environ["DET_CHIEF_IP"] mock_popen.assert_has_calls([mock.call(launch_cmd)]) - mock_api_post.assert_has_calls( - [ - mock.call( - info.master_url, - path=f"/api/v1/allocations/{info.allocation_id}/resources/resourcesId/daemon", - cert=certs.cli_cert, - ) - ] + mock_login.assert_called_once() + mock_post.assert_has_calls( + [mock.call(f"/api/v1/allocations/{info.allocation_id}/resources/resourcesId/daemon")] ) mock_proc.wait.assert_called_once() diff --git a/master/internal/checkpoint_gc.go b/master/internal/checkpoint_gc.go index 827ace10443..1cd0005355f 100644 --- a/master/internal/checkpoint_gc.go +++ b/master/internal/checkpoint_gc.go @@ -18,6 +18,7 @@ import ( "github.com/determined-ai/determined/master/internal/sproto" "github.com/determined-ai/determined/master/internal/storage" "github.com/determined-ai/determined/master/internal/task" + "github.com/determined-ai/determined/master/internal/user" "github.com/determined-ai/determined/master/pkg/logger" "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/master/pkg/protoutils/protoconverter" @@ -113,6 +114,11 @@ func runCheckpointGCTask( } taskSpec.TaskContainerDefaults = tcd + userSessionToken, err := user.StartSession(context.TODO(), owner) + if err != nil { + return errors.Wrapf(err, "unable to create user session for checkpoint gc") + } + taskSpec.UserSessionToken = userSessionToken taskSpec.AgentUserGroup = agentUserGroup taskSpec.Owner = owner diff --git a/master/internal/restore.go b/master/internal/restore.go index a6d43968c68..dfe924ba906 100644 --- a/master/internal/restore.go +++ b/master/internal/restore.go @@ -119,6 +119,12 @@ func (m *Master) restoreExperiment(expModel *model.Experiment) error { } taskSpec.Owner = owner + token, err := user.StartSession(context.Background(), owner) + if err != nil { + return fmt.Errorf("unable to create user session inside task: %w", err) + } + taskSpec.UserSessionToken = token + log.WithField("experiment", expModel.ID).Debug("restoring experiment") snapshot, err := m.retrieveExperimentSnapshot(expModel) if err != nil { diff --git a/master/static/srv/check_ready_logs.py b/master/static/srv/check_ready_logs.py index df7e1c4bad2..006133bb495 100644 --- a/master/static/srv/check_ready_logs.py +++ b/master/static/srv/check_ready_logs.py @@ -12,42 +12,40 @@ from requests.exceptions import RequestException from determined.common import api -from determined.common.api import certs +from determined.common.api import authentication, certs BACKOFF_SECONDS = 5 -def post_ready(master_url: str, cert: certs.Cert, allocation_id: str, state: str): +def post_ready(sess: api.Session, allocation_id: str, state: str): # Since the service is virtually inaccessible by the user unless # the call completes, we may as well try forever or just wait for # them to kill us. while True: try: - api.post( - master_url, - f"/api/v1/allocations/{allocation_id}/{state}", - {}, - cert=cert, - ) + sess.post(f"/api/v1/allocations/{allocation_id}/{state}") return except RequestException as e: if e.response is not None and e.response.status_code < 500: raise e - time.sleep(BACKOFF_SECONDS) def main(ready: Pattern, waiting: Optional[Pattern] = None): master_url = str(os.environ["DET_MASTER"]) cert = certs.default_load(master_url) + # This only runs on-cluster, so it is expected the username and session token are present in the + # environment. + utp = authentication.login_with_cache(master_url, cert=cert) + sess = api.Session(master_url, utp, cert) allocation_id = str(os.environ["DET_ALLOCATION_ID"]) for line in sys.stdin: if ready.match(line): - post_ready(master_url, cert, allocation_id, "ready") + post_ready(sess, allocation_id, "ready") return if waiting and waiting.match(line): - post_ready(master_url, cert, allocation_id, "waiting") - + post_ready(sess, allocation_id, "waiting") + return if __name__ == "__main__": parser = argparse.ArgumentParser(description="Read STDIN for a match and mark a task as ready") diff --git a/master/static/srv/command-entrypoint.sh b/master/static/srv/command-entrypoint.sh index d5d8af0d055..2bb3291e6af 100644 --- a/master/static/srv/command-entrypoint.sh +++ b/master/static/srv/command-entrypoint.sh @@ -4,10 +4,6 @@ source /run/determined/task-setup.sh set -e -if [ -z "$DET_PYTHON_EXECUTABLE" ]; then - export DET_PYTHON_EXECUTABLE="python3" -fi - # In order to be able to use a proxy when running a command, Python must be # available in the container, and the "determined*.whl" must be installed, # which contains the "determined/exec/prep_container.py" script that's needed diff --git a/master/static/srv/entrypoint.sh b/master/static/srv/entrypoint.sh index 3cd91b15b14..42de0ab068b 100755 --- a/master/static/srv/entrypoint.sh +++ b/master/static/srv/entrypoint.sh @@ -4,28 +4,9 @@ source /run/determined/task-setup.sh set -e -STARTUP_HOOK="startup-hook.sh" -export PATH="/run/determined/pythonuserbase/bin:$PATH" - -# If HOME is not explicitly set for a container, libcontainer (Docker) will -# try to guess it by reading /etc/password directly, which will not work with -# our libnss_determined plugin (or any user-defined NSS plugin in a container). -# The default is "/", but HOME must be a writable location for distributed -# training, so we try to query the user system for a valid HOME, or default to -# the working directory otherwise. -if [ "$HOME" = "/" ]; then - HOME="$( - set -o pipefail - getent passwd "$(whoami)" | cut -d: -f6 - )" || HOME="$PWD" - export HOME -fi - -if [ -z "$DET_PYTHON_EXECUTABLE" ]; then - export DET_PYTHON_EXECUTABLE="python3" -fi "$DET_PYTHON_EXECUTABLE" -m determined.exec.prep_container --download_context_directory --resources --proxy +STARTUP_HOOK="startup-hook.sh" set -x test -f "${STARTUP_HOOK}" && source "${STARTUP_HOOK}" set +x diff --git a/master/static/srv/gc-checkpoints-entrypoint.sh b/master/static/srv/gc-checkpoints-entrypoint.sh index 398d9637db6..2e84a2fb921 100644 --- a/master/static/srv/gc-checkpoints-entrypoint.sh +++ b/master/static/srv/gc-checkpoints-entrypoint.sh @@ -4,11 +4,6 @@ source /run/determined/task-setup.sh set -e -export PATH="/run/determined/pythonuserbase/bin:$PATH" -if [ -z "$DET_PYTHON_EXECUTABLE" ]; then - export DET_PYTHON_EXECUTABLE="python3" -fi - "$DET_PYTHON_EXECUTABLE" -m determined.exec.prep_container exec "$DET_PYTHON_EXECUTABLE" -m determined.exec.gc_checkpoints "$@" diff --git a/master/static/srv/notebook-entrypoint.sh b/master/static/srv/notebook-entrypoint.sh index 9c4511cc673..4b2319a313a 100755 --- a/master/static/srv/notebook-entrypoint.sh +++ b/master/static/srv/notebook-entrypoint.sh @@ -4,26 +4,6 @@ source /run/determined/task-setup.sh set -e -STARTUP_HOOK="startup-hook.sh" -export PATH="/run/determined/pythonuserbase/bin:$PATH" -if [ -z "$DET_PYTHON_EXECUTABLE" ]; then - export DET_PYTHON_EXECUTABLE="python3" -fi - -# If HOME is not explicitly set for a container, libcontainer (Docker) will -# try to guess it by reading /etc/password directly, which will not work with -# our libnss_determined plugin (or any user-defined NSS plugin in a container). -# The default is "/", but HOME must be a writable location for distributed -# training, so we try to query the user system for a valid HOME, or default to -# the working directory otherwise. -if [ "$HOME" = "/" ]; then - HOME="$( - set -o pipefail - getent passwd "$(whoami)" | cut -d: -f6 - )" || HOME="$PWD" - export HOME -fi - # Use user's preferred SHELL in JupyterLab terminals. SHELL="$( set -o pipefail @@ -33,6 +13,7 @@ export SHELL "$DET_PYTHON_EXECUTABLE" -m determined.exec.prep_container --resources --proxy --download_context_directory +STARTUP_HOOK="startup-hook.sh" set -x test -f "${STARTUP_HOOK}" && source "${STARTUP_HOOK}" set +x diff --git a/master/static/srv/shell-entrypoint.sh b/master/static/srv/shell-entrypoint.sh index 37180d06c29..7bfd347ab61 100755 --- a/master/static/srv/shell-entrypoint.sh +++ b/master/static/srv/shell-entrypoint.sh @@ -4,18 +4,9 @@ source /run/determined/task-setup.sh set -e -STARTUP_HOOK="startup-hook.sh" -export PATH="/run/determined/pythonuserbase/bin:$PATH" -if [ -z "$DET_PYTHON_EXECUTABLE" ]; then - export DET_PYTHON_EXECUTABLE="python3" -fi - -# Unlike trial and notebook entrypoints, the HOME directory does not need to be -# modified in this entrypoint because the HOME in the user's ssh session is set -# by sshd at a later time. - "$DET_PYTHON_EXECUTABLE" -m determined.exec.prep_container --resources --proxy --download_context_directory +STARTUP_HOOK="startup-hook.sh" set -x test -f "${STARTUP_HOOK}" && source "${STARTUP_HOOK}" set +x diff --git a/master/static/srv/task-setup.sh b/master/static/srv/task-setup.sh index fc334099be1..16ae28e0ca5 100644 --- a/master/static/srv/task-setup.sh +++ b/master/static/srv/task-setup.sh @@ -39,3 +39,17 @@ if [ "$DET_RESOURCES_TYPE" == "slurm-job" ]; then # the image, or "Running", meaning that all containers are running. "$DET_PYTHON_EXECUTABLE" -m determined.exec.prep_container --notify_container_running fi + +# If HOME is not explicitly set for a container, libcontainer (Docker) will +# try to guess it by reading /etc/password directly, which will not work with +# our libnss_determined plugin (or any user-defined NSS plugin in a container). +# The default is "/", but HOME must be a writable location for distributed +# training, so we try to query the user system for a valid HOME, or default to +# the working directory otherwise. +if [ "$HOME" = "/" ]; then + HOME="$( + set -o pipefail + getent passwd "$(whoami)" | cut -d: -f6 + )" || HOME="$PWD" + export HOME +fi diff --git a/master/static/srv/tensorboard-entrypoint.sh b/master/static/srv/tensorboard-entrypoint.sh index b8b2947615a..fd2bc0c95f7 100755 --- a/master/static/srv/tensorboard-entrypoint.sh +++ b/master/static/srv/tensorboard-entrypoint.sh @@ -4,12 +4,6 @@ source /run/determined/task-setup.sh set -e -STARTUP_HOOK="startup-hook.sh" -export PATH="/run/determined/pythonuserbase/bin:$PATH" -if [ -z "$DET_PYTHON_EXECUTABLE" ]; then - export DET_PYTHON_EXECUTABLE="python3" -fi - if [ -z "$DET_SKIP_PIP_INSTALL" ]; then # Install tensorboard if not already installed (for custom PyTorch images) "$DET_PYTHON_EXECUTABLE" -m pip install tensorboard tensorboard-plugin-profile @@ -17,6 +11,7 @@ fi "$DET_PYTHON_EXECUTABLE" -m determined.exec.prep_container --proxy --download_context_directory +STARTUP_HOOK="startup-hook.sh" set -x test -f "${STARTUP_HOOK}" && source "${STARTUP_HOOK}" set +x