Skip to content

Commit

Permalink
small optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
AllanZyne committed Jan 6, 2025
1 parent e8ea136 commit 3eeb2a1
Showing 1 changed file with 26 additions and 28 deletions.
54 changes: 26 additions & 28 deletions source/loader/layers/sanitizer/msan/msan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1307,9 +1307,9 @@ ur_result_t UR_APICALL urEnqueueUSMFill(
auto pfnUSMFill = getContext()->urDdiTable.Enqueue.pfnUSMFill;
getContext()->logger.debug("==== urEnqueueUSMFill");

ur_event_handle_t hEvent = nullptr;
ur_event_handle_t hEvents[2] = {};
UR_CALL(pfnUSMFill(hQueue, pMem, patternSize, pPattern, size,
numEventsInWaitList, phEventWaitList, &hEvent));
numEventsInWaitList, phEventWaitList, &hEvents[0]));

const auto Mem = (uptr)pMem;
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
Expand All @@ -1320,13 +1320,13 @@ ur_result_t UR_APICALL urEnqueueUSMFill(
getMsanInterceptor()->getDeviceInfo(MemInfo->Device);
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);

const ur_event_handle_t hEventWait = hEvent;
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)MemShadow, 0, size, 1,
&hEventWait, &hEvent));
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)MemShadow, 0, size, 0,
nullptr, &hEvents[1]));
}

if (phEvent) {
*phEvent = hEvent;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
}

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -1356,9 +1356,9 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy(
auto pfnUSMMemcpy = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy;
getContext()->logger.debug("==== pfnUSMMemcpy");

ur_event_handle_t hEvent = nullptr;
ur_event_handle_t hEvents[2] = {};
UR_CALL(pfnUSMMemcpy(hQueue, blocking, pDst, pSrc, size,
numEventsInWaitList, phEventWaitList, &hEvent));
numEventsInWaitList, phEventWaitList, &hEvents[0]));

const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
Expand All @@ -1373,23 +1373,22 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy(
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

const ur_event_handle_t hEventWait = hEvent;
UR_CALL(pfnUSMMemcpy(hQueue, blocking, (void *)DstShadow,
(void *)SrcShadow, size, 1, &hEventWait, &hEvent));
(void *)SrcShadow, size, 0, nullptr, &hEvents[1]));
} else if (DstInfoItOp) {
auto DstInfo = (*DstInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(DstInfo->Device);
auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

const ur_event_handle_t hEventWait = hEvent;
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)DstShadow, 0, size, 1,
&hEventWait, &hEvent));
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)DstShadow, 0, size, 0,
nullptr, &hEvents[1]));
}

if (phEvent) {
*phEvent = hEvent;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
}

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -1425,10 +1424,10 @@ ur_result_t UR_APICALL urEnqueueUSMFill2D(
auto pfnUSMFill2D = getContext()->urDdiTable.Enqueue.pfnUSMFill2D;
getContext()->logger.debug("==== urEnqueueUSMFill2D");

ur_event_handle_t hEvent = nullptr;
ur_event_handle_t hEvents[2] = {};
UR_CALL(pfnUSMFill2D(hQueue, pMem, pitch, patternSize, pPattern, width,
height, numEventsInWaitList, phEventWaitList,
&hEvent));
&hEvents[0]));

const auto Mem = (uptr)pMem;
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
Expand All @@ -1440,13 +1439,13 @@ ur_result_t UR_APICALL urEnqueueUSMFill2D(
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);

const char Pattern = 0;
const ur_event_handle_t hEventWait = hEvent;
UR_CALL(pfnUSMFill2D(hQueue, (void *)MemShadow, pitch, 1, &Pattern,
width, height, 1, &hEventWait, &hEvent));
width, height, 0, nullptr, &hEvents[1]));
}

if (phEvent) {
*phEvent = hEvent;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
}

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -1481,10 +1480,10 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
auto pfnUSMMemcpy2D = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D;
getContext()->logger.debug("==== pfnUSMMemcpy2D");

ur_event_handle_t hEvent = nullptr;
ur_event_handle_t hEvents[2] = {};
UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, pDst, dstPitch, pSrc, srcPitch,
width, height, numEventsInWaitList, phEventWaitList,
&hEvent));
&hEvents[0]));

const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
Expand All @@ -1499,10 +1498,9 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

const ur_event_handle_t hEventWait = hEvent;
UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, (void *)DstShadow, dstPitch,
(void *)SrcShadow, srcPitch, width, height, 1,
&hEventWait, &hEvent));
(void *)SrcShadow, srcPitch, width, height, 0,
nullptr, &hEvents[1]));
} else if (DstInfoItOp) {
auto DstInfo = (*DstInfoItOp)->second;

Expand All @@ -1511,14 +1509,14 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

const char Pattern = 0;
const ur_event_handle_t hEventWait = hEvent;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill2D(
hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height, 1,
&hEventWait, &hEvent));
hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height, 0,
nullptr, &hEvents[1]));
}

if (phEvent) {
*phEvent = hEvent;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
}

return UR_RESULT_SUCCESS;
Expand Down

0 comments on commit 3eeb2a1

Please sign in to comment.