Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 11, 2024
1 parent de807f7 commit 5b3b80b
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 68 deletions.
2 changes: 1 addition & 1 deletion adaptive_scheduler/_scheduler/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def start_job(self, name: str, *, index: int | None = None) -> None:
submit_cmd = f"{self.submit_cmd} {name} {self.batch_fname(name_prefix)}"
run_submit(submit_cmd, name)

def extra_scheduler(self, *, index: int | None = None) -> str: # noqa: ARG002
def extra_scheduler(self, *, index: int | None = None) -> str:
"""Get the extra scheduler options."""
msg = "extra_scheduler is not implemented."
raise NotImplementedError(msg)
52 changes: 31 additions & 21 deletions example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,25 @@
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"import adaptive_scheduler\n",
"import random\n",
"\n",
"\n",
"def h(x, width=0.01, offset=0):\n",
" for _ in range(10): # Burn some CPU time just because\n",
" np.linalg.eig(np.random.rand(1000, 1000))\n",
" return x + width ** 2 / (width ** 2 + (x - offset) ** 2)\n",
" return x + width**2 / (width**2 + (x - offset) ** 2)\n",
"\n",
"\n",
"# Define the sequence/samples we want to run\n",
"xs = np.linspace(0, 1, 10_000)\n",
"\n",
"# ⚠️ Here a `learner` is an `adaptive` concept, read it as `jobs`.\n",
"# ⚠️ `fnames` are the result locations\n",
"learners, fnames = adaptive_scheduler.utils.split_sequence_in_sequence_learners(\n",
" h, xs, n_learners=10\n",
" h,\n",
" xs,\n",
" n_learners=10,\n",
")\n",
"\n",
"run_manager = adaptive_scheduler.slurm_run(\n",
Expand All @@ -48,7 +52,7 @@
" nodes=1, # number of nodes per `learner`\n",
" cores_per_node=1, # number of cores on 1 node per `learner`\n",
" log_interval=5, # how often to produce a log message\n",
" save_interval=5, # how often to save the results\n",
" save_interval=5, # how often to save the results\n",
")\n",
"run_manager.start()"
]
Expand Down Expand Up @@ -85,18 +89,18 @@
"from functools import partial\n",
"\n",
"import adaptive\n",
"\n",
"import adaptive_scheduler\n",
"\n",
"\n",
"def h(x, width=0.01, offset=0):\n",
" import numpy as np\n",
" import random\n",
"\n",
" for _ in range(10): # Burn some CPU time just because\n",
" np.linalg.eig(np.random.rand(1000, 1000))\n",
"\n",
" a = width\n",
" return x + a ** 2 / (a ** 2 + (x - offset) ** 2)\n",
" return x + a**2 / (a**2 + (x - offset) ** 2)\n",
"\n",
"\n",
"offsets = [i / 10 - 0.5 for i in range(5)]\n",
Expand Down Expand Up @@ -266,16 +270,16 @@
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from adaptive import SequenceLearner\n",
"from adaptive_scheduler.utils import split, combo_to_fname\n",
"\n",
"from adaptive_scheduler.utils import split\n",
"\n",
"\n",
"def g(xyz):\n",
" x, y, z = xyz\n",
" for _ in range(5): # Burn some CPU time just because\n",
" np.linalg.eig(np.random.rand(1000, 1000))\n",
" return x ** 2 + y ** 2 + z ** 2\n",
" return x**2 + y**2 + z**2\n",
"\n",
"\n",
"xs = np.linspace(0, 10, 11)\n",
Expand All @@ -302,11 +306,17 @@
"\n",
"\n",
"scheduler = adaptive_scheduler.scheduler.DefaultScheduler(\n",
" cores=10, executor_type=\"ipyparallel\",\n",
" cores=10,\n",
" executor_type=\"ipyparallel\",\n",
") # PBS or SLURM\n",
"\n",
"run_manager2 = adaptive_scheduler.server_support.RunManager(\n",
" scheduler, learners, fnames, goal=goal, log_interval=30, save_interval=30,\n",
" scheduler,\n",
" learners,\n",
" fnames,\n",
" goal=goal,\n",
" log_interval=30,\n",
" save_interval=30,\n",
")\n",
"run_manager2.start()"
]
Expand Down Expand Up @@ -343,19 +353,19 @@
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from adaptive import SequenceLearner\n",
"from adaptive_scheduler.utils import split, combo2fname\n",
"from adaptive.utils import named_product\n",
"\n",
"from adaptive_scheduler.utils import combo2fname\n",
"\n",
"\n",
"def g(combo):\n",
" x, y, z = combo[\"x\"], combo[\"y\"], combo[\"z\"]\n",
"\n",
" for _ in range(5): # Burn some CPU time just because\n",
" np.linalg.eig(np.random.rand(1000, 1000))\n",
"\n",
" return x ** 2 + y ** 2 + z ** 2\n",
" return x**2 + y**2 + z**2\n",
"\n",
"\n",
"combos = named_product(x=np.linspace(0, 10), y=np.linspace(-1, 1), z=np.linspace(-3, 3))\n",
Expand All @@ -364,15 +374,15 @@
"\n",
"# We could run this as 1 job with N nodes, but we can also split it up in multiple jobs.\n",
"# This is desireable when you don't want to run a single job with 300 nodes for example.\n",
"# Note that \n",
"# Note that\n",
"# `adaptive_scheduler.utils.split_sequence_in_sequence_learners(g, combos, 100, \"data\")`\n",
"# does the same!\n",
"\n",
"njobs = 100\n",
"split_combos = list(split(combos, njobs))\n",
"\n",
"print(\n",
" f\"Length of split_combos: {len(split_combos)} and length of split_combos[0]: {len(split_combos[0])}.\"\n",
" f\"Length of split_combos: {len(split_combos)} and length of split_combos[0]: {len(split_combos[0])}.\",\n",
")\n",
"\n",
"learners = [SequenceLearner(g, combos_part) for combos_part in split_combos]\n",
Expand All @@ -393,17 +403,16 @@
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"import adaptive_scheduler\n",
"from adaptive_scheduler.scheduler import DefaultScheduler, PBS, SLURM\n",
"from adaptive_scheduler.scheduler import SLURM, DefaultScheduler\n",
"\n",
"\n",
"def goal(learner):\n",
" return learner.done() # the standard goal for a SequenceLearner\n",
"\n",
"\n",
"extra_scheduler = (\n",
" [\"--exclusive\", \"--time=24:00:00\"] if DefaultScheduler is SLURM else []\n",
")\n",
"extra_scheduler = [\"--exclusive\", \"--time=24:00:00\"] if DefaultScheduler is SLURM else []\n",
"\n",
"scheduler = adaptive_scheduler.scheduler.DefaultScheduler(\n",
" cores=10,\n",
Expand Down Expand Up @@ -459,7 +468,8 @@
"source": [
"run_manager3.load_learners() # load the data into the learners\n",
"result = sum(\n",
" [l.result() for l in learners], []\n",
" [l.result() for l in learners],\n",
" [],\n",
") # combine all learner's result into 1 list"
]
}
Expand Down
14 changes: 7 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
import zmq.asyncio


@pytest.fixture()
@pytest.fixture
def mock_scheduler(tmp_path: Path) -> MockScheduler:
"""Fixture for creating a MockScheduler instance."""
return MockScheduler(log_folder=str(tmp_path), cores=8)


@pytest.fixture()
@pytest.fixture
def db_manager(
mock_scheduler: MockScheduler,
learners: list[adaptive.Learner1D]
Expand Down Expand Up @@ -99,14 +99,14 @@ def fnames(
raise NotImplementedError(msg)


@pytest.fixture()
@pytest.fixture
def socket(db_manager: DatabaseManager) -> zmq.asyncio.Socket:
"""Fixture for creating a ZMQ socket."""
with get_socket(db_manager) as socket:
yield socket


@pytest.fixture()
@pytest.fixture
def job_manager(
db_manager: DatabaseManager,
mock_scheduler: MockScheduler,
Expand All @@ -116,7 +116,7 @@ def job_manager(
return JobManager(job_names, db_manager, mock_scheduler, interval=0.05)


@pytest.fixture()
@pytest.fixture
def _mock_slurm_partitions_output() -> Generator[None, None, None]:
"""Mock `slurm_partitions` function."""
mock_output = "hb120v2-low\nhb60-high\nnc24-low*\nnd40v2-mpi\n"
Expand All @@ -125,7 +125,7 @@ def _mock_slurm_partitions_output() -> Generator[None, None, None]:
yield


@pytest.fixture()
@pytest.fixture
def _mock_slurm_partitions() -> Generator[None, None, None]:
"""Mock `slurm_partitions` function."""
with (
Expand All @@ -141,7 +141,7 @@ def _mock_slurm_partitions() -> Generator[None, None, None]:
yield


@pytest.fixture()
@pytest.fixture
def _mock_slurm_queue() -> Generator[None, None, None]:
"""Mock `SLURM.queue` function."""
with patch(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_client_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def client(zmq_url: str) -> zmq.Socket:
return client


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_get_learner(zmq_url: str) -> None:
"""Test `get_learner` function."""
with tempfile.NamedTemporaryFile() as tmpfile:
Expand Down Expand Up @@ -94,7 +94,7 @@ async def test_get_learner(zmq_url: str) -> None:
mock_log.exception.assert_called_with("got an exception")


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_tell_done(zmq_url: str) -> None:
"""Test `tell_done` function."""
fname = "test_learner_file.pkl"
Expand Down
14 changes: 7 additions & 7 deletions tests/test_database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_simple_database_get_all(tmp_path: Path) -> None:
assert done_entries[1][1].fname == "file3.txt"


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_database_manager_start_and_cancel(db_manager: DatabaseManager) -> None:
"""Test starting and canceling the DatabaseManager."""
db_manager.start()
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_database_manager_as_dicts(
]


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_database_manager_dispatch_start_stop(
db_manager: DatabaseManager,
learners: list[adaptive.Learner1D]
Expand Down Expand Up @@ -205,7 +205,7 @@ async def test_database_manager_dispatch_start_stop(
assert entry.is_done is True


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_database_manager_start_and_update(
socket: zmq.asyncio.Socket,
db_manager: DatabaseManager,
Expand Down Expand Up @@ -259,7 +259,7 @@ async def test_database_manager_start_and_update(
assert entry.job_id is None


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_database_manager_start_stop(
socket: zmq.asyncio.Socket,
db_manager: DatabaseManager,
Expand Down Expand Up @@ -322,7 +322,7 @@ async def test_database_manager_start_stop(
await send_message(socket, start_message)


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_database_manager_stop_request_and_requests(
socket: zmq.asyncio.Socket,
db_manager: DatabaseManager,
Expand Down Expand Up @@ -531,7 +531,7 @@ def test_ensure_str_invalid_input(invalid_input: list[str]) -> None:
_ensure_str(invalid_input) # type: ignore[arg-type]


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_dependencies(
db_manager: DatabaseManager,
fnames: list[str] | list[Path],
Expand Down Expand Up @@ -599,7 +599,7 @@ async def test_dependencies(
db_manager._choose_fname()


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_replace_learner(db_manager: DatabaseManager) -> None:
"""Test replacing a learner in the DatabaseManager."""
db_manager.create_empty_db()
Expand Down
16 changes: 8 additions & 8 deletions tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
from adaptive_scheduler.server_support import JobManager, MaxRestartsReachedError


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_job_manager_init(job_manager: JobManager) -> None:
"""Test the initialization of JobManager."""
job_manager.database_manager.start()
job_manager.start()
assert job_manager.task is not None


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_job_manager_queued(job_manager: JobManager) -> None:
"""Test the _queued method of JobManager."""
job_manager.scheduler.start_job("job1")
Expand All @@ -30,7 +30,7 @@ async def test_job_manager_queued(job_manager: JobManager) -> None:
assert job_manager._queued(job_manager.scheduler.queue()) == {"job1", "job2"}


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_job_manager_manage_max_restarts_reached(job_manager: JobManager) -> None:
"""Test the JobManager when the maximum restarts are reached."""
job_manager.n_started = 105
Expand All @@ -48,7 +48,7 @@ async def test_job_manager_manage_max_restarts_reached(job_manager: JobManager)
job_manager.task.result()


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_job_manager_manage_start_jobs(job_manager: JobManager) -> None:
"""Test the JobManager when managing the start of jobs."""
job_manager.database_manager.n_done = MagicMock(return_value=0) # type: ignore[method-assign]
Expand All @@ -60,7 +60,7 @@ async def test_job_manager_manage_start_jobs(job_manager: JobManager) -> None:
assert set(job_manager.scheduler._started_jobs) == {"job1", "job2"} # type: ignore[attr-defined]


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_job_manager_manage_start_max_simultaneous_jobs(
job_manager: JobManager,
) -> None:
Expand All @@ -76,7 +76,7 @@ async def test_job_manager_manage_start_max_simultaneous_jobs(
assert len(job_manager.scheduler._started_jobs) == 1 # type: ignore[attr-defined]


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_job_manager_manage_cancelled_error(
job_manager: JobManager,
caplog: pytest.LogCaptureFixture,
Expand All @@ -100,7 +100,7 @@ async def test_job_manager_manage_cancelled_error(
assert "task was cancelled because of a CancelledError" in caplog.text


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_job_manager_manage_n_done_equal_job_names(
job_manager: JobManager,
) -> None:
Expand All @@ -116,7 +116,7 @@ async def test_job_manager_manage_n_done_equal_job_names(
assert job_manager.task.result() is None


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_job_manager_manage_generic_exception(
job_manager: JobManager,
caplog: pytest.LogCaptureFixture,
Expand Down
Loading

0 comments on commit 5b3b80b

Please sign in to comment.