From 79789a9b44d0520818a4fe5bafad6d84d92f206a Mon Sep 17 00:00:00 2001 From: Viraj Shah Date: Sat, 29 Mar 2025 01:45:36 +0530 Subject: [PATCH 1/3] Add histogram summaries to `GnnModelBase`. --- gematria/granite/python/gnn_model_base.py | 9 +++++++++ gematria/model/python/model_base.py | 11 +++++++++++ 2 files changed, 20 insertions(+) diff --git a/gematria/granite/python/gnn_model_base.py b/gematria/granite/python/gnn_model_base.py index 92439aec..e6cad6a3 100644 --- a/gematria/granite/python/gnn_model_base.py +++ b/gematria/granite/python/gnn_model_base.py @@ -444,6 +444,15 @@ def _execute_graph_network(self, feed_dict) -> graph_nets.graphs.GraphsTuple: ) return graphs_tuple + def _add_histogram_summaries(self) -> None: + """Adds histogram summaries for tensors. + + Logs histograms for all trainable variables within graph layers. + """ + for layer in self._graph_network: + for var in layer.module.trainable_variables: + tf.summary.histogram(var.name.replace(':', '_'), var) + @abc.abstractmethod def _execute_readout_network( self, graph_tuple, feed_dict: model_base.FeedDict diff --git a/gematria/model/python/model_base.py b/gematria/model/python/model_base.py index 76f1d1ca..0f1646ad 100644 --- a/gematria/model/python/model_base.py +++ b/gematria/model/python/model_base.py @@ -381,6 +381,8 @@ def __init__( def initialize(self) -> None: """Initializes the model. Must be called before any other method.""" self._create_optimizer() + self._add_histogram_summaries() + tf.summary.scalar('learning_rate', self._decayed_learning_rate) @property def use_deltas(self) -> bool: @@ -544,6 +546,15 @@ def _add_error_summaries(self, error_name: str, error_tensor: tf.Tensor): summary_name = f'{error_name}_{task_name}' tf.summary.scalar(summary_name, error_tensor[task_idx]) + @abc.abstractmethod + def _add_histogram_summaries(self) -> None: + """Adds histogram summaries for tensors. + + Adds code for logging histogram summaries for model-specific tensors. + + By default, this method is a no-op. + """ + def _make_spearman_correlations( self, expected_outputs: tf.Tensor, output_tensor: tf.Tensor ) -> tf.Tensor: From ea26017fa683055db8e25fd59a6e9c5b58969716 Mon Sep 17 00:00:00 2001 From: Viraj Shah Date: Mon, 14 Apr 2025 18:25:46 +0530 Subject: [PATCH 2/3] Remove `abstractmethod` decoration from `_add_histogram_summaries`. --- gematria/model/python/model_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gematria/model/python/model_base.py b/gematria/model/python/model_base.py index 0f1646ad..c698b4b8 100644 --- a/gematria/model/python/model_base.py +++ b/gematria/model/python/model_base.py @@ -546,7 +546,6 @@ def _add_error_summaries(self, error_name: str, error_tensor: tf.Tensor): summary_name = f'{error_name}_{task_name}' tf.summary.scalar(summary_name, error_tensor[task_idx]) - @abc.abstractmethod def _add_histogram_summaries(self) -> None: """Adds histogram summaries for tensors. @@ -554,6 +553,9 @@ def _add_histogram_summaries(self) -> None: By default, this method is a no-op. """ + # NOTE(vbshah): This method is not marked as abstract as it need not be + # implemented by subclasses, in which case this default (no-op) + # implementation should be invoked. def _make_spearman_correlations( self, expected_outputs: tf.Tensor, output_tensor: tf.Tensor From 572fa738b1761aa3b885f02e01f614ee56568a57 Mon Sep 17 00:00:00 2001 From: Viraj Shah Date: Sun, 15 Jun 2025 23:44:37 +0530 Subject: [PATCH 3/3] Put histogram summaries behind flag. --- gematria/model/python/main_function.py | 6 ++++++ gematria/model/python/model_base.py | 9 +++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/gematria/model/python/main_function.py b/gematria/model/python/main_function.py index c5c1e580..a9af06e0 100644 --- a/gematria/model/python/main_function.py +++ b/gematria/model/python/main_function.py @@ -290,6 +290,11 @@ def main(_): '', 'The directory to which the summaries from the training are stored.', ) +_GEMATRIA_LOG_HISTOGRAM_SUMMARIES = flags.DEFINE_bool( + 'gematria_log_histogram_summaries', + False, + 'Whether or not the model should write histogram summaries.', +) _GEMATRIA_SAVE_CHECKPOINT_EPOCHS = flags.DEFINE_integer( 'gematria_save_checkpoint_epochs', 100, @@ -768,6 +773,7 @@ def run_gematria_model_from_command_line_flags( num_training_worker_replicas=num_replicas, num_training_worker_replicas_to_aggregate=num_replicas_to_aggregate, is_chief=is_chief, + log_histogram_summaries=_GEMATRIA_LOG_HISTOGRAM_SUMMARIES.value, **model_kwargs, ) model.initialize() diff --git a/gematria/model/python/model_base.py b/gematria/model/python/model_base.py index c698b4b8..35ae8348 100644 --- a/gematria/model/python/model_base.py +++ b/gematria/model/python/model_base.py @@ -240,6 +240,7 @@ def __init__( model_name: Optional[str] = None, task_list: Optional[Sequence[str]] = None, trained_variable_groups: Optional[Iterable[str]] = None, + log_histogram_summaries: bool = False, ) -> None: """Creates a new model with the provided parameters. @@ -298,6 +299,8 @@ def __init__( trained_variable_groups: The list of variable group names that are trained by the optimizer. Only variables in this list are touched; if the list is empty or None, all variables are trained. + log_histogram_summaries: Whether or not the model should write histogram + summaries for debugging purposes. """ self._dtype = dtype self._numpy_dtype = dtype.as_numpy_dtype @@ -320,6 +323,7 @@ def __init__( self._grad_clip_norm = grad_clip_norm task_list = task_list or ('default',) self._task_list: Sequence[str] = task_list + self._log_histogram_summaries = log_histogram_summaries self._model_name = model_name @@ -381,8 +385,6 @@ def __init__( def initialize(self) -> None: """Initializes the model. Must be called before any other method.""" self._create_optimizer() - self._add_histogram_summaries() - tf.summary.scalar('learning_rate', self._decayed_learning_rate) @property def use_deltas(self) -> bool: @@ -1406,6 +1408,9 @@ def train_batch( ) tf.summary.scalar('overall_loss', loss_tensor) + if self._log_histogram_summaries: + self._add_histogram_summaries() + # TODO(vbshah): Consider writing delta loss summaries as well. self._add_error_summaries('absolute_mse', loss.mean_squared_error) self._add_error_summaries(