From dc7314d76e0061ee3b753f109ba93ae30f5fb7aa Mon Sep 17 00:00:00 2001 From: ChengXiang Qi <18630816527@163.com> Date: Tue, 31 Dec 2024 15:55:41 +0800 Subject: [PATCH] feat(cell): Implement vectorized access for global to register copy. (#30) * Implement vectorize copy from global to register. * Add Vectorize struct brief. * fix empty line. * fix inline function define. --- .vscode/settings.json | 20 ++++++ include/cell/copy/global_to_register.hpp | 53 +++++++--------- include/cell/copy/mod.hpp | 1 + include/cell/copy/vectorize.hpp | 80 ++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 32 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 include/cell/copy/vectorize.hpp diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..81386e8 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,20 @@ +{ + "gotoSymbolStack.currentStackPosition": 0, + "gotoSymbolStack.maxStackPosition": 0, + "gotoSymbolStack.filePositionInfo": [], + "files.associations": { + "*.tcc": "cpp", + "optional": "cpp", + "ratio": "cpp", + "system_error": "cpp", + "array": "cpp", + "functional": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "utility": "cpp", + "variant": "cpp", + "compare": "cpp", + "concepts": "cpp", + "random": "cpp" + } +} diff --git a/include/cell/copy/global_to_register.hpp b/include/cell/copy/global_to_register.hpp index c39c801..54d6875 100644 --- a/include/cell/copy/global_to_register.hpp +++ b/include/cell/copy/global_to_register.hpp @@ -4,6 +4,7 @@ #pragma once #include "cell/copy/constants.hpp" +#include "cell/copy/vectorize.hpp" #include "cell/copy/warp.hpp" #include "traits/base.hpp" #include "types/mod.hpp" @@ -29,14 +30,11 @@ struct GlobalToRegMatLoader { static constexpr int kStride = Global::kRowStride; DEVICE void operator()(const DType* src, BaseTile& dst) { - dst(0, 0) = src[0 * kStride + 0]; - dst(0, 1) = src[0 * kStride + 1]; - dst(1, 0) = src[0 * kStride + 8]; - dst(1, 1) = src[0 * kStride + 9]; - dst(0, 2) = src[8 * kStride + 0]; - dst(0, 3) = src[8 * kStride + 1]; - dst(1, 2) = src[8 * kStride + 8]; - dst(1, 3) = src[8 * kStride + 9]; + Vectorize vectorize; + vectorize.copy(src + 0 * kStride + 0, &dst(0, 0)); + vectorize.copy(src + 0 * kStride + 8, &dst(1, 0)); + vectorize.copy(src + 8 * kStride + 0, &dst(0, 2)); + vectorize.copy(src + 8 * kStride + 8, &dst(1, 2)); } }; @@ -49,14 +47,11 @@ struct GlobalToRegMatLoader { static constexpr int kStride = Global::kColStride; DEVICE void operator()(const DType* src, BaseTile& dst) { - dst(0, 0) = src[0 * kStride + 0]; - dst(1, 0) = src[0 * kStride + 1]; - dst(0, 1) = src[0 * kStride + 8]; - dst(1, 1) = src[0 * kStride + 9]; - dst(2, 0) = src[8 * kStride + 0]; - dst(3, 0) = src[8 * kStride + 1]; - dst(2, 1) = src[8 * kStride + 8]; - dst(3, 1) = src[8 * kStride + 9]; + Vectorize vectorize; + vectorize.copy(src + 0 * kStride + 0, &dst(0, 0)); + vectorize.copy(src + 0 * kStride + 8, &dst(0, 1)); + vectorize.copy(src + 8 * kStride + 0, &dst(2, 0)); + vectorize.copy(src + 8 * kStride + 8, &dst(2, 1)); } }; @@ -84,14 +79,11 @@ struct RegToGlobalMatStorer { static constexpr int kStride = Global::kRowStride; DEVICE void operator()(const BaseTile& src, DType* dst) { - dst[0 * kStride + 0] = src(0, 0); - dst[0 * kStride + 1] = src(0, 1); - dst[0 * kStride + 8] = src(1, 0); - dst[0 * kStride + 9] = src(1, 1); - dst[8 * kStride + 0] = src(0, 2); - dst[8 * kStride + 1] = src(0, 3); - dst[8 * kStride + 8] = src(1, 2); - dst[8 * kStride + 9] = src(1, 3); + Vectorize vectorize; + vectorize.copy(&src(0, 0), dst + 0 * kStride + 0); + vectorize.copy(&src(1, 0), dst + 0 * kStride + 8); + vectorize.copy(&src(0, 2), dst + 8 * kStride + 0); + vectorize.copy(&src(1, 2), dst + 8 * kStride + 8); } }; @@ -104,14 +96,11 @@ struct RegToGlobalMatStorer { static constexpr int kStride = Global::kColStride; DEVICE void operator()(const BaseTile& src, DType* dst) { - dst[0 * kStride + 0] = src(0, 0); - dst[0 * kStride + 1] = src(1, 0); - dst[0 * kStride + 8] = src(0, 1); - dst[0 * kStride + 9] = src(1, 1); - dst[8 * kStride + 0] = src(2, 0); - dst[8 * kStride + 1] = src(3, 0); - dst[8 * kStride + 8] = src(2, 1); - dst[8 * kStride + 9] = src(3, 1); + Vectorize vectorize; + vectorize.copy(&src(0, 0), dst + 0 * kStride + 0); + vectorize.copy(&src(0, 1), dst + 0 * kStride + 8); + vectorize.copy(&src(2, 0), dst + 8 * kStride + 0); + vectorize.copy(&src(2, 1), dst + 8 * kStride + 8); } }; diff --git a/include/cell/copy/mod.hpp b/include/cell/copy/mod.hpp index 9553316..0c24090 100644 --- a/include/cell/copy/mod.hpp +++ b/include/cell/copy/mod.hpp @@ -9,4 +9,5 @@ #include "cell/copy/global_to_shared.hpp" #include "cell/copy/register.hpp" #include "cell/copy/shared_to_register.hpp" +#include "cell/copy/vectorize.hpp" #include "cell/copy/warp.hpp" diff --git a/include/cell/copy/vectorize.hpp b/include/cell/copy/vectorize.hpp new file mode 100644 index 0000000..ea2f4e5 --- /dev/null +++ b/include/cell/copy/vectorize.hpp @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "config.hpp" +#include "cuda_utils.hpp" + +#include + +namespace tilefusion::cell::copy { + +/** + * @brief Vectorize a data type. + * + * @tparam Element Data type. + * @tparam kVecNums Number of vecotrized elements. + */ +template +struct Vectorize { + using UnVecType = Element; + using VecType = Element; + static constexpr int vectorize_nums = kVecNums; + + /** + * @brief Copy data from unvectorized to vectorized. + * + * @param src Source data. + * @param dst Destination data. + */ + DEVICE void copy(const UnVecType* src, UnVecType* dst) { + const VecType* src_vec = reinterpret_cast(src); + VecType* dst_vec = reinterpret_cast(dst); + *dst_vec = *src_vec; + } +}; + +template <> +struct Vectorize<__half, 2> { + using UnVecType = __half; + using VecType = __half2; + static constexpr int vectorize_nums = 2; + static constexpr int vectorize_bits = 32; + + DEVICE void copy(const __half* src, __half* dst) { + const __half2* src_vec = reinterpret_cast(src); + __half2* dst_vec = reinterpret_cast<__half2*>(dst); + *dst_vec = *src_vec; + } +}; + +template <> +struct Vectorize { + using UnVecType = cutlass::half_t; + using VecType = __half2; + static constexpr int vectorize_nums = 2; + static constexpr int vectorize_bits = 32; + + DEVICE void copy(const cutlass::half_t* src, cutlass::half_t* dst) { + const __half2* src_vec = reinterpret_cast(src); + __half2* dst_vec = reinterpret_cast<__half2*>(dst); + *dst_vec = *src_vec; + } +}; + +template <> +struct Vectorize { + using UnVecType = float; + using VecType = float2; + static constexpr int vectorize_nums = 2; + static constexpr int vectorize_bits = 64; + + DEVICE void copy(const float* src, float* dst) { + const float2* src_vec = reinterpret_cast(src); + float2* dst_vec = reinterpret_cast(dst); + *dst_vec = *src_vec; + } +}; + +} // namespace tilefusion::cell::copy