-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathinference.go
143 lines (120 loc) · 3.43 KB
/
inference.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package lukai
import (
"time"
context "golang.org/x/net/context"
tensorflow "github.com/tensorflow/tensorflow/tensorflow/go"
"github.com/luk-ai/lukai/protobuf/aggregatorpb"
"github.com/luk-ai/lukai/tf"
)
// shouldLoadProdModelRLocked returns whether a new prod model should be loaded.
func (mt *ModelType) shouldLoadProdModelRLocked() bool {
return mt.prod.model == nil || mt.prod.lastUpdate.Before(time.Now().Add(-outOfDateModelTimeout))
}
// loadProdModelRLocked loads the prod model if it isn't present or is out of date.
func (mt *ModelType) loadProdModelRLocked(ctx context.Context) error {
if !mt.shouldLoadProdModelRLocked() {
return nil
}
// Release the read lock, and write lock.
mt.prod.RUnlock()
defer mt.prod.RLock()
mt.prod.Lock()
defer mt.prod.Unlock()
// Make sure nothing has changed since we acquired the write lock.
if !mt.shouldLoadProdModelRLocked() {
return nil
}
conn, err := dial(ctx, EdgeAddress)
if err != nil {
return err
}
defer conn.Close()
c := aggregatorpb.NewEdgeClient(conn)
resp, err := c.ProdModel(ctx, &aggregatorpb.ProdModelRequest{
Id: aggregatorpb.ModelID{
Domain: mt.Domain,
ModelType: mt.ModelType,
},
})
if err != nil {
return err
}
if resp.Id == mt.prod.modelID {
mt.prod.lastUpdate = time.Now()
return nil
}
if mt.prod.model != nil {
if err := mt.prod.model.Close(); err != nil {
return err
}
mt.prod.model = nil
mt.prod.modelID = aggregatorpb.ModelID{}
}
modelResp, err := c.ModelURL(ctx, &aggregatorpb.ModelURLRequest{
Id: resp.Id,
})
if err != nil {
return err
}
mt.prod.model, err = tf.GetModel(modelResp.Url)
if err != nil {
return err
}
mt.prod.modelID = resp.Id
mt.prod.lastUpdate = time.Now()
mt.prod.cache = makeTFOpCache(mt.prod.model)
// TODO(d4l3k): Quantize model weights.
// TODO(d4l3k): Train with local examples.
return nil
}
// ID returns the current model ID. If there a production model loaded it
// returns the production model ID, otherwise, the ID field will be blank and
// just have the domain and model type.
func (mt *ModelType) ID() aggregatorpb.ModelID {
mt.prod.Lock()
defer mt.prod.Unlock()
if (mt.prod.modelID != aggregatorpb.ModelID{}) {
return mt.prod.modelID
}
return aggregatorpb.ModelID{
Domain: mt.Domain,
ModelType: mt.ModelType,
}
}
// Run runs the model with the provided tensorflow feeds, fetches and targets.
// The key for feeds, and fetches should be in the form "name:#", and the
// targets in the form "name".
func (mt *ModelType) Run(
ctx context.Context, feeds map[string]*tensorflow.Tensor, fetches []string, targets []string,
) ([]*tensorflow.Tensor, error) {
tensors, err := mt.runInternal(ctx, feeds, fetches, targets)
if err != nil {
if err := mt.reportErrorDial(ctx, mt.ID(), aggregatorpb.ERROR_INFERENCE, err); err != nil {
return nil, err
}
return nil, err
}
return tensors, nil
}
func (mt *ModelType) runInternal(
ctx context.Context, feeds map[string]*tensorflow.Tensor, fetches []string, targets []string,
) ([]*tensorflow.Tensor, error) {
mt.prod.RLock()
defer mt.prod.RUnlock()
if mt.prod.model == nil {
if err := mt.loadProdModelRLocked(ctx); err != nil {
return nil, err
}
}
feedsResolved, fetchesResolved, targetsResolved, err := mt.prod.cache.resolve(
example{
feeds: feeds,
fetches: fetches,
targets: targets,
},
)
if err != nil {
return nil, err
}
return mt.prod.model.Session.Run(feedsResolved, fetchesResolved, targetsResolved)
}