Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
404263d
low level abstraction
lzhangzz Mar 27, 2025
81bfa75
refactor
lzhangzz Apr 2, 2025
770b85d
eliminate template
lzhangzz Apr 7, 2025
e3a9619
remove unused
lzhangzz Apr 7, 2025
6b9a433
refactor bindings
lzhangzz Apr 7, 2025
613aeec
simplify lm head
lzhangzz Apr 7, 2025
e3fe34c
refactor weight
lzhangzz Apr 8, 2025
1e057d1
fix tp
lzhangzz Apr 8, 2025
40e9097
cublas
lzhangzz Apr 8, 2025
6fc9cc9
Merge remote-tracking branch 'origin/main' into core
lzhangzz Apr 9, 2025
ff3b5f7
refactor sampling
lzhangzz Apr 10, 2025
06ff641
remove unused
lzhangzz Apr 10, 2025
14a7f45
simplify
lzhangzz Apr 11, 2025
096155c
fix AWQ support
lzhangzz Apr 11, 2025
5fd35ae
fix moe
lzhangzz Apr 11, 2025
0c5ef46
fix nccl lm_head
lzhangzz Apr 11, 2025
c2020b2
fix
lzhangzz Apr 11, 2025
510675c
refactor data types
lzhangzz Apr 15, 2025
00b121e
skip legacy ut
lzhangzz Apr 15, 2025
88d17d4
simplify
lzhangzz Apr 15, 2025
699c24f
rename data types
lzhangzz Apr 15, 2025
3ffd070
refactor
lzhangzz Apr 15, 2025
eed6bfb
refactor runtime states
lzhangzz Apr 16, 2025
d2ec3af
fix msvc build
lzhangzz Apr 16, 2025
2529631
fix msvc build
lzhangzz Apr 16, 2025
1d77856
fix msvc build
lzhangzz Apr 16, 2025
6e728cf
fix msvc build
lzhangzz Apr 16, 2025
1b6a80d
fix msvc build
lzhangzz Apr 16, 2025
0d976d3
fix msvc build
lzhangzz Apr 16, 2025
18e7602
fix msvc build
lzhangzz Apr 16, 2025
7fec496
fix msvc build
lzhangzz Apr 16, 2025
69b1841
fix msvc build
lzhangzz Apr 16, 2025
c8bc36d
format
lzhangzz Apr 16, 2025
8161c0d
remove unused
lzhangzz Apr 16, 2025
7459992
fix msvc build
lzhangzz Apr 16, 2025
d38421f
fix msvc build
lzhangzz Apr 16, 2025
7d6ab03
fix msvc build
lzhangzz Apr 16, 2025
b214a0e
fix msvc build
lzhangzz Apr 16, 2025
3ab38ca
fix msvc build
lzhangzz Apr 16, 2025
f394ef0
fix msvc build
lzhangzz Apr 16, 2025
105f1cc
fix msvc build
lzhangzz Apr 16, 2025
5ccf30c
fix msvc build
lzhangzz Apr 16, 2025
b59620c
fix msvc build
lzhangzz Apr 16, 2025
a243da0
fix msvc build
lzhangzz Apr 16, 2025
42172d3
fix msvc build
lzhangzz Apr 16, 2025
4d9910a
fix msvc build
lzhangzz Apr 16, 2025
bf7c213
fix msvc build
lzhangzz Apr 16, 2025
8651edd
fix msvc build
lzhangzz Apr 16, 2025
4788b80
fix msvc build
lzhangzz Apr 16, 2025
529225d
fix msvc build
lzhangzz Apr 16, 2025
6fd5f72
fix msvc build
lzhangzz Apr 17, 2025
86d4e86
fix msvc build
lzhangzz Apr 17, 2025
8ea7e20
fix ut & msvc build
lzhangzz Apr 17, 2025
dcf9669
fix ut & msvc build
lzhangzz Apr 17, 2025
7f8974b
fix gcc build
lzhangzz Apr 17, 2025
646813b
fix lint & ut
lzhangzz Apr 17, 2025
98f4840
fix lint
lzhangzz Apr 17, 2025
ea07957
fetch Catch2 when building tests
lzhangzz Apr 17, 2025
5d69923
rewind msvc build
lzhangzz Apr 17, 2025
d0079b5
fix sampling
lzhangzz Apr 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/windows-x64-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ jobs:
INPUT_CUDA_VERSION: ${{ matrix.cudaver }}
- name: Build wheel
run: |
$env:BUILD_TEST="ON"
$env:BUILD_TEST="OFF"
mkdir build
cd build
..\builder\windows\generate.ps1
cmake --build . --config Release -- /m /v:q
cmake --build . --config Release -- /m /v:n
if (-Not $?) {
echo "build failed"
exit 1
Expand Down
56 changes: 34 additions & 22 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,28 @@
cmake_minimum_required(VERSION 3.11 FATAL_ERROR) # for PyTorch extensions, version should be greater than 3.13
project(TurboMind LANGUAGES CXX CUDA)

find_package(CUDA 10.2 REQUIRED)
if (MSVC)
# use standard conformant preprocessor
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/Zc:preprocessor>)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/Zc:preprocessor")
endif ()

find_package(CUDAToolkit REQUIRED)

if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11")
add_definitions("-DENABLE_BF16")
message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag")
endif()

set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)

option(BUILD_MULTI_GPU "Build multi-gpu support" ON)
option(BUILD_PY_FFI "Build python ffi" ON)
option(BUILD_TEST "Build tests" OFF)
option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF)
option(BUILD_FAST_MATH "Build in fast math mode" ON)

include(FetchContent)

if (BUILD_TEST)
FetchContent_Declare(
repo-cutlass
Expand All @@ -45,6 +51,14 @@ if (BUILD_TEST)

set(CUTLASS_HEADER_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass/include)
set(CUTLASS_EXTENSIONS_DIR ${PROJECT_SOURCE_DIR}/src/turbomind/cutlass_extensions/include)


FetchContent_Declare(
Catch2
GIT_REPOSITORY https://github.com/catchorg/Catch2.git
GIT_TAG v3.8.0
)
FetchContent_MakeAvailable(Catch2)
endif()

FetchContent_Declare(
Expand All @@ -56,10 +70,6 @@ set(YAML_BUILD_SHARED_LIBS OFF CACHE BOOL "Build static library of yaml-cpp")
FetchContent_MakeAvailable(yaml-cpp)


option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF)

option(BUILD_FAST_MATH "Build in fast math mode" ON)

# the environment variable
# ASAN_OPTIONS=protect_shadow_gap=0,intercept_tls_get_addr=0
# must be set at runtime
Expand Down Expand Up @@ -112,13 +122,13 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall -ldl") # -Xptxas -v
# TODO: build for sm_72 & sm_87 on aarch64 platform (Jetson devices)
if (NOT CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70-real 75-real)
if (${CUDA_VERSION} VERSION_GREATER_EQUAL "11")
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11")
list(APPEND CMAKE_CUDA_ARCHITECTURES 80-real)
endif ()
if (${CUDA_VERSION} VERSION_GREATER_EQUAL "11.1")
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11.1")
list(APPEND CMAKE_CUDA_ARCHITECTURES 86-real)
endif ()
if (${CUDA_VERSION} VERSION_GREATER_EQUAL "11.8")
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11.8")
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real 90-real)
endif ()
if (MSVC)
Expand All @@ -132,19 +142,23 @@ set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0")
# set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall --ptxas-options=-v --resource-usage")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall -DCUDA_PTX_FP8_F2FP_ENABLED")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall")

set(CMAKE_CXX_STANDARD "${CXX_STD}")
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD} -DCUDA_PTX_FP8_F2FP_ENABLED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}")

string(REPLACE "-O2" "" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
string(REPLACE "-O2" "" CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE}")
string(REPLACE "-O2" "" CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}")
string(REPLACE "-O2" "" CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO}")

set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -O3")
# set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -O3")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -O3")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -O3")

if(BUILD_FAST_MATH)
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math")
Expand Down Expand Up @@ -207,13 +221,11 @@ link_directories(
${COMMON_LIB_DIRS}
)

# add_subdirectory(3rdparty)
add_subdirectory(src)
# add_subdirectory(examples)

if(BUILD_TEST)
add_subdirectory(tests/csrc)
endif()
# if(BUILD_TEST)
# add_subdirectory(tests/csrc)
# endif()

# install python api
if (BUILD_PY_FFI)
Expand Down
3 changes: 1 addition & 2 deletions builder/windows/generate.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@ cmake .. -A x64 -T "v142,cuda=$env:CUDA_PATH" `
-DCMAKE_INSTALL_PREFIX=install `
-DBUILD_PY_FFI=ON `
-DBUILD_MULTI_GPU=OFF `
-DCMAKE_CUDA_FLAGS="-lineinfo" `
-DUSE_NVTX=ON `
-DUSE_NVTX=OFF `
-DBUILD_TEST="$env:BUILD_TEST"
6 changes: 5 additions & 1 deletion builder/windows/setup_cuda.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ if ($CUDA_VERSION_FULL -eq "12.1.0") {
$downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_531.14_windows.exe"
} elseif ($CUDA_VERSION_FULL -eq "11.8.0") {
$downloadUrl = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe"
} elseif ($CUDA_VERSION_FULL -eq "12.5.0") {
$downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.5.0/local_installers/cuda_12.5.0_555.85_windows.exe"
} else {
Write-Output "Unsupported CUDA version specified"
exit 1
Expand Down Expand Up @@ -84,6 +86,8 @@ $msBuildExtensions = (Get-ChildItem "$src\visual_studio_integration\CUDAVisualS
}
}

$CUDA_FLAGS="-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH=1"

# Add to Github env
Write-Output "Setting environment variables for GitHub Actions..."

Expand All @@ -97,7 +101,7 @@ Write-Output "CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)=$dst" >> $env:GITHUB_ENV
Write-Output "CUDA_PATH_VX_Y=CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)" >> $env:GITHUB_ENV
Write-Output "CudaToolkitDir=$dst" >> $env:GITHUB_ENV
Write-Output "CMAKE_CUDA_COMPILER=$dst\bin\nvcc.exe" >> $env:GITHUB_ENV
Write-Output "NVCC_APPEND_FLAGS=-allow-unsupported-compiler" >> $env:GITHUB_ENV
Write-Output "NVCC_APPEND_FLAGS=$CUDA_FLAGS" >> $env:GITHUB_ENV

Write-Output "CUDA_VERSION=$CUDA_VERSION_FULL" >> $env:GITHUB_ENV
Write-Output "Setup completed."
3 changes: 2 additions & 1 deletion lmdeploy/turbomind/deploy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def pad_weight(tensor: torch.Tensor, tp: int):
if output_weight is not None:
tp = self.model.attn_tp_size
output_weight = pad_weight(output_weight, tp=tp)
self.model.save_split(output_weight, 'output.weight', split_dim=0, split_num=tp)
# transpose
self.model.save_split(output_weight.t(), 'output.weight', split_dim=1, split_num=tp)


class Transformer:
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _from_hf(self, model_source: ModelSource, model_path: str, engine_config: Tu

model_comm = _tm.AbstractTransformerModel.create_llama_model(model_dir='',
config=yaml.safe_dump(self.config_dict),
data_type=self.config.model_config.weight_type)
weight_type=self.config.model_config.weight_type)

# create empty weight
self._create_weight(model_comm)
Expand Down Expand Up @@ -275,7 +275,7 @@ def _from_workspace(self, model_path: str, engine_config: TurbomindEngineConfig)
weight_dir = osp.join(model_path, 'triton_models', 'weights')
model_comm = _tm.AbstractTransformerModel.create_llama_model(model_dir=weight_dir,
config=yaml.safe_dump(self.config_dict),
data_type=self.config.weight_type)
weight_type=self.config.weight_type)

# create weight and load params
self._create_weight(model_comm)
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

add_subdirectory(utils)
add_subdirectory(core)
add_subdirectory(kernels)
add_subdirectory(layers)
add_subdirectory(comm)
Expand Down
5 changes: 3 additions & 2 deletions src/turbomind/comm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
cmake_minimum_required(VERSION 3.8)

add_library(host_comm STATIC host_comm.cc thread_comm.cc)
target_link_libraries(host_comm PRIVATE core logger)
set_property(TARGET host_comm PROPERTY POSITION_INDEPENDENT_CODE ON)

add_library(device_comm STATIC device_comm.cc)
target_link_libraries(device_comm PRIVATE logger)
target_link_libraries(device_comm PRIVATE core logger)
set_property(TARGET device_comm PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET device_comm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

Expand All @@ -21,7 +22,7 @@ if (BUILD_MULTI_GPU)

if (BUILD_TEST)
add_executable(test_comm test_comm.cu)
target_link_libraries(test_comm PRIVATE device_comm host_comm pthread nvtx_utils)
target_link_libraries(test_comm PRIVATE device_comm host_comm core pthread nvtx_utils)
target_compile_options(test_comm PRIVATE -O3 -march=native -mtune=native)
endif ()
endif ()
2 changes: 2 additions & 0 deletions src/turbomind/comm/cuda_ipc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ add_library(cuda_ipc_comm STATIC
target_link_libraries(cuda_ipc_comm PRIVATE
rms_norm
host_comm
core
cuda_utils
CUDA::cuda_driver
logger)

Expand Down
9 changes: 4 additions & 5 deletions src/turbomind/comm/cuda_ipc/allgather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "src/turbomind/comm/cuda_ipc/device_semaphore.h"

#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind::comm {
Expand Down Expand Up @@ -51,7 +50,7 @@ __global__ void __launch_bounds__(1024, 1) Allgather_Simple_Pull(T*
void CudaIpcCommImpl::AllGather(
const void* sendbuff, void* recvbuff, size_t sendcount, DataType type, int group, cudaStream_t stream)
{
const size_t bytesize = get_elem_size(type) * sendcount;
const size_t bytesize = turbomind::byte_size(type) * sendcount;

const int peers = this->n_ranks(group) - 1;
const int rank = this->rank(group);
Expand Down Expand Up @@ -165,9 +164,9 @@ void CudaIpcCommImpl::AllGather2D(const void* sendbuff,
int group,
cudaStream_t stream)
{
const size_t byte_width = get_elem_size(type) * width;
const size_t byte_pitch = get_elem_size(type) * pitch;
const size_t byte_stride = get_elem_size(type) * stride;
const size_t byte_width = byte_size(type, width);
const size_t byte_pitch = byte_size(type, pitch);
const size_t byte_stride = byte_size(type, stride);

void* base{};
size_t offset{};
Expand Down
11 changes: 2 additions & 9 deletions src/turbomind/comm/cuda_ipc/allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
#include "src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h"
#include "src/turbomind/comm/cuda_ipc/device_semaphore.h"

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/meta.h"
#include "src/turbomind/utils/Tensor.h"

#include "src/turbomind/utils/cuda_utils.h"

Expand Down Expand Up @@ -423,14 +423,7 @@ void CudaIpcCommImpl::AllReduceSum(
}
};

switch (type) {
case DataType::TYPE_FP16:
return invoke(half{});
case DataType::TYPE_BF16:
return invoke(nv_bfloat16{});
default:
throw std::runtime_error("not implemented");
}
TM_DISPATCH_PRIMARY_DTYPES(type, invoke);
}

} // namespace turbomind::comm
3 changes: 1 addition & 2 deletions src/turbomind/comm/cuda_ipc/cuda_ipc_comm.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include <memory>
#include <mutex>
#include <type_traits>
#include <numeric>
#include <vector>

#include <cuda.h>
Expand Down
1 change: 0 additions & 1 deletion src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

#include "src/turbomind/kernels/core/array.h"

#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind::comm {
Expand Down
17 changes: 4 additions & 13 deletions src/turbomind/comm/cuda_ipc/fused_allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
#include "src/turbomind/comm/cuda_ipc/device_semaphore.h"
#include "src/turbomind/comm/cuda_ipc/group_sum.h"

#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/meta.h"

#include "src/turbomind/kernels/norm/rms_norm.h"

#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"

namespace turbomind::comm {
Expand Down Expand Up @@ -424,7 +424,7 @@ void CudaIpcCommImpl::AllreduceResidualBiasRMSnorm(void* hidden,
cudaStream_t stream)
{

const size_t elemsize = get_elem_size(dtype);
const size_t elemsize = byte_size(dtype);
const size_t bytesize = elemsize * token_num * dim;

const int n_ranks = this->n_ranks(group);
Expand Down Expand Up @@ -504,19 +504,10 @@ void CudaIpcCommImpl::AllreduceResidualBiasRMSnorm(void* hidden,
return false; // > 1024 vdim
};

auto dispatch = [&] {
switch (dtype) {
case DataType::TYPE_FP16:
return dispatch_D(half{});
case DataType::TYPE_BF16:
return dispatch_D(nv_bfloat16{});
default:
return false;
}
};
auto dispatch = [&]() -> bool { TM_DISPATCH_PRIMARY_DTYPES_RET(dtype, dispatch_D); };

if (bytesize > (1 << 19)) {
if (auto success = dispatch()) {
if (dispatch()) {
return;
}
}
Expand Down
14 changes: 4 additions & 10 deletions src/turbomind/comm/cuda_ipc/fused_allreduce_ex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "src/turbomind/comm/cuda_ipc/group_sum.h"

#include "src/turbomind/comm/cuda_ipc/mscclpp.h"
#include "src/turbomind/core/data_type.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/core/meta.h"
Expand Down Expand Up @@ -279,18 +280,11 @@ void CudaIpcCommImpl::AllreduceResidualBiasRMSnormEx(void* hidden,
return false; // > 1024 vdim
};

auto dispatch = [&] {
switch (dtype) {
case DataType::TYPE_FP16:
return dispatch_D(half{});
case DataType::TYPE_BF16:
return dispatch_D(nv_bfloat16{});
default:
return false;
}
auto dispatch = [&]() -> bool { //
TM_DISPATCH_PRIMARY_DTYPES_RET(dtype, dispatch_D);
};

FT_CHECK(dispatch());
TM_CHECK(dispatch());
}

} // namespace turbomind::comm
2 changes: 1 addition & 1 deletion src/turbomind/comm/device_comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ DeviceComm CreateDeviceCommunicator(const std::string& backend, int n_ranks, int
}
#endif

FT_CHECK_WITH_INFO(0, fmtstr("Unknown communication backend: %s", backend.c_str()));
TM_CHECK(0) << "Unknown communication backend: " << backend;
return {};
}

Expand Down
Loading