Skip to content

Commit

Permalink
Better MNIST network and image processing (#541)
Browse files Browse the repository at this point in the history
  • Loading branch information
reczkok authored Nov 12, 2024
1 parent 8c8b935 commit 819aba9
Show file tree
Hide file tree
Showing 18 changed files with 83 additions and 46 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
font-family: 'Monaco', monospace;
font-size: 1rem;
background: linear-gradient(to right, transparent, #e6e6f2);
border-radius: 9999px;
@media (max-width: 1024px) {
height: 1rem;
font-size: 0.75rem;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ function getLayerData(layer: ArrayBuffer): LayerData {
function createNetwork(layers: [LayerData, LayerData][]): Network {
const buffers = layers.map(([weights, biases]) => {
if (weights.shape[1] !== biases.shape[0]) {
throw new Error('Shape mismatch');
throw new Error(`Shape mismatch: ${weights.shape} and ${biases.shape}`);
}

return {
Expand Down Expand Up @@ -235,33 +235,20 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {

// Data fetching and network creation

const [
layer0Biases,
layer0Weights,
layer1Biases,
layer1Weights,
layer2Biases,
layer2Weights,
] = await Promise.all(
[
'layer0.bias.npy',
'layer0.weight.npy',
'layer1.bias.npy',
'layer1.weight.npy',
'layer2.bias.npy',
'layer2.weight.npy',
].map((fileName) =>
fetch(`/TypeGPU/mnistWeights/${fileName}`).then((res) =>
res.arrayBuffer().then((buffer) => getLayerData(buffer)),
),
const layerData: [LayerData, LayerData][] = await Promise.all(
[0, 1, 2, 3, 4, 5, 6, 7].map(
(layer) =>
Promise.all(
[`layer${layer}.weight.npy`, `layer${layer}.bias.npy`].map((fileName) =>
fetch(`/TypeGPU/mnistWeightsExperimental/${fileName}`).then((res) =>
res.arrayBuffer().then((buffer) => getLayerData(buffer)),
),
),
) as Promise<[LayerData, LayerData]>,
),
);

const network = createNetwork([
[layer0Weights, layer0Biases],
[layer1Weights, layer1Biases],
[layer2Weights, layer2Biases],
]);
const network = createNetwork(layerData);

// Canvas drawing

Expand Down Expand Up @@ -309,47 +296,95 @@ const observer = new ResizeObserver(() => {
observer.observe(canvas.parentNode?.parentNode as HTMLElement);

let isDrawing = false;
let lastPos: { x: number; y: number } | null = null;

canvas.addEventListener('mousedown', () => {
isDrawing = true;
});

window.addEventListener('mouseup', () => {
isDrawing = false;
lastPos = null;
});

let lastPos = { x: 0, y: 0 };
function centerImage(data: number[]) {
const mass = data.reduce((acc, value) => acc + value, 0);
const x =
data.reduce((acc, value, i) => acc + value * (i % SIZE), 0) /
mass;
const y =
data.reduce((acc, value, i) => acc + value * Math.floor(i / SIZE), 0) /
mass;

const offsetX = Math.round(SIZE / 2 - x);
const offsetY = Math.round(SIZE / 2 - y);

const newData = new Array(SIZE * SIZE).fill(0);
for (let i = 0; i < SIZE; i++) {
for (let j = 0; j < SIZE; j++) {
const index = i * SIZE + j;
const newIndex = (i + offsetY) * SIZE + j + offsetX;
if (newIndex >= 0 && newIndex < SIZE * SIZE) {
newData[newIndex] = data[index];
}
}
}

return newData;
}

const handleDrawing = (x: number, y: number) => {
if (!lastPos) {
lastPos = { x, y };
}

if (x === lastPos.x && y === lastPos.y) {
return;
}
lastPos = { x, y };

for (let i = -1; i <= 1; i++) {
for (let j = -1; j <= 1; j++) {
const newX = x + i;
const newY = y + j;
if (newX >= 0 && newX < SIZE && newY >= 0 && newY < SIZE) {
const distance = Math.abs(i) + Math.abs(j);
const add = distance === 0 ? 128 : distance === 1 ? 64 : 32;
const value = canvasData[newY * SIZE + newX];
canvasData[newY * SIZE + newX] = Math.min(value + add, 255);
const interpolate = (start: number, end: number, steps: number) => {
const step = (end - start) / steps;
return Array.from({ length: steps + 1 }, (_, i) => start + step * i);
};

const steps = Math.max(Math.abs(x - lastPos.x), Math.abs(y - lastPos.y));
const xPoints = interpolate(lastPos.x, x, steps);
const yPoints = interpolate(lastPos.y, y, steps);

for (let k = 0; k < xPoints.length; k++) {
const newX = Math.round(xPoints[k]);
const newY = Math.round(yPoints[k]);

for (let i = -1; i <= 1; i++) {
for (let j = -1; j <= 1; j++) {
const adjX = newX + i;
const adjY = newY + j;
if (adjX >= 0 && adjX < SIZE && adjY >= 0 && adjY < SIZE) {
const distance = Math.abs(i) + Math.abs(j);
const add = distance === 0 ? 128 : distance === 1 ? 32 : 16;
const value = canvasData[adjY * SIZE + adjX];
canvasData[adjY * SIZE + adjX] = Math.min(value + add, 255);
}
}
}
}
draw();

network.inference(canvasData.map((x) => x / 255)).then((data) => {
const max = Math.max(...data);
const index = data.indexOf(max);
const sum = data.reduce((a, b) => a + b, 0);
const normalized = data.map((x) => x / sum);
lastPos = { x, y };
draw();

bars.forEach((bar, i) => {
bar.style.setProperty('--bar-width', `${normalized[i] * 100}%`);
bar.style.setProperty('--highlight-opacity', i === index ? '1' : '0');
network
.inference(centerImage(canvasData).map((x) => (x / 255) * 3.24 - 0.42)) // scale the values from 0-255 to -0.42-2.82
.then((data) => {
const max = Math.max(...data);
const index = data.indexOf(max);
const sum = data.reduce((a, b) => a + b, 0);
const normalized = data.map((x) => x / sum);

bars.forEach((bar, i) => {
bar.style.setProperty('--bar-width', `${normalized[i] * 100}%`);
bar.style.setProperty('--highlight-opacity', i === index ? '1' : '0');
});
});
});
};

canvas.addEventListener('mousemove', (event) => {
Expand Down Expand Up @@ -380,5 +415,6 @@ resetAll();

/** @button "Reset" */
export function reset() {
lastPos = null;
resetAll();
}

0 comments on commit 819aba9

Please sign in to comment.