forked from modern-fortran/neural-fortran
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist.f90
62 lines (47 loc) · 1.49 KB
/
mnist.f90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
program mnist
use nf, only: dense, input, network, sgd, label_digits, load_mnist
implicit none
type(network) :: net
real, allocatable :: training_images(:,:), training_labels(:)
real, allocatable :: validation_images(:,:), validation_labels(:)
integer :: n, num_epochs
call load_mnist(training_images, training_labels, &
validation_images, validation_labels)
print '("MNIST")'
print '(60("="))'
net = network([ &
input(784), &
dense(30), &
dense(10) &
])
num_epochs = 10
call net % print_info()
if (this_image() == 1) &
print '(a,f5.2,a)', 'Initial accuracy: ', accuracy( &
net, validation_images, label_digits(validation_labels)) * 100, ' %'
epochs: do n = 1, num_epochs
call net % train( &
training_images, &
label_digits(training_labels), &
batch_size=100, &
epochs=1, &
optimizer=sgd(learning_rate=3.) &
)
if (this_image() == 1) &
print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( &
net, validation_images, label_digits(validation_labels)) * 100, ' %'
end do epochs
contains
real function accuracy(net, x, y)
type(network), intent(in out) :: net
real, intent(in) :: x(:,:), y(:,:)
integer :: i, good
good = 0
do i = 1, size(x, dim=2)
if (all(maxloc(net % predict(x(:,i))) == maxloc(y(:,i)))) then
good = good + 1
end if
end do
accuracy = real(good) / size(x, dim=2)
end function accuracy
end program mnist