-
-
Notifications
You must be signed in to change notification settings - Fork 50
/
dense_softmax_test.go
287 lines (277 loc) · 12.7 KB
/
dense_softmax_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
package tensor
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSoftMax(t *testing.T) {
testCases := []struct {
fn func(x Tensor, axis int, opts ...FuncOpt) (Tensor, error)
x Tensor
axis int
expectedOutput interface{}
}{
{
fn: LogSoftMax,
x: New(
Of(Float64),
WithShape(3, 4),
WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
),
axis: -1,
expectedOutput: []float64{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628},
},
{
fn: LogSoftMax,
x: New(
Of(Float32),
WithShape(3, 4),
WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
),
axis: -1,
expectedOutput: []float32{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628},
},
{
fn: LogSoftMax,
x: New(
Of(Float32),
WithShape(3, 2, 2),
WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
),
axis: -1,
expectedOutput: []float32{-0.7443967, -0.64439666, -0.7443967, -0.64439666, -0.7443967, -0.64439666, -0.7443966, -0.64439666, -0.7443966, -0.64439666, -0.7443967, -0.64439666},
},
{
fn: LogSoftMax,
x: New(
Of(Float64),
WithShape(3, 2, 2),
WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
),
axis: 1,
expectedOutput: []float64{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918},
},
{
fn: SoftMax,
x: New(
Of(Float64),
WithShape(3, 2, 2),
WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
),
axis: 1,
expectedOutput: []float64{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478},
},
{
fn: SoftMax,
x: New(
Of(Float64),
WithShape(3, 2, 2),
WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
),
axis: -1,
expectedOutput: []float64{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894},
},
{
fn: SoftMax,
x: New(
Of(Float32),
WithShape(3, 4),
WithBacking([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
),
axis: -1,
expectedOutput: []float32{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514},
},
{
fn: SoftMax,
x: New(
Of(Float64),
WithShape(3, 4),
WithBacking([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}),
),
axis: -1,
expectedOutput: []float64{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514},
},
}
for i, tC := range testCases {
t.Run(fmt.Sprintf("Example #%d - %v %v", i+1, tC.x.Shape(), tC.x.Dtype()), func(t *testing.T) {
c := assert.New(t)
output, err := tC.fn(tC.x, tC.axis)
t.Logf("output: %#v", output.Data())
c.NoError(err)
c.NotNil(output)
c.Equal(tC.x.Shape(), output.Shape())
c.InDeltaSlice(tC.expectedOutput, output.Data(), 1e-6)
})
}
}
func TestSoftMaxB(t *testing.T) {
testCases := []struct {
fn func(output, grad Tensor, axis int, opts ...FuncOpt) (Tensor, error)
output Tensor
grad Tensor
axis int
expectedOutput interface{}
}{
{
fn: SoftMaxB,
output: New(
Of(Float64),
WithShape(3, 4),
WithBacking([]float64{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}),
),
grad: New(
Of(Float64),
WithShape(3, 4),
WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
),
axis: -1,
expectedOutput: []float64{-0.003474116568224552, -0.0014762147035963322, 0.0009803563066858392, 0.00396997522759976, -0.003474116880376028, -0.001476214931490494, 0.0009803561238580223, 0.003969975025543781, -0.0034741159267098936, -0.0014762139946130218, 0.0009803570151630109, 0.003969976093553957},
},
{
fn: LogSoftMaxB,
output: New(
Of(Float64),
WithShape(3, 4),
WithBacking([]float64{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}),
),
grad: New(
Of(Float64),
WithShape(3, 4),
WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
),
axis: -1,
expectedOutput: []float64{-0.011383822036598441, -0.003632778232153768, 0.0038817407844924366, 0.01113485948425977, -0.005597937295155945, -0.001445223403599799, 0.0020925260396803457, 0.004950634659075405, 0.00018794744628654992, 0.0007423314249541871, 0.00030331129486827163, -0.0012335901661089598},
},
{
fn: SoftMaxB,
output: New(
Of(Float64),
WithShape(3, 2, 2),
WithBacking([]float64{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}),
),
grad: New(
Of(Float64),
WithShape(3, 2, 2),
WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
),
axis: -1,
expectedOutput: []float64{-0.002493760401928919, 0.0024937604019289205, -0.0024937604019289183, 0.002493760401928922, -0.002493760401928915, 0.002493760401928922, -0.002493760401928912, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289183},
},
{
fn: SoftMaxB,
output: New(
Of(Float64),
WithShape(3, 2, 2),
WithBacking([]float64{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}),
),
grad: New(
Of(Float64),
WithShape(3, 2, 2),
WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
),
axis: 1,
expectedOutput: []float64{-0.004950331454237199, -0.004950331454237198, 0.004950331454237199, 0.0049503314542372, -0.004950331454237196, -0.004950331454237193, 0.004950331454237203, 0.0049503314542372065, -0.004950331454237193, -0.0049503314542372, 0.0049503314542372065, 0.004950331454237193},
},
{
fn: LogSoftMaxB,
output: New(
Of(Float64),
WithShape(3, 2, 2),
WithBacking([]float64{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}),
),
grad: New(
Of(Float64),
WithShape(3, 2, 2),
WithBacking([]float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
),
axis: 1,
expectedOutput: []float64{-0.008006640107500884, -0.007009960161251325, 0.00800664010750088, 0.007009960161251332, -0.004019920322502654, -0.003023240376253103, 0.004019920322502661, 0.0030232403762530968, -3.32005375044292e-05, 0.0009634794087451421, 3.320053750442642e-05, -0.0009634794087451543},
},
{
fn: LogSoftMaxB,
output: New(
Of(Float32),
WithShape(3, 2, 2),
WithBacking([]float32{-0.7981388693815918, -0.7981388693815918, -0.5981388693815918, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815919, -0.7981388693815918, -0.7981388693815918, -0.5981388693815919, -0.5981388693815918}),
),
grad: New(
Of(Float32),
WithShape(3, 2, 2),
WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
),
axis: 1,
expectedOutput: []float64{-0.008006640107500884, -0.007009960161251325, 0.00800664010750088, 0.007009960161251332, -0.004019920322502654, -0.003023240376253103, 0.004019920322502661, 0.0030232403762530968, -3.32005375044292e-05, 0.0009634794087451421, 3.320053750442642e-05, -0.0009634794087451543},
},
{
fn: SoftMaxB,
output: New(
Of(Float32),
WithShape(3, 2, 2),
WithBacking([]float32{0.4501660026875221, 0.45016600268752205, 0.549833997312478, 0.5498339973124778, 0.45016600268752205, 0.45016600268752205, 0.5498339973124778, 0.5498339973124778, 0.45016600268752205, 0.4501660026875221, 0.5498339973124778, 0.549833997312478}),
),
grad: New(
Of(Float32),
WithShape(3, 2, 2),
WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
),
axis: 1,
expectedOutput: []float32{-0.004950331454237199, -0.004950331454237198, 0.004950331454237199, 0.0049503314542372, -0.004950331454237196, -0.004950331454237193, 0.004950331454237203, 0.0049503314542372065, -0.004950331454237193, -0.0049503314542372, 0.0049503314542372065, 0.004950331454237193},
},
{
fn: SoftMaxB,
output: New(
Of(Float32),
WithShape(3, 4),
WithBacking([]float32{0.21383822, 0.23632777, 0.2611826, 0.2886514, 0.21383823, 0.23632778, 0.2611826, 0.2886514, 0.21383822, 0.23632777, 0.26118258, 0.2886514}),
),
grad: New(
Of(Float64),
WithShape(3, 4),
WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
),
axis: -1,
expectedOutput: []float32{-0.003474116568224552, -0.0014762147035963322, 0.0009803563066858392, 0.00396997522759976, -0.003474116880376028, -0.001476214931490494, 0.0009803561238580223, 0.003969975025543781, -0.0034741159267098936, -0.0014762139946130218, 0.0009803570151630109, 0.003969976093553957},
},
{
fn: LogSoftMaxB,
output: New(
Of(Float64),
WithShape(3, 4),
WithBacking([]float32{-1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551626, -1.2425355294551628, -1.5425355294551628, -1.4425355294551627, -1.3425355294551629, -1.2425355294551628}),
),
grad: New(
Of(Float64),
WithShape(3, 4),
WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
),
axis: -1,
expectedOutput: []float32{-0.011383822036598441, -0.003632778232153768, 0.0038817407844924366, 0.01113485948425977, -0.005597937295155945, -0.001445223403599799, 0.0020925260396803457, 0.004950634659075405, 0.00018794744628654992, 0.0007423314249541871, 0.00030331129486827163, -0.0012335901661089598},
},
{
fn: SoftMaxB,
output: New(
Of(Float64),
WithShape(3, 2, 2),
WithBacking([]float32{0.47502081252106, 0.52497918747894, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.5249791874789399, 0.47502081252106, 0.52497918747894}),
),
grad: New(
Of(Float64),
WithShape(3, 2, 2),
WithBacking([]float32{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12}),
),
axis: -1,
expectedOutput: []float32{-0.002493760401928919, 0.0024937604019289205, -0.0024937604019289183, 0.002493760401928922, -0.002493760401928915, 0.002493760401928922, -0.002493760401928912, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289253, -0.0024937604019289183, 0.0024937604019289183},
},
}
for i, tC := range testCases {
t.Run(fmt.Sprintf("Example #%d - %v %v", i+1, tC.output.Shape(), tC.output.Dtype()), func(t *testing.T) {
c := assert.New(t)
dx, err := tC.fn(tC.output, tC.grad, tC.axis)
t.Logf("output: %#v", tC.output.Data())
c.NoError(err)
c.NotNil(dx)
c.Equal(tC.output.Shape(), dx.Shape())
c.InDeltaSlice(tC.expectedOutput, dx.Data(), 1e-6)
})
}
}