-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSave and Restore model
29 lines (27 loc) · 2.11 KB
/
Save and Restore model
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
一 只存取variable
1.要先搭好图
2.tf.train.Saver会给graph里的每一个variable或者指定的variable添加一个save和一个restore的op,然后会提供一个总的op来存和读这些variable。
save之后会产生4个文件
".meta" files: containing the graph structure
".data" files: containing the values of variables
".index" files: identifying the checkpoint
"checkpoint" file: a protocol buffer with a list of recent checkpoints
3.可以用tf.python.tools.inspect_checkpoint库来打印checkpoint中存的variable和它的值
二 SavedModel,存取一整个model,不用知道图的结构
是一个把tf.train.Saver包裹起来的API,提供了更多的feature:Signature用来表示模型的输入和输出;Asset用来表示模型所需要的外部文件
1. MetaGraph:a dataflow graph and its variables and signature.
2. MetaGraphDef: the protocol buffer representation for a MetaGraph
3. SavedModelBuilder: 创建一个SavedModel,会储存相应的MetaGraphDef和variable到disk
3.1 在添加MetaGraphDef到SavedModelBuilder时,需要一个tag,用来identify这个对应的MetaGraphDef,在load和restore的时候都需要它
3.2 先使用add_meta_graph_and_variables函数来添加第一个metagraph(比如training graph)和variable,要添加其它的metagraph(比如inference graph),再使用add_meta_graph。它里面其实都是用Saver来完成的
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess,
["foo-tag"],
signature_def_map=foo_signatures,
assets_collection=foo_assets) //sess里的图和变量都会被存为一个SavedModel
builder.add_meta_graph(["bar-tag", "baz-tag"],
assets_collection=bar_baz_assets)
builder.save()
4. 使用tf.saved_model.loader.load方法来加载model
例如:tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)
调用之后,tag_constants.TRAINING对应的MetaGraphDef会被恢复到sess中的graph去。