From 51e81a82d12febe1741d623a198cc41ed1dba03c Mon Sep 17 00:00:00 2001 From: Nitish Gupta Date: Sat, 10 Sep 2016 11:53:59 -0500 Subject: [PATCH] Added func. for collecting variables with partial scope match. --- aiutils/tftools/var_collect.py | 37 ++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/aiutils/tftools/var_collect.py b/aiutils/tftools/var_collect.py index aff112d..7a934f3 100644 --- a/aiutils/tftools/var_collect.py +++ b/aiutils/tftools/var_collect.py @@ -41,6 +41,43 @@ def collect_all(graph=None): return var_list +def get_all_scopes_in_var(var_name): + # Split at ':' to remove the variable number at the end + wo_var_num = var_name.split(":")[0] + + # Split at '/' to get all scope names and variable name + scopes_names = wo_var_num.split("/") + + return set(scopes_names) + + +def collect_partial_scope( + name_scope, + graph=None, + var_type=tf.GraphKeys.VARIABLES): + ''' Function for collecting variables that contain the given name_scope in + their scope hierarchy. + Eg. If var has var.name = 'scope1/scope2/scope3/var_name:4' + Then name_scope='scope2' will collect this variable (if it exists in + var_type collection) + name_scope cannot contain moer than 1 scope. Eg. It cannot be + 'scope2/scope3' + ''' + if graph == None: + graph = tf.get_default_graph() + var_list = graph.get_collection(var_type) + + scope_var_list = [] + for var in var_list: + var_scope_names = self.get_all_scopes_in_var(var.name) + if name_scope in var_scope_names: + scope_var_list.append(var) + + assert_str = "No variable exists with name_scope '{}'".format(name_scope) + assert len(scope_var_list) != 0, assert_str + + return scope_var_list + def collect_all_trainable(graph=None): if graph == None: