Skip to content

Commit

Permalink
Updated AminoAcidFrequencyDistribution report: plot in addition 'freq…
Browse files Browse the repository at this point in the history
…uency change'
  • Loading branch information
LonnekeScheffer committed Nov 1, 2023
1 parent 28d5d7f commit 3e2f234
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 25 deletions.
12 changes: 12 additions & 0 deletions immuneML/reports/PlotlyUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,15 @@ def add_single_axis_labels(figure, x_label, y_label, x_label_position, y_label_p
showarrow=False, text=x_label, textangle=-0,
xref="paper", yref="paper")])
return figure

@staticmethod
def get_amino_acid_color_map():
'''To be used whenever plotting for example a barplot where each amino acid is represented,
to be used as a value for plotly's color_discrete_map'''
return {'Y': 'rgb(102, 197, 204)', 'W': 'rgb(179,222,105)', 'V': 'rgb(220, 176, 242)',
'T': 'rgb(217,217,217)', 'S': 'rgb(141,211,199)', 'R': 'rgb(251,128,114)',
'Q': 'rgb(158, 185, 243)', 'P': 'rgb(248, 156, 116)', 'N': 'rgb(135, 197, 95)',
'M': 'rgb(254, 136, 177)', 'L': 'rgb(201, 219, 116)', 'K': 'rgb(255,237,111)',
'I': 'rgb(180, 151, 231)', 'H': 'rgb(246, 207, 113)', 'G': 'rgb(190,186,218)',
'F': 'rgb(128,177,211)', 'E': 'rgb(253,180,98)', 'D': 'rgb(252,205,229)',
'C': 'rgb(188,128,189)', 'A': 'rgb(204,235,197)'}
114 changes: 93 additions & 21 deletions immuneML/reports/data_reports/AminoAcidFrequencyDistribution.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import warnings
from collections import Counter
from pathlib import Path

import pandas as pd
import plotly.express as px
import numpy as np

from immuneML.data_model.dataset.ReceptorDataset import ReceptorDataset
from immuneML.data_model.dataset.RepertoireDataset import RepertoireDataset
from immuneML.data_model.dataset.SequenceDataset import SequenceDataset
from immuneML.data_model.receptor.receptor_sequence.ReceptorSequence import ReceptorSequence
from immuneML.environment.EnvironmentSettings import EnvironmentSettings
from immuneML.environment.SequenceType import SequenceType
from immuneML.reports.PlotlyUtil import PlotlyUtil
from immuneML.reports.ReportOutput import ReportOutput
from immuneML.reports.ReportResult import ReportResult
from immuneML.reports.data_reports.DataReport import DataReport
Expand All @@ -28,7 +31,7 @@ class AminoAcidFrequencyDistribution(DataReport):
relative_frequency (bool): Whether to plot relative frequencies (true) or absolute counts (false) of the positional amino acids. By default, relative_frequency is True.
split_by_label (bool): Whether to split the plots by a label. If set to true, the Dataset must either contain a single label, or alternatively the label of interest can be specified under 'label'. By default, split_by_label is False.
split_by_label (bool): Whether to split the plots by a label. If set to true, the Dataset must either contain a single label, or alternatively the label of interest can be specified under 'label'. If split_by_label is set to true, the percentage-wise frequency difference between classes is plotted additionally. By default, split_by_label is False.
label (str): if split_by_label is set to True, a label can be specified here.
Expand Down Expand Up @@ -75,13 +78,29 @@ def _generate(self) -> ReportResult:

freq_dist = self._get_plotting_data()

results_table = self._write_results_table(freq_dist)
report_output_fig = self._safe_plot(freq_dist=freq_dist)
tables = []
figures = []


tables.append(self._write_output_table(freq_dist,
self.result_path / "amino_acid_frequency_distribution.tsv",
name="Table of amino acid frequencies"))

figures.append(self._safe_plot(freq_dist=freq_dist, plot_callable="_plot_distribution"))

if self.split_by_label:
frequency_change = self._compute_frequency_change(freq_dist)

tables.append(self._write_output_table(frequency_change,
self.result_path / f"frequency_change.tsv",
name=f"Log-fold change between classes"))
figures.append(self._safe_plot(frequency_change=frequency_change, plot_callable="_plot_frequency_change"))


return ReportResult(name=self.name,
info="A a barplot showing the relative frequency of each amino acid at each position in the sequences of a dataset.",
output_figures=None if report_output_fig is None else [report_output_fig],
output_tables=None if results_table is None else [results_table])
info="A barplot showing the relative frequency of each amino acid at each position in the sequences of a dataset.",
output_figures=[fig for fig in figures if fig is not None],
output_tables=[table for table in tables if table is not None])

def _get_plotting_data(self):
if isinstance(self.dataset, SequenceDataset):
Expand Down Expand Up @@ -205,34 +224,39 @@ def _get_positions(self, sequence: ReceptorSequence):
positions = PositionHelper.gen_imgt_positions_from_length(len(sequence.get_sequence(SequenceType.AMINO_ACID)),
sequence.get_attribute("region_type"))
else:
positions = list(range(len(sequence.get_sequence(SequenceType.AMINO_ACID))))
positions = list(range(1, len(sequence.get_sequence(SequenceType.AMINO_ACID))+1))

return [str(pos) for pos in positions]

def _write_results_table(self, results_table):
file_path = self.result_path / "amino_acid_frequency_distribution.csv"
file_path = self.result_path / "amino_acid_frequency_distribution.tsv"

results_table.to_csv(file_path, index=False)

return ReportOutput(path=file_path, name="Table of amino acid frequencies")

def _get_colors(self):
return ['rgb(102, 197, 204)', 'rgb(179,222,105)', 'rgb(220, 176, 242)', 'rgb(217,217,217)',
'rgb(141,211,199)', 'rgb(251,128,114)', 'rgb(158, 185, 243)', 'rgb(248, 156, 116)',
'rgb(135, 197, 95)', 'rgb(254, 136, 177)', 'rgb(201, 219, 116)', 'rgb(255,237,111)',
'rgb(180, 151, 231)', 'rgb(246, 207, 113)', 'rgb(190,186,218)', 'rgb(128,177,211)',
'rgb(253,180,98)', 'rgb(252,205,229)', 'rgb(188,128,189)', 'rgb(204,235,197)', ]

def _plot(self, freq_dist):
def _get_colors(self):
return {'Y': 'rgb(102, 197, 204)', 'W': 'rgb(179,222,105)', 'V': 'rgb(220, 176, 242)',
'T': 'rgb(217,217,217)', 'S': 'rgb(141,211,199)', 'R': 'rgb(251,128,114)',
'Q': 'rgb(158, 185, 243)', 'P': 'rgb(248, 156, 116)', 'N': 'rgb(135, 197, 95)',
'M': 'rgb(254, 136, 177)', 'L': 'rgb(201, 219, 116)', 'K': 'rgb(255,237,111)',
'I': 'rgb(180, 151, 231)', 'H': 'rgb(246, 207, 113)', 'G': 'rgb(190,186,218)',
'F': 'rgb(128,177,211)', 'E': 'rgb(253,180,98)', 'D': 'rgb(252,205,229)',
'C': 'rgb(188,128,189)', 'A': 'rgb(204,235,197)'}

def _plot_distribution(self, freq_dist):
freq_dist.sort_values(by=["amino acid"], ascending=False, inplace=True)
category_orders = None if "class" not in freq_dist.columns else {"class": sorted(set(freq_dist["class"]))}

y = "relative frequency" if self.relative_frequency else "count"

figure = px.bar(freq_dist, x="position", y=y, color="amino acid", text="amino acid",
facet_col="class" if "class" in freq_dist.columns else None,
facet_row="chain" if "chain" in freq_dist.columns else None,
color_discrete_sequence=self._get_colors(),
labels={"position": "IMGT position" if self.imgt_positions else "Sequence index",
color_discrete_map=PlotlyUtil.get_amino_acid_color_map(),
category_orders=category_orders,
labels={"position": "IMGT position" if self.imgt_positions else "Position",
"count": "Count",
"relative frequency": "Relative frequency",
"amino acid": "Amino acid"}, template="plotly_white")
Expand All @@ -250,6 +274,55 @@ def _plot(self, freq_dist):
def _get_position_order(self, positions):
return [str(int(pos)) if pos.is_integer() else str(pos) for pos in sorted(set(positions.astype(float)))]

def _compute_frequency_change(self, freq_dist):
classes = sorted(set(freq_dist["class"]))
assert len(classes) == 2, f"{AminoAcidFrequencyDistribution.__name__}: cannot compute log fold change when the number of classes is not 2: {classes}"

class_a_df = freq_dist[freq_dist["class"] == classes[0]]
class_b_df = freq_dist[freq_dist["class"] == classes[1]]

on = ["amino acid", "position"]
on = on + ["chain"] if "chain" in freq_dist.columns else on

merged_dfs = pd.merge(class_a_df, class_b_df, on=on, how="outer", suffixes=["_a", "_b"])
merged_dfs = merged_dfs[(merged_dfs["relative frequency_a"] + merged_dfs["relative frequency_b"]) > 0]

merged_dfs["frequency_change"] = merged_dfs["relative frequency_a"] - merged_dfs["relative frequency_b"]

pos_class_a = merged_dfs[merged_dfs["frequency_change"] > 0]
pos_class_b = merged_dfs[merged_dfs["frequency_change"] < 0]

pos_class_a["positive_class"] = classes[0]
pos_class_b["positive_class"] = classes[1]
pos_class_b["frequency_change"] = 0 - pos_class_b["frequency_change"]

keep_cols = on + ["frequency_change", "positive_class"]
pos_class_a = pos_class_a[keep_cols]
pos_class_b = pos_class_b[keep_cols]

return pd.concat([pos_class_a, pos_class_b])

def _plot_frequency_change(self, frequency_change):
figure = px.bar(frequency_change, x="position", y="frequency_change", color="amino acid", text="amino acid",
facet_col="positive_class",
facet_row="chain" if "chain" in frequency_change.columns else None,
color_discrete_map=PlotlyUtil.get_amino_acid_color_map(),
labels={"position": "IMGT position" if self.imgt_positions else "Position",
"positive_class": "Class",
"frequency_change": "Difference in relative frequency",
"amino acid": "Amino acid"}, template="plotly_white")

figure.update_xaxes(categoryorder='array', categoryarray=self._get_position_order(frequency_change["position"]))
figure.update_layout(showlegend=False, yaxis={'categoryorder':'category ascending'})

figure.update_yaxes(tickformat=",.0%")

file_path = self.result_path / "frequency_change.html"
figure.write_html(str(file_path))

return ReportOutput(path=file_path, name="Frequency difference between amino acid usage in the two classes")


def _get_label_name(self):
if self.split_by_label:
if self.label_name is None:
Expand All @@ -263,13 +336,12 @@ def check_prerequisites(self):
if self.split_by_label:
if self.label_name is None:
if len(self.dataset.get_label_names()) != 1:
warnings.warn(
f"{AminoAcidFrequencyDistribution.__name__}: ambiguous label: split_by_label was set to True but no label name was specified, and the number of available labels is {len(self.dataset.get_label_names())}: {self.dataset.get_label_names()}. Skipping this report...")
warnings.warn(f"{AminoAcidFrequencyDistribution.__name__}: ambiguous label: split_by_label was set to True but no label name was specified, and the number of available labels is {len(self.dataset.get_label_names())}: {self.dataset.get_label_names()}. Skipping this report...")
return False
else:
if self.label_name not in self.dataset.get_label_names():
warnings.warn(
f"{AminoAcidFrequencyDistribution.__name__}: the specified label name ({self.label_name}) was not available among the dataset labels: {self.dataset.get_label_names()}. Skipping this report...")
warnings.warn(f"{AminoAcidFrequencyDistribution.__name__}: the specified label name ({self.label_name}) was not available among the dataset labels: {self.dataset.get_label_names()}. Skipping this report...")
return False

return True

14 changes: 10 additions & 4 deletions test/reports/data_reports/test_AminoAcidFrequencyDistribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ def test_generate_sequence_dataset(self):

params = DefaultParamsLoader.load(EnvironmentSettings.default_params_path / "reports/", "AminoAcidFrequencyDistribution")
params["dataset"] = dataset
params["split_by_label"] = True
params["result_path"] = path / "result"

report = AminoAcidFrequencyDistribution.build_object(**params)
self.assertTrue(report.check_prerequisites())

report._generate()

self.assertTrue(os.path.isfile(path / "result/amino_acid_frequency_distribution.csv"))
self.assertTrue(os.path.isfile(path / "result/amino_acid_frequency_distribution.tsv"))
self.assertTrue(os.path.isfile(path / "result/amino_acid_frequency_distribution.html"))
self.assertTrue(os.path.isfile(path / "result/frequency_change.tsv"))
self.assertTrue(os.path.isfile(path / "result/frequency_change.html"))

shutil.rmtree(path)

Expand All @@ -40,15 +43,18 @@ def test_generate_receptor_dataset(self):

params = DefaultParamsLoader.load(EnvironmentSettings.default_params_path / "reports/", "AminoAcidFrequencyDistribution")
params["dataset"] = dataset
params["split_by_label"] = True
params["result_path"] = path / "result"

report = AminoAcidFrequencyDistribution.build_object(**params)
self.assertTrue(report.check_prerequisites())

report._generate()

self.assertTrue(os.path.isfile(path / "result/amino_acid_frequency_distribution.csv"))
self.assertTrue(os.path.isfile(path / "result/amino_acid_frequency_distribution.tsv"))
self.assertTrue(os.path.isfile(path / "result/amino_acid_frequency_distribution.html"))
self.assertTrue(os.path.isfile(path / "result/frequency_change.tsv"))
self.assertTrue(os.path.isfile(path / "result/frequency_change.html"))

shutil.rmtree(path)

Expand All @@ -71,10 +77,10 @@ def test_generate_repertoire_dataset(self):

report._generate()

self.assertTrue(os.path.isfile(path / "result/amino_acid_frequency_distribution.csv"))
self.assertTrue(os.path.isfile(path / "result/amino_acid_frequency_distribution.tsv"))
self.assertTrue(os.path.isfile(path / "result/amino_acid_frequency_distribution.html"))

df = pd.read_csv(path / "result/amino_acid_frequency_distribution.csv")
df = pd.read_csv(path / "result/amino_acid_frequency_distribution.tsv", sep="\t")

# assert that the total amino acid count at each position = n_repertoires (5) * sequences_per_repertoire (20) for each positionin the sequence (10)
self.assertEqual([100] * 10, list(df.groupby("position")["count"].sum()))
Expand Down

0 comments on commit 3e2f234

Please sign in to comment.