-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ADD: Initial run at modern bert - feats within 1e-7
- Loading branch information
1 parent
1c3b4d7
commit 666d036
Showing
8 changed files
with
518 additions
and
8 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import logging | ||
import os | ||
import finetune | ||
from tokenizers import Tokenizer | ||
from finetune.encoding.input_encoder import BaseEncoder | ||
from finetune.encoding.input_encoder import EncodedOutput | ||
|
||
FINETUNE_FOLDER = os.path.dirname(finetune.__file__) | ||
TOKENIZER_PATH = os.path.join(FINETUNE_FOLDER, "model", "modern_bert", "tokenizer.json") | ||
|
||
LOGGER = logging.getLogger("finetune") | ||
|
||
class ModernBertEncoder(BaseEncoder): | ||
def __init__(self): | ||
self.tokenizer = Tokenizer.from_file(TOKENIZER_PATH) | ||
special_tokens_map = {tok.content: k for k, tok in self.tokenizer.get_added_tokens_decoder().items()} | ||
self.start_token = special_tokens_map["[CLS]"] | ||
self.delimiter_token = special_tokens_map["[SEP]"] | ||
self.mask_token = special_tokens_map["[MASK]"] | ||
self.end_token = special_tokens_map["[SEP]"] | ||
self.initialized = True | ||
self.UNK_IDX = None | ||
|
||
@property | ||
def vocab_size(self): | ||
return self.tokenizer.get_vocab_size() | ||
|
||
def _encode(self, texts): | ||
batch_tokens = [] | ||
batch_token_idxs = [] | ||
batch_char_ends = [] | ||
batch_char_starts = [] | ||
for text in texts: | ||
encoded = self.tokenizer.encode(text, add_special_tokens=False) | ||
batch_tokens.append(encoded.tokens) | ||
batch_token_idxs.append(encoded.ids) | ||
token_ends = [] | ||
token_starts = [] | ||
for start, end in encoded.offsets: | ||
if token_ends: | ||
# Finetune requires that tokens never overlap. | ||
# This happens in huggingface tokenizers when | ||
# a single character is split across multiple tokens. | ||
start = max(token_ends[-1], start) | ||
end = max(end, start) | ||
token_starts.append(start) | ||
token_ends.append(end) | ||
|
||
batch_char_ends.append(token_ends) | ||
batch_char_starts.append(token_starts) | ||
|
||
output = EncodedOutput( | ||
token_ids=batch_token_idxs, | ||
tokens=batch_tokens, | ||
token_ends=batch_char_ends, | ||
token_starts=batch_char_starts, | ||
) | ||
return output | ||
|
||
def decode(self, ids): | ||
output = self.tokenizer.decode(ids, skip_special_tokens=True) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import os | ||
import tensorflow as tf | ||
from finetune.base_models.modern_bert.modelling import ModernBert, ModernBertConfig | ||
from finetune.base_models.modern_bert.encoding import ModernBertEncoder | ||
from finetune.base_models import SourceModel | ||
|
||
|
||
def featurizer( | ||
X, encoder, config, train=False, reuse=None, lengths=None, **kwargs | ||
): | ||
initial_shape = tf.shape(input=X) | ||
X = tf.reshape(X, shape=tf.concat(([-1], initial_shape[-1:]), 0)) | ||
X.set_shape([None, None]) | ||
delimiters = tf.cast(tf.equal(X, encoder.delimiter_token), tf.int32) | ||
|
||
seq_length = tf.shape(input=delimiters)[1] | ||
mask = tf.sequence_mask(lengths, maxlen=seq_length, dtype=tf.float32) | ||
with tf.compat.v1.variable_scope("model/featurizer", reuse=reuse): | ||
# TODO: plumb in the config to the finetune config. | ||
model = ModernBert(config=ModernBertConfig()) | ||
embedding = model.embeddings | ||
sequence_out = model(input_ids=X, attention_mask=mask, training=train, seq_len=seq_length) | ||
pooled_out = sequence_out[:, 0, :] | ||
pooled_out.set_shape([None, config.n_embed]) | ||
n_embed = pooled_out.shape[-1] | ||
|
||
features = tf.reshape( | ||
pooled_out, | ||
shape=tf.concat((initial_shape[:-1], [n_embed]), 0), | ||
) | ||
sequence_features = tf.reshape( | ||
sequence_out, | ||
shape=tf.concat((initial_shape, [n_embed]), 0), | ||
) | ||
|
||
output_state = { | ||
"embedding": embedding, | ||
"features": features, | ||
"sequence_features": sequence_features, | ||
"lengths": lengths, | ||
"inputs": X, | ||
} | ||
|
||
return output_state | ||
|
||
|
||
class ModernBertModel(SourceModel): | ||
encoder = ModernBertEncoder | ||
featurizer = featurizer | ||
max_length = 2048 | ||
|
||
settings = { | ||
"base_model_path": os.path.join("modern_bert", "modern_bert.jl"), | ||
"n_layer": 22, | ||
"train_embeddings": True, | ||
"num_layers_trained": 22, | ||
"n_embed": 768, | ||
"max_length": max_length, | ||
"include_bos_eos": True, | ||
} | ||
required_files = [] | ||
|
Oops, something went wrong.