-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/master'
- Loading branch information
Showing
2 changed files
with
201 additions
and
0 deletions.
There are no files selected for viewing
143 changes: 143 additions & 0 deletions
143
immuneML/reports/data_reports/SequenceCountDistribution.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from collections import Counter | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
import plotly.express as px | ||
|
||
from immuneML.data_model.dataset.Dataset import Dataset | ||
from immuneML.data_model.dataset.ReceptorDataset import ReceptorDataset | ||
from immuneML.data_model.dataset.RepertoireDataset import RepertoireDataset | ||
from immuneML.data_model.dataset.SequenceDataset import SequenceDataset | ||
from immuneML.reports.ReportOutput import ReportOutput | ||
from immuneML.reports.ReportResult import ReportResult | ||
from immuneML.reports.ReportUtil import ReportUtil | ||
from immuneML.reports.data_reports.DataReport import DataReport | ||
from immuneML.util.PathBuilder import PathBuilder | ||
|
||
|
||
class SequenceCountDistribution(DataReport): | ||
""" | ||
Generates a histogram of the duplicate counts of the sequences in a dataset. | ||
Specification arguments: | ||
- split_by_label (bool): Whether to split the plots by a label. If set to true, the Dataset must either contain a single label, or alternatively the label of interest can be specified under 'label'. By default, split_by_label is False. | ||
- label (str): Optional label for separating the results by color/creating separate plots. Note that this should the name of a valid dataset label. | ||
YAML specification: | ||
.. indent with spaces | ||
.. code-block:: yaml | ||
my_sld_report: | ||
SequenceCountDistribution: | ||
label: disease | ||
""" | ||
|
||
@classmethod | ||
def build_object(cls, **kwargs): | ||
location = SequenceCountDistribution.__name__ | ||
|
||
ReportUtil.update_split_by_label_kwargs(kwargs, location) | ||
|
||
return SequenceCountDistribution(**kwargs) | ||
|
||
def __init__(self, dataset: Dataset = None, result_path: Path = None, number_of_processes: int = 1, | ||
split_by_label: bool = None, label: str = None, name: str = None): | ||
super().__init__(dataset=dataset, result_path=result_path, number_of_processes=number_of_processes, name=name) | ||
self.split_by_label = split_by_label | ||
self.label_name = label | ||
|
||
def _set_label_name(self): | ||
if self.split_by_label: | ||
if self.label_name is None: | ||
self.label_name = list(self.dataset.get_label_names())[0] | ||
else: | ||
self.label_name = None | ||
|
||
def check_prerequisites(self): | ||
return True | ||
|
||
def _generate(self) -> ReportResult: | ||
self._set_label_name() | ||
|
||
df = self._get_sequence_counts_df() | ||
PathBuilder.build(self.result_path) | ||
|
||
output_table = self._write_output_table(df, self.result_path / "sequence_count_distribution.tsv", | ||
name="Duplicate counts of sequences in the dataset") | ||
|
||
report_output_fig = self._safe_plot(df=df, output_written=False) | ||
output_figures = None if report_output_fig is None else [report_output_fig] | ||
return ReportResult(name=self.name, | ||
info="The sequence count distribution of the dataset.", | ||
output_figures=output_figures, output_tables=[output_table]) | ||
|
||
def _get_sequence_counts_df(self): | ||
if isinstance(self.dataset, RepertoireDataset): | ||
return self._get_repertoire_df() | ||
elif isinstance(self.dataset, ReceptorDataset) or isinstance(self.dataset, SequenceDataset): | ||
return self._get_sequence_receptor_df() | ||
|
||
def _get_repertoire_df(self): | ||
sequence_counts = Counter() | ||
|
||
for repertoire in self.dataset.get_data(): | ||
if self.split_by_label: | ||
label_class = repertoire.metadata[self.label_name] | ||
else: | ||
label_class = None | ||
|
||
repertoire_counter = Counter(repertoire.get_attribute("duplicate_count")) | ||
sequence_counts += Counter({(key, label_class): value for key, value in repertoire_counter.items()}) | ||
|
||
df = pd.DataFrame({"n_observations": list(sequence_counts.values()), | ||
"duplicate_count": [key[0] for key in sequence_counts.keys()]}) | ||
|
||
if self.split_by_label: | ||
df[self.label_name] = [key[1] for key in sequence_counts.keys()] | ||
|
||
return df | ||
|
||
def _get_sequence_receptor_df(self): | ||
try: | ||
counts = self.dataset.get_attribute("duplicate_count") | ||
except AttributeError as e: | ||
raise AttributeError(f"{SequenceCountDistribution.__name__}: SequenceDataset does not contain attribute 'duplicate_count'. This report can only be run when sequence counts are available.") | ||
|
||
chains = self.dataset.get_attribute(attribute="chain", as_list=True) | ||
|
||
if self.split_by_label: | ||
label_classes = self.dataset.get_attribute(attribute=self.label_name, as_list=True) | ||
counter = Counter(zip(counts, chains, label_classes)) | ||
else: | ||
counter = Counter(zip(counts, chains)) | ||
|
||
df = pd.DataFrame({"duplicate_count": [key[0] for key in counter.keys()], | ||
"chain": [key[1] for key in counter.keys()], | ||
"n_observations": counter.values()}) | ||
|
||
if self.split_by_label: | ||
df[self.label_name] = [key[2] for key in counter.keys()] | ||
|
||
return df | ||
|
||
def _plot(self, df: pd.DataFrame) -> ReportOutput: | ||
figure = px.bar(df, x="duplicate_count", y="n_observations", barmode="group", | ||
color=self.label_name if self.split_by_label else None, | ||
facet_col="chain" if isinstance(self.dataset, ReceptorDataset) else None, | ||
color_discrete_sequence=px.colors.diverging.Tealrose, | ||
labels={"n_observations": "Number of observations", | ||
"duplicate_count": "Sequence duplicate count"}) | ||
figure.update_layout(template="plotly_white") | ||
figure.update_xaxes(row=1, type="category") | ||
PathBuilder.build(self.result_path) | ||
|
||
file_path = self.result_path / "sequence_count_distribution.html" | ||
figure.write_html(str(file_path)) | ||
return ReportOutput(path=file_path, name="Sequence duplicate count distribution") | ||
|
58 changes: 58 additions & 0 deletions
58
test/reports/data_reports/test_sequenceCountDistribution.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import os | ||
import shutil | ||
from unittest import TestCase | ||
|
||
from immuneML.caching.CacheType import CacheType | ||
from immuneML.environment.Constants import Constants | ||
from immuneML.environment.EnvironmentSettings import EnvironmentSettings | ||
from immuneML.reports.data_reports.SequenceCountDistribution import SequenceCountDistribution | ||
from immuneML.simulation.dataset_generation.RandomDatasetGenerator import RandomDatasetGenerator | ||
from immuneML.util.PathBuilder import PathBuilder | ||
|
||
|
||
class TestSequenceCountDistribution(TestCase): | ||
|
||
def setUp(self) -> None: | ||
os.environ[Constants.CACHE_TYPE] = CacheType.TEST.name | ||
|
||
def test_sequence_counts_seq_dataset(self): | ||
path = PathBuilder.remove_old_and_build(EnvironmentSettings.tmp_test_path / "sequence_counts") | ||
|
||
dataset = RandomDatasetGenerator.generate_sequence_dataset(50, {4: 0.33, 5: 0.33, 7: 0.33}, {"l1": {"a": 0.5, "b": 0.5}}, path / 'dataset') | ||
|
||
scd = SequenceCountDistribution(dataset, path, 1, split_by_label=True, label="l1") | ||
|
||
result = scd.generate_report() | ||
self.assertTrue(os.path.isfile(result.output_figures[0].path)) | ||
|
||
shutil.rmtree(path) | ||
|
||
def test_sequence_lengths_receptor_dataset(self): | ||
path = PathBuilder.remove_old_and_build(EnvironmentSettings.tmp_test_path / "receptor_counts") | ||
|
||
dataset = RandomDatasetGenerator.generate_receptor_dataset(receptor_count=50, | ||
chain_1_length_probabilities={10:1}, | ||
chain_2_length_probabilities={10:1}, | ||
labels={"l1": {"a": 0.5, "b": 0.5}}, path=path / 'dataset') | ||
|
||
scd = SequenceCountDistribution(dataset, path, 1, split_by_label=False) | ||
|
||
result = scd.generate_report() | ||
self.assertTrue(os.path.isfile(result.output_figures[0].path)) | ||
|
||
shutil.rmtree(path) | ||
|
||
def test_sequence_lengths_repertoire_dataset(self): | ||
path = PathBuilder.remove_old_and_build(EnvironmentSettings.tmp_test_path / "repertoire_counts") | ||
|
||
dataset = RandomDatasetGenerator.generate_repertoire_dataset(repertoire_count=10, | ||
sequence_count_probabilities={10:0.5, 20: 0.5}, | ||
sequence_length_probabilities={10:1}, | ||
labels={"l1": {"a": 0.5, "b": 0.5}}, path=path / 'dataset') | ||
|
||
scd = SequenceCountDistribution(dataset, path, 1, split_by_label=True) | ||
|
||
result = scd.generate_report() | ||
self.assertTrue(os.path.isfile(result.output_figures[0].path)) | ||
|
||
shutil.rmtree(path) |