Skip to content

Commit

Permalink
fix better model
Browse files Browse the repository at this point in the history
  • Loading branch information
zouter committed May 23, 2024
1 parent 7ca6abb commit a6ba19b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 37 deletions.
4 changes: 2 additions & 2 deletions src/chromatinhd/models/diff/interpret/slices.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pandas as pd
import numpy as np
import torch
import xarray as xr


def filter_slices_probs(prob_cutoff=0.0):
import xarray as xr

prob_cutoff = 0.0
# prob_cutoff = -1.
# prob_cutoff = -4.
Expand Down
138 changes: 103 additions & 35 deletions src/chromatinhd/models/pred/model/better.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@

from chromatinhd import get_default_device

from .loss import paircor, paircor_loss, region_paircor_loss, pairzmse_loss, region_pairzmse_loss
from .loss import (
paircor,
paircor_loss,
region_paircor_loss,
pairzmse_loss,
region_pairzmse_loss,
)

from typing import Any

Expand Down Expand Up @@ -108,23 +114,31 @@ def __init__(
self.encoder = ExponentialEncoding(window=fragments.regions.window)
elif encoder == "radial_binary":
self.encoder = RadialBinaryEncoding(
n_frequencies=n_frequencies, window=fragments.regions.window, **encoder_kwargs
n_frequencies=n_frequencies,
window=fragments.regions.window,
**encoder_kwargs,
)
elif encoder == "radial_binary2":
self.encoder = RadialBinaryEncoding2(window=fragments.regions.window, **encoder_kwargs)
elif encoder == "radial_binary_center":
self.encoder = RadialBinaryCenterEncoding(
n_frequencies=n_frequencies, window=fragments.regions.window, **encoder_kwargs
n_frequencies=n_frequencies,
window=fragments.regions.window,
**encoder_kwargs,
)
elif encoder == "linear_binary":
self.encoder = LinearBinaryEncoding(
n_frequencies=n_frequencies, window=fragments.regions.window, **encoder_kwargs
n_frequencies=n_frequencies,
window=fragments.regions.window,
**encoder_kwargs,
)
elif encoder == "spline_binary":
self.encoder = SplineBinaryEncoding(window=fragments.regions.window, **encoder_kwargs)
elif encoder == "tophat_binary":
self.encoder = TophatBinaryEncoding(
n_frequencies=n_frequencies, window=fragments.regions.window, **encoder_kwargs
n_frequencies=n_frequencies,
window=fragments.regions.window,
**encoder_kwargs,
)
elif encoder == "nothing":
self.encoder = OneEncoding()
Expand Down Expand Up @@ -346,7 +360,11 @@ def __init__(
sublayers.append(nonlinear())
if i > 0 and batchnorm:
sublayers.append(
torch.nn.BatchNorm1d(self.n_embedding_dimensions, affine=False, track_running_stats=False)
torch.nn.BatchNorm1d(
self.n_embedding_dimensions,
affine=False,
track_running_stats=False,
)
)
if i > 0 and layernorm:
sublayers.append(torch.nn.LayerNorm((self.n_embedding_dimensions,)))
Expand Down Expand Up @@ -443,7 +461,7 @@ def create(
fragment_embedder_kwargs=None,
**kwargs: Any,
) -> None:
self = super(Model, cls).create(path=path, fragments=fragments, clustering=clustering, reset=overwrite)
self = super(Model, cls).create(path=path, fragments=fragments, transcriptome=transcriptome, reset=overwrite)

self.fragments = fragments
self.transcriptome = transcriptome
Expand Down Expand Up @@ -558,7 +576,8 @@ def forward_multiple(self, data, fragments_oi, min_fragments=1):

if hasattr(self, "library_size_encoder"):
cell_region_embedding = torch.cat(
[cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1
[cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)],
dim=-1,
)

total_expression_predicted = self.embedding_to_expression.forward(cell_region_embedding)
Expand All @@ -581,7 +600,11 @@ def forward_multiple(self, data, fragments_oi, min_fragments=1):

if hasattr(self, "library_size_encoder"):
cell_region_embedding = torch.cat(
[cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1
[
cell_region_embedding,
self.library_size_encoder(data).unsqueeze(-2),
],
dim=-1,
)

expression_predicted = self.embedding_to_expression.forward(cell_region_embedding)
Expand Down Expand Up @@ -609,10 +632,13 @@ def forward_multiple2(self, data, fragments_oi, min_fragments=1):

if hasattr(self, "library_size_encoder"):
cell_region_embedding = torch.cat(
[cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1
[cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)],
dim=-1,
)

total_expression_predicted = self.embedding_to_expression.forward(cell_region_embedding)
# total_expression_predicted = self.embedding_to_expression.forward(
# cell_region_embedding
# )

tot = 0.0

Expand All @@ -638,19 +664,28 @@ def forward_multiple2(self, data, fragments_oi, min_fragments=1):

if hasattr(self, "library_size_encoder"):
cell_region_embedding = torch.cat(
[cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1
[
cell_region_embedding,
self.library_size_encoder(data).unsqueeze(-2),
],
dim=-1,
)

cell_region_embeddings.append(cell_region_embedding)

# expression_predicted = self.embedding_to_expression.forward(cell_region_embedding)
# end = time.time()
# tot += end - start

else:
cell_region_embedding = total_cell_region_embedding
if hasattr(self, "library_size_encoder"):
cell_region_embedding = torch.cat(
[cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1
[
cell_region_embedding,
self.library_size_encoder(data).unsqueeze(-2),
],
dim=-1,
)
cell_region_embeddings.append(cell_region_embedding)
# n_fragments = total_n_fragments
Expand Down Expand Up @@ -861,12 +896,23 @@ def get_prediction(
with torch.no_grad():
pred_mb = self.forward(data)
predicted[
np.ix_(cell_mapping[data.minibatch.cells_oi], region_mapping[data.minibatch.regions_oi])
np.ix_(
cell_mapping[data.minibatch.cells_oi],
region_mapping[data.minibatch.regions_oi],
)
] = pred_mb.cpu().numpy()
expected[
np.ix_(cell_mapping[data.minibatch.cells_oi], region_mapping[data.minibatch.regions_oi])
np.ix_(
cell_mapping[data.minibatch.cells_oi],
region_mapping[data.minibatch.regions_oi],
)
] = data.transcriptome.value.cpu().numpy()
n_fragments[np.ix_(cell_mapping[data.minibatch.cells_oi], region_mapping[data.minibatch.regions_oi])] = (
n_fragments[
np.ix_(
cell_mapping[data.minibatch.cells_oi],
region_mapping[data.minibatch.regions_oi],
)
] = (
torch.bincount(
data.fragments.local_cellxregion_ix,
minlength=len(data.minibatch.cells_oi) * len(data.minibatch.regions_oi),
Expand All @@ -886,17 +932,26 @@ def get_prediction(
"predicted": xr.DataArray(
predicted,
dims=(fragments.obs.index.name, fragments.var.index.name),
coords={fragments.obs.index.name: cells, fragments.var.index.name: regions},
coords={
fragments.obs.index.name: cells,
fragments.var.index.name: regions,
},
),
"expected": xr.DataArray(
expected,
dims=(fragments.obs.index.name, fragments.var.index.name),
coords={fragments.obs.index.name: cells, fragments.var.index.name: regions},
coords={
fragments.obs.index.name: cells,
fragments.var.index.name: regions,
},
),
"n_fragments": xr.DataArray(
n_fragments,
dims=(fragments.obs.index.name, fragments.var.index.name),
coords={fragments.obs.index.name: cells, fragments.var.index.name: regions},
coords={
fragments.obs.index.name: cells,
fragments.var.index.name: regions,
},
),
}
)
Expand Down Expand Up @@ -975,9 +1030,12 @@ def get_prediction_censored(
fragments_oi = censorer(data)

with torch.no_grad():
for design_ix, (
pred_mb,
n_fragments_oi_mb,
for (
design_ix,
(
pred_mb,
n_fragments_oi_mb,
),
) in enumerate(self.forward_multiple(data, fragments_oi, min_fragments=min_fragments)):
predicted.append(pred_mb)
n_fragments.append(n_fragments_oi_mb)
Expand Down Expand Up @@ -1052,9 +1110,12 @@ def get_performance_censored(
fragments_oi = censorer(data)

with torch.no_grad():
for design_ix, (
pred_mb,
n_fragments_oi_mb,
for (
design_ix,
(
pred_mb,
n_fragments_oi_mb,
),
) in enumerate(self.forward_multiple(data, fragments_oi, min_fragments=min_fragments)):
if design_ix == 0:
cor_baseline = paircor(pred_mb, data.transcriptome.value).cpu().numpy()
Expand Down Expand Up @@ -1105,7 +1166,14 @@ def models_path(self):
return path

def train_models(
self, device=None, pbar=True, transcriptome=None, fragments=None, folds=None, regions_oi=None, **kwargs
self,
device=None,
pbar=True,
transcriptome=None,
fragments=None,
folds=None,
regions_oi=None,
**kwargs,
):
if "device" in self.train_params and device is None:
device = self.train_params["device"]
Expand Down Expand Up @@ -1146,7 +1214,9 @@ def train_models(
**self.model_params,
)
model.train_model(
device=device, pbar=False, **{k: v for k, v in self.train_params.items() if k not in ["device"]}
device=device,
pbar=False,
**{k: v for k, v in self.train_params.items() if k not in ["device"]},
)
model.save_state()

Expand All @@ -1171,10 +1241,6 @@ def __iter__(self):
yield self[ix]

def get_region_cors(self, fragments, transcriptome, folds, device=None):
cor_predicted = np.zeros((len(fragments.var.index), len(folds)))
cor_n_fragments = np.zeros((len(fragments.var.index), len(folds)))
n_fragments = np.zeros((len(fragments.var.index), len(folds)))

regions_oi = fragments.var.index if self.regions_oi is None else self.regions_oi

from itertools import product
Expand All @@ -1191,11 +1257,13 @@ def get_region_cors(self, fragments, transcriptome, folds, device=None):
cors.append(
{
fragments.var.index.name: region_id,
"cor": np.corrcoef(prediction["predicted"].values[:, 0], prediction["expected"].values[:, 0])[
0, 1
],
"cor": np.corrcoef(
prediction["predicted"].values[:, 0],
prediction["expected"].values[:, 0],
)[0, 1],
"cor_n_fragments": np.corrcoef(
prediction["n_fragments"].values[:, 0], prediction["expected"].values[:, 0]
prediction["n_fragments"].values[:, 0],
prediction["expected"].values[:, 0],
)[0, 1],
}
)
Expand Down

0 comments on commit a6ba19b

Please sign in to comment.