diff --git a/examples/eval.py b/examples/eval.py index ff338538..c3d20f90 100644 --- a/examples/eval.py +++ b/examples/eval.py @@ -1,4 +1,5 @@ import base64 +import copy import dataclasses import multiprocessing import re @@ -65,7 +66,7 @@ def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: tests = [] lines = content.splitlines() - match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + match = r"\s*([a-zA-Z_]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" for line in lines: parts = line.split(";") case = {} @@ -123,18 +124,19 @@ def calculate_stats(durations: list[int]): worst=float(worst)) -def _clone_data(data): +def _clone_data(data, rank: int): """ Recursively goes through data and clones all tensors. """ if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) + return tuple(_clone_data(x, rank) for x in data) elif isinstance(data, list): - return [_clone_data(x) for x in data] + return [_clone_data(x, rank) for x in data] elif isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} + return {k: _clone_data(v, rank) for k, v in data.items()} elif isinstance(data, torch.Tensor): - return data.clone() + device = f"cuda:{rank}" + return data.clone().to(device) else: return data @@ -157,16 +159,61 @@ def _run_single_test(test: TestCase): from submission import custom_kernel data = generate_input(**test.args) torch.cuda.synchronize() - submission_output = custom_kernel(_clone_data(data)) + submission_output = custom_kernel(_clone_data(data, 0)) torch.cuda.synchronize() return wrap_check_implementation(data, submission_output) +def _run_distributed_test(test: TestCase, rank: int): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + import torch.distributed as dist + world_size = test.args["world_size"] + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12356" + dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, device_id=torch.device(f'cuda:{rank}')) + try: + data = generate_input(**test.args, rank=rank) + torch.cuda.synchronize() + submission_output = custom_kernel(_clone_data(data, rank)) + torch.cuda.synchronize() + return wrap_check_implementation(data, submission_output) + finally: + dist.destroy_process_group() + + +def run_multi_gpu_test(pool: multiprocessing.Pool, test: TestCase, world_size: int): + """ + Runs a single test in another process. + """ + rets = [] + # world_size is a mandatory argument for multi-gpu tests + for i in range(world_size): + rets.append( + pool.apply_async( + _run_distributed_test, + args=(test, i), + ) + ) + # 60 seconds should be more than enough, we want tests to be fast + rets = [el.get(60) for el in rets] + + correct = all(ret[0] for ret in rets) + error_messages = str.join("\n", [f"rank {rank} - {ret[1]}" for rank, ret in enumerate(rets) if not ret[0]]) + return correct, error_messages + + def run_single_test(pool: multiprocessing.Pool, test: TestCase): """ Runs a single test in another process. """ - return pool.apply(_run_single_test, (test,)) + world_size = test.args.get("world_size", None) + if world_size is None: + return pool.apply(_run_single_test, (test,)) + else: + return run_multi_gpu_test(pool, test, world_size) def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): @@ -208,7 +255,7 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t durations = [] # generate input data once data = generate_input(**test.args) - check_copy = _clone_data(data) + check_copy = _clone_data(data, 0) # first, one obligatory correctness check output = custom_kernel(data) good, message = wrap_check_implementation(check_copy, output) @@ -228,7 +275,7 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t test.args["seed"] += 13 data = generate_input(**test.args) - check_copy = _clone_data(data) + check_copy = _clone_data(data, 0) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -261,6 +308,144 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t return calculate_stats(durations) +def _run_distributed_benchmark(test: TestCase, rank: int, recheck: bool, max_repeats: int, + max_time_ns: float) -> Stats | Any: + """ + Runs one distributed benchmark. Do not call directly. + """ + from submission import custom_kernel + import torch.distributed as dist + + world_size = test.args["world_size"] + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12356" + dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, device_id=torch.device(f'cuda:{rank}')) + + try: + durations = [] + # generate input data once + data = generate_input(**test.args, rank=rank) + check_copy = _clone_data(data, rank) + + # first, one obligatory correctness check + output = custom_kernel(_clone_data(data, rank)) + good, message = wrap_check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs with proper distributed synchronization + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + error_message = None + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args, rank=rank) + check_copy = _clone_data(data, rank) + + # Synchronize all ranks before timing + clear_l2_cache() + torch.cuda.synchronize() + dist.barrier() + + # Use distributed timing - only rank 0 records the overall time + if rank == 0: + start_time = time.perf_counter_ns() + + # All ranks execute the kernel + output = custom_kernel(_clone_data(data, rank)) + + # Synchronize all ranks after kernel execution + torch.cuda.synchronize() + dist.barrier() + + if rank == 0: + end_time = time.perf_counter_ns() + duration = end_time - start_time # Already in nanoseconds + durations.append(duration) + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + error_message = message + + del output + + has_error = torch.tensor(1 if error_message is not None else 0, dtype=torch.int32, device=f'cuda:{rank}') + dist.reduce(has_error, 0) + if has_error.item() > 0: + return error_message + + # Only rank 0 checks convergence criteria + if rank == 0 and i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + should_stop = (stats.err / stats.mean < 0.001 or + stats.mean * stats.runs > max_time_ns or + total_bm_duration > 120e9) + else: + should_stop = False + + # Broadcast stop decision to all ranks + stop_tensor = torch.tensor(should_stop, dtype=torch.bool, device=f'cuda:{rank}') + dist.broadcast(stop_tensor, 0) + + if stop_tensor.item(): + break + + # Only rank 0 returns meaningful stats + if rank == 0: + return calculate_stats(durations) + else: + # Non-zero ranks return a dummy stats object + return Stats(runs=len(durations), mean=0.0, std=0.0, err=0.0, best=0.0, worst=0.0) + + finally: + dist.destroy_process_group() + + +def run_multi_gpu_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, + max_time_ns: float, world_size: int): + """ + Runs a multi-GPU benchmark across all ranks. + """ + rets = [] + for i in range(world_size): + rets.append( + pool.apply_async( + _run_distributed_benchmark, + args=(test, i, recheck, max_repeats, max_time_ns), + ) + ) + + # 120 seconds for benchmarking + we run a pre-benchmark test and want to leave some slack + rets = [el.get(timeout=180) for el in rets] + + # For multi-GPU benchmarking, only rank 0 has meaningful stats + failed_ranks = [] + rank_0_result = None + + for rank, ret in enumerate(rets): + if isinstance(ret, Stats): + if rank == 0: + rank_0_result = ret + else: + # ret is an error message + failed_ranks.append((rank, ret)) + + if failed_ranks: + error_messages = str.join("\n", [f"rank {rank} - {msg}" for rank, msg in failed_ranks]) + return error_messages + else: + return rank_0_result if rank_0_result else "No stats returned from rank 0" + + def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float): """ @@ -273,7 +458,12 @@ def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bo @param max_time_ns: Timeout time in nanoseconds. @return: A Stats object for this particular benchmark case or an error if the test fails. """ - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + world_size: Optional[int] = test.args.get("world_size", None) + if world_size is None: + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + else: + return run_multi_gpu_benchmark(pool, test, recheck, max_repeats, max_time_ns, world_size) def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): @@ -319,7 +509,7 @@ def run_single_profile(test: TestCase) -> str: torch.cuda.synchronize() with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - submission_output = custom_kernel(_clone_data(data)) + submission_output = custom_kernel(_clone_data(data, 0)) torch.cuda.synchronize() return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) @@ -345,6 +535,7 @@ def main(): mode = sys.argv[1] seed = os.getenv("POPCORN_SEED") os.unsetenv("POPCORN_SEED") + n_gpus = int(os.getenv("POPCORN_GPUS", "1")) seed = int(seed) if seed else None set_seed(seed or 42) tests = get_test_cases(sys.argv[2], seed) @@ -352,7 +543,7 @@ def main(): with PopcornOutput(int(fd)) as logger: import multiprocessing mp_context = multiprocessing.get_context('spawn') - with mp_context.Pool(1) as pool: + with mp_context.Pool(n_gpus) as pool: if mode == "test": return run_testing(logger, pool, tests) if mode == "benchmark": diff --git a/examples/gather/reference.py b/examples/gather/reference.py new file mode 100644 index 00000000..4cc73a93 --- /dev/null +++ b/examples/gather/reference.py @@ -0,0 +1,19 @@ +import torch +from task import input_t, output_t +from utils import verbose_allclose +from typing import Tuple + + +def generate_input(seed: int, world_size: int, rank: int) -> input_t: + local_data = torch.tensor([rank]).to(f"cuda:{rank}") + return local_data, rank, world_size + + +def check_implementation(data: input_t, output: output_t) -> Tuple[bool, str]: + data, rank, world_size = data + for i in range(world_size): + if output[i].get_device() != rank: + return False, f"mismatch found! output {i} of rank {rank} is on device {output[i].device}" + if (item := output[i].cpu().detach().item()) != i: + return False, f"mismatch found! custom implementation doesn't match reference: rank {rank}, entry {i} has value {item}" + return True, '' diff --git a/examples/gather/submission.py b/examples/gather/submission.py new file mode 100644 index 00000000..9be47cb0 --- /dev/null +++ b/examples/gather/submission.py @@ -0,0 +1,12 @@ +#!POPCORN leaderboard identity_py-dev + +from task import input_t, output_t +import torch +from torch import distributed as dist + + +def custom_kernel(data: input_t) -> output_t: + data, rank, world_size = data + result = [torch.empty_like(data) for _ in range(dist.get_world_size())] + dist.all_gather(result, data) + return result diff --git a/examples/gather/task.py b/examples/gather/task.py new file mode 100644 index 00000000..f1b820d2 --- /dev/null +++ b/examples/gather/task.py @@ -0,0 +1,10 @@ +from typing import TypedDict, List, Tuple +import torch + + +input_t = Tuple[torch.Tensor, int, int] +output_t = List[torch.Tensor] + + +class TestSpec(TypedDict): + pass diff --git a/examples/gather/task.yml b/examples/gather/task.yml new file mode 100644 index 00000000..f9bd843e --- /dev/null +++ b/examples/gather/task.yml @@ -0,0 +1,28 @@ +# name: identity-py + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" +multi_gpu: true +description: + A simple test task - python + +config: + main: "eval.py" + +templates: + Python: "../template.py" + +# small test cases. should be cheap to run. +tests: + - {"seed": 5, "world_size": 4} + +benchmarks: + - {"seed": 10, "world_size": 4} + +ranking_by: "geom" diff --git a/examples/gather/wrong.py b/examples/gather/wrong.py new file mode 100644 index 00000000..39111cbf --- /dev/null +++ b/examples/gather/wrong.py @@ -0,0 +1,11 @@ +#!POPCORN leaderboard identity_py-dev + +from task import input_t, output_t +import torch +from torch import distributed as dist + + +def custom_kernel(data: input_t) -> output_t: + data, rank, world_size = data + result = [torch.ones_like(data) for _ in range(dist.get_world_size())] + return result diff --git a/src/libkernelbot/consts.py b/src/libkernelbot/consts.py index aa1efb6f..b1a5d222 100644 --- a/src/libkernelbot/consts.py +++ b/src/libkernelbot/consts.py @@ -28,6 +28,8 @@ class ModalGPU(Enum): A100 = "A100" H100 = "H100" B200 = "B200" + # multi-gpu + L4x4 = "L4x4" @dataclasses.dataclass @@ -109,7 +111,8 @@ class RankCriterion(Enum): GPU_TO_SM = { "T4": "75", - "L4": "80", + "L4": "89", + "L4x4": "89", "A100": "80", "H100": "90a", "B200": "100", diff --git a/src/libkernelbot/run_eval.py b/src/libkernelbot/run_eval.py index 18344dd3..73e7e374 100644 --- a/src/libkernelbot/run_eval.py +++ b/src/libkernelbot/run_eval.py @@ -44,6 +44,7 @@ class RunResult: class SystemInfo: # fmt: off gpu: str = '' # Model name of the GPU + device_count: int = 1 # Number of GPUs cpu: str = '' # Model name of the CPU platform: str = '' # Platform string of the machine torch: str = '' # Torch version @@ -217,7 +218,9 @@ def compile_cuda_script( # # noqa: C901 ) -def run_program(args: list[str], seed: Optional[int], timeout: int) -> RunResult: +def run_program( + args: list[str], seed: Optional[int], timeout: int, multi_gpu: bool = False +) -> RunResult: print("[Running]") # set up a pipe so the tester can communicate its verdict with us env = os.environ.copy() @@ -226,6 +229,11 @@ def run_program(args: list[str], seed: Optional[int], timeout: int) -> RunResult if seed is not None: env["POPCORN_SEED"] = str(seed) + if multi_gpu: + import torch + + env["POPCORN_GPUS"] = str(torch.cuda.device_count()) + execution_start_time = time.perf_counter() try: run_process = subprocess.run( @@ -279,6 +287,8 @@ def run_program(args: list[str], seed: Optional[int], timeout: int) -> RunResult def run_single_evaluation( call: list[str], mode: str, + *, + multi_gpu: bool = False, tests: Optional[str] = None, benchmarks: Optional[str] = None, test_timeout: int = Timeout.TEST, @@ -295,7 +305,9 @@ def run_single_evaluation( with tempfile.NamedTemporaryFile("w") as tests_file: tests_file.write(tests) tests_file.flush() - return run_program(call + [mode, tests_file.name], seed=seed, timeout=test_timeout) + return run_program( + call + [mode, tests_file.name], seed=seed, timeout=test_timeout, multi_gpu=multi_gpu + ) elif mode in ["benchmark", "profile", "leaderboard"]: timeout = ranked_timeout if mode == "leaderboard" else benchmark_timeout with tempfile.NamedTemporaryFile("w") as bench_file: @@ -304,7 +316,9 @@ def run_single_evaluation( else: bench_file.write(benchmarks) bench_file.flush() - return run_program(call + [mode, bench_file.name], seed=seed, timeout=timeout) + return run_program( + call + [mode, bench_file.name], seed=seed, timeout=timeout, multi_gpu=multi_gpu + ) else: raise ValueError(f"Invalid mode {mode}") @@ -319,6 +333,7 @@ def make_system_info() -> SystemInfo: # https://pytorch.org/docs/stable/notes/hip.html if torch.cuda.is_available(): info.gpu = torch.cuda.get_device_name() + info.device_count = torch.cuda.device_count() except ImportError: # get GPU info manually try: @@ -551,6 +566,7 @@ def run_config(config: dict): "ranked_timeout": config.get("ranked_timeout", Timeout.RANKED), "benchmark_timeout": config.get("benchmark_timeout", Timeout.BENCHMARK), "test_timeout": config.get("test_timeout", Timeout.TEST), + "multi_gpu": config.get("multi_gpu", False), } if config["lang"] == "py": runner = functools.partial( diff --git a/src/libkernelbot/task.py b/src/libkernelbot/task.py index 9f3aa43c..26c90780 100644 --- a/src/libkernelbot/task.py +++ b/src/libkernelbot/task.py @@ -61,6 +61,7 @@ class LeaderboardTask: ranked_timeout: int = 180 ranking_by: RankCriterion = RankCriterion.LAST seed: Optional[int] = None + multi_gpu: bool = False def __post_init__(self): if self.lang == Language.Python and not isinstance(self.config, PythonTaskData): @@ -75,6 +76,7 @@ def from_dict(cls, data: dict): criterion = RankCriterion(data.get("ranking_by", RankCriterion.LAST)) data_["lang"] = lang data_["ranking_by"] = criterion + data_["multi_gpu"] = data.get("multi_gpu", False) if lang == Language.Python: data_["config"] = PythonTaskData(**data["config"]) else: @@ -112,7 +114,7 @@ class LeaderboardDefinition: templates: dict[str, str] = dataclasses.field(default_factory=dict) -def make_task_definition(yaml_file: str | Path) -> LeaderboardDefinition: +def make_task_definition(yaml_file: str | Path) -> LeaderboardDefinition: # noqa: C901 if Path(yaml_file).is_dir(): yaml_file = Path(yaml_file) / "task.yml" @@ -149,6 +151,15 @@ def make_task_definition(yaml_file: str | Path) -> LeaderboardDefinition: description = raw["description"] del raw["description"] task = LeaderboardTask.from_dict(raw) + + # basic validation: + if task.multi_gpu: + for test in task.tests: + if "world_size" not in test: + raise KernelBotError(f"multi-gpu test {test} does not specify world_size") + for benchmark in task.benchmarks: + if "world_size" not in benchmark: + raise KernelBotError(f"multi-gpu benchmark {benchmark} does not specify world_size") return LeaderboardDefinition(task=task, templates=templates, description=description) @@ -176,6 +187,7 @@ def build_task_config( "ranked_timeout": task.ranked_timeout, "ranking_by": task.ranking_by.value, "seed": task.seed, + "multi_gpu": task.multi_gpu, } if task.lang == Language.Python: diff --git a/src/runners/modal_runner.py b/src/runners/modal_runner.py index bb8d952a..26e31257 100644 --- a/src/runners/modal_runner.py +++ b/src/runners/modal_runner.py @@ -40,7 +40,7 @@ ) # other frameworks .pip_install( - "jax[cuda12]==0.5.3", # 0.6 want's cudnn 9.8 in conflict with torch 2.7 + "jax[cuda12]==0.5.3", # 0.6 want's cudnn 9.8 in conflict with torch 2.7 "jax2torch==0.0.7", "tinygrad~=0.10", ) @@ -50,8 +50,8 @@ "nvidia-cutlass-dsl~=4.0", "cuda-core[cu12]~=0.3", "cuda-python[all]==12.8", - #"nvmath-python[cu12]~=0.4", - #"numba-cuda[cu12]~=0.15", + # "nvmath-python[cu12]~=0.4", + # "numba-cuda[cu12]~=0.15", ) ) diff --git a/src/runners/modal_runner_archs.py b/src/runners/modal_runner_archs.py index 3b230bcd..f1557f5b 100644 --- a/src/runners/modal_runner_archs.py +++ b/src/runners/modal_runner_archs.py @@ -2,9 +2,9 @@ # Modal apps on specific devices. We will fix this later. from modal_runner import app, cuda_image, modal_run_config -gpus = ["T4", "L4", "A100-80GB", "H100!", "B200"] +gpus = ["T4", "L4", "L4:4", "A100-80GB", "H100!", "B200"] for gpu in gpus: - gpu_slug = gpu.lower().split("-")[0].strip("!") + gpu_slug = gpu.lower().split("-")[0].strip("!").replace(":", "x") app.function(gpu=gpu, image=cuda_image, name=f"run_cuda_script_{gpu_slug}", serialized=True)( modal_run_config ) diff --git a/tests/conftest.py b/tests/conftest.py index 0408b82e..1a049250 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -106,6 +106,32 @@ def bot(docker_compose, database): CUDA: "template.cu" """ +MULTi_GPU_TASK_YAML = """ +lang: py +description: "Test task description" +ranking_by: geom +multi_gpu: true +test_timeout: 120 +files: + - name: "kernel.py" + source: "kernel.py" + - name: "submission.py" + source: "@SUBMISSION@" +config: + main: "kernel.py" +tests: + - input_size: 1000 + world_size: 4 + dtype: "float32" +benchmarks: + - input_size: 10000 + world_size: 4 + dtype: "float32" +templates: + Python: "template.py" + CUDA: "template.cu" +""" + @pytest.fixture def task_directory(tmp_path): @@ -117,6 +143,7 @@ def task_directory(tmp_path): # Create task.yml Path.write_text(tmp_path / "task.yml", TASK_YAML) + Path.write_text(tmp_path / "multi-task.yml", MULTi_GPU_TASK_YAML) return tmp_path diff --git a/tests/test_backend.py b/tests/test_backend.py index dcc4e8d6..585674cd 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -99,6 +99,7 @@ async def test_handle_submission(bot: backend.KernelBackend, task_directory): "lang": "py", "main": "kernel.py", "mode": "leaderboard", + "multi_gpu": False, "ranked_timeout": 180, "ranking_by": "geom", "seed": None, @@ -152,6 +153,7 @@ async def test_submit_leaderboard(bot: backend.KernelBackend, task_directory): "lang": "py", "main": "kernel.py", "mode": "leaderboard", + "multi_gpu": False, "ranked_timeout": 180, "ranking_by": "geom", "seed": 1337, @@ -205,6 +207,7 @@ async def test_submit_leaderboard(bot: backend.KernelBackend, task_directory): "start_time": eval_result.start.replace(tzinfo=datetime.timezone.utc), "system": { "cpu": "Intel i9-12900K", + "device_count": 1, "gpu": "NVIDIA RTX 4090", "platform": "Linux-5.15.0", "torch": "2.0.1+cu118", @@ -309,6 +312,7 @@ async def test_submit_full(bot: backend.KernelBackend, task_directory): "start_time": ANY, "system": { "cpu": "Intel i9-12900K", + "device_count": 1, "gpu": "NVIDIA RTX 4090", "platform": "Linux-5.15.0", "torch": "2.0.1+cu118", @@ -350,6 +354,7 @@ async def test_submit_full(bot: backend.KernelBackend, task_directory): "start_time": ANY, "system": { "cpu": "Intel i9-12900K", + "device_count": 1, "gpu": "NVIDIA RTX 4090", "platform": "Linux-5.15.0", "torch": "2.0.1+cu118", diff --git a/tests/test_modal.py b/tests/test_modal.py index 9fa1725e..14531015 100644 --- a/tests/test_modal.py +++ b/tests/test_modal.py @@ -1,4 +1,5 @@ import os +import pprint import subprocess from pathlib import Path from typing import Tuple @@ -183,6 +184,112 @@ async def test_modal_launcher_python_script( assert reporter.updates == ["✅ Waiting for modal run to finish... Done"] +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("script, good", [("submission.py", True), ("wrong.py", False)]) +async def test_modal_multi_gpu(modal_deployment, project_root: Path, script: str, good: bool): + """ + This isn't really a modal test, but instead a test using modal to check + that multi-gpu submission testing works (on modal...). + """ + launcher = ModalLauncher(add_include_dirs=[]) + reporter = MockProgressReporter("progress") + + # Load the real identity_py task + task_path = project_root / "examples" / "gather" + if not task_path.exists(): + pytest.skip("examples/gather not found - skipping Modal multi-gpu test") + + # Load the task definition + task_definition = make_task_definition(task_path) + + # Use the actual working submission from the examples + submission_content = (task_path / script).read_text() + + config = build_task_config( + task=task_definition.task, + submission_content=submission_content, + arch=GPU_TO_SM[ModalGPU.L4x4.name], + mode=SubmissionMode.TEST, + ) + + result = await launcher.run_submission(config, ModalGPU.L4x4, reporter) + + # Basic structure and success + assert result.success, f"Expected successful run, got: {result.error}" + assert result.error == "" + assert isinstance(result.runs, dict) + + # System info - test actual expected values + pprint.pprint(result) + assert result.system.device_count == 4 + + # Test run structure + assert "test" in result.runs + test_run = result.runs["test"] + + # For Python runs, compilation is None + assert test_run.compilation is None + + # Run needs to succeed + assert test_run.run.success is True + assert test_run.run.passed is good + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("script, good", [("submission.py", True), ("wrong.py", False)]) +async def test_modal_multi_gpu_benchmark( + modal_deployment, project_root: Path, script: str, good: bool +): + """ + This isn't really a modal test, but instead a test using modal + to check that multi-gpu submission testing works (on modal...). + """ + launcher = ModalLauncher(add_include_dirs=[]) + reporter = MockProgressReporter("progress") + + # Load the real identity_py task + task_path = project_root / "examples" / "gather" + if not task_path.exists(): + pytest.skip("examples/gather not found - skipping Modal multi-gpu test") + + # Load the task definition + task_definition = make_task_definition(task_path) + + # Use the actual working submission from the examples + submission_content = (task_path / script).read_text() + + config = build_task_config( + task=task_definition.task, + submission_content=submission_content, + arch=GPU_TO_SM[ModalGPU.L4x4.name], + mode=SubmissionMode.BENCHMARK, + ) + + result = await launcher.run_submission(config, ModalGPU.L4x4, reporter) + + # Basic structure and success + assert result.success, f"Expected successful run, got: {result.error}" + assert result.error == "" + assert isinstance(result.runs, dict) + + # System info - test actual expected values + pprint.pprint(result) + assert result.system.device_count == 4 + + # Test run structure + assert "benchmark" in result.runs + bench_run = result.runs["benchmark"] + + # For Python runs, compilation is None + assert bench_run.compilation is None + + # Run needs to succeed + assert bench_run.run.success is True + assert bench_run.run.passed is good + + @pytest.mark.integration @pytest.mark.asyncio @pytest.mark.parametrize("script", ["cheat-fd.py", "cheat-input.py", "cheat-rng.py"]) diff --git a/tests/test_task.py b/tests/test_task.py index 6bbd73a8..809a6907 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -14,6 +14,7 @@ build_task_config, make_task_definition, ) +from libkernelbot.utils import KernelBotError @pytest.fixture() @@ -148,6 +149,7 @@ def test_build_task_config_python(leaderboard_task): {"input_size": 5000, "dtype": "float16"}, ], "mode": mode.value, + "multi_gpu": False, "test_timeout": 120, "benchmark_timeout": 180, "ranked_timeout": 180, @@ -201,6 +203,7 @@ def test_build_task_config_cuda(): {"input_size": 5000, "dtype": "float16"}, ], "mode": mode.value, + "multi_gpu": False, "test_timeout": 120, "benchmark_timeout": 180, "ranked_timeout": 180, @@ -234,3 +237,16 @@ def test_make_task_definition(task_directory): assert task.benchmarks == [{"input_size": 10000, "dtype": "float32"}] assert isinstance(task.config, PythonTaskData) assert task.config.main == "kernel.py" + + +def test_multi_gpu_task(task_directory): + """Test make_task_definition with a multi-GPU task""" + orig = (task_directory / "task.yml").read_text() + (task_directory / "task.yml").write_text(orig + "\nmulti_gpu: true") + + # no world size specified => Error + with pytest.raises(KernelBotError, match="does not specify world_size"): + make_task_definition(task_directory / "task.yml") + + result = make_task_definition(task_directory / "multi-task.yml") + assert result.task.multi_gpu is True