Skip to content

Commit

Permalink
use HfArgumentParser()
Browse files Browse the repository at this point in the history
  • Loading branch information
stolzenp committed Feb 8, 2024
1 parent e6d20d5 commit 0a36f6c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 39 deletions.
14 changes: 14 additions & 0 deletions src/small_model_training/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"model_name_or_path": "bert-base-uncased",
"tokenizer_name": "distilbert-base-uncased",
"output_dir":"my_awesome_model",
"learning_rate":2e-5,
"per_device_train_batch_size":16,
"per_device_eval_batch_size":16,
"num_train_epochs":2,
"weight_decay":0.01,
"evaluation_strategy":"epoch",
"save_strategy":"epoch",
"load_best_model_at_end":true,
"push_to_hub":false
}
54 changes: 19 additions & 35 deletions src/small_model_training/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
import numpy as np
from dataclasses import dataclass, field
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, HfArgumentParser
from datasets import load_dataset
import json
import evaluate
@dataclass
class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model from huggingface.co/models"}
)
tokenizer_name: str = field(
metadata={"help": "Path to pretrained tokenizer or model from huggingface.co/models"}
)

def get_influential_subset(dataset):
# get parameters from dict
data = get_training_parameters()
small_model = data['small_model']
batch_size = data['batch_size']
# get parameters from config
parser = HfArgumentParser((ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_json_file('config.json')

tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name)

def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)

tokenized_imdb = dataset.map(preprocess_function, batched=True)

Expand All @@ -20,20 +32,7 @@ def get_influential_subset(dataset):
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id
)

training_args = TrainingArguments(
output_dir="my_awesome_model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=False,
model_args.model_name_or_path, num_labels=2, id2label=id2label, label2id=label2id
)

trainer = Trainer(
Expand All @@ -55,28 +54,13 @@ def get_influential_subset(dataset):
# TO-DO: check for pre-processing
return inf_subset

def get_training_parameters():

# open config file
f = open('training_parameters.json')

# return json object as dict
data = json.load(f)

# close file
f.close()
return data

def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)

def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)

accuracy = evaluate.load("accuracy")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# example dataset for debugging
imdb = load_dataset("imdb")
Expand Down
4 changes: 0 additions & 4 deletions src/small_model_training/training_parameters.json

This file was deleted.

0 comments on commit 0a36f6c

Please sign in to comment.