-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
47 lines (34 loc) · 1.46 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import logging
from pathlib import Path
import numpy as np
import tensorflow as tf
from infer import test_model
from freeze import freeze_graph, run_console_tool
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
def str2bool(v):
return v.lower() in ("yes", "true", "t", "1")
def expand(x, dim, N):
return tf.concat([tf.expand_dims(x, dim) for _ in range(N)], axis=dim)
def learned_init(units):
return tf.squeeze(tf.compat.v1.layers.dense(tf.ones([1, 1]), units))
def create_linear_initializer(input_size, dtype=tf.float32):
stddev = 1.0 / np.sqrt(input_size)
return tf.compat.v1.truncated_normal_initializer(stddev=stddev, dtype=dtype)
def save_session_as_tf_checkpoint(session, saver, current_stage, bits_per_number):
model_dir = Path('./models') / f'{current_stage}'
model_path = model_dir / 'my_model.ckpt'
saver.save(session, str(model_path))
logger.info(f'Saved the trained model at step {current_stage}.')
# freeze_graph(model_dir)
# err = test_model(model_dir, bits_per_number=bits_per_number)
# logger.info(f'Tested frozen model at step {current_stage}, error: {err}.')
tool_arguments = [
'--checkpoint_dir',
str(model_dir)
]
res = run_console_tool(tool_arguments)
logger.info(res)
logger.info(f'Froze the model at step {current_stage}.')