Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A FasterRMSNorm implementation (based on FasterLayerNorm) #1688

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions apex/contrib/csrc/layer_norm/ln.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct ParamsBase {
, gamma(nullptr)
, workspace(nullptr)
, barrier(nullptr)
, is_rms_only(false)
{
}

Expand All @@ -59,6 +60,9 @@ struct ParamsBase {
// Multi-CTA sync barriers in gmem.
int *barrier;

//Indicates whether it is RMSnorm or not
bool is_rms_only;

};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
165 changes: 164 additions & 1 deletion apex/contrib/csrc/layer_norm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,171 @@ std::vector<at::Tensor> ln_bwd(const at::Tensor &dz, // BxSxhidden_size

////////////////////////////////////////////////////////////////////////////////////////////////////

std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &x, // BxSxhidden_size
const at::Tensor &gamma, // hidden_size
const float epsilon
) {
auto itype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = wtype;
auto ctype = torch::kFloat32;

// TORCH_CHECK(beta.scalar_type() == wtype);

TORCH_CHECK(x.is_cuda())
TORCH_CHECK(gamma.is_cuda())
// TORCH_CHECK(beta.is_cuda())

TORCH_CHECK(x.is_contiguous());
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);

const int rows = sizes[0];
const int cols = sizes[1];
auto hidden_size = gamma.numel();

// TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(hidden_size == cols);

TORCH_CHECK(epsilon >= 0.f);

auto opts = x.options();

auto z = torch::empty(sizes, opts.dtype(otype));

// auto mu = torch::empty({ rows }, opts.dtype(ctype));
auto rsigma = torch::empty({ rows }, opts.dtype(ctype));

layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;

launch_params.props = at::cuda::getCurrentDeviceProperties();
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();

// Request the kernel launcher.
auto launcher = get_fwd_launcher(wtype, itype, otype, ctype, hidden_size);

// Query the kernel-specific launch parameters.
launcher(launch_params, true);

at::Tensor workspace, barrier;

// Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
// params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.z = z.data_ptr();
params.epsilon = epsilon;
params.is_rms_only = true;

if( launch_params.barrier_size > 0 ) {
auto options = x.options();
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}

// Launch the kernel.
launcher(launch_params, false);

return { z, rsigma };
}

////////////////////////////////////////////////////////////////////////////////////////////////////

std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma // hidden_size
) {

auto itype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = wtype;
auto ctype = torch::kFloat32;

TORCH_CHECK(dz.dtype() == otype);
// TORCH_CHECK(mu.dtype() == ctype);
TORCH_CHECK(rsigma.dtype() == ctype);

TORCH_CHECK(x.is_cuda());
TORCH_CHECK(dz.is_cuda());
// TORCH_CHECK(mu.is_cuda());
TORCH_CHECK(rsigma.is_cuda());
TORCH_CHECK(gamma.is_cuda());

TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(dz.is_contiguous());

auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dz.sizes() == sizes);
auto rows = sizes[0];
auto cols = sizes[1];

auto hidden_size = gamma.numel();

// TORCH_CHECK(mu.numel() == rows);
// TORCH_CHECK(mu.sizes() == rsigma.sizes());

TORCH_CHECK(gamma.numel() == cols);

auto options = x.options();

auto dx = torch::empty_like(x);
auto dgamma = torch::empty_like(gamma);
// auto dbeta = torch::empty_like(gamma);

layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
launch_params.props = at::cuda::getCurrentDeviceProperties();

auto launcher = get_bwd_launcher(wtype, itype, otype, ctype, hidden_size);

launcher(launch_params, true);

auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype));
// auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype));
at::Tensor workspace, barrier;

layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
// params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.dz = dz.data_ptr();
params.dx = dx.data_ptr();
// params.dbeta = dbeta.data_ptr();
params.dgamma = dgamma.data_ptr();
// params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr();
params.is_rms_only = true;

if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this?
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}

launcher(launch_params, false);

return { dx, dgamma, dgamma_part };
}

////////////////////////////////////////////////////////////////////////////////////////////////////

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA LayerNorm";
m.doc() = "CUDA LayerNorm & RMSNorm";
m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel");
m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel");
m.def("rmsnorm_fwd", &rmsnorm_fwd, "Run RMSNorm forward kernel");
m.def("rmsnorm_bwd", &rmsnorm_bwd, "Run RMSNorm backward kernel");
}
Loading