Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 40 additions & 31 deletions src/cell2sentence/csmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Pytorch, Huggingface imports
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model

# Local imports
from cell2sentence.prompt_formatter import PromptFormatter
Expand Down Expand Up @@ -79,23 +80,27 @@ def fine_tune(self,
train_args: TrainingArguments,
loss_on_response_only: bool = True,
top_k_genes: int = 100,
max_eval_samples: int = 500
max_eval_samples: int = 500,
lora_r: int = 8, # Low-rank adaptation dimension
lora_alpha: int = 16, # Scaling factor for LoRA
lora_dropout: float = 0.1 # Dropout for LoRA layers
):
"""
Fine tune a model using the provided CSData object data
Fine tune a model using LoRA with the provided CSData object data.

Arguments:
csdata: a CSData object to be used as input for finetuning.
alternatively, data can be any generator of sequential
text that satisfies the same functional contract as
a CSData object
task: name of finetuning task (see supported tasks in prompt_formatter.py)
train_args: Huggingface Trainer arguments object
loss_on_response_only: whether to take loss only on model's answer
top_k_genes: number of genes to use for each cell sentence
max_eval_samples: number of samples to use for validation
task: name of finetuning task (see supported tasks in prompt_formatter.py).
train_args: Huggingface Trainer arguments object.
loss_on_response_only: whether to take loss only on model's answer.
top_k_genes: number of genes to use for each cell sentence.
max_eval_samples: number of samples to use for validation.
lora_r: Rank dimension for LoRA adaptation.
lora_alpha: Scaling factor for LoRA layers.
lora_dropout: Dropout rate for LoRA layers.

Return:
None: an updated CSModel is generated in-place
None: an updated CSModel is generated in-place with LoRA fine-tuning.
"""
# Load data from csdata object
if csdata.dataset_backend == "arrow":
Expand All @@ -107,7 +112,7 @@ def fine_tune(self,
prompt_formatter = PromptFormatter(task=task, top_k_genes=top_k_genes)
formatted_hf_ds = prompt_formatter.format_hf_ds(hf_ds)

# Load model
# Load the model and apply LoRA
print("Reloading model from path on disk:", self.save_path)
model = AutoModelForCausalLM.from_pretrained(
self.save_path,
Expand All @@ -116,12 +121,18 @@ def fine_tune(self,
)
model = model.to(self.device)

# Apply LoRA to the model
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "v_proj"], # Specific target layers for LoRA in transformers
lora_dropout=lora_dropout,
bias="none"
)
model = get_peft_model(model, lora_config)

# Tokenize data using LLM tokenizer
# - this function applies a lambda function to tokenize each dataset split in the DatasetDict
if loss_on_response_only:
tokenization_function = tokenize_loss_on_response
else:
tokenization_function = tokenize_all
tokenization_function = tokenize_loss_on_response if loss_on_response_only else tokenize_all
formatted_hf_ds = formatted_hf_ds.map(
lambda batch: tokenization_function(batch, self.tokenizer),
batched=True,
Expand All @@ -130,13 +141,11 @@ def fine_tune(self,
batch_size=1000,
)

# Define parameters needed in data collator:
block_size = model.config.max_position_embeddings # maximum input sequence length possible
tokenizer = self.tokenizer # define tokenizer as variable here so it is accessible in dataloader
# Define parameters for data collator
block_size = model.config.max_position_embeddings
tokenizer = self.tokenizer

def data_collator(examples):
# Note: this data collator assumes we are not using flash attention, and pads samples
# to the max size in the batch. All sample lengths are capped at the size of the
# LLM's context).
max_length = max(list(map(lambda x: len(x["input_ids"]), examples)))
batch_input_ids, batch_attention_mask, batch_labels = [], [], []
for i in range(len(examples)):
Expand All @@ -146,13 +155,13 @@ def data_collator(examples):
assert len(sample_input_ids) == len(label_input_ids) == len(attention_mask)

size_diff = max_length - len(sample_input_ids)
final_input_ids = [tokenizer.pad_token_id] * (size_diff) + sample_input_ids
final_attention_mask = [0] * (size_diff) + attention_mask
final_label_input_ids = [-100] * (size_diff) + label_input_ids
final_input_ids = [tokenizer.pad_token_id] * size_diff + sample_input_ids
final_attention_mask = [0] * size_diff + attention_mask
final_label_input_ids = [-100] * size_diff + label_input_ids

batch_input_ids.append(final_input_ids[: block_size])
batch_attention_mask.append(final_attention_mask[: block_size])
batch_labels.append(final_label_input_ids[: block_size])
batch_input_ids.append(final_input_ids[:block_size])
batch_attention_mask.append(final_attention_mask[:block_size])
batch_labels.append(final_label_input_ids[:block_size])

return {
"input_ids": torch.tensor(batch_input_ids),
Expand All @@ -176,7 +185,7 @@ def data_collator(examples):
np.save(os.path.join(output_dir, 'sampled_eval_indices.npy'), np.array(sampled_eval_indices, dtype=np.int64))
eval_dataset = eval_dataset.select(sampled_eval_indices)

# Define Trainer
# Define Trainer with LoRA model
trainer = Trainer(
model=model,
args=train_args,
Expand All @@ -186,7 +195,7 @@ def data_collator(examples):
tokenizer=self.tokenizer
)
trainer.train()
print(f"Finetuning completed. Updated model saved to disk at: {output_dir}")
print(f"LoRA finetuning completed. Updated model saved to disk at: {output_dir}")

def generate_from_prompt(self, model, prompt, max_num_tokens=1024, **kwargs):
"""
Expand Down
19 changes: 10 additions & 9 deletions src/cell2sentence/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def benchmark_expression_conversion(
plot = (
pn.ggplot(benchmark_df, pn.aes(x=x, y=y))
+ pn.geom_abline(
slope=linear_reg.coef_,
intercept=linear_reg.intercept_,
slope=linear_reg.coef_[0], # Use the first element directly
intercept=linear_reg.intercept_, # Use directly
color="darkorange",
)
+ pn.geom_point(color="blue", size=0.2)
Expand Down Expand Up @@ -275,20 +275,21 @@ def benchmark_expression_conversion(
"x_axis": [x],
"y_axis": [y],
"threshold": [BASE10_THRESHOLD],
"slope": [linear_reg.coef_.item()],
"intercept": [linear_reg.intercept_.item()],
"r_squared": [r_squared_score.item()],
"pearson_r_statistic": [pearson_r_score.statistic.item()],
"pearson_r_pvalue": [pearson_r_score.pvalue.item()],
"spearman_r_statistic": [spearman_r_score.statistic.item()],
"spearman_r_pvalue": [spearman_r_score.pvalue.item()],
"slope": [linear_reg.coef_[0]], # Use the first element directly
"intercept": [linear_reg.intercept_], # Use directly
"r_squared": [r_squared_score], # Directly use the float
"pearson_r_statistic": [pearson_r_score[0]], # Use the first element directly
"pearson_r_pvalue": [pearson_r_score[1]], # Use the second element directly
"spearman_r_statistic": [spearman_r_score[0]], # Use the first element directly
"spearman_r_pvalue": [spearman_r_score[1]], # Use the second element directly
})

# save benchmarking results
benchmark_results_filepath = os.path.join(benchmark_save_dir, "c2s_transformation_metrics.csv")
result_df.to_csv(benchmark_results_filepath, index=False)



def build_arrow_dataset(
cell_names: list,
sentences: list,
Expand Down