diff --git a/src/cell2sentence/csmodel.py b/src/cell2sentence/csmodel.py index c391177..a3ca4a4 100644 --- a/src/cell2sentence/csmodel.py +++ b/src/cell2sentence/csmodel.py @@ -18,6 +18,7 @@ # Pytorch, Huggingface imports import torch from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments +from peft import LoraConfig, get_peft_model # Local imports from cell2sentence.prompt_formatter import PromptFormatter @@ -79,23 +80,27 @@ def fine_tune(self, train_args: TrainingArguments, loss_on_response_only: bool = True, top_k_genes: int = 100, - max_eval_samples: int = 500 + max_eval_samples: int = 500, + lora_r: int = 8, # Low-rank adaptation dimension + lora_alpha: int = 16, # Scaling factor for LoRA + lora_dropout: float = 0.1 # Dropout for LoRA layers ): """ - Fine tune a model using the provided CSData object data + Fine tune a model using LoRA with the provided CSData object data. Arguments: csdata: a CSData object to be used as input for finetuning. - alternatively, data can be any generator of sequential - text that satisfies the same functional contract as - a CSData object - task: name of finetuning task (see supported tasks in prompt_formatter.py) - train_args: Huggingface Trainer arguments object - loss_on_response_only: whether to take loss only on model's answer - top_k_genes: number of genes to use for each cell sentence - max_eval_samples: number of samples to use for validation + task: name of finetuning task (see supported tasks in prompt_formatter.py). + train_args: Huggingface Trainer arguments object. + loss_on_response_only: whether to take loss only on model's answer. + top_k_genes: number of genes to use for each cell sentence. + max_eval_samples: number of samples to use for validation. + lora_r: Rank dimension for LoRA adaptation. + lora_alpha: Scaling factor for LoRA layers. + lora_dropout: Dropout rate for LoRA layers. + Return: - None: an updated CSModel is generated in-place + None: an updated CSModel is generated in-place with LoRA fine-tuning. """ # Load data from csdata object if csdata.dataset_backend == "arrow": @@ -107,7 +112,7 @@ def fine_tune(self, prompt_formatter = PromptFormatter(task=task, top_k_genes=top_k_genes) formatted_hf_ds = prompt_formatter.format_hf_ds(hf_ds) - # Load model + # Load the model and apply LoRA print("Reloading model from path on disk:", self.save_path) model = AutoModelForCausalLM.from_pretrained( self.save_path, @@ -116,12 +121,18 @@ def fine_tune(self, ) model = model.to(self.device) + # Apply LoRA to the model + lora_config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=["q_proj", "v_proj"], # Specific target layers for LoRA in transformers + lora_dropout=lora_dropout, + bias="none" + ) + model = get_peft_model(model, lora_config) + # Tokenize data using LLM tokenizer - # - this function applies a lambda function to tokenize each dataset split in the DatasetDict - if loss_on_response_only: - tokenization_function = tokenize_loss_on_response - else: - tokenization_function = tokenize_all + tokenization_function = tokenize_loss_on_response if loss_on_response_only else tokenize_all formatted_hf_ds = formatted_hf_ds.map( lambda batch: tokenization_function(batch, self.tokenizer), batched=True, @@ -130,13 +141,11 @@ def fine_tune(self, batch_size=1000, ) - # Define parameters needed in data collator: - block_size = model.config.max_position_embeddings # maximum input sequence length possible - tokenizer = self.tokenizer # define tokenizer as variable here so it is accessible in dataloader + # Define parameters for data collator + block_size = model.config.max_position_embeddings + tokenizer = self.tokenizer + def data_collator(examples): - # Note: this data collator assumes we are not using flash attention, and pads samples - # to the max size in the batch. All sample lengths are capped at the size of the - # LLM's context). max_length = max(list(map(lambda x: len(x["input_ids"]), examples))) batch_input_ids, batch_attention_mask, batch_labels = [], [], [] for i in range(len(examples)): @@ -146,13 +155,13 @@ def data_collator(examples): assert len(sample_input_ids) == len(label_input_ids) == len(attention_mask) size_diff = max_length - len(sample_input_ids) - final_input_ids = [tokenizer.pad_token_id] * (size_diff) + sample_input_ids - final_attention_mask = [0] * (size_diff) + attention_mask - final_label_input_ids = [-100] * (size_diff) + label_input_ids + final_input_ids = [tokenizer.pad_token_id] * size_diff + sample_input_ids + final_attention_mask = [0] * size_diff + attention_mask + final_label_input_ids = [-100] * size_diff + label_input_ids - batch_input_ids.append(final_input_ids[: block_size]) - batch_attention_mask.append(final_attention_mask[: block_size]) - batch_labels.append(final_label_input_ids[: block_size]) + batch_input_ids.append(final_input_ids[:block_size]) + batch_attention_mask.append(final_attention_mask[:block_size]) + batch_labels.append(final_label_input_ids[:block_size]) return { "input_ids": torch.tensor(batch_input_ids), @@ -176,7 +185,7 @@ def data_collator(examples): np.save(os.path.join(output_dir, 'sampled_eval_indices.npy'), np.array(sampled_eval_indices, dtype=np.int64)) eval_dataset = eval_dataset.select(sampled_eval_indices) - # Define Trainer + # Define Trainer with LoRA model trainer = Trainer( model=model, args=train_args, @@ -186,7 +195,7 @@ def data_collator(examples): tokenizer=self.tokenizer ) trainer.train() - print(f"Finetuning completed. Updated model saved to disk at: {output_dir}") + print(f"LoRA finetuning completed. Updated model saved to disk at: {output_dir}") def generate_from_prompt(self, model, prompt, max_num_tokens=1024, **kwargs): """ diff --git a/src/cell2sentence/utils.py b/src/cell2sentence/utils.py index 627fdf7..598e691 100644 --- a/src/cell2sentence/utils.py +++ b/src/cell2sentence/utils.py @@ -208,8 +208,8 @@ def benchmark_expression_conversion( plot = ( pn.ggplot(benchmark_df, pn.aes(x=x, y=y)) + pn.geom_abline( - slope=linear_reg.coef_, - intercept=linear_reg.intercept_, + slope=linear_reg.coef_[0], # Use the first element directly + intercept=linear_reg.intercept_, # Use directly color="darkorange", ) + pn.geom_point(color="blue", size=0.2) @@ -275,13 +275,13 @@ def benchmark_expression_conversion( "x_axis": [x], "y_axis": [y], "threshold": [BASE10_THRESHOLD], - "slope": [linear_reg.coef_.item()], - "intercept": [linear_reg.intercept_.item()], - "r_squared": [r_squared_score.item()], - "pearson_r_statistic": [pearson_r_score.statistic.item()], - "pearson_r_pvalue": [pearson_r_score.pvalue.item()], - "spearman_r_statistic": [spearman_r_score.statistic.item()], - "spearman_r_pvalue": [spearman_r_score.pvalue.item()], + "slope": [linear_reg.coef_[0]], # Use the first element directly + "intercept": [linear_reg.intercept_], # Use directly + "r_squared": [r_squared_score], # Directly use the float + "pearson_r_statistic": [pearson_r_score[0]], # Use the first element directly + "pearson_r_pvalue": [pearson_r_score[1]], # Use the second element directly + "spearman_r_statistic": [spearman_r_score[0]], # Use the first element directly + "spearman_r_pvalue": [spearman_r_score[1]], # Use the second element directly }) # save benchmarking results @@ -289,6 +289,7 @@ def benchmark_expression_conversion( result_df.to_csv(benchmark_results_filepath, index=False) + def build_arrow_dataset( cell_names: list, sentences: list,