Skip to content

Commit

Permalink
Replace remaining example-toolkit usages (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhawryluk authored Nov 14, 2024
1 parent 6ec4fd4 commit 5e6485b
Show file tree
Hide file tree
Showing 18 changed files with 530 additions and 489 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
// -- Hooks into the example environment
import { addSliderPlumParameter } from '@typegpu/example-toolkit';
// --

import { type Parsed, arrayOf, f32, struct, vec2f } from 'typegpu/data';
import tgpu, {
asMutable,
Expand All @@ -21,53 +17,29 @@ const MatrixStruct = struct({
numbers: arrayOf(f32, MAX_MATRIX_SIZE ** 2),
});

const paramSettings = {
min: 1,
max: MAX_MATRIX_SIZE,
step: 1,
};

/**
* Used to force recomputation of all matrices.
*/
const forceShufflePlum = wgsl.plum<number>(0);
const firstRowCountPlum = addSliderPlumParameter('#1 rows', 3, paramSettings);
const firstColumnCountPlum = addSliderPlumParameter(
'#1 columns',
4,
paramSettings,
);
const secondColumnCountPlum = addSliderPlumParameter(
'#2 columns',
2,
paramSettings,
);

const firstMatrixPlum = wgsl.plum((get) => {
get(forceShufflePlum); // depending to force recomputation

return createMatrix(
vec2f(get(firstRowCountPlum), get(firstColumnCountPlum)),
() => Math.floor(Math.random() * 10),
);
});

const secondMatrixPlum = wgsl.plum((get) => {
get(forceShufflePlum); // depending to force recomputation
let firstRowCount = 3;
let firstColumnCount = 4;
let secondColumnCount = 2;

return createMatrix(
vec2f(get(firstColumnCountPlum), get(secondColumnCountPlum)),
() => Math.floor(Math.random() * 10),
);
});
function createMatrix(
size: vec2f,
initValue: (row: number, col: number) => number,
) {
return {
size: size,
numbers: Array(size.x * size.y)
.fill(0)
.map((_, i) => initValue(Math.floor(i / size.y), i % size.y)),
};
}

const firstMatrixBuffer = root
.createBuffer(MatrixStruct, firstMatrixPlum)
.createBuffer(MatrixStruct)
.$name('first_matrix')
.$usage('storage');

const secondMatrixBuffer = root
.createBuffer(MatrixStruct, secondMatrixPlum)
.createBuffer(MatrixStruct)
.$name('second_matrix')
.$usage('storage');

Expand All @@ -83,45 +55,60 @@ const resultMatrixData = asMutable(resultMatrixBuffer);
const program = root.makeComputePipeline({
workgroupSize: workgroupSize,
code: wgsl`
let global_id = ${builtin.globalInvocationId};
if (global_id.x >= u32(${firstMatrixData}.size.x) || global_id.y >= u32(${secondMatrixData}.size.y)) {
return;
}
if (global_id.x + global_id.y == 0u) {
${resultMatrixData}.size = vec2(${firstMatrixData}.size.x, ${secondMatrixData}.size.y);
}
let resultCell = vec2(global_id.x, global_id.y);
var result = 0.0;
for (var i = 0u; i < u32(${firstMatrixData}.size.y); i = i + 1u) {
let a = i + resultCell.x * u32(${firstMatrixData}.size.y);
let b = resultCell.y + i * u32(${secondMatrixData}.size.y);
result = result + ${firstMatrixData}.numbers[a] * ${secondMatrixData}.numbers[b];
}
let index = resultCell.y + resultCell.x * u32(${secondMatrixData}.size.y);
${resultMatrixData}.numbers[index] = result;
let global_id = ${builtin.globalInvocationId};
if (global_id.x >= u32(${firstMatrixData}.size.x) || global_id.y >= u32(${secondMatrixData}.size.y)) {
return;
}
if (global_id.x + global_id.y == 0u) {
${resultMatrixData}.size = vec2(${firstMatrixData}.size.x, ${secondMatrixData}.size.y);
}
let resultCell = vec2(global_id.x, global_id.y);
var result = 0.0;
for (var i = 0u; i < u32(${firstMatrixData}.size.y); i = i + 1u) {
let a = i + resultCell.x * u32(${firstMatrixData}.size.y);
let b = resultCell.y + i * u32(${secondMatrixData}.size.y);
result = result + ${firstMatrixData}.numbers[a] * ${secondMatrixData}.numbers[b];
}
let index = resultCell.y + resultCell.x * u32(${secondMatrixData}.size.y);
${resultMatrixData}.numbers[index] = result;
`,
});

async function run() {
const firstMatrix = createMatrix(vec2f(firstRowCount, firstColumnCount), () =>
Math.floor(Math.random() * 10),
);
const secondMatrix = createMatrix(
vec2f(firstColumnCount, secondColumnCount),
() => Math.floor(Math.random() * 10),
);

firstMatrixBuffer.write(firstMatrix);
secondMatrixBuffer.write(secondMatrix);

const workgroupCountX = Math.ceil(firstMatrix.size.x / workgroupSize[0]);
const workgroupCountY = Math.ceil(secondMatrix.size.y / workgroupSize[1]);

program.execute({ workgroups: [workgroupCountX, workgroupCountY] });
const multiplicationResult = await resultMatrixBuffer.read();

printMatrixToHtml(firstTable, firstMatrix);
printMatrixToHtml(secondTable, secondMatrix);
printMatrixToHtml(resultTable, multiplicationResult);
}

run();

// #region UI

const firstTable = document.querySelector('.matrix-a') as HTMLDivElement;
const secondTable = document.querySelector('.matrix-b') as HTMLDivElement;
const resultTable = document.querySelector('.matrix-result') as HTMLDivElement;

function createMatrix(
size: vec2f,
initValue: (row: number, col: number) => number,
) {
return {
size: size,
numbers: Array(size.x * size.y)
.fill(0)
.map((_, i) => initValue(Math.floor(i / size.y), i % size.y)),
};
}

function printMatrixToHtml(
element: HTMLDivElement,
matrix: Parsed<typeof MatrixStruct>,
Expand All @@ -133,31 +120,54 @@ function printMatrixToHtml(
.join('');
}

async function run() {
const firstMatrix = root.readPlum(firstMatrixPlum);
const secondMatrix = root.readPlum(secondMatrixPlum);
const workgroupCountX = Math.ceil(firstMatrix.size.x / workgroupSize[0]);
const workgroupCountY = Math.ceil(secondMatrix.size.y / workgroupSize[1]);
// #endregion

program.execute({ workgroups: [workgroupCountX, workgroupCountY] });
const multiplicationResult = await resultMatrixBuffer.read();
// #region Example controls

printMatrixToHtml(firstTable, firstMatrix);
printMatrixToHtml(secondTable, secondMatrix);
printMatrixToHtml(resultTable, multiplicationResult);
}

run();

root.onPlumChange(firstRowCountPlum, run);
root.onPlumChange(firstColumnCountPlum, run);
root.onPlumChange(secondColumnCountPlum, run);
root.onPlumChange(forceShufflePlum, run);
const paramSettings = {
min: 1,
max: MAX_MATRIX_SIZE,
step: 1,
};

export const controls = {
Reshuffle: {
onButtonClick: () => {
root.setPlum(forceShufflePlum, (prev) => 1 - prev);
run();
},
},

'#1 rows': {
initial: firstRowCount,
...paramSettings,
onSliderChange: (value: number) => {
firstRowCount = value;
run();
},
},

'#1 columns': {
initial: firstColumnCount,
...paramSettings,
onSliderChange: (value: number) => {
firstColumnCount = value;
run();
},
},

'#2 columns': {
initial: secondColumnCount,
...paramSettings,
onSliderChange: (value: number) => {
secondColumnCount = value;
run();
},
},
};

// #endregion

export function onCleanup() {
root.destroy();
root.device.destroy();
}
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,13 @@ function resetDrawing() {
}
}

let disposed = false;

function run() {
if (disposed) {
return;
}

const scale = canvas.width / SIZE;

context.clearRect(0, 0, canvas.width, canvas.height);
Expand Down Expand Up @@ -406,3 +412,13 @@ export const controls = {
};

// #endregion

// #region Resource cleanup

export function onCleanup() {
disposed = true;
root.destroy();
root.device.destroy();
}

// #endregion
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,8 @@ render();

root.onPlumChange(filterSize, () => render());
root.onPlumChange(iterations, () => render());

export function onCleanup() {
root.destroy();
root.device.destroy();
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
// -- Hooks into the example environment
import { onCleanup, onFrame } from '@typegpu/example-toolkit';
// --

import { f32, vec2f } from 'typegpu/data';
import tgpu, { asUniform, builtin, wgsl } from 'typegpu/experimental';

Expand Down Expand Up @@ -103,35 +99,46 @@ export const controls = {
},
};

onFrame(() => {
if (!(video.currentTime > 0)) {
let disposed = false;

function run() {
if (disposed) {
return;
}

renderProgram.execute({
colorAttachments: [
{
view: context.getCurrentTexture().createView(),
clearValue: [0, 0, 0, 1],
loadOp: 'clear',
storeOp: 'store',
},
],
if (video.currentTime > 0) {
renderProgram.execute({
colorAttachments: [
{
view: context.getCurrentTexture().createView(),
clearValue: [0, 0, 0, 1],
loadOp: 'clear',
storeOp: 'store',
},
],

vertexCount: 6,
});

root.flush();
}

vertexCount: 6,
});
requestAnimationFrame(run);
}

root.flush();
});
run();

export function onCleanup() {
disposed = true;

onCleanup(() => {
if (video.srcObject) {
for (const track of (video.srcObject as MediaStream).getTracks()) {
track.stop();
}
}

root.destroy();
});
root.device.destroy();
}

// #endregion
Loading

0 comments on commit 5e6485b

Please sign in to comment.