Skip to content

Commit

Permalink
Include unittests that reproduce the plugin issue with memory segments
Browse files Browse the repository at this point in the history
  • Loading branch information
mairooni authored and gigiblender committed Feb 12, 2025
1 parent e7498e8 commit a3238d5
Showing 1 changed file with 76 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.DoubleBuffer;
import java.util.stream.IntStream;

import org.junit.Test;
Expand All @@ -33,6 +34,7 @@
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.TornadoExecutionResult;
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.HalfFloat;
Expand Down Expand Up @@ -505,5 +507,79 @@ public void testBuildWithSegmentsWrongSize() {
IntArray intArray = IntArray.fromSegment(m);
}

public static void simpleAddition(DoubleArray a, DoubleArray b) {
for (@Parallel int i = 0; i < b.getSize(); i++) {
b.set(i, a.get(i) + 1);
}
}

@Test
public void testArrayFromMemorySegment() throws TornadoExecutionPlanException {
final int numberOfElements = 256;

double[] someArray = new double[numberOfElements];
IntStream.range(0, numberOfElements).sequential().forEach(i -> {
someArray[i] = i;
});

MemorySegment a = MemorySegment.ofArray(someArray);
DoubleArray dataA = DoubleArray.fromSegment(a);

DoubleArray dataB = new DoubleArray(numberOfElements);

DoubleArray dataBSeq = new DoubleArray(numberOfElements);

TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.EVERY_EXECUTION, dataA) //
.task("t0", TestAPI::simpleAddition, dataA, dataB) //
.transferToHost(DataTransferMode.EVERY_EXECUTION, dataB);

ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
executionPlan.execute();
}

simpleAddition(dataA, dataBSeq);

for (int i = 0; i < dataB.getSize(); i++) {
assertEquals(dataBSeq.get(i), dataB.get(i), 0.01f);
}
}

@Test
public void testArrayFromBuffer() throws TornadoExecutionPlanException {
final int numberOfElements = 256;

double[] someArray = new double[numberOfElements];
IntStream.range(0, numberOfElements).sequential().forEach(i -> {
someArray[i] = i;
});

DoubleBuffer buffer = DoubleBuffer.allocate(someArray.length);
buffer.put(someArray);
buffer.flip();
DoubleArray dataA = DoubleArray.fromDoubleBuffer(buffer);

DoubleArray dataB = new DoubleArray(numberOfElements);

DoubleArray dataBSeq = new DoubleArray(numberOfElements);

TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.EVERY_EXECUTION, dataA) //
.task("t0", TestAPI::simpleAddition, dataA, dataB) //
.transferToHost(DataTransferMode.EVERY_EXECUTION, dataB);

ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
executionPlan.execute();
}

simpleAddition(dataA, dataBSeq);

for (int i = 0; i < dataB.getSize(); i++) {
assertEquals(dataBSeq.get(i), dataB.get(i), 0.01f);
}
}

// CHECKSTYLE:ON
}

0 comments on commit a3238d5

Please sign in to comment.