Skip to content

Commit

Permalink
Remove unnecessary logging, add recommended model to VGSL library
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Mar 15, 2024
1 parent 398c694 commit 85a12d3
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 7 deletions.
3 changes: 0 additions & 3 deletions src/data/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,6 @@ def load_initial_charlist(charlist_location: str, existing_model: str,
f"{charlist_location} and "
"replace_final_layer is False.")

logging.info("Using charlist: %s", charlist)
logging.info("Charlist length: %s", len(charlist))

return charlist, removed_padding


Expand Down
2 changes: 1 addition & 1 deletion src/data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _create_data(self,
# Update the charlist if it has changed
if not self.charlist:
self.charlist = sorted(list(characters))
logging.info("Created charlist: %s", self.charlist)
logging.debug("Updated charlist: %s", self.charlist)

logging.info("Created data for %s with %s samples",
partition_name, len(partitions))
Expand Down
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def main():

# Replace the charlist with the one from the data manager
charlist = data_manager.charlist
logging.info("Using charlist: %s", charlist)
logging.info("Charlist length: %s", len(charlist))

# Additional model customization such as freezing layers, replacing
# layers, or adjusting for float32
Expand Down
8 changes: 5 additions & 3 deletions src/model/vgsl_model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,11 @@ def get_model_libary() -> dict:
("None,64,None,4 Bn Ce3,3,16 RB3,3,16 RB3,3,16 RBd3,3,32 "
"RB3,3,32 RB3,3,32 RB3,3,32 RB3,3,32 RBd3,3,64 RB3,3,64 "
"RB3,3,64 RB3,3,64 RB3,3,64 RBd3,3,128 RB3,3,128 Rc "
"Bl256,D20 Bl256,D20 Bl256,D20 O1s92")
"Bl256,D20 Bl256,D20 Bl256,D20 O1s92"),
"recommended":
("None,64,None,1 Cr3,3,24 Mp2,2,2,2 Bn Cr3,3,48 Mp2,2,2,2 Bn "
"Cr3,3,96 Mp2,2,2,2 Bn Cr3,3,96 Mp2,2,2,2 Bn Rc Bl512 D50 "
"Bl512 D50 Bl512 D50 Bl512 D50 Bl512 D50 O1s92")
}

return model_library
Expand Down Expand Up @@ -638,8 +642,6 @@ def conv2d_generator(self,
# Check parameter length and generate corresponding Conv2D layer
if len(conv_filter_params) == 3:
x, y, d = conv_filter_params
logging.warning(
"No stride provided, setting default stride of (1,1)")
return layers.Conv2D(d,
kernel_size=(y, x),
strides=(1, 1),
Expand Down
17 changes: 17 additions & 0 deletions src/setup/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,22 @@
from setup.config import Config


class TensorFlowLogFilter(logging.Filter):
"""Filter to exclude specific TensorFlow logging messages.
This filter checks each log record for specific phrases that are to be
excluded from the logs. If any of the specified phrases are found in a log
message, the message is excluded from the logs.
"""

def filter(self, record):
# Exclude logs containing the specific message
exclude_phrases = [
"Reduce to /job:localhost/replica:0/task:0/device:CPU:",
]
return not any(phrase in record.msg for phrase in exclude_phrases)


def set_deterministic(seed: int) -> None:
"""
Sets the environment and random seeds to ensure deterministic behavior in
Expand Down Expand Up @@ -121,6 +137,7 @@ def setup_logging() -> None:

# Remove the default Tensorflow logger handlers and use our own
tf_logger = tf.get_logger()
tf_logger.addFilter(TensorFlowLogFilter())
while tf_logger.handlers:
tf_logger.handlers.pop()

Expand Down

0 comments on commit 85a12d3

Please sign in to comment.