Skip to content

Commit

Permalink
Merge branch 'main' into cast_to_scaling_fp32_or_scaling_bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
wkcn authored Nov 28, 2023
2 parents c1aca79 + 4480ffa commit 6d57944
Show file tree
Hide file tree
Showing 13 changed files with 744 additions and 7 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ lint: cpplint mdlint

postinstall:
cd msamp/operators/dist_op && bash build.sh && cd -
cd msamp/operators/arithmetic && pip install -v . && cd -
cd msamp/optim && pip install -v . && cd -
60 changes: 59 additions & 1 deletion msamp/optim/common.h → msamp/common/include/common.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#ifndef MSAMP_COMMON_H_
#define MSAMP_COMMON_H_

#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_bf16.h>
Expand All @@ -20,6 +23,11 @@ using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;

template <typename T>
constexpr T DIVUP(const T &x, const T &y) {
return (((x) + ((y)-1)) / (y));
}

#define TORCH_DTYPE_SWITCH(dtype, type, ...) \
switch (dtype) { \
case torch::kUInt8: { \
Expand All @@ -46,6 +54,36 @@ using fp8e5m2 = __nv_fp8_e5m2;
throw "Unexcepted data type"; \
}

#define SELECT_FP8_TYPE(is_e4m3, type, ...) \
if (is_e4m3){ \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} \
else { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
}


#define TORCH_DTYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \
case torch::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case torch::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case torch::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
default: \
throw "Unexcepted data type"; \
}


const int HIP_MAX_GRID_NUM = 65535;
const int HIP_MAX_NUM_THREADS = 512;

Expand All @@ -68,4 +106,24 @@ template <> __host__ __device__ bf16 cast_dtype(const float value) { return __fl

template <> __host__ __device__ float cast_dtype(const fp16 value) { return __half2float(value); }

template <> __host__ __device__ float cast_dtype(const bf16 value) { return __bfloat162float(value); }
template <> __host__ __device__ float cast_dtype(const bf16 value) { return __bfloat162float(value); }

template <typename T>
struct is_fp8 : std::false_type {};

template <>
struct is_fp8<fp8e4m3> : std::true_type {};

template <>
struct is_fp8<fp8e5m2> : std::true_type {};

template <typename T>
struct is_half : std::false_type {};

template <>
struct is_half<fp16> : std::true_type {};

template <>
struct is_half<bf16> : std::true_type {};

#endif // MSAMP_COMMON_H_
64 changes: 64 additions & 0 deletions msamp/common/include/concurrency.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// The file is from https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/concurrency.hpp.

#ifndef MSAMP_CONCURRENCY_H_
#define MSAMP_CONCURRENCY_H_

#include "poll.h"

namespace msamp {

/// A device-wide barrier.
struct DeviceSyncer {
public:
/// Construct a new DeviceSyncer object.
DeviceSyncer() = default;

/// Destroy the DeviceSyncer object.
~DeviceSyncer() = default;

#ifdef __CUDACC__
/// Synchronize all threads inside a kernel. Guarantee that all previous work of all threads in cooperating blocks is
/// finished.
/// @param blockNum The number of blocks that will synchronize.
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
__forceinline__ __device__ void sync(int blockNum, int64_t maxSpinCount = 100000000) {
unsigned int maxOldCnt = blockNum - 1;
__syncthreads();
if (blockNum == 1) return;
if (threadIdx.x == 0) {
// Need a `__threadfence()` before to flip `flag`.
__threadfence();
int tmp = isIncFlag_ ^ 1;
if (tmp) {
if (atomicInc(&count_, maxOldCnt) == maxOldCnt) {
flag_ = 1;
}
POLL_MAYBE_JAILBREAK(!flag_, maxSpinCount);
} else {
if (atomicInc(&count_, maxOldCnt) == maxOldCnt) {
flag_ = 0;
}
POLL_MAYBE_JAILBREAK(flag_, maxSpinCount);
}
isIncFlag_ = tmp;
}
// We need sync here because only a single thread is checking whether
// the flag is flipped.
__syncthreads();
}
#endif

private:
/// The flag to indicate whether the barrier is reached by the latest thread.
volatile int flag_;
/// The counter of synchronized blocks.
unsigned int count_;
/// The flag to indicate whether to increase or decrease @ref flag_.
int isIncFlag_;
};

} // namespace msamp

#endif // MSAMP_CONCURRENCY_H_
56 changes: 56 additions & 0 deletions msamp/common/include/poll.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// The file is from https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/poll.hpp.

#ifndef MSAMP_POLL_H_
#define MSAMP_POLL_H_

#include <cstdint>

extern "C" __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line,
const char *__function) __THROW;

// If a spin is stuck, escape from it and set status to 1.
#define POLL_MAYBE_JAILBREAK_ESCAPE(__cond, __max_spin_cnt, __status) \
do { \
int64_t __spin_cnt = 0; \
__status = 0; \
while (__cond) { \
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \
__status = 1; \
break; \
} \
} \
} while (0);

// If a spin is stuck, print a warning and keep spinning.
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
do { \
int64_t __spin_cnt = 0; \
while (__cond) { \
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \
__assert_fail(#__cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
} \
} \
} while (0);

// the as POLL_MAYBE_JAILBREAK except that __cond1 is checked before __cond2
// this is specially useful when __cond1 is faster to check
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \
do { \
int64_t __spin_cnt = 0; \
while (true) { \
if (!(__cond1)) { \
break; \
} else if (!(__cond2)) { \
break; \
} \
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \
__assert_fail(#__cond1 #__cond2, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
} \
} \
} while (0);


#endif // MSAMP_POLL_H_
File renamed without changes.
6 changes: 2 additions & 4 deletions msamp/megatron/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from msamp.common.dtype import Dtypes
from msamp.common.tensor import ScalingMeta, ScalingTensor
from msamp.operators.arithmetic import Arithmetic


class FP8DistributedDataParallel(DistributedDataParallelBase):
Expand Down Expand Up @@ -177,10 +178,7 @@ def _fp8_make_param_hook(self, param):
def param_hook(*unused):
# Add the gradient to the buffer.
if param.grad is not None:
param.main_grad.copy_(
(param.main_grad.to(param.grad.dtype) +
param.grad).cast(self.wgrad_qtype, meta=param.main_grad.meta)
)
Arithmetic.add_to_fp8(param.main_grad.value, param.main_grad.meta, param.grad)
# Now we can deallocate grad memory.
param.grad = None

Expand Down
8 changes: 8 additions & 0 deletions msamp/operators/arithmetic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Exposes the interface of MS-AMP arithmetic module."""

from msamp.operators.arithmetic.arithmetic import Arithmetic

__all__ = ['Arithmetic']
43 changes: 43 additions & 0 deletions msamp/operators/arithmetic/arithmetic.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>

#include "../../common/include/common.h"
#include "../../common/include/utils.cuh"
#include "../../common/include/concurrency.h"
#include "vectorized_pointwise.h"

namespace msamp {
void add_to_fp8(at::Tensor fp8_tensor,
at::Tensor scale,
at::Tensor scale_inv,
at::Tensor amax,
const at::Tensor& other,
bool is_e4m3) {
const size_t N = other.numel();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_DTYPE_SWITCH_INPUT(other.scalar_type(), IType,
SELECT_FP8_TYPE(is_e4m3, OType,

constexpr int nvec = 32 / sizeof(IType);

VectorizedAddToFp8KernelLauncher<nvec>(
reinterpret_cast<IType*>(other.data_ptr()),
reinterpret_cast<OType*>(fp8_tensor.data_ptr()),
reinterpret_cast<fp32*>(scale.data_ptr()),
reinterpret_cast<fp32*>(scale_inv.data_ptr()),
reinterpret_cast<fp32*>(amax.data_ptr()),
N,
stream
);
);
);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("add_to_fp8", &add_to_fp8, "Add to fp8");
}

} // namespace msamp
35 changes: 35 additions & 0 deletions msamp/operators/arithmetic/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""FP8 Arithmetic module."""

import torch

import msamp_arithmetic
from msamp.common.dtype import Dtypes


class Arithmetic:
"""Arithmetic operator for FP8 tensor."""
@staticmethod
def add_to_fp8(fp8_tensor, meta, other):
"""Add high presicon tensor to fp8_tensor in-place.
Args:
fp8_tensor (torch.Tensor): fp8 tensor to add to.
meta (ScalingTensorMeta): meta data of fp8_tensor.
other (torch.Tensor): high precision tensor to add.
"""
if not (fp8_tensor.is_cuda and fp8_tensor.is_contiguous):
raise ValueError('The fp8 tensor is not in cuda memory or contiguous.')
if not (other.is_cuda and other.is_contiguous):
raise ValueError('The other tensor is not in cuda memory or contiguous.')
if not (fp8_tensor.dtype == torch.uint8 or fp8_tensor.dtype == torch.int8):
raise ValueError('The fp8 tensor is not in uint8 or int8.')

if not (meta.qtype == Dtypes.kfloat8_e4m3 or meta.qtype == Dtypes.kfloat8_e5m2):
raise ValueError('The fp8 tensor is not in e4m3 or e5m2 format.')

is_e4m3 = meta.qtype == Dtypes.kfloat8_e4m3

msamp_arithmetic.add_to_fp8(fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], other, is_e4m3)
28 changes: 28 additions & 0 deletions msamp/operators/arithmetic/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""The setuptools based setup module."""

from setuptools import setup
from torch.utils import cpp_extension

ext_t = cpp_extension.CUDAExtension
ext_fnames = ['arithmetic.cu']
define_macros = []
nvcc_flags = [
'-O3', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--use_fast_math'
]

extra_compile_args = dict(cxx=['-fopenmp', '-O3'], nvcc=nvcc_flags)

define_macros.append(('WITH_CUDA', None))

setup(
name='msamp_arithmetic',
version='0.0.1',
ext_modules=[
ext_t('msamp_arithmetic', ext_fnames, define_macros=define_macros, extra_compile_args=extra_compile_args)
],
cmdclass={'build_ext': cpp_extension.BuildExtension}
)
Loading

0 comments on commit 6d57944

Please sign in to comment.