Handwritten digit classifier written in Java using the GPU for accelerated training and inference. Trained on the MNIST handwritten digit dataset (included in .csv format split into training and test data).
- JCublas and JCuda libraries serve as interface with native Cublas and Cuda libraries (versions >=12.0 must be installed beforehand).
By using the GPU, significant speedups over CPU-based training are achieved.
The main code for training/inference can be found in nn/gpu/NN_GPU.java
Layer Types:
- Fully-connected
Training Algorithms:
- Stochastic Gradient Descent (SGD)
Cost Functions:
- Mean-squared error (MSE)
Activation Functions:
- Sigmoid
Parameter saving: Saving/loading from custom local file formats (examples are provided under "saved_networks").
Test Accuracy=95.96%
Network specifications: 784x32x10 (i.e. a single hidden layer with 32 neurons)
(AKA. saved_networks/tiny.txt)
Hyper-parameters: Batch size=32, Learning rate = 0.1
Training took approximately 1 minute on my machine (with a RTX 4090 GPU). GPU utilization peaked at only 20% on such a small network, but larger networks such as 784x3000x10 (saved_networks/wide.txt) take the same amount of time to train and utilize nearly 100%.