@@ -64,26 +64,26 @@ size_t dispatchGemmClusterShapeSm100(__nv_bfloat16 const* A, __nv_bfloat16 const
6464 _1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr,
6565 workspaceBytes, stream);
6666 break ;
67- // case ClusterShape::ClusterShape_1x2x1:
68- // return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_,
69- // Shape<_1, _2, _1>, _1SM>(
70- // A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
71- // break;
72- // case ClusterShape::ClusterShape_1x4x1:
73- // return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_,
74- // Shape<_1, _4, _1>, _1SM>(
75- // A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
76- // break;
77- // case ClusterShape::ClusterShape_2x1x1:
78- // return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_,
79- // Shape<_2, _1, _1>, _2SM>(
80- // A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
81- // break;
82- // case ClusterShape::ClusterShape_2x2x1:
83- // return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_,
84- // Shape<_2, _2, _1>, _2SM>(
85- // A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
86- // break;
67+ case ClusterShape::ClusterShape_1x2x1:
68+ return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_1, _2, _1> ,
69+ _1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr,
70+ workspaceBytes, stream);
71+ break ;
72+ case ClusterShape::ClusterShape_1x4x1:
73+ return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_1, _4, _1> ,
74+ _1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr,
75+ workspaceBytes, stream);
76+ break ;
77+ case ClusterShape::ClusterShape_2x1x1:
78+ return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_2, _1, _1> ,
79+ _2SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr,
80+ workspaceBytes, stream);
81+ break ;
82+ case ClusterShape::ClusterShape_2x2x1:
83+ return genericBf16GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_2, _2, _1> ,
84+ _2SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr,
85+ workspaceBytes, stream);
86+ break ;
8787 default :
8888 throw std::runtime_error (" invalid config for bf16 gemm" );
8989 break ;
@@ -101,31 +101,22 @@ size_t dispatchToArch(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, i
101101 return dispatchGemmClusterShapeSm100<T, arch, 64 , 64 , 128 >(
102102 B, A, static_cast <T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
103103 break ;
104- // case CutlassTileConfigSM100::CtaShape64x128x128B:
105- // return dispatchGemmClusterShapeSm100<T, arch, 64, 128, 128>(
106- // B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
107- // stream);
108- // break;
109- // case CutlassTileConfigSM100::CtaShape64x256x128B:
110- // return dispatchGemmClusterShapeSm100<T, arch, 64, 256, 128>(
111- // B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
112- // stream);
113- // break;
114- // case CutlassTileConfigSM100::CtaShape128x64x128B:
115- // return dispatchGemmClusterShapeSm100<T, arch, 128, 64, 128>(
116- // B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
117- // stream);
118- // break;
119- // case CutlassTileConfigSM100::CtaShape128x128x128B:
120- // return dispatchGemmClusterShapeSm100<T, arch, 128, 128, 128>(
121- // B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
122- // stream);
123- // break;
124- // case CutlassTileConfigSM100::CtaShape128x256x128B:
125- // return dispatchGemmClusterShapeSm100<T, arch, 128, 256, 128>(
126- // B, A, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
127- // stream);
128- // break;
104+ case CutlassTileConfigSM100::CtaShape64x128x128B:
105+ return dispatchGemmClusterShapeSm100<T, arch, 64 , 128 , 128 >(
106+ B, A, static_cast <T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
107+ break ;
108+ case CutlassTileConfigSM100::CtaShape64x256x128B:
109+ return dispatchGemmClusterShapeSm100<T, arch, 64 , 256 , 128 >(
110+ B, A, static_cast <T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
111+ break ;
112+ case CutlassTileConfigSM100::CtaShape128x64x128B:
113+ return dispatchGemmClusterShapeSm100<T, arch, 128 , 64 , 128 >(
114+ B, A, static_cast <T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
115+ break ;
116+ case CutlassTileConfigSM100::CtaShape128x128x128B:
117+ return dispatchGemmClusterShapeSm100<T, arch, 128 , 128 , 128 >(
118+ B, A, static_cast <T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
119+ break ;
129120
130121 default :
131122 throw std::runtime_error (" unsupported tile config for bf16 gemm" );
@@ -189,15 +180,15 @@ std::vector<CutlassGemmConfig> CutlassBf16GemmRunner<T>::getConfigs() const {
189180 std::vector<CutlassGemmConfig> candidate_configs;
190181
191182 std::vector<CutlassTileConfigSM100> tilesSm100 = {
192- CutlassTileConfigSM100::CtaShape64x64x128B, // CutlassTileConfigSM100::CtaShape64x128x128B,
193- // CutlassTileConfigSM100::CtaShape64x256x128B, CutlassTileConfigSM100::CtaShape128x64x128B,
194- // CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B ,
183+ CutlassTileConfigSM100::CtaShape64x64x128B, CutlassTileConfigSM100::CtaShape64x128x128B,
184+ CutlassTileConfigSM100::CtaShape64x256x128B, CutlassTileConfigSM100::CtaShape128x64x128B,
185+ CutlassTileConfigSM100::CtaShape128x128x128B,
195186 };
196187
197188 std::vector<ClusterShape> clusterShapes = {
198- ClusterShape::ClusterShape_1x1x1, // ClusterShape::ClusterShape_1x2x1,
199- // ClusterShape::ClusterShape_1x4x1, ClusterShape::ClusterShape_2x1x1,
200- // ClusterShape::ClusterShape_2x2x1,
189+ ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1,
190+ ClusterShape::ClusterShape_1x4x1, ClusterShape::ClusterShape_2x1x1,
191+ ClusterShape::ClusterShape_2x2x1,
201192 };
202193
203194 for (auto const & tile_config : tilesSm100) {
0 commit comments