-
Notifications
You must be signed in to change notification settings - Fork 22
Multi-GPU #335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-GPU #335
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import base64 | ||
| import copy | ||
| import dataclasses | ||
| import multiprocessing | ||
| import re | ||
|
|
@@ -65,7 +66,7 @@ def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: | |
|
|
||
| tests = [] | ||
| lines = content.splitlines() | ||
| match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" | ||
| match = r"\s*([a-zA-Z_]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" | ||
| for line in lines: | ||
| parts = line.split(";") | ||
| case = {} | ||
|
|
@@ -123,18 +124,19 @@ def calculate_stats(durations: list[int]): | |
| worst=float(worst)) | ||
|
|
||
|
|
||
| def _clone_data(data): | ||
| def _clone_data(data, rank: int): | ||
| """ | ||
| Recursively goes through data and clones all tensors. | ||
| """ | ||
| if isinstance(data, tuple): | ||
| return tuple(_clone_data(x) for x in data) | ||
| return tuple(_clone_data(x, rank) for x in data) | ||
| elif isinstance(data, list): | ||
| return [_clone_data(x) for x in data] | ||
| return [_clone_data(x, rank) for x in data] | ||
| elif isinstance(data, dict): | ||
| return {k: _clone_data(v) for k, v in data.items()} | ||
| return {k: _clone_data(v, rank) for k, v in data.items()} | ||
| elif isinstance(data, torch.Tensor): | ||
| return data.clone() | ||
| device = f"cuda:{rank}" | ||
| return data.clone().to(device) | ||
| else: | ||
| return data | ||
|
|
||
|
|
@@ -157,16 +159,61 @@ def _run_single_test(test: TestCase): | |
| from submission import custom_kernel | ||
| data = generate_input(**test.args) | ||
| torch.cuda.synchronize() | ||
| submission_output = custom_kernel(_clone_data(data)) | ||
| submission_output = custom_kernel(_clone_data(data, 0)) | ||
| torch.cuda.synchronize() | ||
| return wrap_check_implementation(data, submission_output) | ||
|
|
||
|
|
||
| def _run_distributed_test(test: TestCase, rank: int): | ||
| """ | ||
| Runs a single test case. Do not call directly | ||
| """ | ||
| from submission import custom_kernel | ||
| import torch.distributed as dist | ||
| world_size = test.args["world_size"] | ||
| os.environ["MASTER_ADDR"] = "127.0.0.1" | ||
| os.environ["MASTER_PORT"] = "12356" | ||
| dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, device_id=torch.device(f'cuda:{rank}')) | ||
| try: | ||
| data = generate_input(**test.args, rank=rank) | ||
| torch.cuda.synchronize() | ||
| submission_output = custom_kernel(_clone_data(data, rank)) | ||
| torch.cuda.synchronize() | ||
| return wrap_check_implementation(data, submission_output) | ||
| finally: | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| def run_multi_gpu_test(pool: multiprocessing.Pool, test: TestCase, world_size: int): | ||
| """ | ||
| Runs a single test in another process. | ||
| """ | ||
| rets = [] | ||
| # world_size is a mandatory argument for multi-gpu tests | ||
| for i in range(world_size): | ||
| rets.append( | ||
| pool.apply_async( | ||
| _run_distributed_test, | ||
| args=(test, i), | ||
| ) | ||
| ) | ||
| # 60 seconds should be more than enough, we want tests to be fast | ||
| rets = [el.get(60) for el in rets] | ||
|
|
||
| correct = all(ret[0] for ret in rets) | ||
| error_messages = str.join("\n", [f"rank {rank} - {ret[1]}" for rank, ret in enumerate(rets) if not ret[0]]) | ||
| return correct, error_messages | ||
|
|
||
|
|
||
| def run_single_test(pool: multiprocessing.Pool, test: TestCase): | ||
| """ | ||
| Runs a single test in another process. | ||
| """ | ||
| return pool.apply(_run_single_test, (test,)) | ||
| world_size = test.args.get("world_size", None) | ||
| if world_size is None: | ||
| return pool.apply(_run_single_test, (test,)) | ||
| else: | ||
| return run_multi_gpu_test(pool, test, world_size) | ||
|
|
||
|
|
||
| def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): | ||
|
|
@@ -208,7 +255,7 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t | |
| durations = [] | ||
| # generate input data once | ||
| data = generate_input(**test.args) | ||
| check_copy = _clone_data(data) | ||
| check_copy = _clone_data(data, 0) | ||
| # first, one obligatory correctness check | ||
| output = custom_kernel(data) | ||
| good, message = wrap_check_implementation(check_copy, output) | ||
|
|
@@ -228,7 +275,7 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t | |
| test.args["seed"] += 13 | ||
|
|
||
| data = generate_input(**test.args) | ||
| check_copy = _clone_data(data) | ||
| check_copy = _clone_data(data, 0) | ||
| torch.cuda.synchronize() | ||
| start_event = torch.cuda.Event(enable_timing=True) | ||
| end_event = torch.cuda.Event(enable_timing=True) | ||
|
|
@@ -261,6 +308,144 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t | |
| return calculate_stats(durations) | ||
|
|
||
|
|
||
| def _run_distributed_benchmark(test: TestCase, rank: int, recheck: bool, max_repeats: int, | ||
| max_time_ns: float) -> Stats | Any: | ||
| """ | ||
| Runs one distributed benchmark. Do not call directly. | ||
| """ | ||
| from submission import custom_kernel | ||
| import torch.distributed as dist | ||
|
|
||
| world_size = test.args["world_size"] | ||
| os.environ["MASTER_ADDR"] = "127.0.0.1" | ||
| os.environ["MASTER_PORT"] = "12356" | ||
| dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, device_id=torch.device(f'cuda:{rank}')) | ||
|
|
||
| try: | ||
| durations = [] | ||
| # generate input data once | ||
| data = generate_input(**test.args, rank=rank) | ||
| check_copy = _clone_data(data, rank) | ||
|
|
||
| # first, one obligatory correctness check | ||
| output = custom_kernel(_clone_data(data, rank)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to confirm, if that, users' custom_kernel should be designed to an API accepting and outputing single rank data? but my ref_kernel is accepting all rank data and outputing all rank result...is that conflicted or something need to be changed in my PR? https://github.com/gpu-mode/reference-kernels/pull/51/files#diff-4634bd7a4a47ab89859ee0db3f4f3f3c8123cf18981fb4d24b4655a412777013R240
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think essentially, the user kernel should be what is called
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes,
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. big thanks, I saw siro's commit to my PR, which is compatible with this PR |
||
| good, message = wrap_check_implementation(check_copy, output) | ||
| if not good: | ||
| return message | ||
|
|
||
| # now, do multiple timing runs with proper distributed synchronization | ||
| bm_start_time = time.perf_counter_ns() | ||
| for i in range(max_repeats): | ||
| error_message = None | ||
| if recheck: | ||
| # ensure we use a different seed for every benchmark | ||
| if "seed" in test.args: | ||
| test.args["seed"] += 13 | ||
|
|
||
| data = generate_input(**test.args, rank=rank) | ||
| check_copy = _clone_data(data, rank) | ||
|
|
||
| # Synchronize all ranks before timing | ||
| clear_l2_cache() | ||
| torch.cuda.synchronize() | ||
| dist.barrier() | ||
|
|
||
| # Use distributed timing - only rank 0 records the overall time | ||
| if rank == 0: | ||
| start_time = time.perf_counter_ns() | ||
|
|
||
| # All ranks execute the kernel | ||
| output = custom_kernel(_clone_data(data, rank)) | ||
|
|
||
| # Synchronize all ranks after kernel execution | ||
| torch.cuda.synchronize() | ||
| dist.barrier() | ||
|
|
||
| if rank == 0: | ||
| end_time = time.perf_counter_ns() | ||
| duration = end_time - start_time # Already in nanoseconds | ||
| durations.append(duration) | ||
|
|
||
| if recheck: | ||
| good, message = check_implementation(check_copy, output) | ||
| if not good: | ||
| error_message = message | ||
|
|
||
| del output | ||
|
|
||
| has_error = torch.tensor(1 if error_message is not None else 0, dtype=torch.int32, device=f'cuda:{rank}') | ||
| dist.reduce(has_error, 0) | ||
| if has_error.item() > 0: | ||
| return error_message | ||
|
|
||
| # Only rank 0 checks convergence criteria | ||
| if rank == 0 and i > 1: | ||
| total_bm_duration = time.perf_counter_ns() - bm_start_time | ||
| stats = calculate_stats(durations) | ||
| # stop if either | ||
| # a) relative error dips below 0.1% | ||
| # b) we exceed the total time limit for benchmarking the kernel | ||
| # c) we exceed 2 minutes of total wallclock time. | ||
| should_stop = (stats.err / stats.mean < 0.001 or | ||
| stats.mean * stats.runs > max_time_ns or | ||
| total_bm_duration > 120e9) | ||
| else: | ||
| should_stop = False | ||
|
|
||
| # Broadcast stop decision to all ranks | ||
| stop_tensor = torch.tensor(should_stop, dtype=torch.bool, device=f'cuda:{rank}') | ||
| dist.broadcast(stop_tensor, 0) | ||
|
|
||
| if stop_tensor.item(): | ||
| break | ||
|
|
||
| # Only rank 0 returns meaningful stats | ||
| if rank == 0: | ||
| return calculate_stats(durations) | ||
| else: | ||
| # Non-zero ranks return a dummy stats object | ||
| return Stats(runs=len(durations), mean=0.0, std=0.0, err=0.0, best=0.0, worst=0.0) | ||
|
|
||
| finally: | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| def run_multi_gpu_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, | ||
| max_time_ns: float, world_size: int): | ||
| """ | ||
| Runs a multi-GPU benchmark across all ranks. | ||
| """ | ||
| rets = [] | ||
| for i in range(world_size): | ||
| rets.append( | ||
| pool.apply_async( | ||
| _run_distributed_benchmark, | ||
| args=(test, i, recheck, max_repeats, max_time_ns), | ||
| ) | ||
| ) | ||
|
|
||
| # 120 seconds for benchmarking + we run a pre-benchmark test and want to leave some slack | ||
| rets = [el.get(timeout=180) for el in rets] | ||
|
|
||
| # For multi-GPU benchmarking, only rank 0 has meaningful stats | ||
| failed_ranks = [] | ||
| rank_0_result = None | ||
|
|
||
| for rank, ret in enumerate(rets): | ||
| if isinstance(ret, Stats): | ||
| if rank == 0: | ||
| rank_0_result = ret | ||
| else: | ||
| # ret is an error message | ||
| failed_ranks.append((rank, ret)) | ||
|
|
||
| if failed_ranks: | ||
| error_messages = str.join("\n", [f"rank {rank} - {msg}" for rank, msg in failed_ranks]) | ||
| return error_messages | ||
| else: | ||
| return rank_0_result if rank_0_result else "No stats returned from rank 0" | ||
|
|
||
|
|
||
| def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, | ||
| max_time_ns: float): | ||
| """ | ||
|
|
@@ -273,7 +458,12 @@ def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bo | |
| @param max_time_ns: Timeout time in nanoseconds. | ||
| @return: A Stats object for this particular benchmark case or an error if the test fails. | ||
| """ | ||
| return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) | ||
|
|
||
| world_size: Optional[int] = test.args.get("world_size", None) | ||
| if world_size is None: | ||
| return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) | ||
| else: | ||
| return run_multi_gpu_benchmark(pool, test, recheck, max_repeats, max_time_ns, world_size) | ||
|
|
||
|
|
||
| def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): | ||
|
|
@@ -319,7 +509,7 @@ def run_single_profile(test: TestCase) -> str: | |
| torch.cuda.synchronize() | ||
|
|
||
| with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: | ||
| submission_output = custom_kernel(_clone_data(data)) | ||
| submission_output = custom_kernel(_clone_data(data, 0)) | ||
| torch.cuda.synchronize() | ||
| return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) | ||
|
|
||
|
|
@@ -345,14 +535,15 @@ def main(): | |
| mode = sys.argv[1] | ||
| seed = os.getenv("POPCORN_SEED") | ||
| os.unsetenv("POPCORN_SEED") | ||
| n_gpus = int(os.getenv("POPCORN_GPUS", "1")) | ||
| seed = int(seed) if seed else None | ||
| set_seed(seed or 42) | ||
| tests = get_test_cases(sys.argv[2], seed) | ||
|
|
||
| with PopcornOutput(int(fd)) as logger: | ||
| import multiprocessing | ||
| mp_context = multiprocessing.get_context('spawn') | ||
| with mp_context.Pool(1) as pool: | ||
| with mp_context.Pool(n_gpus) as pool: | ||
| if mode == "test": | ||
| return run_testing(logger, pool, tests) | ||
| if mode == "benchmark": | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| import torch | ||
| from task import input_t, output_t | ||
| from utils import verbose_allclose | ||
| from typing import Tuple | ||
|
|
||
|
|
||
| def generate_input(seed: int, world_size: int, rank: int) -> input_t: | ||
| local_data = torch.tensor([rank]).to(f"cuda:{rank}") | ||
| return local_data, rank, world_size | ||
|
|
||
|
|
||
| def check_implementation(data: input_t, output: output_t) -> Tuple[bool, str]: | ||
| data, rank, world_size = data | ||
| for i in range(world_size): | ||
| if output[i].get_device() != rank: | ||
| return False, f"mismatch found! output {i} of rank {rank} is on device {output[i].device}" | ||
| if (item := output[i].cpu().detach().item()) != i: | ||
| return False, f"mismatch found! custom implementation doesn't match reference: rank {rank}, entry {i} has value {item}" | ||
| return True, '' |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| #!POPCORN leaderboard identity_py-dev | ||
|
|
||
| from task import input_t, output_t | ||
| import torch | ||
| from torch import distributed as dist | ||
|
|
||
|
|
||
| def custom_kernel(data: input_t) -> output_t: | ||
| data, rank, world_size = data | ||
| result = [torch.empty_like(data) for _ in range(dist.get_world_size())] | ||
| dist.all_gather(result, data) | ||
| return result |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| from typing import TypedDict, List, Tuple | ||
| import torch | ||
|
|
||
|
|
||
| input_t = Tuple[torch.Tensor, int, int] | ||
| output_t = List[torch.Tensor] | ||
|
|
||
|
|
||
| class TestSpec(TypedDict): | ||
| pass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # name: identity-py | ||
|
|
||
| files: | ||
| - {"name": "submission.py", "source": "@SUBMISSION@"} | ||
| - {"name": "task.py", "source": "task.py"} | ||
| - {"name": "utils.py", "source": "../utils.py"} | ||
| - {"name": "reference.py", "source": "reference.py"} | ||
| - {"name": "eval.py", "source": "../eval.py"} | ||
|
|
||
| lang: "py" | ||
| multi_gpu: true | ||
| description: | ||
| A simple test task - python | ||
|
|
||
| config: | ||
| main: "eval.py" | ||
|
|
||
| templates: | ||
| Python: "../template.py" | ||
|
|
||
| # small test cases. should be cheap to run. | ||
| tests: | ||
| - {"seed": 5, "world_size": 4} | ||
|
|
||
| benchmarks: | ||
| - {"seed": 10, "world_size": 4} | ||
|
|
||
| ranking_by: "geom" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| #!POPCORN leaderboard identity_py-dev | ||
|
|
||
| from task import input_t, output_t | ||
| import torch | ||
| from torch import distributed as dist | ||
|
|
||
|
|
||
| def custom_kernel(data: input_t) -> output_t: | ||
| data, rank, world_size = data | ||
| result = [torch.ones_like(data) for _ in range(dist.get_world_size())] | ||
| return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might wanna setup a large random port, if for some reason a job fails and the port doesn't get released