Skip to content

Commit

Permalink
feat(cell): Implement vectorized access for global to register copy. (#…
Browse files Browse the repository at this point in the history
…30)

* Implement vectorize copy from global to register.

* Add Vectorize struct brief.

* fix empty line.

* fix inline function define.
  • Loading branch information
KuangjuX authored Dec 31, 2024
1 parent 981a39a commit dc7314d
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 32 deletions.
20 changes: 20 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"gotoSymbolStack.currentStackPosition": 0,
"gotoSymbolStack.maxStackPosition": 0,
"gotoSymbolStack.filePositionInfo": [],
"files.associations": {
"*.tcc": "cpp",
"optional": "cpp",
"ratio": "cpp",
"system_error": "cpp",
"array": "cpp",
"functional": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"utility": "cpp",
"variant": "cpp",
"compare": "cpp",
"concepts": "cpp",
"random": "cpp"
}
}
53 changes: 21 additions & 32 deletions include/cell/copy/global_to_register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include "cell/copy/constants.hpp"
#include "cell/copy/vectorize.hpp"
#include "cell/copy/warp.hpp"
#include "traits/base.hpp"
#include "types/mod.hpp"
Expand All @@ -29,14 +30,11 @@ struct GlobalToRegMatLoader<Global_, BaseTile_, tl::Layout::kRowMajor> {
static constexpr int kStride = Global::kRowStride;

DEVICE void operator()(const DType* src, BaseTile& dst) {
dst(0, 0) = src[0 * kStride + 0];
dst(0, 1) = src[0 * kStride + 1];
dst(1, 0) = src[0 * kStride + 8];
dst(1, 1) = src[0 * kStride + 9];
dst(0, 2) = src[8 * kStride + 0];
dst(0, 3) = src[8 * kStride + 1];
dst(1, 2) = src[8 * kStride + 8];
dst(1, 3) = src[8 * kStride + 9];
Vectorize<DType, 2> vectorize;
vectorize.copy(src + 0 * kStride + 0, &dst(0, 0));
vectorize.copy(src + 0 * kStride + 8, &dst(1, 0));
vectorize.copy(src + 8 * kStride + 0, &dst(0, 2));
vectorize.copy(src + 8 * kStride + 8, &dst(1, 2));
}
};

Expand All @@ -49,14 +47,11 @@ struct GlobalToRegMatLoader<Global_, BaseTile_, tl::Layout::kColMajor> {
static constexpr int kStride = Global::kColStride;

DEVICE void operator()(const DType* src, BaseTile& dst) {
dst(0, 0) = src[0 * kStride + 0];
dst(1, 0) = src[0 * kStride + 1];
dst(0, 1) = src[0 * kStride + 8];
dst(1, 1) = src[0 * kStride + 9];
dst(2, 0) = src[8 * kStride + 0];
dst(3, 0) = src[8 * kStride + 1];
dst(2, 1) = src[8 * kStride + 8];
dst(3, 1) = src[8 * kStride + 9];
Vectorize<DType, 2> vectorize;
vectorize.copy(src + 0 * kStride + 0, &dst(0, 0));
vectorize.copy(src + 0 * kStride + 8, &dst(0, 1));
vectorize.copy(src + 8 * kStride + 0, &dst(2, 0));
vectorize.copy(src + 8 * kStride + 8, &dst(2, 1));
}
};

Expand Down Expand Up @@ -84,14 +79,11 @@ struct RegToGlobalMatStorer<Global_, BaseTile_, tl::Layout::kRowMajor> {
static constexpr int kStride = Global::kRowStride;

DEVICE void operator()(const BaseTile& src, DType* dst) {
dst[0 * kStride + 0] = src(0, 0);
dst[0 * kStride + 1] = src(0, 1);
dst[0 * kStride + 8] = src(1, 0);
dst[0 * kStride + 9] = src(1, 1);
dst[8 * kStride + 0] = src(0, 2);
dst[8 * kStride + 1] = src(0, 3);
dst[8 * kStride + 8] = src(1, 2);
dst[8 * kStride + 9] = src(1, 3);
Vectorize<DType, 2> vectorize;
vectorize.copy(&src(0, 0), dst + 0 * kStride + 0);
vectorize.copy(&src(1, 0), dst + 0 * kStride + 8);
vectorize.copy(&src(0, 2), dst + 8 * kStride + 0);
vectorize.copy(&src(1, 2), dst + 8 * kStride + 8);
}
};

Expand All @@ -104,14 +96,11 @@ struct RegToGlobalMatStorer<Global_, BaseTile_, tl::Layout::kColMajor> {
static constexpr int kStride = Global::kColStride;

DEVICE void operator()(const BaseTile& src, DType* dst) {
dst[0 * kStride + 0] = src(0, 0);
dst[0 * kStride + 1] = src(1, 0);
dst[0 * kStride + 8] = src(0, 1);
dst[0 * kStride + 9] = src(1, 1);
dst[8 * kStride + 0] = src(2, 0);
dst[8 * kStride + 1] = src(3, 0);
dst[8 * kStride + 8] = src(2, 1);
dst[8 * kStride + 9] = src(3, 1);
Vectorize<DType, 2> vectorize;
vectorize.copy(&src(0, 0), dst + 0 * kStride + 0);
vectorize.copy(&src(0, 1), dst + 0 * kStride + 8);
vectorize.copy(&src(2, 0), dst + 8 * kStride + 0);
vectorize.copy(&src(2, 1), dst + 8 * kStride + 8);
}
};

Expand Down
1 change: 1 addition & 0 deletions include/cell/copy/mod.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
#include "cell/copy/global_to_shared.hpp"
#include "cell/copy/register.hpp"
#include "cell/copy/shared_to_register.hpp"
#include "cell/copy/vectorize.hpp"
#include "cell/copy/warp.hpp"
80 changes: 80 additions & 0 deletions include/cell/copy/vectorize.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "config.hpp"
#include "cuda_utils.hpp"

#include <cutlass/half.h>

namespace tilefusion::cell::copy {

/**
* @brief Vectorize a data type.
*
* @tparam Element Data type.
* @tparam kVecNums Number of vecotrized elements.
*/
template <typename Element, const int kVecNums>
struct Vectorize {
using UnVecType = Element;
using VecType = Element;
static constexpr int vectorize_nums = kVecNums;

/**
* @brief Copy data from unvectorized to vectorized.
*
* @param src Source data.
* @param dst Destination data.
*/
DEVICE void copy(const UnVecType* src, UnVecType* dst) {
const VecType* src_vec = reinterpret_cast<const VecType*>(src);
VecType* dst_vec = reinterpret_cast<VecType*>(dst);
*dst_vec = *src_vec;
}
};

template <>
struct Vectorize<__half, 2> {
using UnVecType = __half;
using VecType = __half2;
static constexpr int vectorize_nums = 2;
static constexpr int vectorize_bits = 32;

DEVICE void copy(const __half* src, __half* dst) {
const __half2* src_vec = reinterpret_cast<const __half2*>(src);
__half2* dst_vec = reinterpret_cast<__half2*>(dst);
*dst_vec = *src_vec;
}
};

template <>
struct Vectorize<cutlass::half_t, 2> {
using UnVecType = cutlass::half_t;
using VecType = __half2;
static constexpr int vectorize_nums = 2;
static constexpr int vectorize_bits = 32;

DEVICE void copy(const cutlass::half_t* src, cutlass::half_t* dst) {
const __half2* src_vec = reinterpret_cast<const __half2*>(src);
__half2* dst_vec = reinterpret_cast<__half2*>(dst);
*dst_vec = *src_vec;
}
};

template <>
struct Vectorize<float, 2> {
using UnVecType = float;
using VecType = float2;
static constexpr int vectorize_nums = 2;
static constexpr int vectorize_bits = 64;

DEVICE void copy(const float* src, float* dst) {
const float2* src_vec = reinterpret_cast<const float2*>(src);
float2* dst_vec = reinterpret_cast<float2*>(dst);
*dst_vec = *src_vec;
}
};

} // namespace tilefusion::cell::copy

0 comments on commit dc7314d

Please sign in to comment.