1+ ! Copyright (c) 2024-2025, The Regents of the University of California and Sourcery Institute
2+ ! Terms of use are as specified in LICENSE.txt
3+
4+ module linear_2d_layer_test_m
5+ use julienne_m, only : &
6+ test_t, test_description_t, test_diagnosis_t, test_result_t &
7+ ,operator (.equalsExpected.), operator (// ), operator (.approximates.), operator (.within.), operator (.also.), operator (.all.)
8+ use nf_linear2d_layer, only: linear2d_layer
9+ implicit none
10+
11+ type, extends(test_t) :: linear_2d_layer_test_t
12+ contains
13+ procedure , nopass :: subject
14+ procedure , nopass :: results
15+ end type
16+
17+ contains
18+
19+ pure function subject () result(test_subject)
20+ character (len= :), allocatable :: test_subject
21+ test_subject = ' A linear_2d_layer'
22+ end function
23+
24+ function results () result(test_results)
25+ type (linear_2d_layer_test_t) linear_2d_layer_test
26+ type (test_result_t), allocatable :: test_results(:)
27+ test_results = linear_2d_layer_test% run( &
28+ [test_description_t(' updating gradients' , check_gradient_updates) &
29+ ])
30+ end function
31+
32+ function check_gradient_updates () result(test_diagnosis)
33+ type (test_diagnosis_t) test_diagnosis
34+
35+ real :: input(3 , 4 ) = reshape ([0.0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 ])
36+ real :: gradient(3 , 2 ) = reshape ([0.0 , 10 ., 0.2 , 3 ., 0.4 , 1 .], [3 , 2 ])
37+ type (linear2d_layer) :: linear
38+ real , pointer :: w_ptr(:)
39+ real , pointer :: b_ptr(:)
40+
41+ integer :: num_parameters
42+ real , allocatable :: parameters(:) ! Remove the fixed size
43+ real :: expected_parameters(10 ) = [&
44+ 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 ,&
45+ 0.109999999 , 0.109999999 &
46+ ]
47+ real :: gradients(10 )
48+ real :: expected_gradients(10 ) = [&
49+ 1.03999996 , 4.09999990 , 7.15999985 , 1.12400007 , 0.240000010 , 1.56000006 , 2.88000011 , 2.86399961 ,&
50+ 10.1999998 , 4.40000010 &
51+ ]
52+ real :: updated_parameters(10 )
53+ real :: updated_weights(8 )
54+ real :: updated_biases(2 )
55+ real :: expected_weights(8 ) = [&
56+ 0.203999996 , 0.509999990 , 0.816000044 , 0.212400019 , 0.124000005 , 0.256000012 , 0.388000011 , 0.386399955 &
57+ ]
58+ real :: expected_biases(2 ) = [1.13000000 , 0.550000012 ]
59+
60+ integer :: i
61+ real , parameter :: tolerance = 0 .
62+
63+ linear = linear2d_layer(out_features= 2 )
64+ call linear % init([3 , 4 ])
65+ linear % weights = 0.1
66+ linear % biases = 0.11
67+ call linear % forward(input)
68+ call linear % backward(input, gradient)
69+ num_parameters = linear % get_num_params()
70+
71+ test_diagnosis = (num_parameters .equalsExpected. 10 ) // " (number of parameters)"
72+
73+ call linear % get_params_ptr(w_ptr, b_ptr) ! Change this_layer to linear
74+ allocate (parameters(size (w_ptr) + size (b_ptr)))
75+ parameters(1 :size (w_ptr)) = w_ptr
76+ parameters(size (w_ptr)+ 1 :) = b_ptr
77+ test_diagnosis = test_diagnosis .also. (.all. (parameters .approximates. expected_parameters .within. tolerance) // " (parameters)" )
78+
79+ gradients = linear % get_gradients()
80+ test_diagnosis = test_diagnosis .also. (.all. (gradients .approximates. expected_gradients .within. tolerance) // " (gradients)" )
81+
82+ do i = 1 , num_parameters
83+ updated_parameters(i) = parameters(i) + 0.1 * gradients(i)
84+ end do
85+
86+ call linear % get_params_ptr(w_ptr, b_ptr) ! Change this_layer to linear
87+ w_ptr = updated_parameters(1 :size (w_ptr))
88+ b_ptr = updated_parameters(size (w_ptr)+ 1 :)
89+ updated_weights = reshape (linear % weights, shape (expected_weights))
90+ test_diagnosis = test_diagnosis .also. (.all. (updated_weights .approximates. expected_weights .within. tolerance) // " (updated weights)" )
91+
92+ updated_biases = linear % biases
93+ test_diagnosis = test_diagnosis .also. (.all. (updated_biases .approximates. expected_biases .within. tolerance) // " (updated biases)" )
94+
95+ end function
96+
97+ end module linear_2d_layer_test_m
0 commit comments