Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bad performance on gene value prediction #285

Open
Thodorissio opened this issue Jan 28, 2025 · 0 comments
Open

Bad performance on gene value prediction #285

Thodorissio opened this issue Jan 28, 2025 · 0 comments

Comments

@Thodorissio
Copy link

Bad Performance Issue

Hello and thank you for your awesome research. I would like your help with a problem I am facing. I am trying to use your pretrained foundational model to finetune it for a specific disease. However, after I implemented my training code and started training I noticed the training loss did not improve at all. Therefore I tried performing inference on some basic data from cellXgene portal and the results were confusing.
Here is the data sample I used for the reproducible example.
I also experimented with the following datasets (yielding similar results):
second_dataset, third_dataset (B cell compartment data)

Is there a bug or a bad configuration in my example during the processing, model loading or inference , or is this the actual performance ?

The most confusing point is that all the output values are in the range of 30 (even the values that were not masked)

Key configurations of my example:

{
    "hvg": 1500,
    "mask_ratio": 0.15,
    "mask_value": -1,
    "pad_token": "<pad>",
    "pad_value": 0,
    "binning": 51,
    "d_hid": 512,
    "d_model": 512,
    "nhead": 8,
    "nlayers": 12,
    "dropout": 0,
    "do_mvc": true
    "filter_gene_by_counts": 3,
    "filter_cell_by_counts": 1
}

Reproducible Example

import torch
import logging
import scanpy as sc
import numpy as np

from tqdm import tqdm

from scgpt.tokenizer import GeneVocab, gene_tokenizer, random_mask_value
from scgpt.preprocess import Preprocessor
from scgpt.loss import masked_mse_loss
from scgpt.model import TransformerModel
from scgpt.utils.util import load_pretrained
from scgpt.trainer import prepare_dataloader

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_path = "/home/tsiozos/scgpt_tests/models/scGPT_human/vocab.json"
vocab = GeneVocab.from_file(vocab_path)
logging.basicConfig(level=logging.INFO)

# loading data
data_path = "/home/tsiozos/scgpt_tests/data/scGPT-data/strati_2023_bone_marrow.h5ad"

adata = sc.read_h5ad(data_path)
adata.var.index = adata.var.feature_name

logging.info(f"Loaded data with shape: {adata.shape}")

preprocessor = Preprocessor(
    use_key="X",
    result_normed_key="X_normed",
    result_log1p_key="X_log1p",
    binning=51,
    filter_gene_by_counts=3,
    filter_cell_by_counts=1,
)
preprocessor(adata)
logging.info("Preprocessed data!")

# loading model
pretrained_path = "/home/tsiozos/scgpt_tests/models/scGPT_human/best_model.pt"

model = TransformerModel(
    ntoken=len(vocab),
    vocab=vocab,
    d_hid=512,
    d_model=512,
    nhead=8,
    nlayers=12,
    dropout=0,
    do_mvc=True,
)

pretrained_params = torch.load(pretrained_path, map_location=DEVICE)
model = load_pretrained(model, pretrained_params, verbose=False)
model.to(DEVICE)
logging.info("Loaded pretrained model!")

# setting up gene ids
gene_ids = adata.var_names
gene_ids = gene_ids.tolist()

genes_in_vocab = sum([1 for gene in gene_ids if gene in vocab])
logging.info(f"Genes in vocab: {genes_in_vocab:,}/{len(gene_ids):,}")

gene_ids = vocab(gene_ids)
gene_ids = np.array(gene_ids)

# preparing data

n_hvg = 1500
pad_token = "<pad>"
pad_value = 0
mask_value = -1

expression_values = adata.layers["X_binned"]

tokenized_values = gene_tokenizer.tokenize_and_pad_batch(
    data=expression_values,
    gene_ids=gene_ids,
    max_len=n_hvg + 1,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
)

gene_ids = tokenized_values["genes"]
expression_values = tokenized_values["values"]

masked_values = random_mask_value(
    values=expression_values,
    mask_value=mask_value,
)

train_data = {
    "gene_ids": gene_ids,
    "values": masked_values,
    "target_values": expression_values,
}

train_loader = prepare_dataloader(
    data_pt=train_data,
    batch_size=8,
)

# predicting
logging.info("Starting Predictions...")
gep_losses = []
mvc_losses = []
losses = []
model.eval()
for train_batch in tqdm(train_loader):
    input_gene_ids = train_batch["gene_ids"].to(DEVICE)
    input_values = train_batch["values"].to(DEVICE)
    target_values = train_batch["target_values"].to(DEVICE)

    src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
    masked_positions = input_values.eq(mask_value)

    with torch.cuda.amp.autocast():
        output_dict = model(
            input_gene_ids,
            input_values,
            src_key_padding_mask=src_key_padding_mask,
            MVC=True,
            CLS=False,
            ECS=False,
        )
    loss_gep = masked_mse_loss(
        output_dict["mlm_output"], target_values, masked_positions
    )
    loss_mvc = masked_mse_loss(
        output_dict["mvc_output"], target_values, masked_positions
    )
    loss = loss_gep + loss_mvc

    gep_losses.append(loss_gep.item())
    mvc_losses.append(loss_mvc.item())
    losses.append(loss.item())

logging.info("Inference finished!")

# presenting result
average_gep_loss = sum(gep_losses) / len(gep_losses)
average_mvc_loss = sum(mvc_losses) / len(mvc_losses)
average_loss = sum(losses) / len(losses)

logging.info("Results:")
logging.info(f"Average GEP Loss: {average_gep_loss:.2f}")
logging.info(f"Average MVC Loss: {average_mvc_loss:.2f}")
logging.info(f"Average Loss: {average_loss:.2f}")

logging.info("Inspecting predictions...")
logging.info("Predictions sample:")
logging.info(output_dict["mlm_output"][0, :10])

logging.info("Target sample:")
logging.info(target_values[0, :10])

Output

/home/tsiozos/.conda/envs/scgpt-env/lib/python3.10/site-packages/scgpt/model/model.py:21: UserWarning: flash_attn is not installed
  warnings.warn("flash_attn is not installed")
/home/tsiozos/.conda/envs/scgpt-env/lib/python3.10/site-packages/scgpt/model/multiomic_model.py:19: UserWarning: flash_attn is not installed
  warnings.warn("flash_attn is not installed")
INFO:root:Loaded data with shape: (92676, 26832)
scGPT - INFO - Filtering genes by counts ...
scGPT - INFO - Filtering cells by counts ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Binning data ...
INFO:root:Preprocessed data!
INFO:root:Loaded pretrained model!
INFO:root:Genes in vocab: 19,325/25,515
INFO:root:Starting Predictions...
100%|██████████████████████████████████████████████████████████████|  11585 /11585 [00:15<03:47, 47.84it/s]
INFO:root:Inference finished!
INFO:root:Results:
INFO:root:Average GEP Loss: 226.49
INFO:root:Average MVC Loss: 154.32
INFO:root:Average Loss: 380.81
INFO:root:Inspecting predictions...
INFO:root:Predictions sample:
INFO:root:tensor([30.0469, 31.2031, 28.8281, 29.6719, 29.0938, 29.3125, 29.4531, 28.9844,
        29.5156, 32.6250], device='cuda:0', dtype=torch.float16,
       grad_fn=<SliceBackward0>)
INFO:root:Target sample:
INFO:root:tensor([ 0., 19., 16., 35., 28., 19.,  8.,  3., 13., 42.], device='cuda:0')

Package Versions

python==3.10.13
scanpy==1.10.4
scgpt==0.2.1
numpy==1.26.4
torch==2.1.2
flash_attn==1.0.4 (although it fails to be imported)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant