Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions gematria/granite/python/gnn_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,15 @@ def _execute_graph_network(self, feed_dict) -> graph_nets.graphs.GraphsTuple:
)
return graphs_tuple

def _add_histogram_summaries(self) -> None:
"""Adds histogram summaries for tensors.

Logs histograms for all trainable variables within graph layers.
"""
for layer in self._graph_network:
for var in layer.module.trainable_variables:
tf.summary.histogram(var.name.replace(':', '_'), var)

@abc.abstractmethod
def _execute_readout_network(
self, graph_tuple, feed_dict: model_base.FeedDict
Expand Down
6 changes: 6 additions & 0 deletions gematria/model/python/main_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,11 @@ def main(_):
'',
'The directory to which the summaries from the training are stored.',
)
_GEMATRIA_LOG_HISTOGRAM_SUMMARIES = flags.DEFINE_bool(
'gematria_log_histogram_summaries',
False,
'Whether or not the model should write histogram summaries.',
)
_GEMATRIA_SAVE_CHECKPOINT_EPOCHS = flags.DEFINE_integer(
'gematria_save_checkpoint_epochs',
100,
Expand Down Expand Up @@ -768,6 +773,7 @@ def run_gematria_model_from_command_line_flags(
num_training_worker_replicas=num_replicas,
num_training_worker_replicas_to_aggregate=num_replicas_to_aggregate,
is_chief=is_chief,
log_histogram_summaries=_GEMATRIA_LOG_HISTOGRAM_SUMMARIES.value,
**model_kwargs,
)
model.initialize()
Expand Down
18 changes: 18 additions & 0 deletions gematria/model/python/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def __init__(
model_name: Optional[str] = None,
task_list: Optional[Sequence[str]] = None,
trained_variable_groups: Optional[Iterable[str]] = None,
log_histogram_summaries: bool = False,
) -> None:
"""Creates a new model with the provided parameters.

Expand Down Expand Up @@ -298,6 +299,8 @@ def __init__(
trained_variable_groups: The list of variable group names that are trained
by the optimizer. Only variables in this list are touched; if the list
is empty or None, all variables are trained.
log_histogram_summaries: Whether or not the model should write histogram
summaries for debugging purposes.
"""
self._dtype = dtype
self._numpy_dtype = dtype.as_numpy_dtype
Expand All @@ -320,6 +323,7 @@ def __init__(
self._grad_clip_norm = grad_clip_norm
task_list = task_list or ('default',)
self._task_list: Sequence[str] = task_list
self._log_histogram_summaries = log_histogram_summaries

self._model_name = model_name

Expand Down Expand Up @@ -544,6 +548,17 @@ def _add_error_summaries(self, error_name: str, error_tensor: tf.Tensor):
summary_name = f'{error_name}_{task_name}'
tf.summary.scalar(summary_name, error_tensor[task_idx])

def _add_histogram_summaries(self) -> None:
"""Adds histogram summaries for tensors.

Adds code for logging histogram summaries for model-specific tensors.

By default, this method is a no-op.
"""
# NOTE(vbshah): This method is not marked as abstract as it need not be
# implemented by subclasses, in which case this default (no-op)
# implementation should be invoked.

def _make_spearman_correlations(
self, expected_outputs: tf.Tensor, output_tensor: tf.Tensor
) -> tf.Tensor:
Expand Down Expand Up @@ -1393,6 +1408,9 @@ def train_batch(
)
tf.summary.scalar('overall_loss', loss_tensor)

if self._log_histogram_summaries:
self._add_histogram_summaries()

# TODO(vbshah): Consider writing delta loss summaries as well.
self._add_error_summaries('absolute_mse', loss.mean_squared_error)
self._add_error_summaries(
Expand Down