Skip to content

Commit

Permalink
Merge pull request #2513 from AllanZyne/review/yang/fix_msan_usm
Browse files Browse the repository at this point in the history
[DeviceMSAN] Fix "urEnqueueUSM" APIs
  • Loading branch information
kbenzie authored Jan 6, 2025
2 parents e7366f9 + 3eeb2a1 commit 533ab9b
Show file tree
Hide file tree
Showing 6 changed files with 408 additions and 149 deletions.
258 changes: 257 additions & 1 deletion source/loader/layers/sanitizer/msan/msan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
UR_CALL(DI->allocShadowMemory(Context));
}
CI->DeviceList.emplace_back(hDevice);
CI->AllocInfosMap[hDevice];
}
return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -104,6 +103,17 @@ ur_result_t urUSMDeviceAlloc(
pool, size, ppMem);
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urUSMFree
__urdlllocal ur_result_t UR_APICALL urUSMFree(
ur_context_handle_t hContext, ///< [in] handle of the context object
void *pMem ///< [in] pointer to USM memory object
) {
getContext()->logger.debug("==== urUSMFree");

return getMsanInterceptor()->releaseMemory(hContext, pMem);
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urProgramCreateWithIL
ur_result_t urProgramCreateWithIL(
Expand Down Expand Up @@ -1234,6 +1244,247 @@ ur_result_t urKernelSetArgMemObj(
return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urEnqueueUSMFill
ur_result_t UR_APICALL urEnqueueUSMFill(
ur_queue_handle_t hQueue, ///< [in] handle of the queue object
void *pMem, ///< [in][bounds(0, size)] pointer to USM memory object
size_t
patternSize, ///< [in] the size in bytes of the pattern. Must be a power of 2 and less
///< than or equal to width.
const void
*pPattern, ///< [in] pointer with the bytes of the pattern to set.
size_t
size, ///< [in] size in bytes to be set. Must be a multiple of patternSize.
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
const ur_event_handle_t *
phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
///< events that must be complete before this command can be executed.
///< If nullptr, the numEventsInWaitList must be 0, indicating that this
///< command does not wait on any event to complete.
ur_event_handle_t *
phEvent ///< [out][optional] return an event object that identifies this particular
///< command instance. If phEventWaitList and phEvent are not NULL, phEvent
///< must not refer to an element of the phEventWaitList array.
) {
auto pfnUSMFill = getContext()->urDdiTable.Enqueue.pfnUSMFill;
getContext()->logger.debug("==== urEnqueueUSMFill");

ur_event_handle_t hEvents[2] = {};
UR_CALL(pfnUSMFill(hQueue, pMem, patternSize, pPattern, size,
numEventsInWaitList, phEventWaitList, &hEvents[0]));

const auto Mem = (uptr)pMem;
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
if (MemInfoItOp) {
auto MemInfo = (*MemInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(MemInfo->Device);
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);

UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)MemShadow, 0, size, 0,
nullptr, &hEvents[1]));
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
}

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urEnqueueUSMMemcpy
ur_result_t UR_APICALL urEnqueueUSMMemcpy(
ur_queue_handle_t hQueue, ///< [in] handle of the queue object
bool blocking, ///< [in] blocking or non-blocking copy
void *
pDst, ///< [in][bounds(0, size)] pointer to the destination USM memory object
const void *
pSrc, ///< [in][bounds(0, size)] pointer to the source USM memory object
size_t size, ///< [in] size in bytes to be copied
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
const ur_event_handle_t *
phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
///< events that must be complete before this command can be executed.
///< If nullptr, the numEventsInWaitList must be 0, indicating that this
///< command does not wait on any event to complete.
ur_event_handle_t *
phEvent ///< [out][optional] return an event object that identifies this particular
///< command instance. If phEventWaitList and phEvent are not NULL, phEvent
///< must not refer to an element of the phEventWaitList array.
) {
auto pfnUSMMemcpy = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy;
getContext()->logger.debug("==== pfnUSMMemcpy");

ur_event_handle_t hEvents[2] = {};
UR_CALL(pfnUSMMemcpy(hQueue, blocking, pDst, pSrc, size,
numEventsInWaitList, phEventWaitList, &hEvents[0]));

const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
auto DstInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Dst);

if (SrcInfoItOp && DstInfoItOp) {
auto SrcInfo = (*SrcInfoItOp)->second;
auto DstInfo = (*DstInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(SrcInfo->Device);
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

UR_CALL(pfnUSMMemcpy(hQueue, blocking, (void *)DstShadow,
(void *)SrcShadow, size, 0, nullptr, &hEvents[1]));
} else if (DstInfoItOp) {
auto DstInfo = (*DstInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(DstInfo->Device);
auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)DstShadow, 0, size, 0,
nullptr, &hEvents[1]));
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
}

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urEnqueueUSMFill2D
ur_result_t UR_APICALL urEnqueueUSMFill2D(
ur_queue_handle_t hQueue, ///< [in] handle of the queue to submit to.
void *
pMem, ///< [in][bounds(0, pitch * height)] pointer to memory to be filled.
size_t
pitch, ///< [in] the total width of the destination memory including padding.
size_t
patternSize, ///< [in] the size in bytes of the pattern. Must be a power of 2 and less
///< than or equal to width.
const void
*pPattern, ///< [in] pointer with the bytes of the pattern to set.
size_t
width, ///< [in] the width in bytes of each row to fill. Must be a multiple of
///< patternSize.
size_t height, ///< [in] the height of the columns to fill.
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
const ur_event_handle_t *
phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
///< events that must be complete before the kernel execution.
///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event.
ur_event_handle_t *
phEvent ///< [out][optional] return an event object that identifies this particular
///< kernel execution instance. If phEventWaitList and phEvent are not
///< NULL, phEvent must not refer to an element of the phEventWaitList array.
) {
auto pfnUSMFill2D = getContext()->urDdiTable.Enqueue.pfnUSMFill2D;
getContext()->logger.debug("==== urEnqueueUSMFill2D");

ur_event_handle_t hEvents[2] = {};
UR_CALL(pfnUSMFill2D(hQueue, pMem, pitch, patternSize, pPattern, width,
height, numEventsInWaitList, phEventWaitList,
&hEvents[0]));

const auto Mem = (uptr)pMem;
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
if (MemInfoItOp) {
auto MemInfo = (*MemInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(MemInfo->Device);
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);

const char Pattern = 0;
UR_CALL(pfnUSMFill2D(hQueue, (void *)MemShadow, pitch, 1, &Pattern,
width, height, 0, nullptr, &hEvents[1]));
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
}

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urEnqueueUSMMemcpy2D
ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
ur_queue_handle_t hQueue, ///< [in] handle of the queue to submit to.
bool blocking, ///< [in] indicates if this operation should block the host.
void *
pDst, ///< [in][bounds(0, dstPitch * height)] pointer to memory where data will
///< be copied.
size_t
dstPitch, ///< [in] the total width of the source memory including padding.
const void *
pSrc, ///< [in][bounds(0, srcPitch * height)] pointer to memory to be copied.
size_t
srcPitch, ///< [in] the total width of the source memory including padding.
size_t width, ///< [in] the width in bytes of each row to be copied.
size_t height, ///< [in] the height of columns to be copied.
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
const ur_event_handle_t *
phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
///< events that must be complete before the kernel execution.
///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event.
ur_event_handle_t *
phEvent ///< [out][optional] return an event object that identifies this particular
///< kernel execution instance. If phEventWaitList and phEvent are not
///< NULL, phEvent must not refer to an element of the phEventWaitList array.
) {
auto pfnUSMMemcpy2D = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D;
getContext()->logger.debug("==== pfnUSMMemcpy2D");

ur_event_handle_t hEvents[2] = {};
UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, pDst, dstPitch, pSrc, srcPitch,
width, height, numEventsInWaitList, phEventWaitList,
&hEvents[0]));

const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
auto DstInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Dst);

if (SrcInfoItOp && DstInfoItOp) {
auto SrcInfo = (*SrcInfoItOp)->second;
auto DstInfo = (*DstInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(SrcInfo->Device);
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, (void *)DstShadow, dstPitch,
(void *)SrcShadow, srcPitch, width, height, 0,
nullptr, &hEvents[1]));
} else if (DstInfoItOp) {
auto DstInfo = (*DstInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(DstInfo->Device);
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

const char Pattern = 0;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill2D(
hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height, 0,
nullptr, &hEvents[1]));
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
}

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Exported function for filling application's Global table
/// with current process' addresses
Expand Down Expand Up @@ -1391,6 +1642,10 @@ ur_result_t urGetEnqueueProcAddrTable(
pDdiTable->pfnMemUnmap = ur_sanitizer_layer::msan::urEnqueueMemUnmap;
pDdiTable->pfnKernelLaunch =
ur_sanitizer_layer::msan::urEnqueueKernelLaunch;
pDdiTable->pfnUSMFill = ur_sanitizer_layer::msan::urEnqueueUSMFill;
pDdiTable->pfnUSMMemcpy = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy;
pDdiTable->pfnUSMFill2D = ur_sanitizer_layer::msan::urEnqueueUSMFill2D;
pDdiTable->pfnUSMMemcpy2D = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy2D;

return result;
}
Expand All @@ -1408,6 +1663,7 @@ ur_result_t urGetUSMProcAddrTable(
ur_result_t result = UR_RESULT_SUCCESS;

pDdiTable->pfnDeviceAlloc = ur_sanitizer_layer::msan::urUSMDeviceAlloc;
pDdiTable->pfnFree = ur_sanitizer_layer::msan::urUSMFree;

return result;
}
Expand Down
Loading

0 comments on commit 533ab9b

Please sign in to comment.