From a6ba19b03778243415f04b99dae38fe89be214e3 Mon Sep 17 00:00:00 2001 From: Wouter Saelens Date: Thu, 23 May 2024 15:17:39 +0200 Subject: [PATCH] fix better model --- .../models/diff/interpret/slices.py | 4 +- src/chromatinhd/models/pred/model/better.py | 138 +++++++++++++----- 2 files changed, 105 insertions(+), 37 deletions(-) diff --git a/src/chromatinhd/models/diff/interpret/slices.py b/src/chromatinhd/models/diff/interpret/slices.py index c24e752d..4b9bed62 100644 --- a/src/chromatinhd/models/diff/interpret/slices.py +++ b/src/chromatinhd/models/diff/interpret/slices.py @@ -1,10 +1,10 @@ import pandas as pd import numpy as np +import torch +import xarray as xr def filter_slices_probs(prob_cutoff=0.0): - import xarray as xr - prob_cutoff = 0.0 # prob_cutoff = -1. # prob_cutoff = -4. diff --git a/src/chromatinhd/models/pred/model/better.py b/src/chromatinhd/models/pred/model/better.py index d035f5b8..d0c85c7c 100644 --- a/src/chromatinhd/models/pred/model/better.py +++ b/src/chromatinhd/models/pred/model/better.py @@ -34,7 +34,13 @@ from chromatinhd import get_default_device -from .loss import paircor, paircor_loss, region_paircor_loss, pairzmse_loss, region_pairzmse_loss +from .loss import ( + paircor, + paircor_loss, + region_paircor_loss, + pairzmse_loss, + region_pairzmse_loss, +) from typing import Any @@ -108,23 +114,31 @@ def __init__( self.encoder = ExponentialEncoding(window=fragments.regions.window) elif encoder == "radial_binary": self.encoder = RadialBinaryEncoding( - n_frequencies=n_frequencies, window=fragments.regions.window, **encoder_kwargs + n_frequencies=n_frequencies, + window=fragments.regions.window, + **encoder_kwargs, ) elif encoder == "radial_binary2": self.encoder = RadialBinaryEncoding2(window=fragments.regions.window, **encoder_kwargs) elif encoder == "radial_binary_center": self.encoder = RadialBinaryCenterEncoding( - n_frequencies=n_frequencies, window=fragments.regions.window, **encoder_kwargs + n_frequencies=n_frequencies, + window=fragments.regions.window, + **encoder_kwargs, ) elif encoder == "linear_binary": self.encoder = LinearBinaryEncoding( - n_frequencies=n_frequencies, window=fragments.regions.window, **encoder_kwargs + n_frequencies=n_frequencies, + window=fragments.regions.window, + **encoder_kwargs, ) elif encoder == "spline_binary": self.encoder = SplineBinaryEncoding(window=fragments.regions.window, **encoder_kwargs) elif encoder == "tophat_binary": self.encoder = TophatBinaryEncoding( - n_frequencies=n_frequencies, window=fragments.regions.window, **encoder_kwargs + n_frequencies=n_frequencies, + window=fragments.regions.window, + **encoder_kwargs, ) elif encoder == "nothing": self.encoder = OneEncoding() @@ -346,7 +360,11 @@ def __init__( sublayers.append(nonlinear()) if i > 0 and batchnorm: sublayers.append( - torch.nn.BatchNorm1d(self.n_embedding_dimensions, affine=False, track_running_stats=False) + torch.nn.BatchNorm1d( + self.n_embedding_dimensions, + affine=False, + track_running_stats=False, + ) ) if i > 0 and layernorm: sublayers.append(torch.nn.LayerNorm((self.n_embedding_dimensions,))) @@ -443,7 +461,7 @@ def create( fragment_embedder_kwargs=None, **kwargs: Any, ) -> None: - self = super(Model, cls).create(path=path, fragments=fragments, clustering=clustering, reset=overwrite) + self = super(Model, cls).create(path=path, fragments=fragments, transcriptome=transcriptome, reset=overwrite) self.fragments = fragments self.transcriptome = transcriptome @@ -558,7 +576,8 @@ def forward_multiple(self, data, fragments_oi, min_fragments=1): if hasattr(self, "library_size_encoder"): cell_region_embedding = torch.cat( - [cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1 + [cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], + dim=-1, ) total_expression_predicted = self.embedding_to_expression.forward(cell_region_embedding) @@ -581,7 +600,11 @@ def forward_multiple(self, data, fragments_oi, min_fragments=1): if hasattr(self, "library_size_encoder"): cell_region_embedding = torch.cat( - [cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1 + [ + cell_region_embedding, + self.library_size_encoder(data).unsqueeze(-2), + ], + dim=-1, ) expression_predicted = self.embedding_to_expression.forward(cell_region_embedding) @@ -609,10 +632,13 @@ def forward_multiple2(self, data, fragments_oi, min_fragments=1): if hasattr(self, "library_size_encoder"): cell_region_embedding = torch.cat( - [cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1 + [cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], + dim=-1, ) - total_expression_predicted = self.embedding_to_expression.forward(cell_region_embedding) + # total_expression_predicted = self.embedding_to_expression.forward( + # cell_region_embedding + # ) tot = 0.0 @@ -638,7 +664,11 @@ def forward_multiple2(self, data, fragments_oi, min_fragments=1): if hasattr(self, "library_size_encoder"): cell_region_embedding = torch.cat( - [cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1 + [ + cell_region_embedding, + self.library_size_encoder(data).unsqueeze(-2), + ], + dim=-1, ) cell_region_embeddings.append(cell_region_embedding) @@ -646,11 +676,16 @@ def forward_multiple2(self, data, fragments_oi, min_fragments=1): # expression_predicted = self.embedding_to_expression.forward(cell_region_embedding) # end = time.time() # tot += end - start + else: cell_region_embedding = total_cell_region_embedding if hasattr(self, "library_size_encoder"): cell_region_embedding = torch.cat( - [cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1 + [ + cell_region_embedding, + self.library_size_encoder(data).unsqueeze(-2), + ], + dim=-1, ) cell_region_embeddings.append(cell_region_embedding) # n_fragments = total_n_fragments @@ -861,12 +896,23 @@ def get_prediction( with torch.no_grad(): pred_mb = self.forward(data) predicted[ - np.ix_(cell_mapping[data.minibatch.cells_oi], region_mapping[data.minibatch.regions_oi]) + np.ix_( + cell_mapping[data.minibatch.cells_oi], + region_mapping[data.minibatch.regions_oi], + ) ] = pred_mb.cpu().numpy() expected[ - np.ix_(cell_mapping[data.minibatch.cells_oi], region_mapping[data.minibatch.regions_oi]) + np.ix_( + cell_mapping[data.minibatch.cells_oi], + region_mapping[data.minibatch.regions_oi], + ) ] = data.transcriptome.value.cpu().numpy() - n_fragments[np.ix_(cell_mapping[data.minibatch.cells_oi], region_mapping[data.minibatch.regions_oi])] = ( + n_fragments[ + np.ix_( + cell_mapping[data.minibatch.cells_oi], + region_mapping[data.minibatch.regions_oi], + ) + ] = ( torch.bincount( data.fragments.local_cellxregion_ix, minlength=len(data.minibatch.cells_oi) * len(data.minibatch.regions_oi), @@ -886,17 +932,26 @@ def get_prediction( "predicted": xr.DataArray( predicted, dims=(fragments.obs.index.name, fragments.var.index.name), - coords={fragments.obs.index.name: cells, fragments.var.index.name: regions}, + coords={ + fragments.obs.index.name: cells, + fragments.var.index.name: regions, + }, ), "expected": xr.DataArray( expected, dims=(fragments.obs.index.name, fragments.var.index.name), - coords={fragments.obs.index.name: cells, fragments.var.index.name: regions}, + coords={ + fragments.obs.index.name: cells, + fragments.var.index.name: regions, + }, ), "n_fragments": xr.DataArray( n_fragments, dims=(fragments.obs.index.name, fragments.var.index.name), - coords={fragments.obs.index.name: cells, fragments.var.index.name: regions}, + coords={ + fragments.obs.index.name: cells, + fragments.var.index.name: regions, + }, ), } ) @@ -975,9 +1030,12 @@ def get_prediction_censored( fragments_oi = censorer(data) with torch.no_grad(): - for design_ix, ( - pred_mb, - n_fragments_oi_mb, + for ( + design_ix, + ( + pred_mb, + n_fragments_oi_mb, + ), ) in enumerate(self.forward_multiple(data, fragments_oi, min_fragments=min_fragments)): predicted.append(pred_mb) n_fragments.append(n_fragments_oi_mb) @@ -1052,9 +1110,12 @@ def get_performance_censored( fragments_oi = censorer(data) with torch.no_grad(): - for design_ix, ( - pred_mb, - n_fragments_oi_mb, + for ( + design_ix, + ( + pred_mb, + n_fragments_oi_mb, + ), ) in enumerate(self.forward_multiple(data, fragments_oi, min_fragments=min_fragments)): if design_ix == 0: cor_baseline = paircor(pred_mb, data.transcriptome.value).cpu().numpy() @@ -1105,7 +1166,14 @@ def models_path(self): return path def train_models( - self, device=None, pbar=True, transcriptome=None, fragments=None, folds=None, regions_oi=None, **kwargs + self, + device=None, + pbar=True, + transcriptome=None, + fragments=None, + folds=None, + regions_oi=None, + **kwargs, ): if "device" in self.train_params and device is None: device = self.train_params["device"] @@ -1146,7 +1214,9 @@ def train_models( **self.model_params, ) model.train_model( - device=device, pbar=False, **{k: v for k, v in self.train_params.items() if k not in ["device"]} + device=device, + pbar=False, + **{k: v for k, v in self.train_params.items() if k not in ["device"]}, ) model.save_state() @@ -1171,10 +1241,6 @@ def __iter__(self): yield self[ix] def get_region_cors(self, fragments, transcriptome, folds, device=None): - cor_predicted = np.zeros((len(fragments.var.index), len(folds))) - cor_n_fragments = np.zeros((len(fragments.var.index), len(folds))) - n_fragments = np.zeros((len(fragments.var.index), len(folds))) - regions_oi = fragments.var.index if self.regions_oi is None else self.regions_oi from itertools import product @@ -1191,11 +1257,13 @@ def get_region_cors(self, fragments, transcriptome, folds, device=None): cors.append( { fragments.var.index.name: region_id, - "cor": np.corrcoef(prediction["predicted"].values[:, 0], prediction["expected"].values[:, 0])[ - 0, 1 - ], + "cor": np.corrcoef( + prediction["predicted"].values[:, 0], + prediction["expected"].values[:, 0], + )[0, 1], "cor_n_fragments": np.corrcoef( - prediction["n_fragments"].values[:, 0], prediction["expected"].values[:, 0] + prediction["n_fragments"].values[:, 0], + prediction["expected"].values[:, 0], )[0, 1], } )