Skip to content

Commit dcbc17a

Browse files
committed
small notes, enable other tile sizes
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
1 parent 8a58e45 commit dcbc17a

File tree

5 files changed

+87
-69
lines changed

5 files changed

+87
-69
lines changed

csrc/bf16_gemm_cutlass.jinja

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
namespace flashinfer {
2020
namespace gemm {
2121
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM);
22-
// INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM);
23-
// INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM);
24-
// INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM);
25-
// INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM);
22+
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM);
23+
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM);
24+
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM);
25+
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM);
2626
} // namespace gemm
2727
} // namespace flashinfer

flashinfer/gemm/gemm_base.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,18 @@ def mm_bf16(
211211
torch.Tensor
212212
Out tensor, shape (m, n), bf16 or fp16.
213213
214-
# Note: add Examples section here
214+
Examples
215+
--------
216+
>>> import torch
217+
>>> import torch.nn.functional as F
218+
>>> import flashinfer
219+
>>> input = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16
220+
>>> weight = torch.randn([80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
221+
>>> out = flashinfer.mm_bf16(input, weight)
222+
>>> print(out.shape)
223+
torch.Size([48, 80])
224+
>>> out.dtype
225+
torch.bfloat16
215226
"""
216227
if backend != "cutlass":
217228
raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")
@@ -277,7 +288,18 @@ def bmm_bf16(
277288
torch.Tensor
278289
Out tensor, shape (b, m, n), bf16 or fp16.
279290
280-
# Note: add Examples section here
291+
Examples
292+
--------
293+
>>> import torch
294+
>>> import torch.nn.functional as F
295+
>>> import flashinfer
296+
>>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16
297+
>>> weight = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
298+
>>> out = flashinfer.bmm_bf16(input, weight)
299+
>>> print(out.shape)
300+
torch.Size([16, 48, 80])
301+
>>> out.dtype
302+
torch.bfloat16
281303
"""
282304
if backend != "cutlass":
283305
raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")

flashinfer/jit/gemm/core.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,10 @@ def gen_gemm_sm100_module_cutlass_bf16() -> JitSpec:
202202
dtype_list = ["__nv_bfloat16", "half"]
203203
cta_m_n_k_list = [
204204
(64, 64, 128),
205-
# (64, 128, 128),
206-
# (64, 256, 128),
207-
# (128, 64, 128),
208-
# (128, 128, 128),
209-
# (128, 256, 128),
205+
(64, 128, 128),
206+
(64, 256, 128),
207+
(128, 64, 128),
208+
(128, 128, 128),
210209
]
211210
for cta_m, cta_n, cta_k in cta_m_n_k_list:
212211
for dtype in dtype_list:

include/flashinfer/gemm/bf16_gemm_cutlass_template.h

Lines changed: 42 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -64,26 +64,26 @@ size_t dispatchGemmClusterShapeSm100(__nv_bfloat16 const* A, __nv_bfloat16 const
6464
_1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr,
6565
workspaceBytes, stream);
6666
break;
67-
// case ClusterShape::ClusterShape_1x2x1:
68-
// return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_,
69-
// Shape<_1, _2, _1>, _1SM>(
70-
// A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
71-
// break;
72-
// case ClusterShape::ClusterShape_1x4x1:
73-
// return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_,
74-
// Shape<_1, _4, _1>, _1SM>(
75-
// A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
76-
// break;
77-
// case ClusterShape::ClusterShape_2x1x1:
78-
// return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_,
79-
// Shape<_2, _1, _1>, _2SM>(
80-
// A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
81-
// break;
82-
// case ClusterShape::ClusterShape_2x2x1:
83-
// return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_,
84-
// Shape<_2, _2, _1>, _2SM>(
85-
// A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
86-
// break;
67+
case ClusterShape::ClusterShape_1x2x1:
68+
return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_1, _2, _1>,
69+
_1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr,
70+
workspaceBytes, stream);
71+
break;
72+
case ClusterShape::ClusterShape_1x4x1:
73+
return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_1, _4, _1>,
74+
_1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr,
75+
workspaceBytes, stream);
76+
break;
77+
case ClusterShape::ClusterShape_2x1x1:
78+
return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_2, _1, _1>,
79+
_2SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr,
80+
workspaceBytes, stream);
81+
break;
82+
case ClusterShape::ClusterShape_2x2x1:
83+
return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_2, _2, _1>,
84+
_2SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr,
85+
workspaceBytes, stream);
86+
break;
8787
default:
8888
throw std::runtime_error("invalid config for bf16 gemm");
8989
break;
@@ -101,31 +101,22 @@ size_t dispatchToArch(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, i
101101
return dispatchGemmClusterShapeSm100<T, arch, 64, 64, 128>(
102102
B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
103103
break;
104-
// case CutlassTileConfigSM100::CtaShape64x128x128B:
105-
// return dispatchGemmClusterShapeSm100<T, arch, 64, 128, 128>(
106-
// B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
107-
// stream);
108-
// break;
109-
// case CutlassTileConfigSM100::CtaShape64x256x128B:
110-
// return dispatchGemmClusterShapeSm100<T, arch, 64, 256, 128>(
111-
// B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
112-
// stream);
113-
// break;
114-
// case CutlassTileConfigSM100::CtaShape128x64x128B:
115-
// return dispatchGemmClusterShapeSm100<T, arch, 128, 64, 128>(
116-
// B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
117-
// stream);
118-
// break;
119-
// case CutlassTileConfigSM100::CtaShape128x128x128B:
120-
// return dispatchGemmClusterShapeSm100<T, arch, 128, 128, 128>(
121-
// B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
122-
// stream);
123-
// break;
124-
// case CutlassTileConfigSM100::CtaShape128x256x128B:
125-
// return dispatchGemmClusterShapeSm100<T, arch, 128, 256, 128>(
126-
// B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
127-
// stream);
128-
// break;
104+
case CutlassTileConfigSM100::CtaShape64x128x128B:
105+
return dispatchGemmClusterShapeSm100<T, arch, 64, 128, 128>(
106+
B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
107+
break;
108+
case CutlassTileConfigSM100::CtaShape64x256x128B:
109+
return dispatchGemmClusterShapeSm100<T, arch, 64, 256, 128>(
110+
B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
111+
break;
112+
case CutlassTileConfigSM100::CtaShape128x64x128B:
113+
return dispatchGemmClusterShapeSm100<T, arch, 128, 64, 128>(
114+
B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
115+
break;
116+
case CutlassTileConfigSM100::CtaShape128x128x128B:
117+
return dispatchGemmClusterShapeSm100<T, arch, 128, 128, 128>(
118+
B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
119+
break;
129120

130121
default:
131122
throw std::runtime_error("unsupported tile config for bf16 gemm");
@@ -189,15 +180,15 @@ std::vector<CutlassGemmConfig> CutlassBf16GemmRunner<T>::getConfigs() const {
189180
std::vector<CutlassGemmConfig> candidate_configs;
190181

191182
std::vector<CutlassTileConfigSM100> tilesSm100 = {
192-
CutlassTileConfigSM100::CtaShape64x64x128B, // CutlassTileConfigSM100::CtaShape64x128x128B,
193-
// CutlassTileConfigSM100::CtaShape64x256x128B, CutlassTileConfigSM100::CtaShape128x64x128B,
194-
// CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B,
183+
CutlassTileConfigSM100::CtaShape64x64x128B, CutlassTileConfigSM100::CtaShape64x128x128B,
184+
CutlassTileConfigSM100::CtaShape64x256x128B, CutlassTileConfigSM100::CtaShape128x64x128B,
185+
CutlassTileConfigSM100::CtaShape128x128x128B,
195186
};
196187

197188
std::vector<ClusterShape> clusterShapes = {
198-
ClusterShape::ClusterShape_1x1x1, // ClusterShape::ClusterShape_1x2x1,
199-
// ClusterShape::ClusterShape_1x4x1, ClusterShape::ClusterShape_2x1x1,
200-
// ClusterShape::ClusterShape_2x2x1,
189+
ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1,
190+
ClusterShape::ClusterShape_1x4x1, ClusterShape::ClusterShape_2x1x1,
191+
ClusterShape::ClusterShape_2x2x1,
201192
};
202193

203194
for (auto const& tile_config : tilesSm100) {

include/flashinfer/gemm/bf16_gemm_template_sm100.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,25 +151,31 @@ size_t genericBf16GemmKernelLauncherSm100(__nv_bfloat16 const* A, __nv_bfloat16
151151

152152
Gemm gemm;
153153

154-
CUTLASS_CHECK(gemm.can_implement(arguments));
154+
// Return workspace size
155+
if (!A && !B && !D) {
156+
return gemm.get_workspace_size(arguments);
157+
}
155158

156-
size_t workspace_size = gemm.get_workspace_size(arguments);
157-
if (workspace_size > workspaceBytes) {
159+
if (gemm.get_workspace_size(arguments) > workspaceBytes) {
158160
throw std::runtime_error("[Bf16 Gemm Runner] insufficient workspace");
159161
}
160162

161-
// NOTE: These can also be simplified using CUTLASS_CHECK. Same goes for some of the other files.
162-
cutlass::Status initStatus = gemm.initialize(arguments, workspacePtr, stream);
163+
auto can_implement = gemm.can_implement(arguments);
164+
if (can_implement != cutlass::Status::kSuccess) {
165+
throw std::runtime_error("[Bf16 Gemm Runner] cutlass kernel not implemented given the params");
166+
}
167+
168+
auto initStatus = gemm.initialize(arguments, workspacePtr);
163169
if (initStatus != cutlass::Status::kSuccess) {
164170
throw std::runtime_error("[Bf16 Gemm Runner] failed to initialize");
165171
}
166172

167-
cutlass::Status runStatus = gemm.run(stream);
173+
auto runStatus = gemm.run(stream);
168174
if (runStatus != cutlass::Status::kSuccess) {
169175
throw std::runtime_error("[Bf16 Gemm Runner] failed to run");
170176
}
171177

172-
return workspace_size;
178+
return gemm.get_workspace_size(arguments);
173179
}
174180

175181
} // namespace gemm

0 commit comments

Comments
 (0)