Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions aiutils/tftools/var_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,43 @@ def collect_all(graph=None):

return var_list

def get_all_scopes_in_var(var_name):
# Split at ':' to remove the variable number at the end
wo_var_num = var_name.split(":")[0]

# Split at '/' to get all scope names and variable name
scopes_names = wo_var_num.split("/")

return set(scopes_names)


def collect_partial_scope(
name_scope,
graph=None,
var_type=tf.GraphKeys.VARIABLES):
''' Function for collecting variables that contain the given name_scope in
their scope hierarchy.
Eg. If var has var.name = 'scope1/scope2/scope3/var_name:4'
Then name_scope='scope2' will collect this variable (if it exists in
var_type collection)
name_scope cannot contain moer than 1 scope. Eg. It cannot be
'scope2/scope3'
'''
if graph == None:
graph = tf.get_default_graph()
var_list = graph.get_collection(var_type)

scope_var_list = []
for var in var_list:
var_scope_names = self.get_all_scopes_in_var(var.name)
if name_scope in var_scope_names:
scope_var_list.append(var)

assert_str = "No variable exists with name_scope '{}'".format(name_scope)
assert len(scope_var_list) != 0, assert_str

return scope_var_list


def collect_all_trainable(graph=None):
if graph == None:
Expand Down