Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit ee6f1c2

Browse files
committed
fix bf16 load/store
1 parent 2dbe90c commit ee6f1c2

File tree

2 files changed

+319
-294
lines changed

2 files changed

+319
-294
lines changed

include/common/core/memory.hpp

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,24 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
458458
size_t SurfacePitch,
459459
int X,
460460
int Y) {
461-
if constexpr (BlockWidth * sizeof(T) < sizeof(uint32_t)) {
461+
if constexpr (std::is_same_v<T, bf16>) {
462+
auto ret = xetla_load_global<
463+
fp16,
464+
BlockWidth,
465+
BlockHeight,
466+
NBlocks,
467+
Transposed,
468+
Transformed,
469+
L1H,
470+
L2H>(
471+
reinterpret_cast<const fp16*>(Ptr),
472+
SurfaceWidth,
473+
SurfaceHeight,
474+
SurfacePitch,
475+
X,
476+
Y);
477+
return ret.xetla_format<T>();
478+
} else if constexpr (BlockWidth * sizeof(T) < sizeof(uint32_t)) {
462479
constexpr auto scale_factor = sizeof(uint32_t) / sizeof(T);
463480
xetla_vector<uint32_t, N> ret = __ESIMD_ENS::lsc_load_2d<
464481
uint32_t,
@@ -754,13 +771,25 @@ __XETLA_API void xetla_store_global(
754771
int X,
755772
int Y,
756773
xetla_vector<T, N> Vals) {
757-
__ESIMD_ENS::lsc_store_2d<
758-
T,
759-
BlockWidth,
760-
BlockHeight,
761-
gpu::xetla::detail::get_cache_hint(L1H),
762-
gpu::xetla::detail::get_cache_hint(L2H)>(
763-
Ptr, SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y, Vals);
774+
if constexpr (std::is_same_v<T, bf16>) {
775+
xetla_vector<fp16, N> Vals_fp16 = Vals.xetla_format<fp16>();
776+
xetla_store_global<fp16, BlockWidth, BlockHeight, L1H, L2H>(
777+
reinterpret_cast<fp16*>(Ptr),
778+
SurfaceWidth,
779+
SurfaceHeight,
780+
SurfacePitch,
781+
X,
782+
Y,
783+
Vals_fp16);
784+
} else {
785+
__ESIMD_ENS::lsc_store_2d<
786+
T,
787+
BlockWidth,
788+
BlockHeight,
789+
gpu::xetla::detail::get_cache_hint(L1H),
790+
gpu::xetla::detail::get_cache_hint(L2H)>(
791+
Ptr, SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y, Vals);
792+
}
764793
}
765794
/// template <typename T, int N, int VS = 1, typename OffsetT,
766795
/// typename PropertyListT = empty_properties_t>

0 commit comments

Comments
 (0)