Skip to content

Commit

Permalink
CUTensorMap is only in CUDA v12
Browse files Browse the repository at this point in the history
The documentation at https://docs.nvidia.com/cuda/cutensor/latest/ say
it's supported in 11.0, 11.8, and 12.x

Our local testing against 11.7 fails because none of the required types
or functions are declared in the header.

It's safer to remain buildable with older versions but with this feature
disabled than requiring a version bump that isn't otherwise warranted.
  • Loading branch information
ldrumm committed Dec 23, 2024
1 parent 76a9623 commit f8b04f4
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions source/adapters/cuda/tensor_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@

#include "context.hpp"

#if CUDA_VERSION < 12000
UR_APIEXPORT ur_result_t UR_APICALL urTensorMapEncodeIm2ColExp(
ur_device_handle_t, ur_exp_tensor_map_data_type_flags_t, uint32_t, void *,
const uint64_t *, const uint64_t *, const int *, const int *, uint32_t,
uint32_t, const uint32_t *, ur_exp_tensor_map_interleave_flags_t,
ur_exp_tensor_map_swizzle_flags_t, ur_exp_tensor_map_l2_promotion_flags_t,
ur_exp_tensor_map_oob_fill_flags_t, ur_exp_tensor_map_handle_t *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}
UR_APIEXPORT ur_result_t UR_APICALL urTensorMapEncodeTiledExp(
ur_device_handle_t, ur_exp_tensor_map_data_type_flags_t, uint32_t, void *,
const uint64_t *, const uint64_t *, const uint32_t *, const uint32_t *,
ur_exp_tensor_map_interleave_flags_t, ur_exp_tensor_map_swizzle_flags_t,
ur_exp_tensor_map_l2_promotion_flags_t, ur_exp_tensor_map_oob_fill_flags_t,
ur_exp_tensor_map_handle_t *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}
#else
struct ur_exp_tensor_map_handle_t_ {
CUtensorMap Map;
};
Expand Down Expand Up @@ -140,3 +158,4 @@ UR_APIEXPORT ur_result_t UR_APICALL urTensorMapEncodeTiledExp(
}
return UR_RESULT_SUCCESS;
}
#endif

0 comments on commit f8b04f4

Please sign in to comment.