Skip to content

Commit

Permalink
fix: data movement tasks only perform read operations
Browse files Browse the repository at this point in the history
  • Loading branch information
wlruys committed Jan 17, 2024
1 parent 9fb67f6 commit 7f1a52b
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/c/backend/include/device_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class DeviceManager {

// TODO(hc): use a customized type for device id.

const DevID_t globalid_to_parrayid(DevID_t global_dev_id) const {
const int globalid_to_parrayid(unsigned int global_dev_id) const {
Device *dev = all_devices_[global_dev_id];
if (dev->get_type() == DeviceType::CPU) {
return -1;
Expand All @@ -107,7 +107,7 @@ class DeviceManager {
}
}

const int parrayid_to_globalid(DevID_t parray_dev_id) const {
const unsigned int parrayid_to_globalid(int parray_dev_id) const {
if (parray_dev_id == -1) {
// XXX: This assumes that a CPU device is always single and
// is added at first.
Expand Down
4 changes: 2 additions & 2 deletions src/python/parla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def __enter__(self):
self.original_handler = signal.getsignal(self.sig)

def handler(signum, frame):
print("YOU PRESSED CTRL+C, INTERRUPTING ALL TASKS", flush=True)
print("Attempting to interurpt all running tasks...", flush=True)
self._sched.stop()
self.release()
self.interrupted = True
self.interuppted = True

signal.signal(self.sig, handler)
except ValueError:
Expand Down
14 changes: 9 additions & 5 deletions src/python/parla/common/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@
USE_PYTHON_RUNAHEAD = os.getenv("PARLA_ENABLE_PYTHON_RUNAHEAD", "1") == "1"
PREINIT_THREADS = os.getenv("PARLA_PREINIT_THREADS", "1") == "1"

print("USE_PYTHON_RUNAHEAD: ", USE_PYTHON_RUNAHEAD)
print("CUPY_ENABLED: ", CUPY_ENABLED)
print("PREINIT_THREADS: ", PREINIT_THREADS)

_global_data_tasks = {}


Expand All @@ -65,8 +61,16 @@ class SynchronizationType(IntEnum):
else:
default_sync = SynchronizationType.NON_BLOCKING

print("DEFAULT SYNC: ", default_sync)

def print_config():
print("Parla Configuration", flush=True)
print("-------------------", flush=True)
print("Cupy Found: ", CUPY_ENABLED, flush=True)
print("Crosspy Found: ", CROSSPY_ENABLED, flush=True)
print("Preinitialize Cupy + Handles in Threads: ", PREINIT_THREADS, flush=True)
print("Runahead Scheduling Backend: ", USE_PYTHON_RUNAHEAD, flush=True)
print("Default Runahead Behavior: ", default_sync, flush=True)
print("VCU Precision: ", VCU_BASELINE, flush=True)

class DeviceType(IntEnum):
"""
Expand Down
5 changes: 2 additions & 3 deletions src/python/parla/cython/device_manager.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ cdef extern from "include/device_manager.hpp" nogil:
DeviceManager() except +
void register_device(Device*) except +
void print_registered_devices() except +
int globalid_to_parrayid(int) except +
int parrayid_to_globalid(int) except +

int globalid_to_parrayid(unsigned int) except +
unsigned int parrayid_to_globalid(int) except +

cdef class CyDeviceManager:
cdef DeviceManager* cpp_device_manager_
Expand Down
10 changes: 10 additions & 0 deletions src/python/parla/cython/scheduler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ class WorkerThread(ControllableThread, SchedulerContext):
with self.scheduler.start_monitor:
self.scheduler.start_monitor.notify_all()

device_manager = self.scheduler.device_manager

while self._should_run:
self.status = "Waiting"

Expand Down Expand Up @@ -246,6 +248,14 @@ class WorkerThread(ControllableThread, SchedulerContext):
Locals.push_task(active_task)

with device_context as env:

if isinstance(active_task, ComputeTask):
# Perform write invalidations
for parray, target_idx in active_task.dataflow.inout:
target_device = parla_devices[target_idx]
global_target_id = target_device.get_global_id()
parray_target_id = device_manager.globalid_to_parrayid(global_target_id)
parray._auto_move(parray_target_id, True)

core.binlog_2("Worker", "Running task: ", active_task.inner_task, " on worker: ", self.inner_worker)
# Run the task body (this may complete the task or return a continuation)
Expand Down
5 changes: 3 additions & 2 deletions src/python/parla/cython/tasks.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,9 @@ class DataMovementTask(Task):
Devices are given by the local relative device id within the TaskEnvironment.
"""

print(f"Running data movement task: {self.name}, {self.parray.name} {self.access_mode}", flush=True)
write_flag = True if self.access_mode != AccessMode.IN else False
# write_flag = True if self.access_mode != AccessMode.IN else False
# Data movement tasks should only perform read operations
write_flag = False

# TODO: Get device manager from task environment instead of scheduler at creation time
device_manager = self.scheduler.device_manager
Expand Down

0 comments on commit 7f1a52b

Please sign in to comment.