From 666d036eabe006e358d24660de0a46b8d923afef Mon Sep 17 00:00:00 2001 From: benleetownsend Date: Tue, 14 Jan 2025 22:06:07 +0000 Subject: [PATCH] ADD: Initial run at modern bert - feats within 1e-7 --- finetune/base_models/modern_bert/__init__.py | 0 finetune/base_models/modern_bert/encoding.py | 62 ++++ finetune/base_models/modern_bert/model.py | 62 ++++ finetune/base_models/modern_bert/modelling.py | 330 ++++++++++++++++++ .../modern_bert/modernbert_to_finetune.py | 37 ++ .../modern_bert/verify_features.py | 23 ++ finetune/datasets/reuters.py | 11 +- finetune/saver.py | 1 - 8 files changed, 518 insertions(+), 8 deletions(-) create mode 100644 finetune/base_models/modern_bert/__init__.py create mode 100644 finetune/base_models/modern_bert/encoding.py create mode 100644 finetune/base_models/modern_bert/model.py create mode 100644 finetune/base_models/modern_bert/modelling.py create mode 100644 finetune/base_models/modern_bert/modernbert_to_finetune.py create mode 100644 finetune/base_models/modern_bert/verify_features.py diff --git a/finetune/base_models/modern_bert/__init__.py b/finetune/base_models/modern_bert/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/finetune/base_models/modern_bert/encoding.py b/finetune/base_models/modern_bert/encoding.py new file mode 100644 index 000000000..8152b8847 --- /dev/null +++ b/finetune/base_models/modern_bert/encoding.py @@ -0,0 +1,62 @@ +import logging +import os +import finetune +from tokenizers import Tokenizer +from finetune.encoding.input_encoder import BaseEncoder +from finetune.encoding.input_encoder import EncodedOutput + +FINETUNE_FOLDER = os.path.dirname(finetune.__file__) +TOKENIZER_PATH = os.path.join(FINETUNE_FOLDER, "model", "modern_bert", "tokenizer.json") + +LOGGER = logging.getLogger("finetune") + +class ModernBertEncoder(BaseEncoder): + def __init__(self): + self.tokenizer = Tokenizer.from_file(TOKENIZER_PATH) + special_tokens_map = {tok.content: k for k, tok in self.tokenizer.get_added_tokens_decoder().items()} + self.start_token = special_tokens_map["[CLS]"] + self.delimiter_token = special_tokens_map["[SEP]"] + self.mask_token = special_tokens_map["[MASK]"] + self.end_token = special_tokens_map["[SEP]"] + self.initialized = True + self.UNK_IDX = None + + @property + def vocab_size(self): + return self.tokenizer.get_vocab_size() + + def _encode(self, texts): + batch_tokens = [] + batch_token_idxs = [] + batch_char_ends = [] + batch_char_starts = [] + for text in texts: + encoded = self.tokenizer.encode(text, add_special_tokens=False) + batch_tokens.append(encoded.tokens) + batch_token_idxs.append(encoded.ids) + token_ends = [] + token_starts = [] + for start, end in encoded.offsets: + if token_ends: + # Finetune requires that tokens never overlap. + # This happens in huggingface tokenizers when + # a single character is split across multiple tokens. + start = max(token_ends[-1], start) + end = max(end, start) + token_starts.append(start) + token_ends.append(end) + + batch_char_ends.append(token_ends) + batch_char_starts.append(token_starts) + + output = EncodedOutput( + token_ids=batch_token_idxs, + tokens=batch_tokens, + token_ends=batch_char_ends, + token_starts=batch_char_starts, + ) + return output + + def decode(self, ids): + output = self.tokenizer.decode(ids, skip_special_tokens=True) + return output \ No newline at end of file diff --git a/finetune/base_models/modern_bert/model.py b/finetune/base_models/modern_bert/model.py new file mode 100644 index 000000000..6503ea355 --- /dev/null +++ b/finetune/base_models/modern_bert/model.py @@ -0,0 +1,62 @@ +import os +import tensorflow as tf +from finetune.base_models.modern_bert.modelling import ModernBert, ModernBertConfig +from finetune.base_models.modern_bert.encoding import ModernBertEncoder +from finetune.base_models import SourceModel + + +def featurizer( + X, encoder, config, train=False, reuse=None, lengths=None, **kwargs + ): + initial_shape = tf.shape(input=X) + X = tf.reshape(X, shape=tf.concat(([-1], initial_shape[-1:]), 0)) + X.set_shape([None, None]) + delimiters = tf.cast(tf.equal(X, encoder.delimiter_token), tf.int32) + + seq_length = tf.shape(input=delimiters)[1] + mask = tf.sequence_mask(lengths, maxlen=seq_length, dtype=tf.float32) + with tf.compat.v1.variable_scope("model/featurizer", reuse=reuse): + # TODO: plumb in the config to the finetune config. + model = ModernBert(config=ModernBertConfig()) + embedding = model.embeddings + sequence_out = model(input_ids=X, attention_mask=mask, training=train, seq_len=seq_length) + pooled_out = sequence_out[:, 0, :] + pooled_out.set_shape([None, config.n_embed]) + n_embed = pooled_out.shape[-1] + + features = tf.reshape( + pooled_out, + shape=tf.concat((initial_shape[:-1], [n_embed]), 0), + ) + sequence_features = tf.reshape( + sequence_out, + shape=tf.concat((initial_shape, [n_embed]), 0), + ) + + output_state = { + "embedding": embedding, + "features": features, + "sequence_features": sequence_features, + "lengths": lengths, + "inputs": X, + } + + return output_state + + +class ModernBertModel(SourceModel): + encoder = ModernBertEncoder + featurizer = featurizer + max_length = 2048 + + settings = { + "base_model_path": os.path.join("modern_bert", "modern_bert.jl"), + "n_layer": 22, + "train_embeddings": True, + "num_layers_trained": 22, + "n_embed": 768, + "max_length": max_length, + "include_bos_eos": True, + } + required_files = [] + diff --git a/finetune/base_models/modern_bert/modelling.py b/finetune/base_models/modern_bert/modelling.py new file mode 100644 index 000000000..d7cb2c315 --- /dev/null +++ b/finetune/base_models/modern_bert/modelling.py @@ -0,0 +1,330 @@ +import os +import tensorflow as tf +import finetune +import functools + +from typing import Optional, Tuple, Union +import json + + +FINETUNE_FOLDER = os.path.dirname(finetune.__file__) +CONFIG_PATH = os.path.join(FINETUNE_FOLDER, "model", "modern_bert", "config.json") + +class ModernBertConfig: + # TODO: make this configurable so we can support different models. + def __init__(self): + with open(CONFIG_PATH, "r") as f: + self.config = json.load(f) + + def __getattr__(self, name): + return self.config[name] + + +class EmbeddingWithPadIdx(tf.keras.layers.Embedding): + # TODO: we are not masking grads for the pad idx like the torch version. + # Shouldn't be necessary as there should be no gradient here anyway. + # But worth looking at if we run into issues + def __init__(self, *args, pad_idx=None, name="EmbeddingWithPadIdx", **kwargs): + super().__init__(*args, **kwargs, name=name) + self.pad_idx = pad_idx + + def compute_mask(self, inputs, mask=None): + return tf.not_equal(inputs, self.pad_idx) + + +class ModernBertEmbeddings(tf.keras.layers.Layer): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config: ModernBertConfig, name="Embedding"): + super().__init__(name=name) + self.config = config + self.tok_embeddings = EmbeddingWithPadIdx(config.vocab_size, config.hidden_size, pad_idx=config.pad_token_id) + self.norm = tf.keras.layers.LayerNormalization(epsilon=config.norm_eps, center=config.norm_bias, name="EmbeddingNorm") + self.drop = tf.keras.layers.Dropout(config.embedding_dropout) + + def call( + self, input_ids: tf.Tensor = None, inputs_embeds: Optional[tf.Tensor] = None, training=False + ) -> tf.Tensor: + if inputs_embeds is not None: + hidden_states = self.drop(self.norm(inputs_embeds), training=training) + else: + hidden_states = self.drop(self.norm(self.tok_embeddings(input_ids)), training=training) + return hidden_states + + +class ModernBertMLP(tf.keras.layers.Layer): + """Applies the GLU at the end of each ModernBERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: ModernBertConfig, name="GLU"): + super().__init__(name=name) + self.config = config + self.Wi = tf.keras.layers.Dense(int(config.intermediate_size) * 2, use_bias=config.mlp_bias, name="Wi") + if config.hidden_activation == "gelu": + self.act = functools.partial(tf.keras.activations.gelu, approximate=False) + else: + raise ValueError(f"Unsupported activation: {config.hidden_activation}") + self.drop = tf.keras.layers.Dropout(config.mlp_dropout) + self.Wo = tf.keras.layers.Dense(config.hidden_size, use_bias=config.mlp_bias, name="Wo") + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + input, gate = tf.split(self.Wi(hidden_states), 2, axis=-1) + result = self.Wo(self.drop(self.act(input) * gate)) + return result + + +class ModernBertRotaryEmbedding(tf.keras.layers.Layer): + def __init__(self, dim: int, base: float): + super().__init__() + self.inv_freq = 1.0 / (base ** (tf.range(0, dim, 2, dtype=tf.float32) / dim)) + + def call(self, x, position_ids): + inv_freq_expanded = tf.tile(tf.expand_dims(tf.expand_dims(self.inv_freq, 0), 2), [position_ids.shape[0], 1, 1]) + position_ids_expanded = tf.expand_dims(position_ids, 1) + freqs = tf.transpose(tf.matmul(inv_freq_expanded, position_ids_expanded), perm=[0, 2, 1]) + emb = tf.concat([freqs, freqs], axis=-1) + cos = tf.cast(tf.cos(emb), x.dtype) + sin = tf.cast(tf.sin(emb), x.dtype) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return tf.concat([-x2, x1], axis=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = tf.expand_dims(cos, unsqueeze_dim) + sin = tf.expand_dims(sin, unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + + + +class ModernBertAttention(tf.keras.layers.Layer): + """Performs multi-headed self attention on a batch of unpadded sequences. + + See `forward` method for additional details. + """ + + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None, name="Attention"): + super().__init__(name=name) + self.config = config + self.layer_id = layer_id + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" + ) + + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.all_head_size = self.head_dim * self.num_heads + self.Wqkv = tf.keras.layers.Dense(3 * self.all_head_size, use_bias=config.attention_bias, name="Wqkv") + + if layer_id % config.global_attn_every_n_layers != 0: + self.local_attention = (config.local_attention // 2, config.local_attention // 2) + else: + self.local_attention = (-1, -1) + + rope_theta = config.global_rope_theta + if self.local_attention != (-1, -1): + if config.local_rope_theta is not None: + rope_theta = config.local_rope_theta + + self.rotary_emb = ModernBertRotaryEmbedding(dim=self.head_dim, base=rope_theta) + + self.Wo = tf.keras.layers.Dense(config.hidden_size, use_bias=config.attention_bias, name="Wo") + self.out_drop = tf.keras.layers.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else tf.keras.layers.Identity() + self.pruned_heads = set() + + + def eager_attention_forward( + self, + qkv: tf.Tensor, + attention_mask: tf.Tensor, + sliding_window_mask: tf.Tensor, + position_ids: Optional[tf.Tensor], + local_attention: Tuple[int, int], + bs: int, + dim: int, + training: bool, + ) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[tf.Tensor]]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = self.rotary_emb(qkv, position_ids=position_ids) + query, key, value = tf.unstack(tf.transpose(qkv, perm=[0, 3, 2, 1, 4]), 3, axis=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + scale = self.head_dim**-0.5 + attn_weights = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) * scale + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = tf.keras.activations.softmax(attn_weights, axis=-1) + attn_weights = tf.keras.layers.Dropout(rate=self.attention_dropout)(attn_weights, training=training) + attn_output = tf.matmul(attn_weights, value) + attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3]) + attn_output = tf.reshape(attn_output, [bs, -1, dim]) + return attn_output + + def call( + self, + hidden_states: tf.Tensor, + attention_mask, + position_ids, + sliding_window_mask, + training: bool, + ) -> tf.Tensor: + qkv = self.Wqkv(hidden_states) + bs = tf.shape(hidden_states)[0] + qkv = tf.reshape(qkv, [bs, -1, 3, self.num_heads, self.head_dim]) + + attn_outputs = self.eager_attention_forward( + qkv=qkv, + local_attention=self.local_attention, + bs=bs, + dim=self.all_head_size, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + training=training, + ) + hidden_states = self.out_drop(self.Wo(attn_outputs)) + return hidden_states + +class ModernBertEncoderLayer(tf.keras.layers.Layer): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None, name="EncoderLayer"): + super().__init__(name=name) + self.config = config + if layer_id == 0: + self.attn_norm = tf.keras.layers.Identity() + else: + self.attn_norm = tf.keras.layers.LayerNormalization(epsilon=config.norm_eps, center=config.norm_bias, name="AttnNorm") + self.attn = ModernBertAttention(config=config, layer_id=layer_id) + self.mlp_norm = tf.keras.layers.LayerNormalization(epsilon=config.norm_eps, center=config.norm_bias, name="MLPNorm") + self.mlp = ModernBertMLP(config) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: Optional[tf.Tensor] = None, + sliding_window_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + ) -> tf.Tensor: + + normed = self.attn_norm(hidden_states) + attn_outputs = self.attn( + normed, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + ) + hidden_states = hidden_states + attn_outputs + mlp_output = self.mlp(self.mlp_norm(hidden_states)) + hidden_states = hidden_states + mlp_output + + return hidden_states + + + +class ModernBert(tf.keras.layers.Layer): + def __init__(self, config, name="ModernBert"): + super().__init__(name=name) + self.config = config + self.embeddings = ModernBertEmbeddings(config) + self.layers = [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] + self.final_norm = tf.keras.layers.LayerNormalization(epsilon=config.norm_eps, center=config.norm_bias, name="FinalNorm") + + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + + def call( + self, + input_ids: Optional[tf.Tensor] = None, + attention_mask: Optional[tf.Tensor] = None, + sliding_window_mask: Optional[tf.Tensor] = None, + inputs_embeds: Optional[tf.Tensor] = None, + seq_len: Optional[int] = None, + training: bool = False, + ) -> Union[Tuple[tf.Tensor, ...], dict]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + position_ids = tf.expand_dims(tf.range(seq_len, dtype=tf.float32), 0) + + attention_mask, sliding_window_mask = self._update_attention_mask(attention_mask) + + hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds, training=training) + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + ) + + hidden_states = layer_outputs + + hidden_states = self.final_norm(hidden_states) + return hidden_states + + + def _update_attention_mask(self, attention_mask: tf.Tensor) -> tf.Tensor: + expanded_mask = tf.tile(attention_mask[:, None, None, :], [1, 1, tf.shape(attention_mask)[1], 1]) + inverted_mask = 1.0 - expanded_mask + + global_attention_mask = tf.where(tf.cast(inverted_mask, tf.bool), + tf.fill(tf.shape(inverted_mask), tf.float32.min), + inverted_mask) + + # Create position indices + rows = tf.expand_dims(tf.range(tf.shape(global_attention_mask)[2]), 0) + # Calculate distance between positions + distance = tf.abs(rows - tf.transpose(rows)) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = ( + tf.expand_dims(tf.expand_dims(distance <= self.config.local_attention // 2, 0), 0) + ) + # Combine with existing mask + sliding_window_mask = tf.where(tf.logical_not(window_mask), + tf.fill(tf.shape(global_attention_mask), tf.float32.min), + global_attention_mask) + return global_attention_mask, sliding_window_mask diff --git a/finetune/base_models/modern_bert/modernbert_to_finetune.py b/finetune/base_models/modern_bert/modernbert_to_finetune.py new file mode 100644 index 000000000..7c79cf9b5 --- /dev/null +++ b/finetune/base_models/modern_bert/modernbert_to_finetune.py @@ -0,0 +1,37 @@ + +if __name__ == "__main__": + import torch as nn + import numpy as np + import joblib as jl + + torch_model = nn.load("/Finetune/finetune/model/modern_bert/pytorch_model.bin") + finetune_model = dict() + mapping = { + "model.embeddings.tok_embeddings.weight": ("model/featurizer/ModernBert/Embedding/EmbeddingWithPadIdx/embeddings:0", False), + "model.embeddings.norm.weight": (f"model/featurizer/ModernBert/Embedding/EmbeddingNorm/gamma:0", False), + } + + for layer_i in range(22): + layer_name = "" if layer_i == 0 else f"_{layer_i}" + mapping[f"model.layers.{layer_i}.attn_norm.weight"] = (f"model/featurizer/ModernBert/EncoderLayer{layer_name}/AttnNorm/gamma:0", False) + mapping[f'model.layers.{layer_i}.attn.Wqkv.weight'] = (f'model/featurizer/ModernBert/EncoderLayer{layer_name}/Attention/Wqkv/kernel:0', True) + mapping[f'model.layers.{layer_i}.attn.Wo.weight'] = (f'model/featurizer/ModernBert/EncoderLayer{layer_name}/Attention/Wo/kernel:0', True) + mapping[f"model.layers.{layer_i}.mlp_norm.weight"] = (f"model/featurizer/ModernBert/EncoderLayer{layer_name}/MLPNorm/gamma:0", False) + mapping[f"model.layers.{layer_i}.mlp.Wi.weight"] = (f"model/featurizer/ModernBert/EncoderLayer{layer_name}/GLU/Wi/kernel:0", True) + mapping[f"model.layers.{layer_i}.mlp.Wo.weight"] = (f"model/featurizer/ModernBert/EncoderLayer{layer_name}/GLU/Wo/kernel:0", True) + mapping["model.final_norm.weight"] = ("model/featurizer/ModernBert/FinalNorm/gamma:0", False) + + print(mapping) + for k, v in torch_model.items(): + if k.startswith("head") or k.startswith("decoder"): + print("Skipping", k) + continue + print(f"===== {k} {v.shape} =====") + new_name, do_transpose = mapping[k] + in_numpy = v.cpu().numpy() + if do_transpose: + in_numpy = np.transpose(in_numpy) + print(f"Output = {new_name}, {in_numpy.shape}") + finetune_model[new_name] = in_numpy + + jl.dump(finetune_model, "modern_bert.jl") diff --git a/finetune/base_models/modern_bert/verify_features.py b/finetune/base_models/modern_bert/verify_features.py new file mode 100644 index 000000000..a6fc3c06d --- /dev/null +++ b/finetune/base_models/modern_bert/verify_features.py @@ -0,0 +1,23 @@ +from transformers import AutoTokenizer, ModernBertModel +import torch +from finetune.base_models.modern_bert.model import ModernBertModel as FTModernBertModel +from finetune import SequenceLabeler +import numpy as np + +if __name__ == "__main__": + text = "The quick brown fox jumps over the lazy dog" + finetune_model = SequenceLabeler(base_model=FTModernBertModel) + finetune_features = finetune_model.featurize_sequence([text])[0] + + tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") + model = ModernBertModel.from_pretrained("answerdotai/ModernBERT-base", attn_implementation="eager") + + inputs = tokenizer(text, return_tensors="pt") + print(inputs) + + with torch.no_grad(): + transformers_features = model(**inputs).last_hidden_state.to("cpu").numpy()[0][1:-1] + + print(finetune_features.shape == transformers_features.shape) + print(np.abs(finetune_features - transformers_features)) + diff --git a/finetune/datasets/reuters.py b/finetune/datasets/reuters.py index 46fbd0b3a..b7bfeab58 100644 --- a/finetune/datasets/reuters.py +++ b/finetune/datasets/reuters.py @@ -1,10 +1,6 @@ import os import requests -import codecs import json -import hashlib -import io -from pathlib import Path import pandas as pd from bs4 import BeautifulSoup as bs @@ -13,10 +9,10 @@ from finetune import SequenceLabeler from finetune.datasets import Dataset -from finetune.base_models import GPT, GPT2, TCN, RoBERTa -from finetune.base_models.huggingface.models import HFDebertaV3Base from finetune.encoding.sequence_encoder import finetune_to_indico_sequence from finetune.util.metrics import annotation_report, sequence_labeling_token_confusion +from finetune.base_models.modern_bert.model import ModernBertModel +from finetune.base_models.bert.model import RoBERTa XML_PATH = os.path.join("Data", "Sequence", "reuters.xml") DATA_PATH = os.path.join("Data", "Sequence", "reuters.json") @@ -79,8 +75,9 @@ def download(self): test_size=0.2, random_state=42 ) - model = SequenceLabeler(batch_size=1, n_epochs=3, val_size=0.0, max_length=512, chunk_long_sequences=True, subtoken_predictions=False, crf_sequence_labeling=True, multi_label_sequences=False) + model = SequenceLabeler(base_model=ModernBertModel, batch_size=4, n_epochs=5, val_size=0.0, max_length=512, chunk_long_sequences=True, subtoken_predictions=False, crf_sequence_labeling=True, multi_label_sequences=False) model.fit(trainX, trainY) + # print({k: v.shape for k, v in model.saver.variables.items()}) predictions = model.predict(testX) print(predictions) print(annotation_report(testY, predictions)) diff --git a/finetune/saver.py b/finetune/saver.py index 753bd1bca..599c5b02b 100644 --- a/finetune/saver.py +++ b/finetune/saver.py @@ -259,7 +259,6 @@ def init_fn(scaffold, session): all_vars = tf.compat.v1.global_variables() global_step_var = tf.compat.v1.train.get_global_step() - for var in all_vars: if self.restart_global_step and global_step_var is not None and global_step_var.name == var.name: continue