Skip to content

Commit

Permalink
Improved API logging, add MirroredStrategy, mixed_float16
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Oct 13, 2023
1 parent 10c2375 commit b040ad3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
15 changes: 15 additions & 0 deletions src/api/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
from flask import request


class TensorFlowLogFilter(logging.Filter):
def filter(self, record):
# Exclude logs containing the specific message
exclude_phrases = [
"Reduce to /job:localhost/replica:0/task:0/device:CPU:"
]
return not any(phrase in record.msg for phrase in exclude_phrases)


def setup_logging(level: str = "INFO") -> logging.Logger:
"""
Set up logging with the specified level and return a logger instance.
Expand Down Expand Up @@ -39,6 +48,12 @@ def setup_logging(level: str = "INFO") -> logging.Logger:
level=logging_levels[level],
)

# Get TensorFlow's logger and remove its handlers to prevent duplicate logs
tf_logger = logging.getLogger('tensorflow')
tf_logger.addFilter(TensorFlowLogFilter())
while tf_logger.handlers:
tf_logger.handlers.pop()

return logging.getLogger(__name__)


Expand Down
30 changes: 29 additions & 1 deletion src/api/batch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# > Third-party dependencies
import tensorflow as tf
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras import mixed_precision


def batch_prediction_worker(batch_size: int,
Expand Down Expand Up @@ -67,9 +68,32 @@ def batch_prediction_worker(batch_size: int,
physical_devices = tf.config.experimental.list_physical_devices('GPU')
logger.debug(f"Number of GPUs available: {len(physical_devices)}")
if physical_devices:
all_gpus_support_mixed_precision = True

for device in physical_devices:
tf.config.experimental.set_memory_growth(device, True)
logger.debug(device)

# Get the compute capability of the GPU
details = tf.config.experimental.get_device_details(device)
major = details.get('compute_capability')[0]

# Check if the compute capability is less than 7.0
if int(major) < 7:
all_gpus_support_mixed_precision = False
logger.debug(
f"GPU {device} does not support efficient mixed precision."
)
break

# If all GPUs support mixed precision, enable it
if all_gpus_support_mixed_precision:
mixed_precision.set_global_policy('mixed_float16')
logger.debug("Mixed precision set to 'mixed_float16'")
else:
logger.debug(
"Not all GPUs support efficient mixed precision. Running in "
"standard mode.")
else:
logger.warning("No GPUs available")

Expand All @@ -81,8 +105,12 @@ def batch_prediction_worker(batch_size: int,

from utils import decode_batch_predictions, normalize_confidence

strategy = tf.distribute.MirroredStrategy()

try:
model, utils = create_model(model_path, charlist_path, num_channels)
with strategy.scope():
model, utils = create_model(
model_path, charlist_path, num_channels)
logger.info("Model created and utilities initialized")
except Exception as e:
logger.error(e)
Expand Down

0 comments on commit b040ad3

Please sign in to comment.