From 1247162ab92f0e1118871fab781bd2bd9b7966e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 11 Nov 2023 11:01:06 +0100 Subject: [PATCH] added higher level param transfer function to clean up run.py --- apax/train/run.py | 11 +++-------- apax/transfer_learning/__init__.py | 7 +++++-- apax/transfer_learning/parameter_transfer.py | 14 +++++++++++++- .../transfer_learning/test_parameter_transfer.py | 4 ++-- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/apax/train/run.py b/apax/train/run.py index 92922e43..fdf2b174 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -10,11 +10,11 @@ from apax.model import ModelBuilder from apax.optimizer import get_opt from apax.train.callbacks import initialize_callbacks -from apax.train.checkpoints import create_params, create_train_state, load_params +from apax.train.checkpoints import create_params, create_train_state from apax.train.loss import Loss, LossCollection from apax.train.metrics import initialize_metrics from apax.train.trainer import fit -from apax.transfer_learning import param_transfer +from apax.transfer_learning import transfer_parameters from apax.utils.random import seed_py_np_tf log = logging.getLogger(__name__) @@ -98,12 +98,7 @@ def run(user_config, log_file="train.log", log_level="error"): base_checkpoint = config.checkpoints.base_model_checkpoint do_transfer_learning = base_checkpoint is not None if do_transfer_learning: - source_params = load_params(base_checkpoint) - log.info("Transferring parameters from %s", base_checkpoint) - params = param_transfer( - source_params, state.params, config.checkpoints.reset_layers - ) - state.replace(params=params) + state = transfer_parameters(state, config.checkpoints) fit( state, diff --git a/apax/transfer_learning/__init__.py b/apax/transfer_learning/__init__.py index c31db7a2..c42ecf7e 100644 --- a/apax/transfer_learning/__init__.py +++ b/apax/transfer_learning/__init__.py @@ -1,3 +1,6 @@ -from apax.transfer_learning.parameter_transfer import param_transfer +from apax.transfer_learning.parameter_transfer import ( + black_list_param_transfer, + transfer_parameters, +) -__all__ = ["param_transfer"] +__all__ = ["transfer_parameters", "black_list_param_transfer"] diff --git a/apax/transfer_learning/parameter_transfer.py b/apax/transfer_learning/parameter_transfer.py index 64cab560..ca41d7b1 100644 --- a/apax/transfer_learning/parameter_transfer.py +++ b/apax/transfer_learning/parameter_transfer.py @@ -3,10 +3,12 @@ from flax.core.frozen_dict import freeze, unfreeze from flax.traverse_util import flatten_dict, unflatten_dict +from apax.train.checkpoints import load_params + log = logging.getLogger(__name__) -def param_transfer(source_params, target_params, param_black_list): +def black_list_param_transfer(source_params, target_params, param_black_list): source_params = unfreeze(source_params) target_params = unfreeze(target_params) @@ -20,3 +22,13 @@ def param_transfer(source_params, target_params, param_black_list): transfered_target = unflatten_dict(flat_target) transfered_target = freeze(transfered_target) return transfered_target + + +def transfer_parameters(state, ckpt_config): + source_params = load_params(ckpt_config.base_checkpoint) + log.info("Transferring parameters from %s", ckpt_config.base_checkpoint) + params = black_list_param_transfer( + source_params, state.params, ckpt_config.reset_layers + ) + state.replace(params=params) + return state diff --git a/tests/unit_tests/transfer_learning/test_parameter_transfer.py b/tests/unit_tests/transfer_learning/test_parameter_transfer.py index ad810776..3f24a979 100644 --- a/tests/unit_tests/transfer_learning/test_parameter_transfer.py +++ b/tests/unit_tests/transfer_learning/test_parameter_transfer.py @@ -1,4 +1,4 @@ -from apax.transfer_learning import param_transfer +from apax.transfer_learning import black_list_param_transfer def test_param_transfer(): @@ -15,7 +15,7 @@ def test_param_transfer(): } } reinitialize_layers = ["basis"] - transfered_target = param_transfer(source, target, reinitialize_layers) + transfered_target = black_list_param_transfer(source, target, reinitialize_layers) assert transfered_target["params"]["dense"]["w"] == source["params"]["dense"]["w"] assert transfered_target["params"]["dense"]["b"] == source["params"]["dense"]["b"]