Skip to content

Commit

Permalink
added higher level param transfer function to clean up run.py
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Nov 11, 2023
1 parent dea322f commit 1247162
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 13 deletions.
11 changes: 3 additions & 8 deletions apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from apax.model import ModelBuilder
from apax.optimizer import get_opt
from apax.train.callbacks import initialize_callbacks
from apax.train.checkpoints import create_params, create_train_state, load_params
from apax.train.checkpoints import create_params, create_train_state
from apax.train.loss import Loss, LossCollection
from apax.train.metrics import initialize_metrics
from apax.train.trainer import fit
from apax.transfer_learning import param_transfer
from apax.transfer_learning import transfer_parameters
from apax.utils.random import seed_py_np_tf

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -98,12 +98,7 @@ def run(user_config, log_file="train.log", log_level="error"):
base_checkpoint = config.checkpoints.base_model_checkpoint
do_transfer_learning = base_checkpoint is not None
if do_transfer_learning:
source_params = load_params(base_checkpoint)
log.info("Transferring parameters from %s", base_checkpoint)
params = param_transfer(
source_params, state.params, config.checkpoints.reset_layers
)
state.replace(params=params)
state = transfer_parameters(state, config.checkpoints)

fit(
state,
Expand Down
7 changes: 5 additions & 2 deletions apax/transfer_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from apax.transfer_learning.parameter_transfer import param_transfer
from apax.transfer_learning.parameter_transfer import (
black_list_param_transfer,
transfer_parameters,
)

__all__ = ["param_transfer"]
__all__ = ["transfer_parameters", "black_list_param_transfer"]
14 changes: 13 additions & 1 deletion apax/transfer_learning/parameter_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from flax.core.frozen_dict import freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict

from apax.train.checkpoints import load_params

log = logging.getLogger(__name__)


def param_transfer(source_params, target_params, param_black_list):
def black_list_param_transfer(source_params, target_params, param_black_list):
source_params = unfreeze(source_params)
target_params = unfreeze(target_params)

Expand All @@ -20,3 +22,13 @@ def param_transfer(source_params, target_params, param_black_list):
transfered_target = unflatten_dict(flat_target)
transfered_target = freeze(transfered_target)
return transfered_target


def transfer_parameters(state, ckpt_config):
source_params = load_params(ckpt_config.base_checkpoint)
log.info("Transferring parameters from %s", ckpt_config.base_checkpoint)
params = black_list_param_transfer(
source_params, state.params, ckpt_config.reset_layers
)
state.replace(params=params)
return state
4 changes: 2 additions & 2 deletions tests/unit_tests/transfer_learning/test_parameter_transfer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from apax.transfer_learning import param_transfer
from apax.transfer_learning import black_list_param_transfer


def test_param_transfer():
Expand All @@ -15,7 +15,7 @@ def test_param_transfer():
}
}
reinitialize_layers = ["basis"]
transfered_target = param_transfer(source, target, reinitialize_layers)
transfered_target = black_list_param_transfer(source, target, reinitialize_layers)

assert transfered_target["params"]["dense"]["w"] == source["params"]["dense"]["w"]
assert transfered_target["params"]["dense"]["b"] == source["params"]["dense"]["b"]
Expand Down

0 comments on commit 1247162

Please sign in to comment.