@@ -458,7 +458,24 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
458
458
size_t SurfacePitch,
459
459
int X,
460
460
int Y) {
461
- if constexpr (BlockWidth * sizeof (T) < sizeof (uint32_t )) {
461
+ if constexpr (std::is_same_v<T, bf16>) {
462
+ auto ret = xetla_load_global<
463
+ fp16,
464
+ BlockWidth,
465
+ BlockHeight,
466
+ NBlocks,
467
+ Transposed,
468
+ Transformed,
469
+ L1H,
470
+ L2H>(
471
+ reinterpret_cast <const fp16*>(Ptr ),
472
+ SurfaceWidth,
473
+ SurfaceHeight,
474
+ SurfacePitch,
475
+ X,
476
+ Y);
477
+ return ret.xetla_format <T>();
478
+ } else if constexpr (BlockWidth * sizeof (T) < sizeof (uint32_t )) {
462
479
constexpr auto scale_factor = sizeof (uint32_t ) / sizeof (T);
463
480
xetla_vector<uint32_t , N> ret = __ESIMD_ENS::lsc_load_2d<
464
481
uint32_t ,
@@ -754,13 +771,25 @@ __XETLA_API void xetla_store_global(
754
771
int X,
755
772
int Y,
756
773
xetla_vector<T, N> Vals) {
757
- __ESIMD_ENS::lsc_store_2d<
758
- T,
759
- BlockWidth,
760
- BlockHeight,
761
- gpu::xetla::detail::get_cache_hint (L1H),
762
- gpu::xetla::detail::get_cache_hint (L2H)>(
763
- Ptr , SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y, Vals);
774
+ if constexpr (std::is_same_v<T, bf16>) {
775
+ xetla_vector<fp16, N> Vals_fp16 = Vals.xetla_format <fp16>();
776
+ xetla_store_global<fp16, BlockWidth, BlockHeight, L1H, L2H>(
777
+ reinterpret_cast <fp16*>(Ptr ),
778
+ SurfaceWidth,
779
+ SurfaceHeight,
780
+ SurfacePitch,
781
+ X,
782
+ Y,
783
+ Vals_fp16);
784
+ } else {
785
+ __ESIMD_ENS::lsc_store_2d<
786
+ T,
787
+ BlockWidth,
788
+ BlockHeight,
789
+ gpu::xetla::detail::get_cache_hint (L1H),
790
+ gpu::xetla::detail::get_cache_hint (L2H)>(
791
+ Ptr , SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y, Vals);
792
+ }
764
793
}
765
794
// / template <typename T, int N, int VS = 1, typename OffsetT,
766
795
// / typename PropertyListT = empty_properties_t>
0 commit comments