Skip to content

Commit

Permalink
feat: optimize diffexp (#7056)
Browse files Browse the repository at this point in the history
  • Loading branch information
atarashansky authored May 21, 2024
1 parent 8443107 commit 94c2cbe
Show file tree
Hide file tree
Showing 36 changed files with 485 additions and 546 deletions.
47 changes: 5 additions & 42 deletions backend/de/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
)
from backend.wmg.data.ontology_labels import gene_term_label, ontology_term_label
from backend.wmg.data.query import DeQueryCriteria, WmgFiltersQueryCriteria, WmgQuery
from backend.wmg.data.schemas.expression_summary_cube_schemas_diffexp import (
base_expression_summary_indexed_dims,
expression_summary_secondary_dims,
)
from backend.wmg.data.schemas.cube_schema_diffexp import cell_counts_logical_dims_exclude_dataset_id
from backend.wmg.data.snapshot import WmgSnapshot, load_snapshot

DEPLOYMENT_STAGE = os.environ.get("DEPLOYMENT_STAGE", "")
Expand Down Expand Up @@ -242,57 +239,23 @@ def run_differential_expression(q: WmgQuery, criteria1, criteria2) -> Tuple[List
set(sum([descendants(i) for i in criteria2.cell_type_ontology_term_ids], []))
)

cell_counts1 = q.cell_counts_df(criteria1)
cell_counts2 = q.cell_counts_df(criteria2)
es1, cell_counts1 = q.expression_summary_and_cell_counts_diffexp(criteria1)
es2, cell_counts2 = q.expression_summary_and_cell_counts_diffexp(criteria2)

n_cells1 = cell_counts1["n_total_cells"].sum()
n_cells2 = cell_counts2["n_total_cells"].sum()
es1 = q.expression_summary_diffexp(criteria1)
es2 = q.expression_summary_diffexp(criteria2)

# identify number of overlapping populations
filter_columns = [
col
for col in (base_expression_summary_indexed_dims + expression_summary_secondary_dims)
for col in cell_counts_logical_dims_exclude_dataset_id
if col in cell_counts1.columns and col in cell_counts2.columns
]
index1 = cell_counts1.set_index(filter_columns).index
index2 = cell_counts2.set_index(filter_columns).index
overlap_filter = index1.isin(index2)
n_overlap = int(cell_counts1[overlap_filter]["n_total_cells"].sum())

""" (alec): If the specific overlapping populations are ever required, refer to this code
unique_overlap_indices = index1[overlap_filter].unique()
num_values_per_level = {
name: unique_overlap_indices.get_level_values(name).unique().size for name in unique_overlap_indices.names
}
overlap = []
for dims in [dict(zip(unique_overlap_indices.names, idx, strict=False)) for idx in unique_overlap_indices]:
population = {}
if "disease_ontology_term_id" in dims and num_values_per_level["disease_ontology_term_id"] > 1:
population["disease_terms"] = ontology_term_id_label_mapping(dims["disease_ontology_term_id"])
if "sex_ontology_term_id" in dims and num_values_per_level["sex_ontology_term_id"] > 1:
population["sex_terms"] = ontology_term_id_label_mapping(dims["sex_ontology_term_id"])
if (
"self_reported_ethnicity_ontology_term_id" in dims
and "," not in dims["self_reported_ethnicity_ontology_term_id"]
and num_values_per_level["self_reported_ethnicity_ontology_term_id"] > 1
):
population["self_reported_ethnicity_terms"] = ontology_term_id_label_mapping(
dims["self_reported_ethnicity_ontology_term_id"]
)
if "publication_citation" in dims and num_values_per_level["publication_citation"] > 1:
population["publication_citations"] = dims["publication_citation"]
if "cell_type_ontology_term_id" in dims and num_values_per_level["cell_type_ontology_term_id"] > 1:
population["cell_type_terms"] = ontology_term_id_label_mapping(dims["cell_type_ontology_term_id"])
if "tissue_ontology_term_id" in dims and num_values_per_level["tissue_ontology_term_id"] > 1:
population["tissue_terms"] = ontology_term_id_label_mapping(dims["tissue_ontology_term_id"])
if "organism_ontology_term_id" in dims and num_values_per_level["organism_ontology_term_id"] > 1:
population["organism_terms"] = ontology_term_id_label_mapping(dims["organism_ontology_term_id"])
overlap.append(population)
"""

es_agg1 = es1.groupby("gene_ontology_term_id").sum(numeric_only=True)
es_agg2 = es2.groupby("gene_ontology_term_id").sum(numeric_only=True)

Expand Down Expand Up @@ -340,7 +303,7 @@ def _get_cell_counts_for_query(q: WmgQuery, criteria: WmgFiltersQueryCriteria) -
criteria.cell_type_ontology_term_ids = list(
set(sum([descendants(i) for i in criteria.cell_type_ontology_term_ids], []))
)
cell_counts = q.cell_counts_df(criteria)
cell_counts = q.cell_counts_diffexp_df(criteria)
return int(cell_counts["n_total_cells"].sum())


Expand Down
2 changes: 1 addition & 1 deletion backend/wmg/api/wmg_api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# WMG_API_FORCE_LOAD_SNAPSHOT_ID should be set.
# LATEST_READER_SNAPSHOT_SCHEMA_VERSION in container_init.sh and WMG_API_SNAPSHOT_SCHEMA_VERSION
# below should have the same value.
WMG_API_SNAPSHOT_SCHEMA_VERSION = "v4"
WMG_API_SNAPSHOT_SCHEMA_VERSION = "v5"

# In the case we need to rollback or rollforward
# set this variable to a specific snapshot id
Expand Down
80 changes: 39 additions & 41 deletions backend/wmg/data/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import BaseModel, Field
from tiledb import Array

from backend.wmg.data.schemas.cube_schema_diffexp import cell_counts_indexed_dims
from backend.wmg.data.snapshot import WmgSnapshot


Expand Down Expand Up @@ -94,13 +95,6 @@ def __init__(self, snapshot: WmgSnapshot, cube_query_params: Optional[WmgCubeQue
self._snapshot = snapshot
self._cube_query_params = cube_query_params

@tracer.wrap(name="expression_summary_diffexp", service="de-api", resource="_query", span_type="de-api")
def expression_summary_diffexp(self, criteria: DeQueryCriteria) -> DataFrame:
return self._query(
cube=_select_cube_with_best_discriminatory_power(self._snapshot, criteria),
criteria=criteria,
)

@tracer.wrap(name="expression_summary", service="wmg-api", resource="_query", span_type="wmg-api")
def expression_summary(self, criteria: WmgQueryCriteria, compare_dimension=None) -> DataFrame:
return self._query(
Expand Down Expand Up @@ -143,6 +137,44 @@ def cell_counts_df(self, criteria: WmgQueryCriteria) -> DataFrame:
mask &= df[key].isin(values)
return df[mask].rename(columns={"n_cells": "n_total_cells"})

def cell_counts_diffexp_df(self, criteria: DeQueryCriteria) -> DataFrame:
df = self._snapshot.cell_counts_diffexp_df
mask = np.array([True] * len(df))
for key, values in dict(criteria).items():
values = values if isinstance(values, list) else [values]
key = depluralize(key)
if key in df.columns and values:
mask &= df[key].isin(values)

return df[mask].rename(columns={"n_cells": "n_total_cells"})

@tracer.wrap(
name="expression_summary_and_cell_counts_diffexp", service="de-api", resource="_query", span_type="de-api"
)
def expression_summary_and_cell_counts_diffexp(self, criteria: DeQueryCriteria) -> tuple[DataFrame, DataFrame]:
use_simple = not any(
depluralize(key) not in cell_counts_indexed_dims and values for key, values in dict(criteria).items()
)

cell_counts_diffexp_df = self.cell_counts_diffexp_df(criteria)
key = "group_id_simple" if use_simple else "group_id"
cube = (
self._snapshot.expression_summary_diffexp_simple_cube
if use_simple
else self._snapshot.expression_summary_diffexp_cube
)
group_ids = cell_counts_diffexp_df[key].unique().tolist()
return (
pd.concat(
cube.query(
return_incomplete=True,
use_arrow=True,
dims=["group_id"],
).df[group_ids]
),
cell_counts_diffexp_df,
)

# TODO: refactor for readability: https://app.zenhub.com/workspaces/single-cell-5e2a191dad828d52cc78b028/issues
# /chanzuckerberg/single-cell-data-portal/2133
def _query(
Expand Down Expand Up @@ -268,37 +300,3 @@ def retrieve_top_n_markers(query_result, test, n_markers):
markers = markers.sort_values("marker_score", ascending=False)
records = markers[attrs].to_dict(orient="records")
return records


def _select_cube_with_best_discriminatory_power(snapshot: WmgSnapshot, criteria: DeQueryCriteria) -> Array:
"""
Selects the cube with the best discriminatory power based on the given criteria.
This function evaluates each dimension's discriminatory power by comparing the number
of criteria specified for that dimension against its total cardinality within the snapshot.
It then selects the cube that maximizes this discriminatory power. If no dimension meets the
criteria or if the discriminatory power cannot be determined, the default cube is selected.
Parameters
----------
snapshot : WmgSnapshot
The snapshot object containing all cubes and their metadata.
criteria : DeQueryCriteria
The criteria object containing dimensions and their corresponding values to filter on.
Returns
-------
Array
The cube with the best discriminatory power based on the given criteria.
"""
cardinality_per_dimension = snapshot.cardinality_per_dimension
criteria_dict = criteria.dict()
base_indexed_dims = [dim.name for dim in snapshot.diffexp_expression_summary_cubes["default"].schema.domain]
discriminatory_power = {
depluralize(dim): len(criteria_dict[dim]) / cardinality_per_dimension[depluralize(dim)]
for dim in criteria_dict
if len(criteria_dict[dim]) > 0 and depluralize(dim) not in base_indexed_dims
}
use_default = len(discriminatory_power) == 0
cube_key = "default" if use_default else min(discriminatory_power, key=discriminatory_power.get)
return snapshot.diffexp_expression_summary_cubes[cube_key]
107 changes: 107 additions & 0 deletions backend/wmg/data/schemas/cube_schema_diffexp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import numpy as np
import tiledb

from backend.wmg.data.schemas.tiledb_filters import filters_categorical, filters_numeric

expression_summary_indexed_dims = [
"group_id",
]


expression_summary_non_indexed_dims = [
"gene_ontology_term_id",
]

# The full set of logical cube dimensions by which the cube can be queried.
expression_summary_logical_dims = expression_summary_indexed_dims + expression_summary_non_indexed_dims


expression_summary_domain = tiledb.Domain(
[
tiledb.Dim(
name=cube_indexed_dim,
domain=(0, np.iinfo(np.uint32).max - 1),
tile=None,
dtype=np.uint32,
filters=filters_numeric,
)
for cube_indexed_dim in expression_summary_indexed_dims
]
)
# The cube attributes that comprise the core data stored within the cube.
expression_summary_logical_attrs = [
tiledb.Attr(name="sum", dtype=np.float32, filters=filters_numeric),
tiledb.Attr(name="sqsum", dtype=np.float32, filters=filters_numeric),
]

# The TileDB `Attr`s of the cube TileDB Array. This includes the
# logical cube attributes, above, along with the non-indexed logical
# cube dimensions, which we models as TileDB `Attr`s.
expression_summary_physical_attrs = [
tiledb.Attr(name=nonindexed_dim, dtype="ascii", var=True, filters=filters_categorical)
for nonindexed_dim in expression_summary_non_indexed_dims
] + expression_summary_logical_attrs

expression_summary_schema = tiledb.ArraySchema(
domain=expression_summary_domain,
sparse=True,
allows_duplicates=True,
attrs=expression_summary_physical_attrs,
cell_order="row-major",
tile_order="row-major",
capacity=10000,
)

# Cell Counts Array

cell_counts_indexed_dims = [
"cell_type_ontology_term_id",
"tissue_ontology_term_id",
"organism_ontology_term_id",
]

cell_counts_non_indexed_dims_excluding_dataset_id = [
"publication_citation",
"disease_ontology_term_id",
"self_reported_ethnicity_ontology_term_id",
"sex_ontology_term_id",
]
cell_counts_non_indexed_dims = cell_counts_non_indexed_dims_excluding_dataset_id + ["dataset_id"]

cell_counts_logical_dims = cell_counts_indexed_dims + cell_counts_non_indexed_dims

cell_counts_logical_dims_exclude_dataset_id = (
cell_counts_indexed_dims + cell_counts_non_indexed_dims_excluding_dataset_id
)

cell_counts_domain = tiledb.Domain(
[
tiledb.Dim(name=cell_counts_indexed_dim, domain=None, tile=None, dtype="ascii", filters=filters_categorical)
for cell_counts_indexed_dim in cell_counts_indexed_dims
]
)

cell_counts_logical_attrs = [
# total count of cells, regardless of expression level
tiledb.Attr(name="n_cells", dtype=np.uint32, filters=filters_numeric),
# groups corresponding to dimensions in cell_counts_logical_dims_exclude_dataset_id
tiledb.Attr(name="group_id", dtype=np.uint32, filters=filters_numeric),
# groups corresponding to dimensions in cell_counts_indexed_dims
tiledb.Attr(name="group_id_simple", dtype=np.uint32, filters=filters_numeric),
]

cell_counts_physical_attrs = [
tiledb.Attr(name=nonindexed_dim, dtype="ascii", var=True, filters=filters_categorical)
for nonindexed_dim in cell_counts_non_indexed_dims
] + cell_counts_logical_attrs


cell_counts_schema = tiledb.ArraySchema(
domain=cell_counts_domain,
sparse=True,
allows_duplicates=True,
attrs=cell_counts_physical_attrs,
cell_order="row-major",
tile_order="row-major",
capacity=10000,
)

This file was deleted.

Loading

0 comments on commit 94c2cbe

Please sign in to comment.