From 7f83be52c996287aa96aaa626abf372c333a6eb1 Mon Sep 17 00:00:00 2001 From: Jatin Chaudhary Date: Tue, 17 Jan 2023 10:40:05 +0000 Subject: [PATCH] SWDEV-372153 - Add hipStreamGetDevice Implementation Change-Id: Ifd1f13e311e8221ca6d94cf27f9131eb97678067 --- include/hip/amd_detail/hip_prof_str.h | 26 ++++++++++++++++++- .../nvidia_detail/nvidia_hip_runtime_api.h | 14 ++++++++++ src/amdhip.def | 1 + src/hip_hcc.def.in | 1 + src/hip_hcc.map.in | 1 + src/hip_stream.cpp | 24 +++++++++++++++++ 6 files changed, 66 insertions(+), 1 deletion(-) diff --git a/include/hip/amd_detail/hip_prof_str.h b/include/hip/amd_detail/hip_prof_str.h index d72fd38d..d0b24d01 100644 --- a/include/hip/amd_detail/hip_prof_str.h +++ b/include/hip/amd_detail/hip_prof_str.h @@ -373,7 +373,8 @@ enum hip_api_id_t { HIP_API_ID_hipArray3DGetDescriptor = 360, HIP_API_ID_hipArrayGetDescriptor = 361, HIP_API_ID_hipArrayGetInfo = 362, - HIP_API_ID_LAST = 362, + HIP_API_ID_hipStreamGetDevice = 363, + HIP_API_ID_LAST = 363, HIP_API_ID_hipBindTexture = HIP_API_ID_NONE, HIP_API_ID_hipBindTexture2D = HIP_API_ID_NONE, @@ -743,6 +744,7 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipStreamEndCapture: return "hipStreamEndCapture"; case HIP_API_ID_hipStreamGetCaptureInfo: return "hipStreamGetCaptureInfo"; case HIP_API_ID_hipStreamGetCaptureInfo_v2: return "hipStreamGetCaptureInfo_v2"; + case HIP_API_ID_hipStreamGetDevice: return "hipStreamGetDevice"; case HIP_API_ID_hipStreamGetFlags: return "hipStreamGetFlags"; case HIP_API_ID_hipStreamGetPriority: return "hipStreamGetPriority"; case HIP_API_ID_hipStreamIsCapturing: return "hipStreamIsCapturing"; @@ -1108,6 +1110,7 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipStreamEndCapture", name) == 0) return HIP_API_ID_hipStreamEndCapture; if (strcmp("hipStreamGetCaptureInfo", name) == 0) return HIP_API_ID_hipStreamGetCaptureInfo; if (strcmp("hipStreamGetCaptureInfo_v2", name) == 0) return HIP_API_ID_hipStreamGetCaptureInfo_v2; + if (strcmp("hipStreamGetDevice", name) == 0) return HIP_API_ID_hipStreamGetDevice; if (strcmp("hipStreamGetFlags", name) == 0) return HIP_API_ID_hipStreamGetFlags; if (strcmp("hipStreamGetPriority", name) == 0) return HIP_API_ID_hipStreamGetPriority; if (strcmp("hipStreamIsCapturing", name) == 0) return HIP_API_ID_hipStreamIsCapturing; @@ -3062,6 +3065,11 @@ typedef struct hip_api_data_s { size_t* numDependencies_out; size_t numDependencies_out__val; } hipStreamGetCaptureInfo_v2; + struct { + hipStream_t stream; + hipDevice_t* device; + hipDevice_t device__val; + } hipStreamGetDevice; struct { hipStream_t stream; unsigned int* flags; @@ -5231,6 +5239,11 @@ typedef struct hip_api_data_s { cb_data.args.hipStreamGetCaptureInfo_v2.dependencies_out = (const hipGraphNode_t**)dependencies_out; \ cb_data.args.hipStreamGetCaptureInfo_v2.numDependencies_out = (size_t*)numDependencies_out; \ }; +// hipStreamGetDevice[('hipStream_t', 'stream'), ('hipDevice_t*', 'device')] +#define INIT_hipStreamGetDevice_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipStreamGetDevice.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamGetDevice.device = (hipDevice_t*)device; \ +}; // hipStreamGetFlags[('hipStream_t', 'stream'), ('unsigned int*', 'flags')] #define INIT_hipStreamGetFlags_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipStreamGetFlags.stream = (hipStream_t)stream; \ @@ -6765,6 +6778,10 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { if (data->args.hipStreamGetCaptureInfo_v2.dependencies_out) data->args.hipStreamGetCaptureInfo_v2.dependencies_out__val = *(data->args.hipStreamGetCaptureInfo_v2.dependencies_out); if (data->args.hipStreamGetCaptureInfo_v2.numDependencies_out) data->args.hipStreamGetCaptureInfo_v2.numDependencies_out__val = *(data->args.hipStreamGetCaptureInfo_v2.numDependencies_out); break; +// hipStreamGetDevice[('hipStream_t', 'stream'), ('hipDevice_t*', 'device')] + case HIP_API_ID_hipStreamGetDevice: + if (data->args.hipStreamGetDevice.device) data->args.hipStreamGetDevice.device__val = *(data->args.hipStreamGetDevice.device); + break; // hipStreamGetFlags[('hipStream_t', 'stream'), ('unsigned int*', 'flags')] case HIP_API_ID_hipStreamGetFlags: if (data->args.hipStreamGetFlags.flags) data->args.hipStreamGetFlags.flags__val = *(data->args.hipStreamGetFlags.flags); @@ -9491,6 +9508,13 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da else { oss << ", numDependencies_out="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetCaptureInfo_v2.numDependencies_out__val); } oss << ")"; break; + case HIP_API_ID_hipStreamGetDevice: + oss << "hipStreamGetDevice("; + oss << "stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetDevice.stream); + if (data->args.hipStreamGetDevice.device == NULL) oss << ", device=NULL"; + else { oss << ", device="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetDevice.device__val); } + oss << ")"; + break; case HIP_API_ID_hipStreamGetFlags: oss << "hipStreamGetFlags("; oss << "stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetFlags.stream); diff --git a/include/hip/nvidia_detail/nvidia_hip_runtime_api.h b/include/hip/nvidia_detail/nvidia_hip_runtime_api.h index d7701826..4c8be9af 100644 --- a/include/hip/nvidia_detail/nvidia_hip_runtime_api.h +++ b/include/hip/nvidia_detail/nvidia_hip_runtime_api.h @@ -2507,6 +2507,20 @@ inline static hipError_t hipStreamAddCallback(hipStream_t stream, hipStreamCallb cudaStreamAddCallback(stream, (cudaStreamCallback_t)callback, userData, flags)); } +inline static hipError_t hipStreamGetDevice(hipStream_t stream, hipDevice_t* device) { + hipCtx_t context; + auto err = hipCUResultTohipError(cuStreamGetCtx(stream, &context)); + if (err != hipSuccess) return err; + + err = hipCUResultTohipError(cuCtxPushCurrent(context)); + if (err != hipSuccess) return err; + + err = hipCUResultTohipError(cuCtxGetDevice(device)); + if (err != hipSuccess) return err; + + return hipCUResultTohipError(cuCtxPopCurrent(&context)); +} + inline static hipError_t hipDriverGetVersion(int* driverVersion) { return hipCUDAErrorTohipError(cudaDriverGetVersion(driverVersion)); } diff --git a/src/amdhip.def b/src/amdhip.def index 70279e21..ffaff7f5 100644 --- a/src/amdhip.def +++ b/src/amdhip.def @@ -193,6 +193,7 @@ hipStreamCreate hipStreamCreateWithFlags hipStreamCreateWithPriority hipStreamDestroy +hipStreamGetDevice hipStreamGetFlags hipStreamQuery hipStreamSynchronize diff --git a/src/hip_hcc.def.in b/src/hip_hcc.def.in index 129fa7c4..fe219359 100644 --- a/src/hip_hcc.def.in +++ b/src/hip_hcc.def.in @@ -194,6 +194,7 @@ hipStreamCreate hipStreamCreateWithFlags hipStreamCreateWithPriority hipStreamDestroy +hipStreamGetDevice hipStreamGetFlags hipStreamQuery hipStreamSynchronize diff --git a/src/hip_hcc.map.in b/src/hip_hcc.map.in index 81251cca..204b139f 100644 --- a/src/hip_hcc.map.in +++ b/src/hip_hcc.map.in @@ -169,6 +169,7 @@ global: hipStreamCreateWithFlags; hipStreamCreateWithPriority; hipStreamDestroy; + hipStreamGetDevice; hipStreamGetFlags; hipStreamQuery; hipStreamSynchronize; diff --git a/src/hip_stream.cpp b/src/hip_stream.cpp index 1342ac72..6d085fe7 100644 --- a/src/hip_stream.cpp +++ b/src/hip_stream.cpp @@ -795,3 +795,27 @@ hipError_t hipExtStreamGetCUMask(hipStream_t stream, uint32_t cuMaskSize, uint32 } HIP_RETURN(hipSuccess); } + +// ================================================================================================ +hipError_t hipStreamGetDevice(hipStream_t stream, hipDevice_t* device) { + HIP_INIT_API(hipStreamGetDevice, stream, device); + + if (device == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + + if (!hip::isValid(stream)) { + return HIP_RETURN(hipErrorContextIsDestroyed); + } + + if (stream == nullptr) { // handle null stream + // null stream is associated with current device, return the device id associated with the + // current device + *device = hip::getCurrentDevice()->deviceId(); + } else { + getStreamPerThread(stream); + *device = reinterpret_cast(stream)->DeviceId(); + } + + HIP_RETURN(hipSuccess); +}