diff --git a/ini_file.md b/ini_file.md index 3722677..7cafbee 100644 --- a/ini_file.md +++ b/ini_file.md @@ -303,6 +303,8 @@ * device = cpu * **patience**: Number of epochs to wait if the result gets better (for early stopping) * patience = 5 +* **model_ckpt**: Base model for finetuning/transfer learning. Variants of wav2vec2, Hubert, and WavLM are tested to work. + * model_ckpt = microsoft/wavlm-base ### EXPL * **model**: Which model to use to estimate feature importance. diff --git a/nkululeko/models/model_tuned.py b/nkululeko/models/model_tuned.py index 1dd81ae..d477496 100644 --- a/nkululeko/models/model_tuned.py +++ b/nkululeko/models/model_tuned.py @@ -21,6 +21,8 @@ Wav2Vec2PreTrainedModel, ) +from transformers import AutoConfig, AutoModel + import nkululeko.glob_conf as glob_conf from nkululeko.models.model import Model as BaseModel from nkululeko.reporting.reporter import Reporter @@ -39,7 +41,7 @@ def __init__(self, df_train, df_test, feats_train, feats_test): labels = glob_conf.labels self.class_num = len(labels) # device = self.util.config_val("MODEL", "device", "cpu") - self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.device = "cuda" if torch.cuda.is_available() else "cpu" self.batch_size = int(self.util.config_val("MODEL", "batch_size", "8")) if self.device != "cpu": self.util.debug(f"running on device {self.device}") @@ -52,12 +54,16 @@ def __init__(self, df_train, df_test, feats_train, feats_test): def _init_model(self): model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h" + model_ckpt = self.util.config_val("MODEL", "model_ckpt", model_path) self.num_layers = None self.sampling_rate = 16000 self.max_duration_sec = 8.0 self.accumulation_steps = 4 - # create dataset + + # print finetuning information via debug + self.util.debug(f"Finetuning from model: {model_ckpt}") + # create dataset dataset = {} target_name = glob_conf.target data_sources = { @@ -86,12 +92,13 @@ def _init_model(self): value in target_mapping.items()} self.config = transformers.AutoConfig.from_pretrained( - model_path, + model_ckpt, num_labels=len(target_mapping), label2id=target_mapping, id2label=target_mapping_reverse, finetuning_task=target_name, ) + if self.num_layers is not None: self.config.num_hidden_layers = self.num_layers setattr(self.config, "sampling_rate", self.sampling_rate) @@ -117,7 +124,7 @@ def _init_model(self): assert self.processor.feature_extractor.sampling_rate == self.sampling_rate self.model = Model.from_pretrained( - model_path, + model_ckpt, config=self.config, ) self.model.freeze_feature_extractor() @@ -370,8 +377,11 @@ class Model(Wav2Vec2PreTrainedModel): def __init__(self, config): - super().__init__(config) + if not hasattr(config, 'add_adapter'): + setattr(config, 'add_adapter', False) + super().__init__(config) + self.wav2vec2 = Wav2Vec2Model(config) self.cat = ModelHead(config) self.init_weights()