Skip to content

Latest commit

 

History

History
45 lines (22 loc) · 1.57 KB

README.md

File metadata and controls

45 lines (22 loc) · 1.57 KB

NeuralNetwork4J_CUDA

Handwritten digit classifier written in Java using the GPU for accelerated training and inference. Trained on the MNIST handwritten digit dataset (included in .csv format split into training and test data).

  • JCublas and JCuda libraries serve as interface with native Cublas and Cuda libraries (versions >=12.0 must be installed beforehand).

By using the GPU, significant speedups over CPU-based training are achieved.

The main code for training/inference can be found in nn/gpu/NN_GPU.java

Currently available

Layer Types:

- Fully-connected

Training Algorithms:

- Stochastic Gradient Descent (SGD)

Cost Functions:

- Mean-squared error (MSE)

Activation Functions:

- Sigmoid

Parameter saving: Saving/loading from custom local file formats (examples are provided under "saved_networks").

Example Performance

Network Performance

Test Accuracy=95.96%

Network specifications: 784x32x10 (i.e. a single hidden layer with 32 neurons)

(AKA. saved_networks/tiny.txt)

Hyper-parameters: Batch size=32, Learning rate = 0.1

Training took approximately 1 minute on my machine (with a RTX 4090 GPU). GPU utilization peaked at only 20% on such a small network, but larger networks such as 784x3000x10 (saved_networks/wide.txt) take the same amount of time to train and utilize nearly 100%.