This interactive web application built with TensorFlow.js allows you to train and experiment with neural networks on the MNIST handwritten digit dataset:
- 🎯 Interactive Training - Train a CNN model directly in your browser
- 📊 Real-time Visualization - Monitor training metrics with live charts
- ✍️ Draw & Predict - Draw digits and get instant predictions
- 📈 Comprehensive Metrics - View batch and epoch-level performance
- 🔍 Confusion Matrix - Analyze model performance across all digit classes
- 🖼️ Dataset Preview - Visualize random samples from the MNIST dataset
- ⚡ WebGPU Acceleration - Leverage GPU for faster training
- 🧠 WebGL Backend - Fallback option for wider browser compatibility
- 📱 Responsive Design - Works seamlessly on desktop and mobile devices
The application provides comprehensive training capabilities:
| Feature | Description | Use Case |
|---|---|---|
| 🔧 Configurable Parameters | Adjust training data size, batch size, epochs | 🎛️ Experiment with different training setups |
| 📊 Live Metrics | Real-time loss and accuracy tracking | 📈 Monitor training progress |
| 🎨 Interactive Canvas | Draw digits for instant prediction | ✍️ Test model performance |
| 📉 Performance Charts | Batch and epoch-level visualizations | 📊 Analyze training dynamics |
| 🔄 Auto Prediction | Automatic inference after drawing | ⚡ Seamless user experience |
The CNN model uses modern deep learning practices with Batch Normalization for improved training stability:
Input (28×28×1)
↓
[Conv2D(32, 3×3) → BatchNorm → ReLU → MaxPool(2×2)]
↓
[Conv2D(64, 3×3) → BatchNorm → ReLU → MaxPool(2×2)]
↓
Flatten → Dropout(0.5)
↓
[Dense(128) → BatchNorm → ReLU → Dropout(0.5)]
↓
Dense(10) → Softmax
| Layer Type | Configuration | Output Shape | Parameters |
|---|---|---|---|
| Input | 28×28 grayscale | (28, 28, 1) | 0 |
| Conv2D | 32 filters, 3×3 kernel, HeNormal init | (26, 26, 32) | 320 |
| BatchNorm | - | (26, 26, 32) | 128 |
| ReLU | - | (26, 26, 32) | 0 |
| MaxPool2D | 2×2 pool, stride 2 | (13, 13, 32) | 0 |
| Conv2D | 64 filters, 3×3 kernel, HeNormal init | (11, 11, 64) | 18,496 |
| BatchNorm | - | (11, 11, 64) | 256 |
| ReLU | - | (11, 11, 64) | 0 |
| MaxPool2D | 2×2 pool, stride 2 | (5, 5, 64) | 0 |
| Flatten | - | (1600) | 0 |
| Dropout | rate=0.5 | (1600) | 0 |
| Dense | 128 units, HeNormal init | (128) | 204,928 |
| BatchNorm | - | (128) | 512 |
| ReLU | - | (128) | 0 |
| Dropout | rate=0.5 | (128) | 0 |
| Dense | 10 units (output) | (10) | 1,290 |
| Softmax | - | (10) | 0 |
Total Parameters: ~225,930
- 🎯 Batch Normalization: Applied after convolutions and dense layers for faster convergence
- 🔧 He Normal Initialization: Optimal weight initialization for ReLU activations
- 🛡️ Dropout Regularization: 50% dropout rate to prevent overfitting
- ⚡ Adam Optimizer: Adaptive learning rate optimization
- 📊 Categorical Crossentropy: Standard loss for multi-class classification
{
optimizer: 'adam',
learningRate: 0.001, // Configurable in UI
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
}-
Batch Normalization
- Stabilizes training by normalizing layer inputs
- Allows higher learning rates
- Acts as regularization
-
He Normal Initialization
- Specifically designed for ReLU activations
- Prevents vanishing/exploding gradients
-
Progressive Feature Extraction
- 32 filters → 64 filters: Gradually increases feature complexity
- MaxPooling: Reduces spatial dimensions while preserving features
-
Dropout for Robustness
- Applied after flatten and dense layers
- Reduces overfitting on training data
With default settings (5,500 training samples, 10 epochs):
- Training Accuracy: ~98-99%
- Validation Accuracy: ~97-98%
- Training Time: ~2-5 minutes (depending on hardware)
- Clone this repository
git clone https://github.com/yourusername/mnist-playground-tfjs.git- Navigate to the project directory
cd mnist-playground-tfjs- Install dependencies
npm installStart development server
npm run devBuild the project
npm run buildPreview production build
npm run preview| Parameter | Range | Default | Description |
|---|---|---|---|
| Train Data | 1,000 - 60,000 | 5,500 | Number of training samples |
| Test Data | 1,000 - 10,000 | 1,000 | Number of validation samples |
| Batch Size | 1 - 512 | 128 | Number of samples per training batch |
| Epochs | 1 - 200 | 10 | Number of complete training iterations |
| Learning Rate | 0.0001 - 1 | 0.001 | Optimizer learning rate |
| Backend | WebGPU/WebGL | WebGPU | Computational backend for training |
-
Batch-Level Metrics
- Loss and accuracy per batch
- Average batch processing time
- Progress tracking
-
Epoch-Level Metrics
- Training and validation loss
- Training and validation accuracy
- Average epoch processing time
-
Confusion Matrix
- Overall accuracy
- Per-class precision, recall, F1-score
- Visual heatmap of predictions
-
Draw Digits
- Use mouse or touch to draw on the 280x280 canvas
- Automatic prediction after 0.5 seconds of inactivity
- Manual prediction button available
-
Prediction Display
- Predicted digit with confidence score
- Probability distribution across all 10 digits
- Color-coded confidence levels:
- 🟢 Green (>80%): High confidence
- 🟡 Yellow (50-80%): Medium confidence
- 🔴 Red (<50%): Low confidence
-
Canvas Controls
- Clear Canvas: Reset the drawing area
- Predict Now: Trigger immediate prediction
⚠️ Note: Drawing is disabled during training and requires a trained model
- Measures how far predictions are from actual values
- Lower is better
- Should decrease during training
- Percentage of correct predictions
- Higher is better
- Should increase during training
- Shows which digits are commonly confused
- Diagonal elements represent correct predictions
- Off-diagonal elements show misclassifications
- Precision: Of all predicted X, how many were actually X?
- Recall: Of all actual X, how many were correctly identified?
- F1-Score: Harmonic mean of precision and recall
- Start Small: Begin with smaller datasets (5,000-10,000 samples) for faster experimentation
- Adjust Batch Size: Larger batches (128-256) for stability, smaller for better generalization
- Monitor Overfitting: Watch for diverging training and validation accuracy
- Experiment: Try different learning rates and epochs to find optimal settings
- Draw Clearly: Make digits large and centered
- Use Bold Strokes: Thicker lines work better
- Single Digit: Draw one digit at a time
- Center Position: Keep the digit in the middle of the canvas
You can modify the model architecture in src/utils/model.js:
export function createModel(lr = 0.001) {
const model = sequential();
// layer_1 - 32 filters, 3x3 kernel
model.add(
layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 3,
filters: 32,
activation: "linear",
kernelInitializer: "heNormal",
})
);
// ... more layers
return model;
}To use a custom dataset:
- Prepare your data in the MNIST format (28x28 grayscale images)
- Update
src/utils/data.jsto load your dataset - Adjust the number of classes if needed
| Browser | WebGPU | WebGL | Status |
|---|---|---|---|
| Chrome (113+) | ✅ | ✅ | ✅ |
| Edge (113+) | ✅ | ✅ | ✅ |
| Firefox | 🚧 | ✅ | ✅ |
| Safari | 🚧 | ✅ | ✅ |
⚡ WebGPU Support: WebGPU is currently supported in Chrome and Edge. Other browsers will automatically fall back to WebGL.
Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.
- TensorFlow.js - Machine learning library
- Chart.js - Data visualization
- Bootstrap - UI framework
- Bootstrap Icons - Icon library
- MNIST Dataloader - Dataset source
For questions or feedback, please open an issue on GitHub.

