-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathhugot_training_xla.go
111 lines (93 loc) · 3.36 KB
/
hugot_training_xla.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
//go:build XLA || ALL
package hugot
import (
"fmt"
"github.com/gomlx/exceptions"
"github.com/gomlx/gomlx/graph"
"github.com/gomlx/gomlx/ml/context"
"github.com/gomlx/gomlx/ml/train"
"github.com/gomlx/gomlx/ml/train/losses"
"github.com/gomlx/gomlx/ml/train/optimizers"
"github.com/gomlx/gopjrt/dtypes"
"github.com/knights-analytics/hugot/pipelineBackends"
"github.com/knights-analytics/hugot/pipelines"
)
type XLATrainingOptions struct {
Optimizer optimizers.Interface
Loss losses.LossFn
}
func NewXLATrainingSession[T pipelineBackends.Pipeline](config TrainingConfig) (*TrainingSession, error) {
s, err := newTrainingSession[T]("XLA", config)
if err != nil {
return nil, err
}
// set defaults
switch any(s.pipeline).(type) {
case *pipelines.FeatureExtractionPipeline:
if s.config.XlaTrainingOptions == nil {
s.config.XlaTrainingOptions = &XLATrainingOptions{}
}
if s.config.XlaTrainingOptions.Optimizer == nil {
s.config.XlaTrainingOptions.Optimizer = optimizers.StochasticGradientDescent()
}
if s.config.XlaTrainingOptions.Loss == nil {
s.config.XlaTrainingOptions.Loss = losses.MeanSquaredError
}
default:
return nil, fmt.Errorf("loss function is required")
}
return s, nil
}
func TrainXLA(s *TrainingSession) error {
switch p := s.pipeline.(type) {
case *pipelines.FeatureExtractionPipeline:
XLAModel := p.Model.XLAModel
backend := XLAModel.Backend
ctx := XLAModel.Ctx
modelFn := func(ctx *context.Context, spec any, inputs []*context.Node) []*context.Node {
inputsLhs := inputs[:3] // inputIDs, attentionMask, tokenTypeIDs if present
inputsRhs := inputs[3:]
embeddingLhs := XLAModel.Call(ctx.Reuse(), inputsLhs)[0]
embeddingRhs := XLAModel.Call(ctx.Reuse(), inputsRhs)[0]
// we mean pool the results if needed e.g. if dimensions are [batch, seq, hidden]
if len(embeddingLhs.Shape().Dimensions) > 2 {
batchSize := embeddingLhs.Shape().Dim(0)
embeddingSize := embeddingLhs.Shape().Dim(-1)
embeddingLhs = graph.Reshape(embeddingLhs, batchSize, -1, embeddingSize)
embeddingRhs = graph.Reshape(embeddingRhs, batchSize, -1, embeddingSize)
maskLhs := graph.ConvertDType(graph.BroadcastToShape(graph.Reshape(inputsLhs[1], batchSize, -1, 1), embeddingLhs.Shape()), dtypes.Bool)
maskRhs := graph.ConvertDType(graph.BroadcastToShape(graph.Reshape(inputsRhs[1], batchSize, -1, 1), embeddingRhs.Shape()), dtypes.Bool)
embeddingLhs = graph.MaskedReduceMean(embeddingLhs, maskLhs, 1)
embeddingRhs = graph.MaskedReduceMean(embeddingRhs, maskRhs, 1)
}
cosineSimilarity := graph.CosineSimilarity(embeddingLhs, embeddingRhs, -1)
return []*context.Node{cosineSimilarity}
}
gomlxTrainer := train.NewTrainer(backend,
ctx,
modelFn,
s.config.XlaTrainingOptions.Loss,
s.config.XlaTrainingOptions.Optimizer,
nil,
nil)
loop := train.NewLoop(gomlxTrainer)
// Loop for given number of steps.
if s.config.Verbose {
fmt.Printf("Training for %d epochs\n", s.config.Epochs)
}
// we rely on try catch because an error is returned if there is an initialization error but
// a panic will be thrown if e.g. dataset reset fails.
err := exceptions.TryCatch[error](func() {
if _, err := loop.RunEpochs(s.config.Dataset, s.config.Epochs); err != nil {
panic(err)
}
})
if err != nil {
return err
}
if s.config.Verbose {
fmt.Println("Training complete")
}
}
return nil
}