From e814ee7fb877fb7094db3bebb502f6a61b88adc0 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell Date: Sat, 6 Mar 2021 13:34:07 -0500 Subject: [PATCH] api change --- phipkit/antigen_analysis.py | 17 ++++++----------- phipkit/call_antigens.py | 14 +------------- phipkit/call_hits.py | 2 +- phipkit/common.py | 14 +++++++++++++- phipkit/plot_antigens.py | 7 ++++--- test/test_integration.py | 1 - 6 files changed, 25 insertions(+), 30 deletions(-) diff --git a/phipkit/antigen_analysis.py b/phipkit/antigen_analysis.py index 8ea71cf..d8ce873 100644 --- a/phipkit/antigen_analysis.py +++ b/phipkit/antigen_analysis.py @@ -66,22 +66,17 @@ def compute_coverage(antigen, sequence, blast_df): class AntigenAnalysis(object): - def __init__(self, blast_df, antigens_df, hits_df=None, sample_to_kind=None): + def __init__(self, blast_df, antigens_df, sample_to_hit_clones, sample_to_kind=None): self.blast_df = blast_df self.antigens_df = antigens_df - self.hits_df = hits_df self.sample_to_kind = sample_to_kind - say("Collecting hits.") - self.sample_to_hit_clones = collections.defaultdict(set) - all_clones = set() - for _, row in tqdm(self.hits_df.iterrows(), total=len(self.hits_df)): - self.sample_to_hit_clones[row.sample_id].add(row.clone1) - self.sample_to_hit_clones[row.sample_id].add(row.clone2) - all_clones.add(row.clone1) - all_clones.add(row.clone2) self.sample_to_hit_clones = dict( - (k, list(v)) for (k, v) in self.sample_to_hit_clones.items()) + (k, list(v)) for (k, v) in sample_to_hit_clones.items()) + + all_clones = set() + for clones in self.sample_to_hit_clones.values(): + all_clones.update(clones) self.clone_by_sample_hits_matrix = pandas.DataFrame( index=sorted(all_clones), diff --git a/phipkit/call_antigens.py b/phipkit/call_antigens.py index 340b93c..22fd8fa 100644 --- a/phipkit/call_antigens.py +++ b/phipkit/call_antigens.py @@ -65,7 +65,7 @@ import pandas -from . common import say, reconstruct_antigen_sequences +from .common import say, reconstruct_antigen_sequences, hits_to_dict parser = argparse.ArgumentParser( description=__doc__, @@ -168,18 +168,6 @@ def run(argv=sys.argv[1:]): say("Wrote: ", args.out) -def hits_to_dict(hits_df): - """ - Given a hits_df, return a dict of sample id -> list of hits - """ - sample_to_clones = {} - for sample, sub_hits_df in hits_df.groupby("sample_id"): - sample_to_clones[sample] = sub_hits_df[ - ["clone1", "clone2"] - ].stack().unique() - return sample_to_clones - - def find_consensus(sequences, threshold=0.7): """ Given aligned sequences, return a string where each position i is the diff --git a/phipkit/call_hits.py b/phipkit/call_hits.py index 6ae8712..3319626 100644 --- a/phipkit/call_hits.py +++ b/phipkit/call_hits.py @@ -45,7 +45,7 @@ import pandas -from . common import say +from .common import say DEFAULTS = { 'fdr': 0.01, diff --git a/phipkit/common.py b/phipkit/common.py index ec54f71..7765397 100644 --- a/phipkit/common.py +++ b/phipkit/common.py @@ -44,4 +44,16 @@ def reconstruct_antigen_sequences(blast_df): antigen_sequences[title][hit_from - 1: hit_to] = hseq.encode('ascii') antigen_sequences = antigen_sequences.map(lambda arr: arr.decode()) - return antigen_sequences \ No newline at end of file + return antigen_sequences + + +def hits_to_dict(hits_df): + """ + Given a hits_df, return a dict of sample id -> list of hits + """ + sample_to_clones = {} + for sample, sub_hits_df in hits_df.groupby("sample_id"): + sample_to_clones[sample] = sub_hits_df[ + ["clone1", "clone2"] + ].stack().unique() + return sample_to_clones \ No newline at end of file diff --git a/phipkit/plot_antigens.py b/phipkit/plot_antigens.py index 5badab5..a378aa8 100644 --- a/phipkit/plot_antigens.py +++ b/phipkit/plot_antigens.py @@ -20,7 +20,7 @@ from matplotlib import pyplot from matplotlib.backends.backend_pdf import PdfPages -from . common import say +from .common import say, hits_to_dict from .antigen_analysis import AntigenAnalysis parser = argparse.ArgumentParser( @@ -121,8 +121,9 @@ def plot_antigens(blast_df, hits_df, antigens_df, out, include_redundant=False): analyzer = AntigenAnalysis( blast_df=blast_df, - hits_df=hits_df, - antigens_df=antigens_df) + antigens_df=antigens_df, + sample_to_hit_clones=hits_to_dict(hits_df)) + say("Generating plots") antigens = antigens_df.antigen.unique() diff --git a/test/test_integration.py b/test/test_integration.py index 9460cdd..6b9fa4c 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -305,7 +305,6 @@ def test_integrated(save_dir=None): assert os.path.exists(out) - if __name__ == "__main__": import sys import argparse