Skip to content

Commit

Permalink
[ROCM] bindings with kernels and device local host (iree-org#6558)
Browse files Browse the repository at this point in the history
Following CUDA backend updates:
-Fix mapping of bindings to kernel argument with multiple sets.
-Fix allocation for device local host visible case.
  • Loading branch information
raikonenfnu authored Jul 27, 2021
1 parent 7968136 commit 628877c
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 8 deletions.
9 changes: 9 additions & 0 deletions experimental/rocm/descriptor_set_layout.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
typedef struct iree_hal_rocm_descriptor_set_layout_t {
iree_hal_resource_t resource;
iree_hal_rocm_context_wrapper_t *context;
iree_host_size_t binding_count;
} iree_hal_rocm_descriptor_set_layout_t;

extern const iree_hal_descriptor_set_layout_vtable_t
Expand Down Expand Up @@ -46,13 +47,21 @@ iree_status_t iree_hal_rocm_descriptor_set_layout_create(
iree_hal_resource_initialize(&iree_hal_rocm_descriptor_set_layout_vtable,
&descriptor_set_layout->resource);
descriptor_set_layout->context = context;
descriptor_set_layout->binding_count = binding_count;
*out_descriptor_set_layout =
(iree_hal_descriptor_set_layout_t *)descriptor_set_layout;
}
IREE_TRACE_ZONE_END(z0);
return status;
}

iree_host_size_t iree_hal_rocm_descriptor_set_layout_binding_count(
iree_hal_descriptor_set_layout_t *base_descriptor_set_layout) {
iree_hal_rocm_descriptor_set_layout_t *descriptor_set_layout =
iree_hal_rocm_descriptor_set_layout_cast(base_descriptor_set_layout);
return descriptor_set_layout->binding_count;
}

static void iree_hal_rocm_descriptor_set_layout_destroy(
iree_hal_descriptor_set_layout_t *base_descriptor_set_layout) {
iree_hal_rocm_descriptor_set_layout_t *descriptor_set_layout =
Expand Down
4 changes: 4 additions & 0 deletions experimental/rocm/descriptor_set_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ iree_status_t iree_hal_rocm_descriptor_set_layout_create(
const iree_hal_descriptor_set_layout_binding_t *bindings,
iree_hal_descriptor_set_layout_t **out_descriptor_set_layout);

// Return the binding count for the given descriptor set layout.
iree_host_size_t iree_hal_rocm_descriptor_set_layout_binding_count(
iree_hal_descriptor_set_layout_t *descriptor_set_layout);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
Expand Down
6 changes: 5 additions & 1 deletion experimental/rocm/direct_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <stdint.h>

#include "experimental/rocm/dynamic_symbols.h"
#include "experimental/rocm/executable_layout.h"
#include "experimental/rocm/native_executable.h"
#include "experimental/rocm/rocm_buffer.h"
#include "experimental/rocm/status_util.h"
Expand Down Expand Up @@ -283,6 +284,8 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_push_descriptor_set(
const iree_hal_descriptor_set_binding_t* bindings) {
iree_hal_rocm_direct_command_buffer_t* command_buffer =
iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
iree_host_size_t base_binding =
iree_hal_rocm_base_binding_index(executable_layout, set);
// Convention with the compiler side. We map bindings to kernel argument.
// We compact the bindings to get a dense set of arguments and keep them order
// based on the binding index.
Expand All @@ -303,7 +306,8 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_push_descriptor_set(
iree_hal_rocm_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(binding.buffer)) +
iree_hal_buffer_byte_offset(binding.buffer) + binding.offset;
*((hipDeviceptr_t*)command_buffer->current_descriptor[i]) = device_ptr;
*((hipDeviceptr_t*)command_buffer->current_descriptor[i + base_binding]) =
device_ptr;
}
return iree_ok_status();
}
Expand Down
1 change: 1 addition & 0 deletions experimental/rocm/dynamic_symbol_tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ RC_PFN_DECL(hipMemcpy, void *, const void *, size_t, hipMemcpyKind)
RC_PFN_DECL(hipMemcpyAsync, void *, const void *, size_t, hipMemcpyKind,
hipStream_t)
RC_PFN_DECL(hipMalloc, void **, size_t)
RC_PFN_DECL(hipMallocManaged, hipDeviceptr_t *, size_t, unsigned int)
RC_PFN_DECL(hipFree, void *)
RC_PFN_DECL(hipHostFree, void *)
RC_PFN_DECL(hipMemAllocHost, void **, size_t, unsigned int)
Expand Down
15 changes: 15 additions & 0 deletions experimental/rocm/executable_layout.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <stddef.h>

#include "experimental/rocm/descriptor_set_layout.h"
#include "iree/base/api.h"
#include "iree/base/tracing.h"

Expand Down Expand Up @@ -76,6 +77,20 @@ static void iree_hal_rocm_executable_layout_destroy(
IREE_TRACE_ZONE_END(z0);
}

iree_host_size_t iree_hal_rocm_base_binding_index(
iree_hal_executable_layout_t *base_executable_layout, uint32_t set) {
iree_hal_rocm_executable_layout_t *executable_layout =
iree_hal_rocm_executable_layout_cast(base_executable_layout);
iree_host_size_t base_binding = 0;
for (iree_host_size_t i = 0; i < set; ++i) {
iree_host_size_t binding_count =
iree_hal_rocm_descriptor_set_layout_binding_count(
executable_layout->set_layouts[i]);
base_binding += binding_count;
}
return base_binding;
}

const iree_hal_executable_layout_vtable_t
iree_hal_rocm_executable_layout_vtable = {
.destroy = iree_hal_rocm_executable_layout_destroy,
Expand Down
4 changes: 4 additions & 0 deletions experimental/rocm/executable_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ iree_status_t iree_hal_rocm_executable_layout_create(
iree_host_size_t push_constant_count,
iree_hal_executable_layout_t **out_executable_layout);

// Return the base binding index for the given set.
iree_host_size_t iree_hal_rocm_base_binding_index(
iree_hal_executable_layout_t *executable_layout, uint32_t set);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
Expand Down
24 changes: 17 additions & 7 deletions experimental/rocm/rocm_allocator.c
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,19 @@ static iree_status_t iree_hal_rocm_allocator_allocate_buffer(
iree_status_t status;
void *host_ptr = NULL;
hipDeviceptr_t device_ptr = 0;
if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) {
// Device local case.
if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
status = ROCM_RESULT_TO_STATUS(
allocator->context->syms,
hipMallocManaged(&device_ptr, allocation_size, hipMemAttachGlobal));
host_ptr = (void *)device_ptr;
} else {
// Device only.
status = ROCM_RESULT_TO_STATUS(allocator->context->syms,
hipMalloc(&device_ptr, allocation_size));
}
} else {
unsigned int flags = hipHostMallocMapped;
if (!iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_CACHED)) {
flags |= hipHostMallocWriteCombined;
Expand All @@ -121,9 +133,6 @@ static iree_status_t iree_hal_rocm_allocator_allocate_buffer(
allocator->context->syms,
hipHostGetDevicePointer(&device_ptr, host_ptr, /*flags=*/0));
}
} else {
status = ROCM_RESULT_TO_STATUS(allocator->context->syms,
hipMalloc(&device_ptr, allocation_size));
}

if (iree_status_is_ok(status)) {
Expand All @@ -145,10 +154,11 @@ void iree_hal_rocm_allocator_free(iree_hal_allocator_t *base_allocator,
iree_hal_memory_type_t memory_type) {
iree_hal_rocm_allocator_t *allocator =
iree_hal_rocm_allocator_cast(base_allocator);
if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
ROCM_IGNORE_ERROR(allocator->context->syms, hipHostFree(host_ptr));
} else {
if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) {
ROCM_IGNORE_ERROR(allocator->context->syms, hipFree(device_ptr));
} else {
// Host local.
ROCM_IGNORE_ERROR(allocator->context->syms, hipHostFree(host_ptr));
}
}

Expand Down

0 comments on commit 628877c

Please sign in to comment.