Skip to content

Commit

Permalink
ADD: Initial run at modern bert - feats within 1e-7
Browse files Browse the repository at this point in the history
  • Loading branch information
benleetownsend committed Jan 14, 2025
1 parent 1c3b4d7 commit 666d036
Show file tree
Hide file tree
Showing 8 changed files with 518 additions and 8 deletions.
Empty file.
62 changes: 62 additions & 0 deletions finetune/base_models/modern_bert/encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import logging
import os
import finetune
from tokenizers import Tokenizer
from finetune.encoding.input_encoder import BaseEncoder
from finetune.encoding.input_encoder import EncodedOutput

FINETUNE_FOLDER = os.path.dirname(finetune.__file__)
TOKENIZER_PATH = os.path.join(FINETUNE_FOLDER, "model", "modern_bert", "tokenizer.json")

LOGGER = logging.getLogger("finetune")

class ModernBertEncoder(BaseEncoder):
def __init__(self):
self.tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
special_tokens_map = {tok.content: k for k, tok in self.tokenizer.get_added_tokens_decoder().items()}
self.start_token = special_tokens_map["[CLS]"]
self.delimiter_token = special_tokens_map["[SEP]"]
self.mask_token = special_tokens_map["[MASK]"]
self.end_token = special_tokens_map["[SEP]"]
self.initialized = True
self.UNK_IDX = None

@property
def vocab_size(self):
return self.tokenizer.get_vocab_size()

def _encode(self, texts):
batch_tokens = []
batch_token_idxs = []
batch_char_ends = []
batch_char_starts = []
for text in texts:
encoded = self.tokenizer.encode(text, add_special_tokens=False)
batch_tokens.append(encoded.tokens)
batch_token_idxs.append(encoded.ids)
token_ends = []
token_starts = []
for start, end in encoded.offsets:
if token_ends:
# Finetune requires that tokens never overlap.
# This happens in huggingface tokenizers when
# a single character is split across multiple tokens.
start = max(token_ends[-1], start)
end = max(end, start)
token_starts.append(start)
token_ends.append(end)

batch_char_ends.append(token_ends)
batch_char_starts.append(token_starts)

output = EncodedOutput(
token_ids=batch_token_idxs,
tokens=batch_tokens,
token_ends=batch_char_ends,
token_starts=batch_char_starts,
)
return output

def decode(self, ids):
output = self.tokenizer.decode(ids, skip_special_tokens=True)
return output
62 changes: 62 additions & 0 deletions finetune/base_models/modern_bert/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import tensorflow as tf
from finetune.base_models.modern_bert.modelling import ModernBert, ModernBertConfig
from finetune.base_models.modern_bert.encoding import ModernBertEncoder
from finetune.base_models import SourceModel


def featurizer(
X, encoder, config, train=False, reuse=None, lengths=None, **kwargs
):
initial_shape = tf.shape(input=X)
X = tf.reshape(X, shape=tf.concat(([-1], initial_shape[-1:]), 0))
X.set_shape([None, None])
delimiters = tf.cast(tf.equal(X, encoder.delimiter_token), tf.int32)

seq_length = tf.shape(input=delimiters)[1]
mask = tf.sequence_mask(lengths, maxlen=seq_length, dtype=tf.float32)
with tf.compat.v1.variable_scope("model/featurizer", reuse=reuse):
# TODO: plumb in the config to the finetune config.
model = ModernBert(config=ModernBertConfig())
embedding = model.embeddings
sequence_out = model(input_ids=X, attention_mask=mask, training=train, seq_len=seq_length)
pooled_out = sequence_out[:, 0, :]
pooled_out.set_shape([None, config.n_embed])
n_embed = pooled_out.shape[-1]

features = tf.reshape(
pooled_out,
shape=tf.concat((initial_shape[:-1], [n_embed]), 0),
)
sequence_features = tf.reshape(
sequence_out,
shape=tf.concat((initial_shape, [n_embed]), 0),
)

output_state = {
"embedding": embedding,
"features": features,
"sequence_features": sequence_features,
"lengths": lengths,
"inputs": X,
}

return output_state


class ModernBertModel(SourceModel):
encoder = ModernBertEncoder
featurizer = featurizer
max_length = 2048

settings = {
"base_model_path": os.path.join("modern_bert", "modern_bert.jl"),
"n_layer": 22,
"train_embeddings": True,
"num_layers_trained": 22,
"n_embed": 768,
"max_length": max_length,
"include_bos_eos": True,
}
required_files = []

Loading

0 comments on commit 666d036

Please sign in to comment.