Skip to content

Commit

Permalink
Return invalid value from when calling USM alloc with a non-power of …
Browse files Browse the repository at this point in the history
…2 alignment value
  • Loading branch information
lbushi25 committed May 1, 2024
1 parent e8b8722 commit 5dbfda0
Showing 1 changed file with 22 additions and 23 deletions.
45 changes: 22 additions & 23 deletions source/adapters/level_zero/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
uint32_t Align = USMDesc ? USMDesc->align : 0;
// L0 supports alignment up to 64KB and silently ignores higher values.
// We flag alignment > 64KB as an invalid value.
if (Align > 65536)
// L0 spec says alignment values that are not powers of 2 are invalid.
if (Align > 65536 || Align && (Align - 1) != 0)
return UR_RESULT_ERROR_INVALID_VALUE;

ur_platform_handle_t Plt = Context->getPlatform();
Expand Down Expand Up @@ -337,32 +338,30 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
// find the allocator depending on context as we do for Shared and Device
// allocations.
umf_memory_pool_handle_t hPoolInternal = nullptr;
if (!UseUSMAllocator ||
// L0 spec says that allocation fails if Alignment != 2^n, in order to
// keep the same behavior for the allocator, just call L0 API directly and
// return the error code.
((Align & (Align - 1)) != 0)) {
if (!UseUSMAllocator)
hPoolInternal = Context->HostMemProxyPool.get();
} else if (Pool) {
hPoolInternal = Pool->HostMemPool.get();
} else {
hPoolInternal = Context->HostMemPool.get();
}
}
else if (Pool) {
hPoolInternal = Pool->HostMemPool.get();
}
else {
hPoolInternal = Context->HostMemPool.get();
}

*RetMem = umfPoolAlignedMalloc(hPoolInternal, Size, Align);
if (*RetMem == nullptr) {
auto umfRet = umfPoolGetLastAllocationError(hPoolInternal);
return umf2urResult(umfRet);
}
*RetMem = umfPoolAlignedMalloc(hPoolInternal, Size, Align);
if (*RetMem == nullptr) {
auto umfRet = umfPoolGetLastAllocationError(hPoolInternal);
return umf2urResult(umfRet);
}

if (IndirectAccessTrackingEnabled) {
// Keep track of all memory allocations in the context
Context->MemAllocs.emplace(std::piecewise_construct,
std::forward_as_tuple(*RetMem),
std::forward_as_tuple(Context));
}
if (IndirectAccessTrackingEnabled) {
// Keep track of all memory allocations in the context
Context->MemAllocs.emplace(std::piecewise_construct,
std::forward_as_tuple(*RetMem),
std::forward_as_tuple(Context));
}

return UR_RESULT_SUCCESS;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
Expand Down

0 comments on commit 5dbfda0

Please sign in to comment.