Skip to content

Commit

Permalink
Add capability to set a default device for all threads (#699)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfelder authored Jan 5, 2025
1 parent 0c33c74 commit 29fe5de
Show file tree
Hide file tree
Showing 17 changed files with 245 additions and 55 deletions.
15 changes: 15 additions & 0 deletions docs/docs/icicle/programmers_guide/cpp.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ eIcicleError result = icicle_set_device(device);
eIcicleError result = icicle_get_active_device(device);
```
### Setting and Getting the Default Device
You can set the default device for all threads:
```cpp
icicle::Device device = {"CUDA", 0}; // or other
eIcicleError result = icicle_set_default_device(device);
```

:::caution

Setting a default device should be done **once** from the main thread of the application. If another device or backend is needed for a specific thread [icicle_set_device](#setting-and-getting-active-device) should be used instead.

:::

### Querying Device Information

Retrieve the number of available devices and check if a pointer is allocated on the host or on the active device:
Expand Down
1 change: 1 addition & 0 deletions docs/docs/icicle/programmers_guide/general.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ ICICLE provides a device abstraction layer that allows you to interact with diff

- **Loading Backends**: Backends are loaded dynamically based on the environment configuration or a specified path.
- **Setting Active Device**: The active device for a thread can be set, allowing for targeted computation on a specific device.
- **Setting Default Device**: The default device for any thread without an active device can be set, removing the need to specify an alternative device on each thread. This is especially useful when running on a backend that is not the built-in CPU backend which is the default device to start.

## Streams

Expand Down
17 changes: 16 additions & 1 deletion docs/docs/icicle/programmers_guide/go.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,27 @@ result := runtime.LoadBackend("/path/to/backend/installdir", true)
You can set the active device for the current thread and retrieve it when needed:

```go
device = runtime.CreateDevice("CUDA", 0) // or other
device := runtime.CreateDevice("CUDA", 0) // or other
result := runtime.SetDevice(device)
// or query current (thread) device
activeDevice := runtime.GetActiveDevice()
```

### Setting and Getting the Default Device

You can set the default device for all threads:

```go
device := runtime.CreateDevice("CUDA", 0) // or other
defaultDevice := runtime.SetDefaultDevice(device);
```

:::caution

Setting a default device should be done **once** from the main thread of the application. If another device or backend is needed for a specific thread [runtime.SetDevice](#setting-and-getting-active-device) should be used instead.

:::

### Querying Device Information

Retrieve the number of available devices and check if a pointer is allocated on the host or on the active device:
Expand Down
15 changes: 15 additions & 0 deletions docs/docs/icicle/programmers_guide/rust.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,21 @@ icicle_runtime::set_device(&device).unwrap();
let active_device = icicle_runtime::get_active_device().unwrap();
```
### Setting and Getting the Default Device
You can set the default device for all threads:
```caution
let device = Device::new("CUDA", 0); // or other
let default_device = icicle_runtime::set_default_device(device);
```
:::note
Setting a default device should be done **once** from the main thread of the application. If another device or backend is needed for a specific thread [icicle_runtime::set_device](#setting-and-getting-active-device) should be used instead.
:::
### Querying Device Information
Retrieve the number of available devices and check if a pointer is allocated on the host or on the active device:
Expand Down
1 change: 1 addition & 0 deletions icicle/include/icicle/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ namespace icicle {

public:
static eIcicleError set_thread_local_device(const Device& device);
static eIcicleError set_default_device(const Device& device);
static const Device& get_thread_local_device();
static const DeviceAPI* get_thread_local_deviceAPI();
static DeviceTracker& get_global_memory_tracker() { return sMemTracker; }
Expand Down
8 changes: 8 additions & 0 deletions icicle/include/icicle/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ extern "C" eIcicleError icicle_load_backend_from_env_or_default();
*/
extern "C" eIcicleError icicle_set_device(const icicle::Device& device);

/**
* @brief Set default device for all threads
*
* @return eIcicleError::SUCCESS if successful, otherwise throws INVALID_DEVICE
*/
extern "C" eIcicleError icicle_set_default_device(const icicle::Device& device);

/**
* @brief Get active device for thread
*
Expand Down
16 changes: 16 additions & 0 deletions icicle/src/device_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ namespace icicle {

const Device& get_default_device() { return m_default_device; }

eIcicleError set_default_device(const Device& dev)
{
if (!is_device_registered(dev.type)) {
ICICLE_LOG_ERROR << "Device type " + std::string(dev.type) + " is not valid as it has not been registered";
return eIcicleError::INVALID_DEVICE;
}

m_default_device = dev;
return eIcicleError::SUCCESS;
}

std::vector<std::string> get_registered_devices_list()
{
std::vector<std::string> registered_devices;
Expand Down Expand Up @@ -116,6 +127,11 @@ namespace icicle {
return default_deviceAPI.get();
}

eIcicleError DeviceAPI::set_default_device(const Device& dev)
{
return DeviceAPIRegistry::Global().set_default_device(dev);
}

/********************************************************************************** */

DeviceAPI* get_deviceAPI(const Device& device) { return DeviceAPIRegistry::Global().get_deviceAPI(device).get(); }
Expand Down
5 changes: 5 additions & 0 deletions icicle/src/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ using namespace icicle;

extern "C" eIcicleError icicle_set_device(const Device& device) { return DeviceAPI::set_thread_local_device(device); }

extern "C" eIcicleError icicle_set_default_device(const Device& device)
{
return DeviceAPI::set_default_device(device);
}

extern "C" eIcicleError icicle_get_active_device(icicle::Device& device)
{
const Device& active_device = DeviceAPI::get_thread_local_device();
Expand Down
31 changes: 31 additions & 0 deletions icicle/tests/test_device_api.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

#include <gtest/gtest.h>
#include <thread>
#include <iostream>

#include "icicle/runtime.h"
Expand All @@ -19,6 +20,36 @@ TEST_F(DeviceApiTest, UnregisteredDeviceError)
EXPECT_ANY_THROW(get_deviceAPI(dev));
}

TEST_F(DeviceApiTest, SetDefaultDevice)
{
icicle::Device active_dev = {UNKOWN_DEVICE, -1};

icicle::Device cpu_dev = {s_ref_device, 0};
EXPECT_NO_THROW(icicle_set_device(cpu_dev));
EXPECT_NO_THROW(icicle_get_active_device(active_dev));

ASSERT_EQ(cpu_dev, active_dev);

active_dev = {UNKOWN_DEVICE, -1};

icicle::Device gpu_dev = {s_main_device, 0};
EXPECT_NO_THROW(icicle_set_default_device(gpu_dev));

// setting a new default device doesn't override already set local thread devices
EXPECT_NO_THROW(icicle_get_active_device(active_dev));
ASSERT_EQ(cpu_dev, active_dev);

active_dev = {UNKOWN_DEVICE, -1};
auto thread_func = [&active_dev, &gpu_dev]() {
EXPECT_NO_THROW(icicle_get_active_device(active_dev));
ASSERT_EQ(gpu_dev, active_dev);
};

std::thread worker_thread(thread_func);

worker_thread.join();
}

TEST_F(DeviceApiTest, MemoryCopySync)
{
int input[2] = {1, 2};
Expand Down
6 changes: 6 additions & 0 deletions wrappers/golang/runtime/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ func SetDevice(device *Device) EIcicleError {
return EIcicleError(cErr)
}

func SetDefaultDevice(device *Device) EIcicleError {
cDevice := (*C.Device)(unsafe.Pointer(device))
cErr := C.icicle_set_default_device(cDevice)
return EIcicleError(cErr)
}

func GetActiveDevice() (*Device, EIcicleError) {
device := CreateDevice("invalid", -1)
cDevice := (*C.Device)(unsafe.Pointer(&device))
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/runtime/include/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ typedef struct DeviceProperties DeviceProperties;
int icicle_load_backend(const char* path, bool is_recursive);
int icicle_load_backend_from_env_or_default();
int icicle_set_device(const Device* device);
int icicle_set_default_device(const Device* device);
int icicle_get_active_device(Device* device);
int icicle_is_host_memory(const void* ptr);
int icicle_is_active_device_memory(const void* ptr);
Expand Down
118 changes: 85 additions & 33 deletions wrappers/golang/runtime/tests/device_test.go
Original file line number Diff line number Diff line change
@@ -1,70 +1,122 @@
package tests

import (
"fmt"
"os/exec"
"runtime"
"strconv"
"strings"
"syscall"
"testing"

"github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime"
icicle_runtime "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime"

"github.com/stretchr/testify/assert"
)

func TestGetDeviceType(t *testing.T) {
expectedDeviceName := "test"
config := runtime.CreateDevice(expectedDeviceName, 0)
config := icicle_runtime.CreateDevice(expectedDeviceName, 0)
assert.Equal(t, expectedDeviceName, config.GetDeviceType())

expectedDeviceNameLong := "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttest"
configLargeName := runtime.CreateDevice(expectedDeviceNameLong, 1)
configLargeName := icicle_runtime.CreateDevice(expectedDeviceNameLong, 1)
assert.NotEqual(t, expectedDeviceNameLong, configLargeName.GetDeviceType())
}

func TestIsDeviceAvailable(t *testing.T) {
runtime.LoadBackendFromEnvOrDefault()
dev := runtime.CreateDevice("CUDA", 0)
_ = runtime.SetDevice(&dev)
res, err := runtime.GetDeviceCount()

expectedNumDevices, error := exec.Command("nvidia-smi", "-L", "|", "wc", "-l").Output()
if error != nil {
t.Skip("Failed to get number of devices")
dev := icicle_runtime.CreateDevice("CUDA", 0)
_ = icicle_runtime.SetDevice(&dev)
res, err := icicle_runtime.GetDeviceCount()

smiCommand := exec.Command("nvidia-smi", "-L")
smiCommandStdout, _ := smiCommand.StdoutPipe()
wcCommand := exec.Command("wc", "-l")
wcCommand.Stdin = smiCommandStdout

smiCommand.Start()

expectedNumDevicesRaw, wcErr := wcCommand.Output()
smiCommand.Wait()

expectedNumDevicesAsString := strings.TrimRight(string(expectedNumDevicesRaw), " \n\r\t")
expectedNumDevices, _ := strconv.Atoi(expectedNumDevicesAsString)
if wcErr != nil {
t.Skip("Failed to get number of devices:", wcErr)
}

assert.Equal(t, runtime.Success, err)
assert.Equal(t, icicle_runtime.Success, err)
assert.Equal(t, expectedNumDevices, res)

err = runtime.LoadBackendFromEnvOrDefault()
assert.Equal(t, runtime.Success, err)
devCuda := runtime.CreateDevice("CUDA", 0)
assert.True(t, runtime.IsDeviceAvailable(&devCuda))
devCpu := runtime.CreateDevice("CPU", 0)
assert.True(t, runtime.IsDeviceAvailable(&devCpu))
devInvalid := runtime.CreateDevice("invalid", 0)
assert.False(t, runtime.IsDeviceAvailable(&devInvalid))
assert.Equal(t, icicle_runtime.Success, err)
devCuda := icicle_runtime.CreateDevice("CUDA", 0)
assert.True(t, icicle_runtime.IsDeviceAvailable(&devCuda))
devCpu := icicle_runtime.CreateDevice("CPU", 0)
assert.True(t, icicle_runtime.IsDeviceAvailable(&devCpu))
devInvalid := icicle_runtime.CreateDevice("invalid", 0)
assert.False(t, icicle_runtime.IsDeviceAvailable(&devInvalid))
}

func TestSetDefaultDevice(t *testing.T) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
tidOuter := syscall.Gettid()

gpuDevice := icicle_runtime.CreateDevice("CUDA", 0)
icicle_runtime.SetDefaultDevice(&gpuDevice)

activeDevice, err := icicle_runtime.GetActiveDevice()
assert.Equal(t, icicle_runtime.Success, err)
assert.Equal(t, gpuDevice, *activeDevice)

done := make(chan struct{}, 1)
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()

// Ensure we are operating on an OS thread other than the original one
tidInner := syscall.Gettid()
for tidInner == tidOuter {
fmt.Println("Locked thread is the same as original, getting new locked thread")
runtime.UnlockOSThread()
runtime.LockOSThread()
tidInner = syscall.Gettid()
}

activeDevice, err := icicle_runtime.GetActiveDevice()
assert.Equal(t, icicle_runtime.Success, err)
assert.Equal(t, gpuDevice, *activeDevice)

close(done)
}()

<-done

cpuDevice := icicle_runtime.CreateDevice("CPU", 0)
icicle_runtime.SetDefaultDevice(&cpuDevice)
}

func TestRegisteredDevices(t *testing.T) {
err := runtime.LoadBackendFromEnvOrDefault()
assert.Equal(t, runtime.Success, err)
devices, _ := runtime.GetRegisteredDevices()
devices, _ := icicle_runtime.GetRegisteredDevices()
assert.Equal(t, []string{"CUDA", "CPU"}, devices)
}

func TestDeviceProperties(t *testing.T) {
_, err := runtime.GetDeviceProperties()
assert.Equal(t, runtime.Success, err)
_, err := icicle_runtime.GetDeviceProperties()
assert.Equal(t, icicle_runtime.Success, err)
}

func TestActiveDevice(t *testing.T) {
runtime.SetDevice(&DEVICE)
activeDevice, err := runtime.GetActiveDevice()
assert.Equal(t, runtime.Success, err)
assert.Equal(t, DEVICE, *activeDevice)
memory1, err := runtime.GetAvailableMemory()
if err == runtime.ApiNotImplemented {
t.Skipf("GetAvailableMemory() function is not implemented on %s device", DEVICE.GetDeviceType())
devCpu := icicle_runtime.CreateDevice("CUDA", 0)
icicle_runtime.SetDevice(&devCpu)
activeDevice, err := icicle_runtime.GetActiveDevice()
assert.Equal(t, icicle_runtime.Success, err)
assert.Equal(t, devCpu, *activeDevice)
memory1, err := icicle_runtime.GetAvailableMemory()
if err == icicle_runtime.ApiNotImplemented {
t.Skipf("GetAvailableMemory() function is not implemented on %s device", devCpu.GetDeviceType())
}
assert.Equal(t, runtime.Success, err)
assert.Equal(t, icicle_runtime.Success, err)
assert.Greater(t, memory1.Total, uint(0))
assert.Greater(t, memory1.Free, uint(0))
}
14 changes: 1 addition & 13 deletions wrappers/golang/runtime/tests/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,7 @@ import (
"github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime"
)

var DEVICE runtime.Device

func TestMain(m *testing.M) {
runtime.LoadBackendFromEnvOrDefault()
devices, e := runtime.GetRegisteredDevices()
if e != runtime.Success {
panic("Failed to load registered devices")
}
for _, deviceType := range devices {
DEVICE = runtime.CreateDevice(deviceType, 0)
runtime.SetDevice(&DEVICE)

// execute tests
m.Run()
}
m.Run()
}
Loading

0 comments on commit 29fe5de

Please sign in to comment.