diff --git a/nkululeko/models/finetune_model.py b/nkululeko/models/finetune_model.py new file mode 100644 index 00000000..71d2d867 --- /dev/null +++ b/nkululeko/models/finetune_model.py @@ -0,0 +1,181 @@ +import dataclasses +import typing + +import torch +import transformers +from transformers.models.wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2PreTrainedModel, + Wav2Vec2Model, +) + + +class ConcordanceCorCoeff(torch.nn.Module): + + def __init__(self): + + super().__init__() + + self.mean = torch.mean + self.var = torch.var + self.sum = torch.sum + self.sqrt = torch.sqrt + self.std = torch.std + + def forward(self, prediction, ground_truth): + + mean_gt = self.mean(ground_truth, 0) + mean_pred = self.mean(prediction, 0) + var_gt = self.var(ground_truth, 0) + var_pred = self.var(prediction, 0) + v_pred = prediction - mean_pred + v_gt = ground_truth - mean_gt + cor = self.sum(v_pred * v_gt) / ( + self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2)) + ) + sd_gt = self.std(ground_truth) + sd_pred = self.std(prediction) + numerator = 2 * cor * sd_gt * sd_pred + denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2 + ccc = numerator / denominator + + return 1 - ccc + + +@dataclasses.dataclass +class ModelOutput(transformers.file_utils.ModelOutput): + + logits_cat: torch.FloatTensor = None + hidden_states: typing.Tuple[torch.FloatTensor] = None + cnn_features: torch.FloatTensor = None + + +class ModelHead(torch.nn.Module): + + def __init__(self, config, num_labels): + + super().__init__() + + self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = torch.nn.Dropout(config.final_dropout) + self.out_proj = torch.nn.Linear(config.hidden_size, num_labels) + + def forward(self, features, **kwargs): + + x = features + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + + return x + + +class Model(Wav2Vec2PreTrainedModel): + + def __init__(self, config): + + super().__init__(config) + + self.wav2vec2 = Wav2Vec2Model(config) + self.cat = ModelHead(config, 2) + self.init_weights() + + def freeze_feature_extractor(self): + self.wav2vec2.feature_extractor._freeze_parameters() + + def pooling( + self, + hidden_states, + attention_mask, + ): + + if attention_mask is None: # For evaluation with batch_size==1 + outputs = torch.mean(hidden_states, dim=1) + else: + attention_mask = self._get_feature_vector_attention_mask( + hidden_states.shape[1], + attention_mask, + ) + hidden_states = hidden_states * torch.reshape( + attention_mask, + (-1, attention_mask.shape[-1], 1), + ) + outputs = torch.sum(hidden_states, dim=1) + attention_sum = torch.sum(attention_mask, dim=1) + outputs = outputs / torch.reshape(attention_sum, (-1, 1)) + + return outputs + + def forward( + self, + input_values, + attention_mask=None, + labels=None, + return_hidden=False, + ): + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + ) + + cnn_features = outputs.extract_features + hidden_states_framewise = outputs.last_hidden_state + hidden_states = self.pooling( + hidden_states_framewise, + attention_mask, + ) + logits_cat = self.cat(hidden_states) + + if not self.training: + logits_cat = torch.softmax(logits_cat, dim=1) + + if return_hidden: + + # make time last axis + cnn_features = torch.transpose(cnn_features, 1, 2) + + return ModelOutput( + logits_cat=logits_cat, + hidden_states=hidden_states, + cnn_features=cnn_features, + ) + + else: + + return ModelOutput( + logits_cat=logits_cat, + ) + + +class ModelWithPreProcessing(Model): + + def __init__(self, config): + super().__init__(config) + + def forward( + self, + input_values, + ): + # Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm(): + # normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + + mean = input_values.mean() + + # var = input_values.var() + # raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11) + + var = torch.square(input_values - mean).mean() + input_values = (input_values - mean) / torch.sqrt(var + 1e-7) + + output = super().forward( + input_values, + return_hidden=True, + ) + + return ( + output.hidden_states, + output.logits_cat, + output.cnn_features, + ) diff --git a/nkululeko/test_pretrain.py b/nkululeko/test_pretrain.py index 0b94840e..54242027 100644 --- a/nkululeko/test_pretrain.py +++ b/nkululeko/test_pretrain.py @@ -11,11 +11,14 @@ import audeer import audiofile +import audmetric from nkululeko.constants import VERSION import nkululeko.experiment as exp +import nkululeko.models.finetune_model as fm import nkululeko.glob_conf as glob_conf from nkululeko.utils.util import Util +import json def doit(config_file): @@ -50,28 +53,42 @@ def doit(config_file): expr.fill_train_and_tests() util.debug(f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}") + log_root = audeer.mkdir("log") + model_root = audeer.mkdir("model") + torch_root = audeer.path(model_root, "torch") + + metrics_gender = { + "UAR": audmetric.unweighted_average_recall, + "ACC": audmetric.accuracy, + } + sampling_rate = 16000 max_duration_sec = 8.0 model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h" num_layers = None + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = "3" + batch_size = 16 accumulation_steps = 4 - # create dataset dataset = {} + target_name = glob_conf.target data_sources = { - "train": pd.DataFrame(expr.df_train[glob_conf.target]), - "dev": pd.DataFrame(expr.df_test[glob_conf.target]), + "train": pd.DataFrame(expr.df_train[target_name]), + "dev": pd.DataFrame(expr.df_test[target_name]), } for split in ["train", "dev"]: + df = data_sources[split] + df[target_name] = df[target_name].astype("float") y = pd.Series( - data=data_sources[split].itertuples(index=False, name=None), - index=data_sources[split].index, + data=df.itertuples(index=False, name=None), + index=df.index, dtype=object, name="labels", ) @@ -80,23 +97,183 @@ def doit(config_file): df = y.reset_index() df.start = df.start.dt.total_seconds() df.end = df.end.dt.total_seconds() + print(f"{split}: {len(df)}") + ds = datasets.Dataset.from_pandas(df) dataset[split] = ds - dataset = datasets.DatasetDict(dataset) + dataset = datasets.DatasetDict(dataset) + + # load pre-trained model + le = glob_conf.label_encoder + mapping = dict(zip(le.classes_, range(len(le.classes_)))) + target_mapping = {k: int(v) for k, v in mapping.items()} + target_mapping_reverse = {value: key for key, value in target_mapping.items()} config = transformers.AutoConfig.from_pretrained( model_path, - num_labels=len(util.la), - label2id=data.gender_mapping, - id2label=data.gender_mapping_reverse, - finetuning_task="age-gender", + num_labels=len(target_mapping), + label2id=target_mapping, + id2label=target_mapping_reverse, + finetuning_task=target_name, ) if num_layers is not None: config.num_hidden_layers = num_layers setattr(config, "sampling_rate", sampling_rate) - setattr(config, "data", ",".join(sources)) + setattr(config, "data", util.get_data_name()) + + vocab_dict = {} + with open("vocab.json", "w") as vocab_file: + json.dump(vocab_dict, vocab_file) + tokenizer = transformers.Wav2Vec2CTCTokenizer("./vocab.json") + tokenizer.save_pretrained(".") + + feature_extractor = transformers.Wav2Vec2FeatureExtractor( + feature_size=1, + sampling_rate=16000, + padding_value=0.0, + do_normalize=True, + return_attention_mask=True, + ) + processor = transformers.Wav2Vec2Processor( + feature_extractor=feature_extractor, + tokenizer=tokenizer, + ) + assert processor.feature_extractor.sampling_rate == sampling_rate + + model = fm.Model.from_pretrained( + model_path, + config=config, + ) + model.freeze_feature_extractor() + model.train() + + # training + + def data_collator(data): + + files = [d["file"] for d in data] + starts = [d["start"] for d in data] + ends = [d["end"] for d in data] + targets = [d["targets"] for d in data] + + signals = [] + for file, start, end in zip( + files, + starts, + ends, + ): + offset = start + duration = end - offset + if max_duration_sec is not None: + duration = min(duration, max_duration_sec) + signal, _ = audiofile.read( + file, + offset=offset, + duration=duration, + ) + signals.append(signal.squeeze()) + + input_values = processor( + signals, + sampling_rate=sampling_rate, + padding=True, + ) + batch = processor.pad( + input_values, + padding=True, + return_tensors="pt", + ) + + batch["labels"] = torch.tensor(targets) + + return batch + + def compute_metrics(p: transformers.EvalPrediction): + + truth_gender = p.label_ids[:, 0].astype(int) + preds = p.predictions + preds_gender = np.argmax(preds, axis=1) + + scores = {} + + for name, metric in metrics_gender.items(): + scores[f"gender-{name}"] = metric(truth_gender, preds_gender) + + scores["combined"] = scores["gender-UAR"] + + return scores + + targets = pd.DataFrame(dataset["train"]["targets"]) + counts = targets[0].value_counts().sort_index() + train_weights = 1 / counts + train_weights /= train_weights.sum() + + print(train_weights) + + criterion_gender = torch.nn.CrossEntropyLoss( + weight=torch.Tensor(train_weights).to("cuda"), + ) + + class Trainer(transformers.Trainer): + + def compute_loss( + self, + model, + inputs, + return_outputs=False, + ): + + targets = inputs.pop("labels").squeeze() + targets_gender = targets.type(torch.long) + + outputs = model(**inputs) + logits_gender = outputs[0].squeeze() + + loss_gender = criterion_gender(logits_gender, targets_gender) + + loss = loss_gender + + return (loss, outputs) if return_outputs else loss + + num_steps = len(dataset["train"]) // (batch_size * accumulation_steps) // 5 + num_steps = max(1, num_steps) + print(num_steps) + + training_args = transformers.TrainingArguments( + output_dir=model_root, + logging_dir=log_root, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + gradient_accumulation_steps=accumulation_steps, + evaluation_strategy="steps", + num_train_epochs=5.0, + fp16=True, + save_steps=num_steps, + eval_steps=num_steps, + logging_steps=num_steps, + learning_rate=1e-4, + save_total_limit=2, + metric_for_best_model="combined", + greater_is_better=True, + load_best_model_at_end=True, + remove_unused_columns=False, + ) + + trainer = Trainer( + model=model, + data_collator=data_collator, + args=training_args, + compute_metrics=compute_metrics, + train_dataset=dataset["train"], + eval_dataset=dataset["dev"], + tokenizer=processor.feature_extractor, + callbacks=[transformers.integrations.TensorBoardCallback()], + ) + + trainer.train() + trainer.save_model(torch_root) print("DONE")