Skip to content

Commit

Permalink
implemented load_from_spec() function, added private methods for fit()
Browse files Browse the repository at this point in the history
  • Loading branch information
wanxinran committed Mar 17, 2024
1 parent 9fcc5d5 commit ddec969
Showing 1 changed file with 86 additions and 3 deletions.
89 changes: 86 additions & 3 deletions llments/lm/base/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(
self.text_generator: TextGenerationPipeline = pipeline(
"text-generation", model=model, device=device
)
self.model_name = model
self.device = device

def fit(
self, target: LanguageModel, task_description: str | None = None
Expand All @@ -36,7 +38,62 @@ def fit(
Returns:
The fitted language model.
"""
raise NotImplementedError("This is not implemented yet.")
inputs, labels = self._prepare_training_data(target)
dataset = GeneratedDataset(inputs, labels)

# TODO: use HF Trainer class to train the model


def _prepare_training_data(self, target: LanguageModel):
"""Generate data from the target language model, using generate() function.
Helper function of fit().
Args:
target: target language model.
Returns:
inputs: Generated data (type: HF BatchEncoding): result from calling HF tokenizer.
labels: "Up shift" each token to create the labels.
"""
# Generate samples from the target model, consider this as one batch.
samples = target.generate(condition=None, do_sample=True, max_length=50, temperature=1.0, num_return_sequences=1000)
try:
from transformers import AutoTokenizer
except ImportError:
raise ImportError("You need to install the `transformers` package to use this method.")

tokenizer = AutoTokenizer.from_pretrained(self.model_name)
inputs = tokenizer(samples, padding=True, truncation=True, return_tensors="pt") # return pytorch tensor

# Prepare labels by shifting
labels = inputs.input_ids[:, 1:].clone()
try:
import torch
except:
raise ImportError("You need to install/import 'torch' package to use this function.")
labels = torch.nn.functional.pad(labels, (0, 1), value=-100) # Pad with -100 on the right

# Adjust input_ids by removing the last token to match labels' size
inputs.input_ids = inputs.input_ids[:, :-1]
return inputs, labels

def _prepare_training_dataset(self, inputs, labels):
"""Return customized Dataset object, to be used in HF Trainer class.
Helper function of fit()
Args:
inputs: generate inputs
labels: labels from generate inputs
Returns:
Dataset object
"""

try:
import torch
from torch.utils.data import Dataset
except:
raise ImportError("You need both 'torch' and 'torch.utils.data' packages to use this function.")
return GeneratedDataset(inputs, labels)


def generate(
self,
Expand Down Expand Up @@ -86,14 +143,40 @@ def set_seed(self, seed: int):
)
set_seed(seed)

import Dataset
class GeneratedDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels

def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item

def __len__(self):
return len(self.labels)


def load_from_spec(spec_file: str) -> HuggingFaceLM:
"""Load a language model from a specification file.
Args:
spec_file: The path to the specification file.
The file should specifies the model identifier "model" and any other relevant parameters such as "device".
Returns:
A language model.
A HuggingFaceLM instance.
"""
raise NotImplementedError("This is not implemented yet.")
try:
import json
except ImportError:
raise ImportError("You need to import/install json to use this function.")
with open(spec_file, 'r') as file:
spec = json.load(file)

model_name = spec.get('model')
device = spec.get('device', None)

return HuggingFaceLM(model=model_name, device=device)

0 comments on commit ddec969

Please sign in to comment.