-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimize performance by fuse adding high precision tensor to fp8 tens…
…or (#132) **Description** Optimize performance by fuse add high precision tensor to fp8 tensor **Major Revision** - Add an extension msamp_arithmetic - Add fuse kernel for adding high precision tensor to FP8 in extension - Move common header files to msamp/common/include - Add UT - Apply it megatron-FP8DistributedDataParallel
- Loading branch information
Showing
13 changed files
with
744 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
// The file is from https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/concurrency.hpp. | ||
|
||
#ifndef MSAMP_CONCURRENCY_H_ | ||
#define MSAMP_CONCURRENCY_H_ | ||
|
||
#include "poll.h" | ||
|
||
namespace msamp { | ||
|
||
/// A device-wide barrier. | ||
struct DeviceSyncer { | ||
public: | ||
/// Construct a new DeviceSyncer object. | ||
DeviceSyncer() = default; | ||
|
||
/// Destroy the DeviceSyncer object. | ||
~DeviceSyncer() = default; | ||
|
||
#ifdef __CUDACC__ | ||
/// Synchronize all threads inside a kernel. Guarantee that all previous work of all threads in cooperating blocks is | ||
/// finished. | ||
/// @param blockNum The number of blocks that will synchronize. | ||
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative. | ||
__forceinline__ __device__ void sync(int blockNum, int64_t maxSpinCount = 100000000) { | ||
unsigned int maxOldCnt = blockNum - 1; | ||
__syncthreads(); | ||
if (blockNum == 1) return; | ||
if (threadIdx.x == 0) { | ||
// Need a `__threadfence()` before to flip `flag`. | ||
__threadfence(); | ||
int tmp = isIncFlag_ ^ 1; | ||
if (tmp) { | ||
if (atomicInc(&count_, maxOldCnt) == maxOldCnt) { | ||
flag_ = 1; | ||
} | ||
POLL_MAYBE_JAILBREAK(!flag_, maxSpinCount); | ||
} else { | ||
if (atomicInc(&count_, maxOldCnt) == maxOldCnt) { | ||
flag_ = 0; | ||
} | ||
POLL_MAYBE_JAILBREAK(flag_, maxSpinCount); | ||
} | ||
isIncFlag_ = tmp; | ||
} | ||
// We need sync here because only a single thread is checking whether | ||
// the flag is flipped. | ||
__syncthreads(); | ||
} | ||
#endif | ||
|
||
private: | ||
/// The flag to indicate whether the barrier is reached by the latest thread. | ||
volatile int flag_; | ||
/// The counter of synchronized blocks. | ||
unsigned int count_; | ||
/// The flag to indicate whether to increase or decrease @ref flag_. | ||
int isIncFlag_; | ||
}; | ||
|
||
} // namespace msamp | ||
|
||
#endif // MSAMP_CONCURRENCY_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
|
||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
// The file is from https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/poll.hpp. | ||
|
||
#ifndef MSAMP_POLL_H_ | ||
#define MSAMP_POLL_H_ | ||
|
||
#include <cstdint> | ||
|
||
extern "C" __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line, | ||
const char *__function) __THROW; | ||
|
||
// If a spin is stuck, escape from it and set status to 1. | ||
#define POLL_MAYBE_JAILBREAK_ESCAPE(__cond, __max_spin_cnt, __status) \ | ||
do { \ | ||
int64_t __spin_cnt = 0; \ | ||
__status = 0; \ | ||
while (__cond) { \ | ||
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \ | ||
__status = 1; \ | ||
break; \ | ||
} \ | ||
} \ | ||
} while (0); | ||
|
||
// If a spin is stuck, print a warning and keep spinning. | ||
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \ | ||
do { \ | ||
int64_t __spin_cnt = 0; \ | ||
while (__cond) { \ | ||
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \ | ||
__assert_fail(#__cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ | ||
} \ | ||
} \ | ||
} while (0); | ||
|
||
// the as POLL_MAYBE_JAILBREAK except that __cond1 is checked before __cond2 | ||
// this is specially useful when __cond1 is faster to check | ||
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \ | ||
do { \ | ||
int64_t __spin_cnt = 0; \ | ||
while (true) { \ | ||
if (!(__cond1)) { \ | ||
break; \ | ||
} else if (!(__cond2)) { \ | ||
break; \ | ||
} \ | ||
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \ | ||
__assert_fail(#__cond1 #__cond2, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ | ||
} \ | ||
} \ | ||
} while (0); | ||
|
||
|
||
#endif // MSAMP_POLL_H_ |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
"""Exposes the interface of MS-AMP arithmetic module.""" | ||
|
||
from msamp.operators.arithmetic.arithmetic import Arithmetic | ||
|
||
__all__ = ['Arithmetic'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
|
||
#include <torch/extension.h> | ||
#include <c10/cuda/CUDAStream.h> | ||
|
||
#include "../../common/include/common.h" | ||
#include "../../common/include/utils.cuh" | ||
#include "../../common/include/concurrency.h" | ||
#include "vectorized_pointwise.h" | ||
|
||
namespace msamp { | ||
void add_to_fp8(at::Tensor fp8_tensor, | ||
at::Tensor scale, | ||
at::Tensor scale_inv, | ||
at::Tensor amax, | ||
const at::Tensor& other, | ||
bool is_e4m3) { | ||
const size_t N = other.numel(); | ||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
TORCH_DTYPE_SWITCH_INPUT(other.scalar_type(), IType, | ||
SELECT_FP8_TYPE(is_e4m3, OType, | ||
|
||
constexpr int nvec = 32 / sizeof(IType); | ||
|
||
VectorizedAddToFp8KernelLauncher<nvec>( | ||
reinterpret_cast<IType*>(other.data_ptr()), | ||
reinterpret_cast<OType*>(fp8_tensor.data_ptr()), | ||
reinterpret_cast<fp32*>(scale.data_ptr()), | ||
reinterpret_cast<fp32*>(scale_inv.data_ptr()), | ||
reinterpret_cast<fp32*>(amax.data_ptr()), | ||
N, | ||
stream | ||
); | ||
); | ||
); | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("add_to_fp8", &add_to_fp8, "Add to fp8"); | ||
} | ||
|
||
} // namespace msamp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
"""FP8 Arithmetic module.""" | ||
|
||
import torch | ||
|
||
import msamp_arithmetic | ||
from msamp.common.dtype import Dtypes | ||
|
||
|
||
class Arithmetic: | ||
"""Arithmetic operator for FP8 tensor.""" | ||
@staticmethod | ||
def add_to_fp8(fp8_tensor, meta, other): | ||
"""Add high presicon tensor to fp8_tensor in-place. | ||
Args: | ||
fp8_tensor (torch.Tensor): fp8 tensor to add to. | ||
meta (ScalingTensorMeta): meta data of fp8_tensor. | ||
other (torch.Tensor): high precision tensor to add. | ||
""" | ||
if not (fp8_tensor.is_cuda and fp8_tensor.is_contiguous): | ||
raise ValueError('The fp8 tensor is not in cuda memory or contiguous.') | ||
if not (other.is_cuda and other.is_contiguous): | ||
raise ValueError('The other tensor is not in cuda memory or contiguous.') | ||
if not (fp8_tensor.dtype == torch.uint8 or fp8_tensor.dtype == torch.int8): | ||
raise ValueError('The fp8 tensor is not in uint8 or int8.') | ||
|
||
if not (meta.qtype == Dtypes.kfloat8_e4m3 or meta.qtype == Dtypes.kfloat8_e5m2): | ||
raise ValueError('The fp8 tensor is not in e4m3 or e5m2 format.') | ||
|
||
is_e4m3 = meta.qtype == Dtypes.kfloat8_e4m3 | ||
|
||
msamp_arithmetic.add_to_fp8(fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], other, is_e4m3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
"""The setuptools based setup module.""" | ||
|
||
from setuptools import setup | ||
from torch.utils import cpp_extension | ||
|
||
ext_t = cpp_extension.CUDAExtension | ||
ext_fnames = ['arithmetic.cu'] | ||
define_macros = [] | ||
nvcc_flags = [ | ||
'-O3', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', '--expt-relaxed-constexpr', | ||
'--expt-extended-lambda', '--use_fast_math' | ||
] | ||
|
||
extra_compile_args = dict(cxx=['-fopenmp', '-O3'], nvcc=nvcc_flags) | ||
|
||
define_macros.append(('WITH_CUDA', None)) | ||
|
||
setup( | ||
name='msamp_arithmetic', | ||
version='0.0.1', | ||
ext_modules=[ | ||
ext_t('msamp_arithmetic', ext_fnames, define_macros=define_macros, extra_compile_args=extra_compile_args) | ||
], | ||
cmdclass={'build_ext': cpp_extension.BuildExtension} | ||
) |
Oops, something went wrong.