Skip to content

Commit

Permalink
make base model as variable in INI file
Browse files Browse the repository at this point in the history
  • Loading branch information
bagustris committed May 27, 2024
1 parent 02c0c7b commit 3c89eeb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
2 changes: 2 additions & 0 deletions ini_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@
* device = cpu
* **patience**: Number of epochs to wait if the result gets better (for early stopping)
* patience = 5
* **model_ckpt**: Base model for finetuning/transfer learning. Variants of wav2vec2, Hubert, and WavLM are tested to work.
* model_ckpt = microsoft/wavlm-base

### EXPL
* **model**: Which model to use to estimate feature importance.
Expand Down
20 changes: 15 additions & 5 deletions nkululeko/models/model_tuned.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
Wav2Vec2PreTrainedModel,
)

from transformers import AutoConfig, AutoModel

import nkululeko.glob_conf as glob_conf
from nkululeko.models.model import Model as BaseModel
from nkululeko.reporting.reporter import Reporter
Expand All @@ -39,7 +41,7 @@ def __init__(self, df_train, df_test, feats_train, feats_test):
labels = glob_conf.labels
self.class_num = len(labels)
# device = self.util.config_val("MODEL", "device", "cpu")
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.batch_size = int(self.util.config_val("MODEL", "batch_size", "8"))
if self.device != "cpu":
self.util.debug(f"running on device {self.device}")
Expand All @@ -52,12 +54,16 @@ def __init__(self, df_train, df_test, feats_train, feats_test):

def _init_model(self):
model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
model_ckpt = self.util.config_val("MODEL", "model_ckpt", model_path)
self.num_layers = None
self.sampling_rate = 16000
self.max_duration_sec = 8.0
self.accumulation_steps = 4
# create dataset

# print finetuning information via debug
self.util.debug(f"Finetuning from model: {model_ckpt}")

# create dataset
dataset = {}
target_name = glob_conf.target
data_sources = {
Expand Down Expand Up @@ -86,12 +92,13 @@ def _init_model(self):
value in target_mapping.items()}

self.config = transformers.AutoConfig.from_pretrained(
model_path,
model_ckpt,
num_labels=len(target_mapping),
label2id=target_mapping,
id2label=target_mapping_reverse,
finetuning_task=target_name,
)

if self.num_layers is not None:
self.config.num_hidden_layers = self.num_layers
setattr(self.config, "sampling_rate", self.sampling_rate)
Expand All @@ -117,7 +124,7 @@ def _init_model(self):
assert self.processor.feature_extractor.sampling_rate == self.sampling_rate

self.model = Model.from_pretrained(
model_path,
model_ckpt,
config=self.config,
)
self.model.freeze_feature_extractor()
Expand Down Expand Up @@ -370,8 +377,11 @@ class Model(Wav2Vec2PreTrainedModel):

def __init__(self, config):

super().__init__(config)
if not hasattr(config, 'add_adapter'):
setattr(config, 'add_adapter', False)

super().__init__(config)

self.wav2vec2 = Wav2Vec2Model(config)
self.cat = ModelHead(config)
self.init_weights()
Expand Down

0 comments on commit 3c89eeb

Please sign in to comment.