Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge upstream #77

Open
wants to merge 665 commits into
base: amd-develop
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
665 commits
Select commit Hold shift + click to select a range
57d182a
Fix memory pool using in GEMM profiler (#556)
aakhundov Apr 10, 2023
fda4c60
dynamic seq (#560)
Apr 11, 2023
ca2d572
Fix MSVC compiler complaints (#551)
Apr 11, 2023
216cd17
Fix MSVC compiler narrowing conversion errors for cuda/gemm_epilogue_…
Apr 11, 2023
d67c97f
Add dynamic_seq_len and dynamic_num_head support in b2b bmm. (#530)
ipiszy Apr 11, 2023
c360bf8
Fix reduce ops with last input dim IntVar (#563)
aakhundov Apr 12, 2023
d5e6538
Build Cache CI Integration (#541)
kadeng Apr 12, 2023
1517314
More robust cutlass include dir generation in FBCUDA (#565)
kadeng Apr 12, 2023
c9f0f4d
add conv1d op (#562)
tenpercent Apr 13, 2023
b4b0140
arange as model param (#566)
mortzur Apr 13, 2023
f4416e7
fixed an infinite loop in move_view_ops transformation (#570)
chenyang78 Apr 13, 2023
0f1165e
Remove hack in [fx/aten]2ait (#516)
muchulee8 Apr 13, 2023
e6ad08a
Fix _fuse_strided_op_and_cat: no GEMM+concat fusion with dim>=rank (#…
Apr 13, 2023
a58d757
refactor
fsx950223 Apr 14, 2023
e17d032
merge updates
fsx950223 Apr 14, 2023
2d8511c
Accept JaggedIntVar in padded_dense_to_jagged (#577)
aakhundov Apr 14, 2023
27ec1e1
implement per op benchmark backbone (#555)
tenpercent Apr 14, 2023
c7d36cb
Add SM90 kernel foundations (#575)
aakhundov Apr 14, 2023
0e71c60
add reshape for `to` converter (#576)
frank-wei Apr 14, 2023
a4ba50a
Allow split without dim as input (#578)
ZhengkaiZ Apr 14, 2023
7fa1852
Replace infer_shape for split op. (#568)
muchulee8 Apr 15, 2023
c27702d
Add Identity op. (#579)
muchulee8 Apr 15, 2023
b0f9116
Build cache: Do not cache failed builds (#586)
kadeng Apr 17, 2023
3b537f9
Add BatchNorm FE module (#587)
Apr 17, 2023
4d6d469
add identity for `to` and `contiguous` (#588)
frank-wei Apr 17, 2023
535d064
In-place weight updates tracing (#569)
khabinov Apr 18, 2023
5c4800e
Remove hardcoded dtype float16 in acc_ops_clone (#585)
feikou Apr 18, 2023
55e150f
Fix acc_ops converter on std when keepdim=False (#593)
tissue3 Apr 19, 2023
7b97317
Fix race condition in FBCUDA (#591)
kadeng Apr 20, 2023
be47ae5
Upgrade CUTLASS to 3.1 (#584)
aakhundov Apr 20, 2023
7142046
benchmark: concat (#592)
chenyang78 Apr 20, 2023
aa84406
Ldm update (#573)
terrychenism Apr 21, 2023
d227f60
Eliminated redundant permute pairs (#529)
PivovarA Apr 21, 2023
95053be
Back out "Ldm update" (#603)
muchulee8 Apr 22, 2023
7f30588
Add name for directory created by test_vanilla_attention (#602)
muchulee8 Apr 22, 2023
d17d5f2
Fix MSVC compiler narrowing conversion errors for cuda/gemm_epilogue_…
Apr 23, 2023
e139f40
Mask select converter (#601)
wushirong Apr 23, 2023
26afef7
Skip incompatible tests outside CI (#606)
aakhundov Apr 24, 2023
3b19c9b
Add env_variables context manager to test_utils (#607)
aakhundov Apr 24, 2023
ae495ab
Add CUDA_SM90 test env to test_utils (#608)
aakhundov Apr 24, 2023
c0a97af
Ldm update (#611)
terrychenism Apr 24, 2023
5a73b00
Add testname for test_var:test_batched_var (#612)
muchulee8 Apr 24, 2023
d5c84d2
Prepare for CUTLASS 3.x integration (#614)
aakhundov Apr 24, 2023
e1ec740
Set skip_on_empty=True from filter_test_cases_by_params (#609)
aakhundov Apr 25, 2023
4a15037
Add MaxPool3d FE module (#595)
Apr 26, 2023
24a5603
Add GELU FE module (#600)
Apr 26, 2023
d6af8b0
Fill MultiScaleBlock gap (#622)
Apr 26, 2023
d761c94
Update infer_shape for argmax (#623)
muchulee8 Apr 26, 2023
1e60f7b
move use_cuda from global to class CrossAttention level (#626)
hl475 Apr 26, 2023
ad45691
Introduce AIT_USE_FAST_MATH environment flag (#627)
Apr 26, 2023
c0dd23e
Apply `transform_permute_to_reshape` when input tensor has smaller th…
Apr 27, 2023
73624b8
use current stream in mem_eff_attention kernel (#628)
mortzur Apr 27, 2023
62426cc
Minor fix to makefile normalization (#630)
kadeng Apr 27, 2023
7e930fd
Add SM90 CUTLASS 3.x kernels to gemm_rcr (#617)
aakhundov Apr 27, 2023
747852a
Add SM90 CUTLASS 3.x kernels to gemm_rrr (#621)
aakhundov Apr 27, 2023
1d9d2ed
Add SM90 CUTLASS 3.x kernels to gemm_rcr_bias (#620)
aakhundov Apr 27, 2023
9487143
Small fixes to b2b bmm kernels (#564)
ipiszy Apr 27, 2023
1d5a942
Fix AIT topk converter (#631)
wushirong Apr 28, 2023
7ec28ff
Support fp32 accumulation (#632)
kadeng Apr 28, 2023
e6f12eb
bias support (#633)
frank-wei Apr 28, 2023
5661170
Add PatchEmbed FE (#629)
Apr 28, 2023
6ad3351
Refactor test_bmm(_add) with filter_test_cases_by_test_env (#635)
aakhundov Apr 28, 2023
bace6d8
Update infer_shape for batch_gather (#624)
muchulee8 Apr 29, 2023
d9ced25
Add SM90 CUTLASS 3.x kernels to bmm_xxx and bmm_xxx_add (#637)
aakhundov Apr 30, 2023
32dd05c
dynamic_slice infer_shape (#598)
muchulee8 Apr 30, 2023
aa7010f
Replace hasattr with getattr in aitemplate/AITemplate/python/aitempla…
r-barnes May 1, 2023
be49644
Fix circleci test errors (#639)
ipiszy May 2, 2023
b041e0e
Fix internal CI issue (#640)
ipiszy May 2, 2023
a4b533c
optionally print python callstack stats for model compilation (#642)
May 3, 2023
c435114
Add VisionTransformerBasicHead (#636)
May 3, 2023
60df7ce
Duplicate LDM from Example to FB Folder (#646)
henryhu6 May 3, 2023
693f4b9
Remove manually set Parameter names (#638)
May 3, 2023
bc16a15
classic_b2b_bmm: add multi-head and strides (#625)
kadeng May 4, 2023
2707974
Fix naming error caused by dump_program(). (#649)
ipiszy May 4, 2023
091ccdb
SD example fix (#654)
terrychenism May 4, 2023
109b98a
remove eliminate_permutations for now (#655)
chenyang78 May 4, 2023
ea7ed4c
jagged SHA and MHA module support (#651)
frank-wei May 4, 2023
332eccd
Add BF16 support for ads model (#656)
wushirong May 4, 2023
83b39e7
_dlclose support for Windows (#657)
May 4, 2023
971c2c1
Add MultiscaleVisionTransformers FE + validate MVIT 21 block config E…
May 5, 2023
27d9cb6
Make GEMM ProfilerMemoryPool size computation generic (#653)
aakhundov May 5, 2023
2f8c15f
Add Windows support for owned constants (#658)
May 6, 2023
09ce751
simple multi-stream (#615)
May 6, 2023
b912ca9
Use device 0 for profiling if devices list is empty (#667)
aakhundov May 6, 2023
6bc3253
Allow setting jagged tensor total_length upper bound from offsets (#634)
aakhundov May 7, 2023
6c3a88d
Increase CircleCI no_output_timeout to 20m (#668)
aakhundov May 7, 2023
76b6b71
Integrate CUDA 12.1 (#661)
aakhundov May 7, 2023
aeaddeb
Sync CUTLASS with upstream (#662)
aakhundov May 7, 2023
d7c2877
Add SM90 CUTLASS 3.x kernels to gemm_rcr_bias_relu (#663)
aakhundov May 7, 2023
cf3639b
Fix test logging errors (#664)
May 9, 2023
d7af2cf
fix bugs
fsx950223 May 9, 2023
2f782fb
revert some changes
fsx950223 May 9, 2023
45b579e
format fx2ait code
fsx950223 May 9, 2023
c312272
Add SM90 CUTLASS 3.x kernels to remaining gemm_rcr_bias_activation op…
aakhundov May 9, 2023
3fd5d10
Add bf16 support to classic_b2b_bmm op (#665)
aakhundov May 9, 2023
ce167f2
merge_updates
fsx950223 May 9, 2023
4593c4e
Add bf16 support to fmha_style_b2b_bmm op (#666)
aakhundov May 9, 2023
d915218
Windows test_standalone.py fix (#676)
May 9, 2023
7a57a02
Refactor infrastructure (#671)
May 9, 2023
cc5cbf2
introduce get_include_directories() in target_def.py files (#672)
May 9, 2023
9a1b4f9
Add is_view_of for identity op (#677)
muchulee8 May 9, 2023
47cdc54
MSVC fix (#675)
May 10, 2023
cd5862e
Fix get_positive_dim for IntVar (#680)
tissue3 May 10, 2023
5f63da8
Add bf16 support to full op (#679)
tissue3 May 10, 2023
d897342
introduce get_host_compiler_options() and get_device_compiler_options…
May 11, 2023
0481d1b
speed up MultiScaleBlock (#681)
May 11, 2023
d1d66fb
Introduce a CMake compiler engine (#674)
May 11, 2023
f49b6e4
update ci (#589)
fsx950223 May 11, 2023
87124bd
Fix codegen condition check issue
wushirong May 11, 2023
1e79696
merge updates
fsx950223 May 12, 2023
7615743
Fix the profiler bug in bmm_xxx_add SM90 kernels (#682)
aakhundov May 12, 2023
7cdc6f5
refactor op level benchmark (#660)
frank-wei May 12, 2023
1ffbc1f
Add attribute to allow tensor not to participate in constant folding …
muchulee8 May 12, 2023
e429979
Fix a tricky fused_elementwise alignment issue. (#693)
ipiszy May 14, 2023
d468fbc
add ops
fsx950223 May 15, 2023
5c92661
Fix-forward - multiscaleblock test failure (#694)
May 15, 2023
4332686
Fix codegen duplicate dim decl issue (#692)
qxy11 May 15, 2023
5847ce1
Add fast_math to LowerSettings and AIT Target (#695)
ipiszy May 15, 2023
e6a8dfc
fix bugs
fsx950223 May 16, 2023
5968179
Fix use_fast_math (#697)
ipiszy May 16, 2023
bc49d20
Add layernorm CUDA kernel based on Welford's algorithm (#698)
aakhundov May 16, 2023
f64cb6e
fix profiler group bug
fsx950223 May 17, 2023
3c28258
fix bugs
fsx950223 May 17, 2023
334877d
CUDA debug log utility class (#688)
kadeng May 17, 2023
d56544b
refactor sm detection; add quadro card name (#690)
tenpercent May 18, 2023
f738b9b
add nix-shell config (#691)
tenpercent May 18, 2023
b9d77bd
Stable Diffusion dynamic input shape, include/exclude constants, load…
hlky May 18, 2023
9dc346d
Dump function properties at profiling time. (#684)
ipiszy May 18, 2023
a0ddbe0
add graph mode
fsx950223 May 18, 2023
d7f10b7
set target_has_graph_mode to true
fsx950223 May 18, 2023
7de589f
Add missing __init__.py files (#702)
ymwangg May 18, 2023
5da4ae2
add missing Meta headers for __init__.py files (#705)
May 19, 2023
473e3f3
Add a flag for elementwise computation in float32 (#700)
aakhundov May 19, 2023
02b04e4
Add SM90 CUTLASS 3.x kernels to gemm_rcr_bias_broadcast ops (#687)
aakhundov May 19, 2023
d7e7996
Add SM90 CUTLASS 3.x kernels to perm102_bmm ops (#689)
aakhundov May 19, 2023
0b8561a
Add support for keyword argument in AITModule (#707)
henryhu6 May 20, 2023
9c26601
adjust launch config for group_layernorm kernels (#708)
chenyang78 May 21, 2023
46a81f2
remove identity ops from the graph (#703)
chenyang78 May 21, 2023
1fb2b33
fx2ait: explicitly ensure workdir's existence (#714)
kflu May 22, 2023
39455d8
enable bmm_rrr for concat + concat fusion (#716)
chenyang78 May 22, 2023
fdb0a98
support dynamic batch size
fsx950223 May 23, 2023
f86fda0
fix a bug
fsx950223 May 23, 2023
b300446
Sync CUTLASS version with upstream (#706)
aakhundov May 23, 2023
28dda09
Merge branch 'merge_upstream_update' into merge_upstream
fsx950223 May 23, 2023
fd804db
Merge remote-tracking branch 'upstream/main' into mlperf
fsx950223 May 23, 2023
6f4e747
check if the tensor is none before copying (#704)
chengscott May 23, 2023
326aae6
Fix 02_vision_model example (#718)
ymwangg May 23, 2023
329f152
fix bert example (#710)
tenpercent May 23, 2023
4c87930
Added a pass to fuse expand + bmm (#715)
chenyang78 May 24, 2023
560041b
rename layers
fsx950223 May 24, 2023
420c83a
Fix flaky test_batch_norm (#720)
aakhundov May 24, 2023
be81c39
Stable Diffusion ControlNet (#713)
hlky May 24, 2023
8f622a1
Fix fx2ait max/avg_pool with stride=None (#701)
ymwangg May 24, 2023
7f3811e
nvcc options (#721)
chengscott May 25, 2023
201937e
2/n support layernorm and elementwise op (#711)
frank-wei May 25, 2023
08bc584
fix bugs
fsx950223 May 25, 2023
a57e3bb
Merge remote-tracking branch 'upstream/main' into mlperf
fsx950223 May 25, 2023
02460f6
Back out "fix bert example" (#728)
aakhundov May 25, 2023
89711d9
Stable Diffusion Alt fixes #722 #723 (#724)
hlky May 25, 2023
126bb1d
Sync CUTLASS with upstream again (#727)
aakhundov May 27, 2023
50d8266
add missing script
fsx950223 May 29, 2023
c246a86
Split test_strided_layernorm test cases (#729)
aakhundov May 29, 2023
42790e3
fix stable diffusion example
fsx950223 May 31, 2023
bd41366
Merge remote-tracking branch 'upstream/main' into mlperf
fsx950223 May 31, 2023
cc1dceb
Merge branch 'mlperf' into merge_upstream
fsx950223 May 31, 2023
a68c83f
format code
fsx950223 May 31, 2023
dacd043
update enabled ci types
fsx950223 May 31, 2023
395dc79
Grouped Classic B2B BMM 1 ( copy base impl ) (#736)
kadeng May 31, 2023
f0e6d8c
Add SM90-related profiler extensions (#732)
aakhundov May 31, 2023
3714b40
Disable residual in SM90 kernels of gemm_rcr / rrr (#733)
aakhundov May 31, 2023
707d818
Pass bias vector via epilogue schedule in SM90 gemm_rcr_bias (#734)
aakhundov May 31, 2023
6204f62
Pass bias vector via epilogue schedule in SM90 gemm_rcr_bias_activati…
aakhundov May 31, 2023
e5b2dac
Merge remote-tracking branch 'upstream/main' into merge_upstream
fsx950223 Jun 1, 2023
a54b4c9
fix gemm hardswish
fsx950223 Jun 1, 2023
2146627
fix hstu unit test in fx2ait (#743)
frank-wei Jun 1, 2023
daf1624
update conv2d (#725)
fsx950223 Jun 1, 2023
d49a2cd
Add BFloat16 for debugging and messaging (#744)
muchulee8 Jun 1, 2023
98e6a13
Grouped Classic B2B BMM op 2 ( padded -> jagged operator implementati…
kadeng Jun 2, 2023
22a74c7
Add support of nn.functional.hardtanh (#739)
mcremon-meta Jun 2, 2023
c320d26
Eiminate elementwise no-ops (*/1, +-0) (#746)
AlbertDachiChen Jun 6, 2023
f05c8e9
Update README.md (#745)
eltociear Jun 7, 2023
329fe8a
Add PG509 as a detected CUDA GPU (#749)
erjiang Jun 7, 2023
2f2912f
Move XRayVideo related FE modules to frontend/nn + disambiguate test …
Jun 7, 2023
8d7d819
add missing copy
carlushuang Jun 9, 2023
e63e7e2
separate cuda and rocm graph
fsx950223 Jun 2, 2023
1ee2625
update setup.py
fsx950223 Jun 9, 2023
2bdc21b
fix compile bugs
fsx950223 Jun 9, 2023
e23d04f
Fix SD compilation example to use user provided H and W (#755)
apivovarov Jun 10, 2023
fff93a1
Add --model-name param to SD download_pipeline script (#757)
apivovarov Jun 11, 2023
db2a9e9
Include split+cat in fuse_split optimization (#740)
erjiang Jun 11, 2023
1c32db2
frontend.nn.attention dtype (#759)
hlky Jun 12, 2023
5d5f5f3
Increase recursion limit in dump_program using try except (#761)
hl475 Jun 12, 2023
3fe7be3
Daily `arc lint --take BLACK`
Jun 12, 2023
82accdd
Map VAE params without AIT_AutoencoderKL (#760)
hlky Jun 12, 2023
cbb8be8
small updates
fsx950223 Jun 13, 2023
5f44043
merge updates
fsx950223 Jun 13, 2023
b3d6705
format code
fsx950223 Jun 13, 2023
f91c59c
Add # usort:skip to make both internal and OSS lint happy (#763)
hl475 Jun 13, 2023
4d00081
Fix SD image quality on some GPUs (#765)
apivovarov Jun 13, 2023
8992f3f
Add support to List[List[Tensor]] Input shape (#756)
henryhu6 Jun 13, 2023
1ec9d9a
Add ait profile timeout to lower_settings (#764)
qxy11 Jun 14, 2023
c1e9b42
fix quick gelu shape
fsx950223 Jun 15, 2023
2ae8184
correctly run bmm_rrr tests (#770)
chenyang78 Jun 15, 2023
72faba8
upstream gemm and embeddings (#726)
fsx950223 Jun 15, 2023
9ee885c
add attention backend (#741)
fsx950223 Jun 15, 2023
e98d2dd
Add bf16 support to upsampling2d nearest (#750)
henryhu6 Jun 15, 2023
d09eeaa
merge updates
fsx950223 Jun 16, 2023
4f53879
fix a bug
fsx950223 Jun 16, 2023
908d861
Add jagged_lengths_to_offsets op (#766)
aakhundov Jun 19, 2023
a996ec6
Grouped classic b2b bmm op - tuned version (#771)
kadeng Jun 19, 2023
9282317
Add jagged_lengths_to_presences op (#767)
aakhundov Jun 19, 2023
eea99c1
rename to warp size
fsx950223 Jun 20, 2023
2db5b42
native CUDA development helper (#772)
kadeng Jun 20, 2023
33f2797
Standalone test cases (#773)
kadeng Jun 20, 2023
373000f
pool2d upstream (#775)
fsx950223 Jun 21, 2023
f5896d0
upstream norm (#774)
fsx950223 Jun 21, 2023
f0f676f
Download SD model without token by default (#769)
apivovarov Jun 23, 2023
9336061
Fix invalid escape sequence (#790)
r-barnes Jun 24, 2023
57c5e03
Re-sync with internal repository (#793)
facebook-github-bot Jun 24, 2023
b5bd10d
tensor upstream (#776)
fsx950223 Jun 25, 2023
e3a2388
Support List[Tensor] arg for from_two_input_lists (#792)
henryhu6 Jun 26, 2023
3c3b256
Adding support for relational operations (#783)
AlbertDachiChen Jun 27, 2023
5e30494
skip fusing slice with a strided op if the relevant tensor accessors …
chenyang78 Jun 27, 2023
0cf1c2e
Add support for where operator (#791)
AlbertDachiChen Jun 27, 2023
dccb361
Add test cases to cover `split_large_slice_scatter` (#794)
y-sq Jun 27, 2023
a20384e
Add init_random_weights test util (#800)
Jun 27, 2023
eb4c375
Sync CUTLASS with upstream (#803)
aakhundov Jun 27, 2023
b080f5c
Fix SD Alternative pipeline README to use demo_alt.py (#784)
apivovarov Jun 29, 2023
79d10cd
Split slice_scatter into multiple ones if it has too many inputs (#801)
y-sq Jun 30, 2023
039bb9f
update frontend and mk_ck_lib (#777)
fsx950223 Jun 30, 2023
774c646
merge upstream
fsx950223 Jun 30, 2023
c9add3a
update dockerfile
fsx950223 Oct 19, 2023
f4c3c92
fix a profiler bug
fsx950223 Oct 24, 2023
fcd9302
fix dockerfile
fsx950223 Nov 24, 2023
0f319a4
pin diffusers and transformers
dejay-vu Dec 12, 2023
fb5a110
merge sdxl
fsx950223 Dec 22, 2023
ebefd83
fix bugs
fsx950223 Dec 25, 2023
7e828a7
optimize performance
fsx950223 Jan 3, 2024
1829cbb
optimize performance
fsx950223 Jan 3, 2024
8e2293e
fix pool bugs
fsx950223 Jan 3, 2024
51e10dd
fix download model bug
fsx950223 Jan 4, 2024
5a6775c
remove rocm hack
fsx950223 Jan 4, 2024
db1851c
optimizer sdxl performance
fsx950223 Jan 5, 2024
318f3aa
enable hipgraph on MI300
fsx950223 Jan 5, 2024
892e32a
fix negative prompt
fsx950223 Jan 5, 2024
65c08c2
fix profile bug
fsx950223 Jan 9, 2024
5958997
fix profiler
fsx950223 Jan 10, 2024
f54c2b5
add interwave and pipeline tuning
fsx950223 Jan 17, 2024
0a8beb7
fix a profiler bug
fsx950223 Jan 18, 2024
40248f2
fix revision
fsx950223 Jan 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor test_bmm(_add) with filter_test_cases_by_test_env (facebooki…
…ncubator#635)

Summary:
Pull Request resolved: facebookincubator#635

Almost all unit tests from `test_bmm` and `test_bmm_add` are running on *both* V100 and A100 hosts. This is inefficient, as only a small fraction of the tests must run on A100. In this diff, the tests are refactored to rely on `filter_test_cases_by_test_env` instead of `filter_test_cases_by_params`, which leads to a more frugal use of A100 hosts. See the test plan for the before / after numbers.

Reviewed By: hl475

Differential Revision: D45405007

fbshipit-source-id: 98ce1c39d09b057586f759f22b9eedd6bcd4dbd5
  • Loading branch information
aakhundov authored and facebook-github-bot committed Apr 28, 2023
commit 6ad335187006781d01d86ab5acdf2636dad6a8b8
193 changes: 132 additions & 61 deletions tests/unittest/ops/test_bmm.py
Original file line number Diff line number Diff line change
@@ -21,21 +21,12 @@
from aitemplate.frontend import Tensor
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import (
filter_test_cases_by_params,
filter_test_cases_by_test_env,
get_random_torch_tensor,
get_torch_empty_tensor,
TestEnv,
)
from aitemplate.utils import shape_utils

from parameterized import parameterized


_TEST_PARAMS = {
TestEnv.CUDA_LESS_THAN_SM80: [("float16")],
TestEnv.CUDA_SM80: [("float32"), ("bfloat16")],
}


class BMMTestCase(unittest.TestCase):
def _test_rcr(self, bs, ms, N, K, test_name, dtype="float16"):
@@ -68,15 +59,15 @@ def _test_rcr(self, bs, ms, N, K, test_name, dtype="float16"):

def test_rcr(self):
self._test_rcr([1024], [128], N=512, K=256, test_name="static")
if detect_target().name() == "cuda":
self._test_rcr([1, 5, 977, 1024], [32], N=512, K=256, test_name="dynamic_b")
self._test_rcr([1], [100, 200, 300], N=512, K=256, test_name="dynamic_m")
self._test_rcr(
[1, 2, 5], [100, 200, 300], N=512, K=256, test_name="dynamic_bm"
)
self._test_rcr([0], [128], N=512, K=256, test_name="zero_batch")
self._test_rcr([1], [128], N=512, K=0, test_name="zero_k")
self._test_rcr([1], [128], N=0, K=8, test_name="zero_n")
self._test_rcr([1, 5, 977, 1024], [32], N=512, K=256, test_name="dynamic_b")
self._test_rcr([1], [100, 200, 300], N=512, K=256, test_name="dynamic_m")
self._test_rcr([1, 2, 5], [100, 200, 300], N=512, K=256, test_name="dynamic_bm")
self._test_rcr([0], [128], N=512, K=256, test_name="zero_batch")
self._test_rcr([1], [128], N=512, K=0, test_name="zero_k")
self._test_rcr([1], [128], N=0, K=8, test_name="zero_n")

def test_rcr_rocm(self):
self._test_rcr([1024], [128], N=512, K=256, test_name="static")

def _test_crr(self, bs, ks, M, N, test_name, dtype="float16"):
target = detect_target()
@@ -107,10 +98,12 @@ def _test_crr(self, bs, ks, M, N, test_name, dtype="float16"):

def test_crr(self):
self._test_crr([1024], [128], M=256, N=512, test_name="static")
if detect_target().name() == "cuda":
self._test_crr([3, 977, 1024], [128], M=256, N=512, test_name="dynamic_b")
self._test_crr([5], [45, 56, 78], M=256, N=512, test_name="dynamic_k")
self._test_crr([1, 2, 5], [3, 6, 8], M=256, N=512, test_name="dynamic_bk")
self._test_crr([3, 977, 1024], [128], M=256, N=512, test_name="dynamic_b")
self._test_crr([5], [45, 56, 78], M=256, N=512, test_name="dynamic_k")
self._test_crr([1, 2, 5], [3, 6, 8], M=256, N=512, test_name="dynamic_bk")

def test_crr_rocm(self):
self._test_crr([1024], [128], M=256, N=512, test_name="static")

def _test_rrr(self, bs, ms, K, N, test_name, dtype="float16"):
target = detect_target()
@@ -138,10 +131,12 @@ def _test_rrr(self, bs, ms, K, N, test_name, dtype="float16"):

def test_rrr(self):
self._test_rrr([87], [23], K=256, N=512, test_name="static")
if detect_target().name() == "cuda":
self._test_rrr([2, 5, 99], [23], K=128, N=512, test_name="dynamic_b")
self._test_rrr([77], [4, 7, 9], K=8, N=512, test_name="dynamic_m")
self._test_rrr([2, 5, 7], [1, 7, 9], K=256, N=512, test_name="dynamic_bm")
self._test_rrr([2, 5, 99], [23], K=128, N=512, test_name="dynamic_b")
self._test_rrr([77], [4, 7, 9], K=8, N=512, test_name="dynamic_m")
self._test_rrr([2, 5, 7], [1, 7, 9], K=256, N=512, test_name="dynamic_bm")

def test_rrr_rocm(self):
self._test_rrr([87], [23], K=256, N=512, test_name="static")

def _test_ccr(self, bs, M, N, K, test_name, dtype="float16"):
target = detect_target()
@@ -166,8 +161,10 @@ def _test_ccr(self, bs, M, N, K, test_name, dtype="float16"):

def test_ccr(self):
self._test_ccr([77], M=256, N=64, K=128, test_name="static")
if detect_target().name() == "cuda":
self._test_ccr([1, 9, 101], M=256, N=64, K=128, test_name="dynamic_b")
self._test_ccr([1, 9, 101], M=256, N=64, K=128, test_name="dynamic_b")

def test_ccr_rocm(self):
self._test_ccr([77], M=256, N=64, K=128, test_name="static")

def _test_rcc(self, bs, ms, N, K, test_name, dtype="float16"):
target = detect_target()
@@ -200,15 +197,15 @@ def _test_rcc(self, bs, ms, N, K, test_name, dtype="float16"):

def test_rcc(self):
self._test_rcc([1024], [128], N=512, K=256, test_name="static")
if detect_target().name() == "cuda":
self._test_rcc([1, 5, 977, 1024], [32], N=512, K=256, test_name="dynamic_b")
self._test_rcc([1], [100, 200, 300], N=512, K=256, test_name="dynamic_m")
self._test_rcc(
[1, 2, 5], [100, 200, 300], N=512, K=256, test_name="dynamic_bm"
)
self._test_rcc([0], [128], N=512, K=256, test_name="zero_batch")
self._test_rcc([1], [128], N=512, K=0, test_name="zero_k")
self._test_rcc([1], [128], N=0, K=8, test_name="zero_n")
self._test_rcc([1, 5, 977, 1024], [32], N=512, K=256, test_name="dynamic_b")
self._test_rcc([1], [100, 200, 300], N=512, K=256, test_name="dynamic_m")
self._test_rcc([1, 2, 5], [100, 200, 300], N=512, K=256, test_name="dynamic_bm")
self._test_rcc([0], [128], N=512, K=256, test_name="zero_batch")
self._test_rcc([1], [128], N=512, K=0, test_name="zero_k")
self._test_rcc([1], [128], N=0, K=8, test_name="zero_n")

def test_rcc_rocm(self):
self._test_rcc([1024], [128], N=512, K=256, test_name="static")

def _test_crc(self, bs, ks, M, N, test_name, dtype="float16"):
target = detect_target()
@@ -240,10 +237,12 @@ def _test_crc(self, bs, ks, M, N, test_name, dtype="float16"):

def test_crc(self):
self._test_crc([1024], [128], M=256, N=512, test_name="static")
if detect_target().name() == "cuda":
self._test_crc([3, 977, 1024], [128], M=256, N=512, test_name="dynamic_b")
self._test_crc([5], [45, 56, 78], M=256, N=512, test_name="dynamic_k")
self._test_crc([1, 2, 5], [3, 6, 8], M=256, N=512, test_name="dynamic_bk")
self._test_crc([3, 977, 1024], [128], M=256, N=512, test_name="dynamic_b")
self._test_crc([5], [45, 56, 78], M=256, N=512, test_name="dynamic_k")
self._test_crc([1, 2, 5], [3, 6, 8], M=256, N=512, test_name="dynamic_bk")

def test_crc_rocm(self):
self._test_crc([1024], [128], M=256, N=512, test_name="static")

def _test_rrc(self, bs, ms, K, N, test_name, dtype="float16"):
target = detect_target()
@@ -272,10 +271,12 @@ def _test_rrc(self, bs, ms, K, N, test_name, dtype="float16"):

def test_rrc(self):
self._test_rrc([87], [23], K=256, N=512, test_name="static")
if detect_target().name() == "cuda":
self._test_rrc([2, 5, 99], [23], K=128, N=512, test_name="dynamic_b")
self._test_rrc([77], [4, 7, 9], K=8, N=512, test_name="dynamic_m")
self._test_rrc([2, 5, 7], [1, 7, 9], K=256, N=512, test_name="dynamic_bm")
self._test_rrc([2, 5, 99], [23], K=128, N=512, test_name="dynamic_b")
self._test_rrc([77], [4, 7, 9], K=8, N=512, test_name="dynamic_m")
self._test_rrc([2, 5, 7], [1, 7, 9], K=256, N=512, test_name="dynamic_bm")

def test_rrc_rocm(self):
self._test_rrc([87], [23], K=256, N=512, test_name="static")

def _test_ccc(self, bs, M, N, K, test_name, dtype="float16"):
target = detect_target()
@@ -302,12 +303,37 @@ def _test_ccc(self, bs, M, N, K, test_name, dtype="float16"):

def test_ccc(self):
self._test_ccc([77], M=256, N=64, K=128, test_name="static")
if detect_target().name() == "cuda":
self._test_ccc([1, 9, 101], M=256, N=64, K=128, test_name="dynamic_b")
self._test_ccc([1, 9, 101], M=256, N=64, K=128, test_name="dynamic_b")

def test_ccc_rocm(self):
self._test_ccc([77], M=256, N=64, K=128, test_name="static")

def test_bmm_0_fp32_sm80(self, dtype="float32"):
self._test_rcr([128], [64], N=8, K=64, test_name=f"static_{dtype}", dtype=dtype)
self._test_rcr(
[1, 5, 77, 128],
[32],
N=16,
K=64,
test_name=f"dynamic_b_{dtype}",
dtype=dtype,
)
self._test_crr(
[1, 2, 5],
[3, 6, 8],
M=24,
N=64,
test_name=f"dynamic_bk_{dtype}",
dtype=dtype,
)
self._test_rrr(
[8], [4, 7, 9], K=64, N=32, test_name=f"dynamic_m_{dtype}", dtype=dtype
)
self._test_ccr(
[1, 9, 11], M=64, N=32, K=16, test_name=f"dynamic_b_{dtype}", dtype=dtype
)

@parameterized.expand(**filter_test_cases_by_params(_TEST_PARAMS))
@unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.")
def test_bmm_0_dtype(self, dtype):
def test_bmm_0_bf16(self, dtype="bfloat16"):
self._test_rcr([128], [64], N=8, K=64, test_name=f"static_{dtype}", dtype=dtype)
self._test_rcr(
[1, 5, 77, 128],
@@ -332,9 +358,32 @@ def test_bmm_0_dtype(self, dtype):
[1, 9, 11], M=64, N=32, K=16, test_name=f"dynamic_b_{dtype}", dtype=dtype
)

@parameterized.expand(**filter_test_cases_by_params(_TEST_PARAMS))
@unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.")
def test_bmm_1_dtype(self, dtype):
def test_bmm_1_fp32_sm80(self, dtype="float32"):
self._test_rcc([128], [64], N=8, K=64, test_name=f"static_{dtype}", dtype=dtype)
self._test_rcc(
[1, 5, 77, 128],
[32],
N=16,
K=64,
test_name=f"dynamic_b_{dtype}",
dtype=dtype,
)
self._test_crc(
[1, 2, 5],
[3, 6, 8],
M=24,
N=64,
test_name=f"dynamic_bk_{dtype}",
dtype=dtype,
)
self._test_rrc(
[8], [4, 7, 9], K=64, N=32, test_name=f"dynamic_m_{dtype}", dtype=dtype
)
self._test_ccc(
[1, 9, 11], M=64, N=32, K=16, test_name=f"dynamic_b_{dtype}", dtype=dtype
)

def test_bmm_1_bf16(self, dtype="bfloat16"):
self._test_rcc([128], [64], N=8, K=64, test_name=f"static_{dtype}", dtype=dtype)
self._test_rcc(
[1, 5, 77, 128],
@@ -727,8 +776,7 @@ def test_ccc(self):
self._test_ccr([8, 16], [8, 32, 8], "2d_broadcastable_a")
self._test_ccr([8, 8, 16], [32, 8], "2d_broadcastable_b")

@parameterized.expand(**filter_test_cases_by_params(_TEST_PARAMS))
def test_bmm_broadcast_0_dtype(self, dtype):
def test_bmm_broadcast_0_fp32_sm80(self, dtype="float32"):
self._test_rcr([2, 16, 8], [1, 32, 8], f"broadcastable_b_{dtype}", dtype=dtype)
self._test_rcr([16, 8], [8, 32, 8], f"2d_broadcastable_a_{dtype}", dtype=dtype)
self._test_crr([1, 8, 16], [2, 8, 32], f"broadcastable_a_{dtype}", dtype=dtype)
@@ -738,8 +786,27 @@ def test_bmm_broadcast_0_dtype(self, dtype):
self._test_ccr([1, 8, 16], [2, 32, 8], f"broadcastable_a_{dtype}", dtype=dtype)
self._test_ccr([8, 8, 16], [32, 8], f"2d_broadcastable_b_{dtype}", dtype=dtype)

@parameterized.expand(**filter_test_cases_by_params(_TEST_PARAMS))
def test_bmm_broadcast_1_dtype(self, dtype):
def test_bmm_broadcast_0_bf16(self, dtype="bfloat16"):
self._test_rcr([2, 16, 8], [1, 32, 8], f"broadcastable_b_{dtype}", dtype=dtype)
self._test_rcr([16, 8], [8, 32, 8], f"2d_broadcastable_a_{dtype}", dtype=dtype)
self._test_crr([1, 8, 16], [2, 8, 32], f"broadcastable_a_{dtype}", dtype=dtype)
self._test_crr([8, 8, 16], [8, 32], f"2d_broadcastable_b_{dtype}", dtype=dtype)
self._test_rrr([2, 16, 8], [1, 8, 32], f"broadcastable_b_{dtype}", dtype=dtype)
self._test_rrr([16, 8], [8, 8, 32], f"2d_broadcastable_a_{dtype}", dtype=dtype)
self._test_ccr([1, 8, 16], [2, 32, 8], f"broadcastable_a_{dtype}", dtype=dtype)
self._test_ccr([8, 8, 16], [32, 8], f"2d_broadcastable_b_{dtype}", dtype=dtype)

def test_bmm_broadcast_1_fp32_sm80(self, dtype="float32"):
self._test_rcc([2, 16, 8], [1, 32, 8], f"broadcastable_b_{dtype}", dtype=dtype)
self._test_rcc([16, 8], [8, 32, 8], f"2d_broadcastable_a_{dtype}", dtype=dtype)
self._test_crc([1, 8, 16], [2, 8, 32], f"broadcastable_a_{dtype}", dtype=dtype)
self._test_crc([8, 8, 16], [8, 32], f"2d_broadcastable_b_{dtype}", dtype=dtype)
self._test_rrc([2, 16, 8], [1, 8, 32], f"broadcastable_b_{dtype}", dtype=dtype)
self._test_rrc([16, 8], [8, 8, 32], f"2d_broadcastable_a_{dtype}", dtype=dtype)
self._test_ccc([1, 8, 16], [2, 32, 8], f"broadcastable_a_{dtype}", dtype=dtype)
self._test_ccc([8, 8, 16], [32, 8], f"2d_broadcastable_b_{dtype}", dtype=dtype)

def test_bmm_broadcast_1_bf16(self, dtype="bfloat16"):
self._test_rcc([2, 16, 8], [1, 32, 8], f"broadcastable_b_{dtype}", dtype=dtype)
self._test_rcc([16, 8], [8, 32, 8], f"2d_broadcastable_a_{dtype}", dtype=dtype)
self._test_crc([1, 8, 16], [2, 8, 32], f"broadcastable_a_{dtype}", dtype=dtype)
@@ -772,7 +839,7 @@ def test_rcr_fail(self, dtype="float16"):
try:
module.run_with_tensors({"input_0": X_pt, "input_1": W_pt}, [y])
raise AssertionError(
"Shouldn't be able to run be imcompatible tensor shape!"
"Shouldn't be able to run be incompatible tensor shape!"
)
except RuntimeError:
pass
@@ -800,7 +867,7 @@ def test_rrr_fail(self, dtype="float16"):
try:
module.run_with_tensors({"input_0": X_pt, "input_1": W_pt}, [y])
raise AssertionError(
"Shouldn't be able to run be imcompatible tensor shape!"
"Shouldn't be able to run be incompatible tensor shape!"
)
except RuntimeError:
pass
@@ -828,11 +895,15 @@ def test_rcc_fail(self, dtype="float16"):
try:
module.run_with_tensors({"input_0": X_pt, "input_1": W_pt}, [y])
raise AssertionError(
"Shouldn't be able to run be imcompatible tensor shape!"
"Shouldn't be able to run be incompatible tensor shape!"
)
except RuntimeError:
pass


filter_test_cases_by_test_env(BMMTestCase)
filter_test_cases_by_test_env(BMMBroadcastTestCase)


if __name__ == "__main__":
unittest.main()
Loading