diff --git a/Makefile b/Makefile index bd7c5857..a37f66fb 100644 --- a/Makefile +++ b/Makefile @@ -23,4 +23,5 @@ lint: cpplint mdlint postinstall: cd msamp/operators/dist_op && bash build.sh && cd - + cd msamp/operators/arithmetic && pip install -v . && cd - cd msamp/optim && pip install -v . && cd - diff --git a/msamp/optim/common.h b/msamp/common/include/common.h similarity index 55% rename from msamp/optim/common.h rename to msamp/common/include/common.h index e559c437..9b205d69 100644 --- a/msamp/optim/common.h +++ b/msamp/common/include/common.h @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#ifndef MSAMP_COMMON_H_ +#define MSAMP_COMMON_H_ + #include #include #include @@ -20,6 +23,11 @@ using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; +template +constexpr T DIVUP(const T &x, const T &y) { + return (((x) + ((y)-1)) / (y)); +} + #define TORCH_DTYPE_SWITCH(dtype, type, ...) \ switch (dtype) { \ case torch::kUInt8: { \ @@ -46,6 +54,36 @@ using fp8e5m2 = __nv_fp8_e5m2; throw "Unexcepted data type"; \ } +#define SELECT_FP8_TYPE(is_e4m3, type, ...) \ + if (is_e4m3){ \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } \ + else { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } + + +#define TORCH_DTYPE_SWITCH_INPUT(dtype, type, ...) \ + switch (dtype) { \ + case torch::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case torch::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case torch::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + throw "Unexcepted data type"; \ + } + + const int HIP_MAX_GRID_NUM = 65535; const int HIP_MAX_NUM_THREADS = 512; @@ -68,4 +106,24 @@ template <> __host__ __device__ bf16 cast_dtype(const float value) { return __fl template <> __host__ __device__ float cast_dtype(const fp16 value) { return __half2float(value); } -template <> __host__ __device__ float cast_dtype(const bf16 value) { return __bfloat162float(value); } \ No newline at end of file +template <> __host__ __device__ float cast_dtype(const bf16 value) { return __bfloat162float(value); } + +template +struct is_fp8 : std::false_type {}; + +template <> +struct is_fp8 : std::true_type {}; + +template <> +struct is_fp8 : std::true_type {}; + +template +struct is_half : std::false_type {}; + +template <> +struct is_half : std::true_type {}; + +template <> +struct is_half : std::true_type {}; + +#endif // MSAMP_COMMON_H_ \ No newline at end of file diff --git a/msamp/common/include/concurrency.h b/msamp/common/include/concurrency.h new file mode 100644 index 00000000..6075f967 --- /dev/null +++ b/msamp/common/include/concurrency.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// The file is from https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/concurrency.hpp. + +#ifndef MSAMP_CONCURRENCY_H_ +#define MSAMP_CONCURRENCY_H_ + +#include "poll.h" + +namespace msamp { + +/// A device-wide barrier. +struct DeviceSyncer { + public: + /// Construct a new DeviceSyncer object. + DeviceSyncer() = default; + + /// Destroy the DeviceSyncer object. + ~DeviceSyncer() = default; + +#ifdef __CUDACC__ + /// Synchronize all threads inside a kernel. Guarantee that all previous work of all threads in cooperating blocks is + /// finished. + /// @param blockNum The number of blocks that will synchronize. + /// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative. + __forceinline__ __device__ void sync(int blockNum, int64_t maxSpinCount = 100000000) { + unsigned int maxOldCnt = blockNum - 1; + __syncthreads(); + if (blockNum == 1) return; + if (threadIdx.x == 0) { + // Need a `__threadfence()` before to flip `flag`. + __threadfence(); + int tmp = isIncFlag_ ^ 1; + if (tmp) { + if (atomicInc(&count_, maxOldCnt) == maxOldCnt) { + flag_ = 1; + } + POLL_MAYBE_JAILBREAK(!flag_, maxSpinCount); + } else { + if (atomicInc(&count_, maxOldCnt) == maxOldCnt) { + flag_ = 0; + } + POLL_MAYBE_JAILBREAK(flag_, maxSpinCount); + } + isIncFlag_ = tmp; + } + // We need sync here because only a single thread is checking whether + // the flag is flipped. + __syncthreads(); + } +#endif + + private: + /// The flag to indicate whether the barrier is reached by the latest thread. + volatile int flag_; + /// The counter of synchronized blocks. + unsigned int count_; + /// The flag to indicate whether to increase or decrease @ref flag_. + int isIncFlag_; +}; + +} // namespace msamp + +#endif // MSAMP_CONCURRENCY_H_ \ No newline at end of file diff --git a/msamp/common/include/poll.h b/msamp/common/include/poll.h new file mode 100644 index 00000000..0ea53328 --- /dev/null +++ b/msamp/common/include/poll.h @@ -0,0 +1,56 @@ + +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// The file is from https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/poll.hpp. + +#ifndef MSAMP_POLL_H_ +#define MSAMP_POLL_H_ + +#include + +extern "C" __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line, + const char *__function) __THROW; + +// If a spin is stuck, escape from it and set status to 1. +#define POLL_MAYBE_JAILBREAK_ESCAPE(__cond, __max_spin_cnt, __status) \ + do { \ + int64_t __spin_cnt = 0; \ + __status = 0; \ + while (__cond) { \ + if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \ + __status = 1; \ + break; \ + } \ + } \ + } while (0); + +// If a spin is stuck, print a warning and keep spinning. +#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \ + do { \ + int64_t __spin_cnt = 0; \ + while (__cond) { \ + if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \ + __assert_fail(#__cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ + } \ + } \ + } while (0); + +// the as POLL_MAYBE_JAILBREAK except that __cond1 is checked before __cond2 +// this is specially useful when __cond1 is faster to check +#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \ + do { \ + int64_t __spin_cnt = 0; \ + while (true) { \ + if (!(__cond1)) { \ + break; \ + } else if (!(__cond2)) { \ + break; \ + } \ + if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \ + __assert_fail(#__cond1 #__cond2, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ + } \ + } \ + } while (0); + + +#endif // MSAMP_POLL_H_ \ No newline at end of file diff --git a/msamp/optim/utils.cuh b/msamp/common/include/utils.cuh similarity index 100% rename from msamp/optim/utils.cuh rename to msamp/common/include/utils.cuh diff --git a/msamp/megatron/distributed.py b/msamp/megatron/distributed.py index 379b8741..a2079f99 100644 --- a/msamp/megatron/distributed.py +++ b/msamp/megatron/distributed.py @@ -11,6 +11,7 @@ from msamp.common.dtype import Dtypes from msamp.common.tensor import ScalingMeta, ScalingTensor +from msamp.operators.arithmetic import Arithmetic class FP8DistributedDataParallel(DistributedDataParallelBase): @@ -177,10 +178,7 @@ def _fp8_make_param_hook(self, param): def param_hook(*unused): # Add the gradient to the buffer. if param.grad is not None: - param.main_grad.copy_( - (param.main_grad.to(param.grad.dtype) + - param.grad).cast(self.wgrad_qtype, meta=param.main_grad.meta) - ) + Arithmetic.add_to_fp8(param.main_grad.value, param.main_grad.meta, param.grad) # Now we can deallocate grad memory. param.grad = None diff --git a/msamp/operators/arithmetic/__init__.py b/msamp/operators/arithmetic/__init__.py new file mode 100644 index 00000000..55fc214a --- /dev/null +++ b/msamp/operators/arithmetic/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Exposes the interface of MS-AMP arithmetic module.""" + +from msamp.operators.arithmetic.arithmetic import Arithmetic + +__all__ = ['Arithmetic'] diff --git a/msamp/operators/arithmetic/arithmetic.cu b/msamp/operators/arithmetic/arithmetic.cu new file mode 100644 index 00000000..ba07ab18 --- /dev/null +++ b/msamp/operators/arithmetic/arithmetic.cu @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include +#include + +#include "../../common/include/common.h" +#include "../../common/include/utils.cuh" +#include "../../common/include/concurrency.h" +#include "vectorized_pointwise.h" + +namespace msamp { +void add_to_fp8(at::Tensor fp8_tensor, + at::Tensor scale, + at::Tensor scale_inv, + at::Tensor amax, + const at::Tensor& other, + bool is_e4m3) { + const size_t N = other.numel(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TORCH_DTYPE_SWITCH_INPUT(other.scalar_type(), IType, + SELECT_FP8_TYPE(is_e4m3, OType, + + constexpr int nvec = 32 / sizeof(IType); + + VectorizedAddToFp8KernelLauncher( + reinterpret_cast(other.data_ptr()), + reinterpret_cast(fp8_tensor.data_ptr()), + reinterpret_cast(scale.data_ptr()), + reinterpret_cast(scale_inv.data_ptr()), + reinterpret_cast(amax.data_ptr()), + N, + stream + ); + ); + ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("add_to_fp8", &add_to_fp8, "Add to fp8"); +} + +} // namespace msamp diff --git a/msamp/operators/arithmetic/arithmetic.py b/msamp/operators/arithmetic/arithmetic.py new file mode 100644 index 00000000..1c5ed748 --- /dev/null +++ b/msamp/operators/arithmetic/arithmetic.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""FP8 Arithmetic module.""" + +import torch + +import msamp_arithmetic +from msamp.common.dtype import Dtypes + + +class Arithmetic: + """Arithmetic operator for FP8 tensor.""" + @staticmethod + def add_to_fp8(fp8_tensor, meta, other): + """Add high presicon tensor to fp8_tensor in-place. + + Args: + fp8_tensor (torch.Tensor): fp8 tensor to add to. + meta (ScalingTensorMeta): meta data of fp8_tensor. + other (torch.Tensor): high precision tensor to add. + """ + if not (fp8_tensor.is_cuda and fp8_tensor.is_contiguous): + raise ValueError('The fp8 tensor is not in cuda memory or contiguous.') + if not (other.is_cuda and other.is_contiguous): + raise ValueError('The other tensor is not in cuda memory or contiguous.') + if not (fp8_tensor.dtype == torch.uint8 or fp8_tensor.dtype == torch.int8): + raise ValueError('The fp8 tensor is not in uint8 or int8.') + + if not (meta.qtype == Dtypes.kfloat8_e4m3 or meta.qtype == Dtypes.kfloat8_e5m2): + raise ValueError('The fp8 tensor is not in e4m3 or e5m2 format.') + + is_e4m3 = meta.qtype == Dtypes.kfloat8_e4m3 + + msamp_arithmetic.add_to_fp8(fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], other, is_e4m3) diff --git a/msamp/operators/arithmetic/setup.py b/msamp/operators/arithmetic/setup.py new file mode 100644 index 00000000..fc5ad369 --- /dev/null +++ b/msamp/operators/arithmetic/setup.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""The setuptools based setup module.""" + +from setuptools import setup +from torch.utils import cpp_extension + +ext_t = cpp_extension.CUDAExtension +ext_fnames = ['arithmetic.cu'] +define_macros = [] +nvcc_flags = [ + '-O3', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', '--expt-relaxed-constexpr', + '--expt-extended-lambda', '--use_fast_math' +] + +extra_compile_args = dict(cxx=['-fopenmp', '-O3'], nvcc=nvcc_flags) + +define_macros.append(('WITH_CUDA', None)) + +setup( + name='msamp_arithmetic', + version='0.0.1', + ext_modules=[ + ext_t('msamp_arithmetic', ext_fnames, define_macros=define_macros, extra_compile_args=extra_compile_args) + ], + cmdclass={'build_ext': cpp_extension.BuildExtension} +) diff --git a/msamp/operators/arithmetic/vectorized_pointwise.h b/msamp/operators/arithmetic/vectorized_pointwise.h new file mode 100644 index 00000000..bd765637 --- /dev/null +++ b/msamp/operators/arithmetic/vectorized_pointwise.h @@ -0,0 +1,404 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// The file is adapted from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/util/vectorized_pointwise.h. + +#ifndef MSAMP_VECTORIZED_POINTWISE_H +#define MSAMP_VECTORIZED_POINTWISE_H + +#include +#include +#include + +#include "../../common/include/common.h" +#include "../../common/include/utils.cuh" +#include "../../common/include/concurrency.h" + +namespace msamp { +/* \brief Helper class that enables storing multiple values of type DType + as 1 value of type LType. +*/ +template +class VectorizedStorage { + public: + using LType = typename transformer_engine::BytesToType::Type; + constexpr static int nvec = n; + union vectorized_storage { + LType aligned; + DType separate[nvec]; // NOLINT(*) + + inline __device__ vectorized_storage() {} + inline __device__ ~vectorized_storage() {} + } scratch_; + + inline __device__ VectorizedStorage() {} + inline __device__ VectorizedStorage(const VectorizedStorage& y2) { + scratch_.aligned = y2.scratch_.aligned; + } + inline __device__ VectorizedStorage(const LType &y2) { + scratch_.aligned = y2; + } + inline __device__ VectorizedStorage& operator+=( + const VectorizedStorage& rhs) { + #pragma unroll + for (int i = 0; i < nvec; ++i) { + scratch_.separate[i] = add_elem(scratch_.separate[i], rhs.scratch_.separate[i]); + } + return *this; + } + inline __device__ ~VectorizedStorage() {} +}; + +// Returns const LType is DType is const +template +struct select_const { + using type = LType; +}; + +template +struct select_const { + using type = const LType; +}; + + +/* \brief Helper class that enables accessing multiple values of type DType + as 1 value of type LType. Additional aligned template argument + allows performance optimizations if the pointer and the size of + the allocation is aligned to sizeof(LType) / sizeof(DType) elements. +*/ +template +class VectorizedAccessor { + public: + using StorageType = VectorizedStorage::type, + nvec>; + using LType = typename select_const::type; + StorageType storage_; + + LType* aligned_ptr_; + DType* unaligned_ptr_; + int alignment_; + size_t n_elems_; + + inline __device__ VectorizedAccessor(DType* const ptr, const size_t size) { + unaligned_ptr_ = ptr; + if (aligned) { + alignment_ = 0; + aligned_ptr_ = reinterpret_cast(ptr); + n_elems_ = (size + nvec - 1) / nvec; + } else { + size_t ptr_as_number = reinterpret_cast(ptr); + alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType); + aligned_ptr_ = reinterpret_cast(ptr - alignment_); + n_elems_ = (size + alignment_ + nvec - 1) / nvec; + } + } + + /* \brief Alignment of the input pointer in elements. */ + inline __device__ int alignment() const { + return alignment_; + } + + /* \brief Access to separate elements. */ + inline __device__ DType* separate() { + return storage_.scratch_.separate; + } + + /* \brief Number of aligned elements that span the entire input tensor. */ + inline __device__ size_t num_aligned_elements() const { + return n_elems_; + } + + /* \brief Load values from the input. + \param id Aligned index of the element. + \param N size of the tensor. + */ + inline __device__ void load(const size_t id, const size_t N) { + if (aligned) { + storage_.scratch_.aligned = aligned_ptr_[id]; + } else { + if (id > 0 && id < n_elems_ - 1) { + storage_.scratch_.aligned = aligned_ptr_[id]; + } else { +#pragma unroll + for (int j = 0; j < nvec; ++j) { + DType* ptr = reinterpret_cast(&(aligned_ptr_[id])) + j; + if (reinterpret_cast(ptr) >= reinterpret_cast(unaligned_ptr_) && + reinterpret_cast(ptr) < reinterpret_cast(unaligned_ptr_ + N)) { + storage_.scratch_.separate[j] = *ptr; + } else { + storage_.scratch_.separate[j] = DType(); + } + } + } + } + } +}; + +/* \brief Class used for vectorized read-only access. */ +template +class VectorizedLoader : public VectorizedAccessor { + public: + inline __device__ VectorizedLoader(const DType* ptr, const size_t N) : + VectorizedAccessor(ptr, N) { + } +}; + +/* \brief Class used for vectorized writable access. */ +template +class VectorizedStorer : public VectorizedAccessor { + public: + inline __device__ VectorizedStorer(DType* ptr, const size_t N) : + VectorizedAccessor(ptr, N) { + } + + /* \brief Store values to the output. + \param id Aligned index of the element. + \param N size of the tensor. + */ + inline __device__ void store(const size_t id, const size_t N) { + if (aligned) { + this->aligned_ptr_[id] = this->storage_.scratch_.aligned; + } else { + if (id > 0 && id < this->n_elems_ - 1) { + this->aligned_ptr_[id] = this->storage_.scratch_.aligned; + } else { +#pragma unroll + for (int j = 0; j < nvec; ++j) { + DType* ptr = reinterpret_cast(&(this->aligned_ptr_[id])) + j; + if (reinterpret_cast(ptr) >= reinterpret_cast(this->unaligned_ptr_) && + reinterpret_cast(ptr) < reinterpret_cast(this->unaligned_ptr_ + N)) { + *ptr = this->storage_.scratch_.separate[j]; + } + } + } + } + } +}; + + +constexpr int unary_kernel_threads = 512; +constexpr float e4m3_max = 448.0; +constexpr float e5m2_max = 57344.0; + +extern __device__ msamp::DeviceSyncer device_syncer; + +template +__launch_bounds__(unary_kernel_threads) +__global__ void add_to_fp8_kernel(InputType *input, + OutputType *output, + ComputeType *scale, + ComputeType *scale_inv, + ComputeType *amax, + const size_t N, + const size_t num_aligned_elements) { + if (threadIdx.x == 0 && blockIdx.x == 0) { + *amax = 0; + } + device_syncer.sync(gridDim.x, -1); + + // input is high precision, output is fp8 + VectorizedStorer output_storer(output, N); + VectorizedStorer input_storer(input, N); + + ComputeType max = 0; + ComputeType s = 0; + if constexpr (is_fp8::value) { + if (scale_inv != nullptr) s = *scale_inv; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + + const size_t M = num_aligned_elements; + + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < M; + tid += gridDim.x * blockDim.x) { + input_storer.load(tid, N); + output_storer.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const InputType val1 = static_cast(input_storer.separate()[i]); + const ComputeType val2 = static_cast(output_storer.separate()[i]); + + InputType temp = static_cast(val2 * s); + + if constexpr (is_half::value) { + temp = static_cast(__hadd(temp, val1)); + } else { + temp += val1; + } + + if constexpr (is_fp8::value) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(temp), max); + } + } + } + + if constexpr (is_fp8::value) { + /* warp tile amax reduce*/ + max = transformer_engine::reduce_max(max, warp_id); + + if (threadIdx.x == 0 && amax != nullptr) { + static_assert(std::is_same::value); + transformer_engine::atomicMaxFloat(amax, max); + } + } + + device_syncer.sync(gridDim.x, -1); + + /* Compute scaling factor, translate the following python code to c++: + exp = torch.floor(torch.log2(fp_max / amax)) - margin + sf = torch.round(torch.pow(2, torch.abs(exp))) + sf = torch.where(amax > 0.0, sf, scale) + sf = torch.where(torch.isfinite(amax), sf, scale) + sf = torch.where(exp < 0, 1 / sf, sf) + */ + ComputeType amax_value = *amax; + + ComputeType fp_max = std::is_same::value ? e4m3_max : e5m2_max; + + ComputeType exp = floorf(log2f(fp_max/(amax_value))); + ComputeType sf = roundf(powf(2, fabsf(exp))); + + if (amax_value <= 0 || !isfinite(amax_value)) { + sf = *scale; + } + + if (exp < 0) { + sf = 1 / sf; + } + + // using new scaling factor to quantize the input + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < M; + tid += gridDim.x * blockDim.x) { + input_storer.load(tid, N); + output_storer.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const InputType val1 = static_cast(input_storer.separate()[i]); + const ComputeType val2 = static_cast(output_storer.separate()[i]); + + InputType temp1 = static_cast(val2 * s); + + if constexpr (is_half::value) { + temp1 = static_cast(__hadd(temp1, val1)); + } else { + temp1 += val1; + } + ComputeType temp2 = sf * static_cast(temp1); + output_storer.separate()[i] = static_cast(temp2); + } + output_storer.store(tid, N); + } + + if (threadIdx.x == 0 && blockIdx.x == 0) { + *scale = sf; + *scale_inv = 1.0 / sf; + } +} + + +namespace { + +inline size_t get_num_aligned_elements(const void *ptr, const size_t lead_dim, + const int nvec, const int size) { + size_t ptr_as_number = reinterpret_cast(ptr); + int alignment = (ptr_as_number % (nvec * size)) / size; + return DIVUP(lead_dim + alignment, static_cast(nvec)); +} + +enum class Alignment { + SAME_ALIGNED, // All tensors aligned + SAME_UNALIGNED, // All tensors have the same misalignment + DIFFERENT // Tensors have different alignment +}; + +inline int CalcAlignment(const void *ptr, const int size) { + size_t ptr_as_number = reinterpret_cast(ptr); + return ptr_as_number % size; +} + +/* \brief Check alignment of the inputs and outputs when using vectorized accesses. + \param lead_dim Leading dimension of the tensors. + \param other_dim The size of the other dimensions of the tensors. + \param nvec Length of the vector. + \param ptrs Inputs and Outputs to the operator. +*/ +template +Alignment CheckAlignment(const size_t lead_dim, + const int nvec, + const T... ptrs + ) { + std::vector alignments; + alignments.reserve(sizeof...(T)); + + // calculate the alignments of all ptrs and store them into alignments + (..., alignments.push_back(CalcAlignment(ptrs, sizeof(*ptrs) * nvec))); + + bool all_same = std::all_of(alignments.cbegin(), alignments.cend(), + [alignments](int val) {return val == alignments.front();}); + if (!all_same) { + return Alignment::DIFFERENT; + } + + if (alignments.front() == 0 && + lead_dim % nvec == 0) { + // all alignment are 0 + return Alignment::SAME_ALIGNED; + } else { + return Alignment::SAME_UNALIGNED; + } +} + +} + +template +void VectorizedAddToFp8KernelLauncher(InputType *input, + OutputType *output, + fp32 *scale, + fp32 *scale_inv, + fp32 *amax, + const size_t N, + cudaStream_t stream) { + if (N != 0) { + auto align = CheckAlignment(N, nvec, input, output); + + size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, + sizeof(InputType)); + constexpr size_t threads = unary_kernel_threads; + size_t num_blocks = DIVUP(num_aligned_elements, threads); + + // We use DeviceSyncer to sync the amax value between blocks, the block number should be less than + // (SMCount*MaxThreadsPerSM)/unary_kernel_threads, which is 132*2048/512 = 528 on H100 SXM. We set + // max_blocks to half of 528 to make sure it works on other H100 GPUs. + // constexpr size_t max_blocks = 65535; + constexpr size_t max_blocks = 264; + num_blocks = std::min(num_blocks, max_blocks); + + switch (align) { + case Alignment::SAME_ALIGNED: + add_to_fp8_kernel<<>>( + input, output, scale, scale_inv, amax, N, num_aligned_elements); + break; + case Alignment::SAME_UNALIGNED: + add_to_fp8_kernel<<>>( + input, output, scale, scale_inv, amax, N, num_aligned_elements); + break; + case Alignment::DIFFERENT: { + // If the pointers are aligned differently we cannot vectorize + add_to_fp8_kernel<1, true, fp32><<>>( + input, output, scale, scale_inv, amax, N, num_aligned_elements); + break; + } + } + } +} + +} // namespace msamp + +#endif // MSAMP_VECTORIZED_POINTWISE_H \ No newline at end of file diff --git a/msamp/optim/adamw.cu b/msamp/optim/adamw.cu index fb4c7468..a8f6c222 100644 --- a/msamp/optim/adamw.cu +++ b/msamp/optim/adamw.cu @@ -9,8 +9,8 @@ #include #include -#include "common.h" -#include "utils.cuh" +#include "../common/include/common.h" +#include "../common/include/utils.cuh" using namespace std; using namespace torch; diff --git a/tests/operators/test_arithmetic.py b/tests/operators/test_arithmetic.py new file mode 100644 index 00000000..23386991 --- /dev/null +++ b/tests/operators/test_arithmetic.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for arithmetic module.""" + +import itertools +import unittest + +import torch + +from tests.helper import decorator +from msamp.common.dtype import Dtypes +from msamp.operators.arithmetic import Arithmetic + + +class ArithmeticTestCase(unittest.TestCase): + """A class for Arithmetic test cases.""" + def _check_scaling_tensor(self, scaling_tensor1, scaling_tensor2): + self.assertTrue(torch.all(torch.eq(scaling_tensor1.value, scaling_tensor2.value))) + self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.scale, scaling_tensor2.meta.scale))) + self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.scale_inv, scaling_tensor2.meta.scale_inv))) + self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.amax, scaling_tensor2.meta.amax))) + + @decorator.cuda_test + def test_add_to_fp8(self): + """Test the function Arithmetic.add_to_fp8().""" + torch.manual_seed(100) + sizes = list(range(1024, 8193, 1024)) + dtypes = [torch.float16, torch.bfloat16, torch.float32] + qtypes = [Dtypes.kfloat8_e4m3, Dtypes.kfloat8_e5m2] + for i, j, dtype, qtype, in itertools.product(sizes, sizes, dtypes, qtypes): + size = (i, j) + input1 = torch.rand(size, dtype=dtype, device='cuda') + scaling_tensor1 = input1.cast(qtype) + scaling_tensor2 = input1.cast(qtype) + + for i in range(10): + input2 = torch.rand(size, dtype=dtype, device='cuda') + meta = scaling_tensor1.meta + Arithmetic.add_to_fp8(scaling_tensor1.value, meta, input2) + scaling_tensor2.copy_((scaling_tensor2.to(dtype) + input2).cast(qtype, meta=scaling_tensor2.meta)) + self._check_scaling_tensor(scaling_tensor1, scaling_tensor2)