forked from r9y9/nnet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.go
127 lines (110 loc) · 3.18 KB
/
train.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package nnet
import (
"errors"
"fmt"
)
// SupervisedObjecitiver is an interface to provide objective function
// for supervised training.
type SupervisedObjectiver interface {
SupervisedObjective(input, target [][]float64) float64
}
type SupervisedOnlineUpdater interface {
SupervisedOnlineUpdate(input, target []float64)
SupervisedObjectiver
}
type SupervisedMiniBatchUpdater interface {
SupervisedMiniBatchUpdate(input, target [][]float64)
SupervisedObjectiver
}
// UnSupervisedObjecitiver is an interface to provide objective function
// for un-supervised training.
type UnSupervisedObjectiver interface {
UnSupervisedObjective(input [][]float64) float64
}
type UnSupervisedOnlineUpdater interface {
UnSupervisedOnlineUpdate(input []float64)
UnSupervisedObjectiver
}
type UnSupervisedMiniBatchUpdater interface {
UnSupervisedMiniBatchUpdate(input [][]float64, epoch,
miniBatchIndex int)
UnSupervisedObjectiver
}
type Trainer struct {
Option BaseTrainingOption
}
type BaseTrainingOption struct {
Epoches int
MiniBatchSize int // not used in standerd sgd
Monitoring bool
}
// New creates a new instance from training option.
func NewTrainer(option BaseTrainingOption) *Trainer {
s := new(Trainer)
s.Option = option
return s
}
func (t *Trainer) ParseTrainingOption(option BaseTrainingOption) error {
t.Option = option
if t.Option.MiniBatchSize <= 0 {
return errors.New("Number of mini-batchs must be larger than zero.")
}
if t.Option.Epoches <= 0 {
return errors.New("Epoches must be larger than zero.")
}
return nil
}
func (s *Trainer) SupervisedOnlineTrain(u SupervisedOnlineUpdater,
input, target [][]float64) error {
for epoch := 0; epoch < s.Option.Epoches; epoch++ {
for m := 0; m < len(input); m++ {
u.SupervisedOnlineUpdate(input[m], target[m])
}
if s.Option.Monitoring {
fmt.Println(epoch, u.SupervisedObjective(input, target))
}
}
return nil
}
func (s *Trainer) SupervisedMiniBatchTrain(u SupervisedMiniBatchUpdater,
input, target [][]float64) error {
numMiniBatches := len(input) / s.Option.MiniBatchSize
for epoch := 0; epoch < s.Option.Epoches; epoch++ {
for m := 0; m < numMiniBatches; m++ {
b := m * s.Option.MiniBatchSize
e := (m + 1) * s.Option.MiniBatchSize
u.SupervisedMiniBatchUpdate(input[b:e], target[b:e])
}
if s.Option.Monitoring {
fmt.Println(epoch, u.SupervisedObjective(input, target))
}
}
return nil
}
func (s *Trainer) UnSupervisedOnlineTrain(u UnSupervisedOnlineUpdater,
input [][]float64) error {
for epoch := 0; epoch < s.Option.Epoches; epoch++ {
for m := 0; m < len(input); m++ {
u.UnSupervisedOnlineUpdate(input[m])
}
if s.Option.Monitoring {
fmt.Println(epoch, u.UnSupervisedObjective(input))
}
}
return nil
}
func (s *Trainer) UnSupervisedMiniBatchTrain(u UnSupervisedMiniBatchUpdater,
input [][]float64) error {
numMiniBatches := len(input) / s.Option.MiniBatchSize
for epoch := 0; epoch < s.Option.Epoches; epoch++ {
for m := 0; m < numMiniBatches; m++ {
b := m * s.Option.MiniBatchSize
e := (m + 1) * s.Option.MiniBatchSize
u.UnSupervisedMiniBatchUpdate(input[b:e], epoch, m)
}
if s.Option.Monitoring {
fmt.Println(epoch, u.UnSupervisedObjective(input))
}
}
return nil
}