forked from tile-ai/tilelang
-
Notifications
You must be signed in to change notification settings - Fork 4
[Feature] Support unified T.copy lowering to both SIMT and TMA for intra-node copy
#36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Rachmanino
wants to merge
14
commits into
main
Choose a base branch
from
wt/dev
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
5fb7df6
add draft test
Rachmanino b822ea9
draft
Rachmanino 784dd18
support SIMT push and fix a bug
Rachmanino a434339
refactor and support pull
Rachmanino ee46984
lint
Rachmanino 3664eb2
Update src/tl_templates/cuda/common.h
Rachmanino e79f865
lint
chengyupku 7cb7ec6
fix bot's comments
Rachmanino 08b84c8
lint
Rachmanino f2906f7
bugfix of parse_op
Rachmanino 10c6394
fix
Rachmanino fba4134
apply bot's suggestions
Rachmanino a510e74
add test and fix compatibility
Rachmanino a536bdd
fix previous bug
Rachmanino File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
192 changes: 192 additions & 0 deletions
192
examples/distributed/primitives/example_tilescale_copy.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,192 @@ | ||
| import os | ||
| import tilelang | ||
| import tilelang.language as T | ||
| import argparse | ||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.multiprocessing | ||
| from tilelang.distributed import init_dist | ||
|
|
||
| tilelang.disable_cache() | ||
| os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log | ||
|
|
||
|
|
||
| @tilelang.jit | ||
| def get_kernel(M, N, block_M, block_N, threads, kernel='simt_push_tile'): | ||
|
|
||
| @T.prim_func | ||
| def simt_push_buffer( | ||
| dst: T.Tensor((M, N), "float32"), | ||
| src: T.Tensor((M, N), "float32"), | ||
| ): | ||
| with T.Kernel((1), threads=threads): | ||
| rank = T.alloc_local([1], "uint64") | ||
| rank[0] = T.get_rank() | ||
|
|
||
| T.copy( | ||
| src, | ||
| dst, | ||
| dst_pe=1 - rank[0], | ||
| disable_tma=True # Ensure testing SIMT remote copy | ||
| ) | ||
|
|
||
| @T.prim_func | ||
| def simt_push_tile( | ||
| dst: T.Tensor((M, N), "float32"), | ||
| src: T.Tensor((M, N), "float32"), | ||
| ): | ||
| with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by): | ||
| rank = T.alloc_local([1], "uint64") | ||
| rank[0] = T.get_rank() | ||
|
|
||
| smem = T.alloc_shared((block_M, block_N), "float32") | ||
| T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)}) | ||
|
|
||
| T.copy( | ||
| src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], | ||
| smem, | ||
| disable_tma=True # Ensure testing SIMT remote copy | ||
| ) | ||
|
|
||
| T.copy( | ||
| smem, | ||
| dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], | ||
| dst_pe=1 - rank[0], | ||
| disable_tma=True # Ensure testing SIMT remote copy | ||
| ) | ||
|
|
||
| @T.prim_func | ||
| def simt_pull_tile( | ||
| dst: T.Tensor((M, N), "float32"), | ||
| src: T.Tensor((M, N), "float32"), | ||
| ): | ||
| with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by): | ||
| rank = T.alloc_local([1], "uint64") | ||
| rank[0] = T.get_rank() | ||
|
|
||
| smem = T.alloc_shared((block_M, block_N), "float32") | ||
| T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)}) | ||
|
|
||
| T.copy( | ||
| src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], | ||
| smem, | ||
| src_pe=1 - rank[0], | ||
| disable_tma=True # Ensure testing SIMT remote copy | ||
| ) | ||
|
|
||
| T.copy( | ||
| smem, | ||
| dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], | ||
| disable_tma=True # Ensure testing SIMT remote copy | ||
| ) | ||
|
|
||
| # TMA kernel requires run-time aware peer rank | ||
| @T.prim_func | ||
| def tma_load_tile( | ||
| dst: T.Tensor((M, N), "float32"), | ||
| src: T.Tensor((M, N), "float32"), | ||
| ): | ||
| with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by): | ||
|
|
||
| smem = T.alloc_shared((block_M, block_N), "float32") | ||
| T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)}) | ||
|
|
||
| # TMA load | ||
| T.copy( | ||
| src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], | ||
| smem, | ||
| src_pe=1 - T.get_rank(), | ||
| # NOTE(wt): We cannot use rank[0] as above for TMA remote copy currently. | ||
| ) | ||
|
|
||
| T.copy( | ||
| smem, | ||
| dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], | ||
| disable_tma=True # Ensure testing SIMT remote copy | ||
| ) | ||
|
|
||
| @T.prim_func | ||
| def tma_store_tile( | ||
| dst: T.Tensor((M, N), "float32"), | ||
| src: T.Tensor((M, N), "float32"), | ||
| ): | ||
| with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by): | ||
|
|
||
| smem = T.alloc_shared((block_M, block_N), "float32") | ||
| T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)}) | ||
|
|
||
| T.copy( | ||
| src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], | ||
| smem, | ||
| disable_tma=True # Ensure testing SIMT remote copy | ||
| ) | ||
|
|
||
| # TMA store | ||
| T.copy( | ||
| smem, | ||
| dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], | ||
| dst_pe=1 - T.get_rank()) | ||
|
|
||
| return { | ||
| 'simt_push_buffer': simt_push_buffer, | ||
| 'simt_push_tile': simt_push_tile, | ||
| 'simt_pull_tile': simt_pull_tile, | ||
| 'tma_load_tile': tma_load_tile, | ||
| 'tma_store_tile': tma_store_tile | ||
| }[kernel] | ||
|
|
||
|
|
||
| def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): | ||
| M = args.M | ||
| N = args.N | ||
| BLOCK_M = 64 | ||
| BLOCK_N = 128 | ||
| threads = 128 | ||
| assert num_local_ranks == 2, "this example only supports 2 ranks copying to each other" | ||
|
|
||
| _, _, group = init_dist(local_rank, num_local_ranks) | ||
| allocator = tilelang.get_allocator( | ||
| size=2**25, | ||
| device="cuda", | ||
| is_distributed=True, | ||
| local_rank=local_rank, | ||
| num_local_ranks=num_local_ranks, | ||
| group=group) | ||
|
|
||
| kernel = get_kernel(M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel) | ||
| kernel.initialize(allocator=allocator) | ||
| if local_rank == 0: | ||
| print(kernel.get_kernel_source()) | ||
|
|
||
| src = tilelang.tensor((M, N), torch.float32, allocator=allocator).normal_() | ||
| dst = tilelang.tensor((M, N), torch.float32, allocator=allocator) | ||
|
|
||
| torch.cuda.synchronize() | ||
| torch.distributed.barrier(group) | ||
| kernel(dst, src) | ||
| torch.cuda.synchronize() | ||
| torch.distributed.barrier(group) | ||
|
|
||
| dst_torchs = [torch.empty_like(src) for _ in range(num_local_ranks)] | ||
| dist.all_gather(dst_torchs, src, group) | ||
| dst_torch = dst_torchs[local_rank ^ 1] | ||
|
|
||
| if torch.allclose(dst_torch, dst, atol=1e-6, rtol=1e-6): | ||
| print(f"rank {local_rank} check passed.✅") | ||
| else: | ||
| print(f"rank {local_rank} check failed.❌") | ||
| print(f"dst_torch: {dst_torch}, dst: {dst}") | ||
| raise ValueError("Test failed") | ||
|
|
||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument('--M', type=int, default=1024, help='M dimension') | ||
| parser.add_argument('--N', type=int, default=1024, help='N dimension') | ||
| parser.add_argument('--kernel', type=str, default='simt_push_tile', help='kernel to use') | ||
| args = parser.parse_args() | ||
| num_processes = 2 | ||
|
|
||
| torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| import argparse | ||
| import tilelang.testing | ||
| import torch | ||
| import torch.multiprocessing | ||
|
|
||
| import example_tilescale_copy | ||
|
|
||
|
|
||
| @tilelang.testing.requires_cuda | ||
| def test_example_tilescale_copy_simt_push_tile(): | ||
| args = argparse.Namespace(M=1024, N=1024, kernel='simt_push_tile') | ||
| torch.multiprocessing.spawn(example_tilescale_copy.main, args=(2, args), nprocs=2) | ||
|
|
||
|
|
||
| @tilelang.testing.requires_cuda | ||
| @tilelang.testing.requires_cuda_compute_version_ge(9, 0) | ||
| def test_example_tilescale_copy_tma_load_tile(): | ||
| args = argparse.Namespace(M=1024, N=1024, kernel='tma_load_tile') | ||
| torch.multiprocessing.spawn(example_tilescale_copy.main, args=(2, args), nprocs=2) | ||
|
|
||
|
|
||
| @tilelang.testing.requires_cuda | ||
| @tilelang.testing.requires_cuda_compute_version_ge(9, 0) | ||
| def test_example_tilescale_copy_tma_store_tile(): | ||
| args = argparse.Namespace(M=1024, N=1024, kernel='tma_store_tile') | ||
| torch.multiprocessing.spawn(example_tilescale_copy.main, args=(2, args), nprocs=2) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard remote PE defaults before remote checks.
is_remote_push()/is_remote_pull()run immediately after the optional assignments. Legacy modules compiled before this PR still emit the 5-argumenttl.copy; when we deserialize them here,args.size()is only 5, sodst_pe/src_pestayPrimExpr()and the subsequentdst_pe->IsInstancedereferences a null handle. That’s a hard crash/regression for any pre-existing artifact. Please seed both fields with-1before theifblocks and make the helpers resilient to an undefinedPrimExpr.Apply this diff to fix the issue:
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) { ObjectPtr<CopyNode> node = make_object<CopyNode>(); + node->src_pe = Integer(-1); + node->dst_pe = Integer(-1); Array<Range> rgs[2]; @@ - if (args.size() >= 6) { - node->src_pe = args[5]; - } - if (args.size() >= 7) { - node->dst_pe = args[6]; - } + if (args.size() >= 6) { + node->src_pe = args[5]; + } + if (args.size() >= 7) { + node->dst_pe = args[6]; + } @@ bool CopyNode::is_remote_push() const { - return !(dst_pe->IsInstance<IntImmNode>() && - dst_pe.as<IntImmNode>()->value == -1); + if (!dst_pe.defined()) { + return false; + } + if (const auto *imm = dst_pe.as<IntImmNode>()) { + return imm->value != -1; + } + return true; } @@ bool CopyNode::is_remote_pull() const { - return !(src_pe->IsInstance<IntImmNode>() && - src_pe.as<IntImmNode>()->value == -1); + if (!src_pe.defined()) { + return false; + } + if (const auto *imm = src_pe.as<IntImmNode>()) { + return imm->value != -1; + } + return true; }