Skip to content

Commit

Permalink
Improve fill op implementation
Browse files Browse the repository at this point in the history
Match the CUDA change from
oneapi-src#1319
in HIP.
  • Loading branch information
EwanC committed Feb 19, 2024
1 parent 25b0843 commit 90bd325
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions source/adapters/hip/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ static ur_result_t enqueueCommandBufferFillHelper(

try {
const size_t N = Size / PatternSize;
auto Value = *static_cast<const uint32_t *>(Pattern);
auto DstPtr = DstType == hipMemoryTypeDevice
? *static_cast<hipDeviceptr_t *>(DstDevice)
: DstDevice;
Expand All @@ -168,9 +167,27 @@ static ur_result_t enqueueCommandBufferFillHelper(
NodeParams.elementSize = PatternSize;
NodeParams.height = N;
NodeParams.pitch = PatternSize;
NodeParams.value = Value;
NodeParams.width = 1;

// pattern size in bytes
switch (PatternSize) {
case 1: {
auto Value = *static_cast<const uint8_t *>(Pattern);
NodeParams.value = Value;
break;
}
case 2: {
auto Value = *static_cast<const uint16_t *>(Pattern);
NodeParams.value = Value;
break;
}
case 4: {
auto Value = *static_cast<const uint32_t *>(Pattern);
NodeParams.value = Value;
break;
}
}

UR_CHECK_ERROR(hipGraphAddMemsetNode(&GraphNode, CommandBuffer->HIPGraph,
DepsList.data(), DepsList.size(),
&NodeParams));
Expand All @@ -187,15 +204,10 @@ static ur_result_t enqueueCommandBufferFillHelper(
// This means that one hipGraphAddMemsetNode call is made for every 1
// bytes in the pattern.

// List to handle inter-node dependencies
std::vector<hipGraphNode_t> HIPNodesList = {};
// List shared pointer that will point to the last node created
std::shared_ptr<hipGraphNode_t> GraphNodePtr;

size_t NumberOfSteps = PatternSize / sizeof(uint8_t);

// take 4 bytes of the pattern
auto ValueFirst = *(static_cast<const uint32_t *>(Pattern));
// Shared pointer that will point to the last node created
std::shared_ptr<hipGraphNode_t> GraphNodePtr;

// Create a new node
hipGraphNode_t GraphNodeFirst;
Expand All @@ -205,7 +217,7 @@ static ur_result_t enqueueCommandBufferFillHelper(
NodeParamsStepFirst.elementSize = 4;
NodeParamsStepFirst.height = Size / sizeof(uint32_t);
NodeParamsStepFirst.pitch = 4;
NodeParamsStepFirst.value = ValueFirst;
NodeParamsStepFirst.value = *(static_cast<const uint32_t *>(Pattern));
NodeParamsStepFirst.width = 1;

UR_CHECK_ERROR(hipGraphAddMemsetNode(
Expand All @@ -216,7 +228,8 @@ static ur_result_t enqueueCommandBufferFillHelper(
*SyncPoint = CommandBuffer->addSyncPoint(
std::make_shared<hipGraphNode_t>(GraphNodeFirst));

HIPNodesList.push_back(GraphNodeFirst);
DepsList.clear();
DepsList.push_back(GraphNodeFirst);

// we walk up the pattern in 1-byte steps, and add Memset node for each
// 1-byte chunk of the pattern.
Expand All @@ -233,22 +246,22 @@ static ur_result_t enqueueCommandBufferFillHelper(
// Update NodeParam
hipMemsetParams NodeParamsStep = {};
NodeParamsStep.dst = reinterpret_cast<void *>(OffsetPtr);
NodeParamsStep.elementSize = 1;
NodeParamsStep.elementSize = sizeof(uint8_t);
NodeParamsStep.height = Size / NumberOfSteps;
NodeParamsStep.pitch = NumberOfSteps * sizeof(uint8_t);
NodeParamsStep.value = Value;
NodeParamsStep.width = 1;

UR_CHECK_ERROR(hipGraphAddMemsetNode(
&GraphNode, CommandBuffer->HIPGraph, HIPNodesList.data(),
HIPNodesList.size(), &NodeParamsStep));
&GraphNode, CommandBuffer->HIPGraph, DepsList.data(),
DepsList.size(), &NodeParamsStep));

GraphNodePtr = std::make_shared<hipGraphNode_t>(GraphNode);
// Get sync point and register the node with it.
*SyncPoint = CommandBuffer->addSyncPoint(GraphNodePtr);

HIPNodesList.clear();
HIPNodesList.push_back(*GraphNodePtr.get());
DepsList.clear();
DepsList.push_back(*GraphNodePtr.get());
}
}
} catch (ur_result_t Err) {
Expand Down

0 comments on commit 90bd325

Please sign in to comment.