Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small updates #266

Merged
merged 7 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 150 additions & 36 deletions singlecellmultiomics/bamProcessing/bamAnalyzeCutDistances.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from scipy.optimize import curve_fit
import argparse
from singlecellmultiomics.bamProcessing.bamFunctions import get_contigs_with_reads, get_r1_counts_per_cell
from singlecellmultiomics.bamProcessing.bamFunctions import get_contigs_with_reads, get_r1_counts_per_cell, mate_iter
from singlecellmultiomics.bamProcessing.bamBinCounts import merge_overlapping_ranges
from collections import Counter, defaultdict
import numpy as np
Expand Down Expand Up @@ -61,12 +61,16 @@ def generate_prefix(prefix, prefix_with_region, contig, start, end ):
else:
return prefix

def keep_contig(contig):
return not ('_' in contig or contig in ('chrY', 'chrM', 'chrEBV', 'MT') or contig.startswith('KN') or contig.startswith('KZ'))

def get_sc_cut_dictionary(bam_path: str, filter_function=None, strand_specific=False, prefix_with_bam=False, regions=None, prefix_with_region=False, n_threads=None):

def get_sc_cut_dictionary(bam_path: str, filter_function=None, strand_specific=False, prefix_with_bam=False, regions=None, prefix_with_region=False, n_threads=None, bulk=False, count_function=None):
"""
Generates cut distribution dictionary (contig)->sample->position->obs

"""
assert count_function is not None
if filter_function is None:
filter_function = read_counts_function
cut_sites = {}
Expand All @@ -76,9 +80,6 @@ def get_sc_cut_dictionary(bam_path: str, filter_function=None, strand_specific=F
else:
bam_paths=bam_path




with Pool(n_threads) as workers:
for bam_path in bam_paths:
if prefix_with_bam:
Expand All @@ -87,26 +88,29 @@ def get_sc_cut_dictionary(bam_path: str, filter_function=None, strand_specific=F
prefix=None

if regions is None:
regions = [(contig, None, None) for contig in get_contigs_with_reads(bam_path)]
regions = [(contig, None, None) for contig in get_contigs_with_reads(bam_path) if keep_contig(contig)]

print("Selected regions (max 10 shown)")
for r in regions[:10]:
print(f'\t{r}')

with pysam.AlignmentFile(bam_path) as alignments:
start = None
end= None
for contig,r in workers.imap_unordered(
_get_sc_cut_dictionary, (
count_function, (
(bam_path,
contig,
strand_specific,
filter_function,
generate_prefix(prefix,prefix_with_region,contig,start,end)
, start, end, n_threads)
, start, end, n_threads, bulk)
for contig, start, end in regions )):
# Perform merge:
if not contig in cut_sites:
cut_sites[contig]=r
else:
for sample, positions in r.items():
cut_sites[contig][sample].update(positions)
print(f'\tFinished {contig}')

return cut_sites

Expand Down Expand Up @@ -202,17 +206,31 @@ def strict_read_counts_function(read):
return True


def _get_sc_cut_dictionary(args):
def loose_read_counts_function(read):
if read.is_duplicate or \
read.is_duplicate or \
read.mapping_quality<50 or \
'S' in read.cigarstring or \
'I' in read.cigarstring or \
not read.is_proper_pair or \
read.reference_start is None or read.reference_end is None:
return False
return True

def _get_ds_sc_cut_dictionary(args):

bam, contig, strand_specific, filter_function, prefix, start, end, n_threads = args
bam, contig, strand_specific, filter_function, prefix, start, end, n_threads, bulk = args
cut_positions = defaultdict(Counter)
with pysam.AlignmentFile(bam) as alignments:
for read in alignments.fetch(contig, start, end):

if not filter_function(read):
if not filter_function(read): #(dup qcfail etc)
continue

k = read.get_tag('SM') if prefix is None else (prefix, read.get_tag('SM'))
if bulk:
k=('bulk' if prefix is None else (prefix,'bulk'))
else:
k = read.get_tag('SM') if prefix is None else (prefix, read.get_tag('SM'))

cut_positions[k][
(read.is_reverse, read.get_tag('DS'))
Expand All @@ -222,8 +240,89 @@ def _get_sc_cut_dictionary(args):

return contig,cut_positions


def _get_sc_cut_dictionary(args):
bam, contig, strand_specific, filter_function, prefix, start, end, n_threads, bulk = args
cut_positions = defaultdict(Counter)
print_reasons = False
reasons = Counter()
with pysam.AlignmentFile(bam) as alignments:
for R1, R2 in mate_iter(alignments, contig=contig):


if R1 is None:
if print_reasons:
reasons['r1_none'] += 1
continue
if not filter_function(R1):
if print_reasons:
if R1.reference_start is None:
reasons['norefstart'] += 1
elif R1.reference_end is None:
reasons['norefend'] += 1
elif R1.is_duplicate is None:
reasons['duplicate'] += 1
elif R1.mapping_quality<50:
reasons['mq'] += 1
else:
reasons['filter'] += 1
continue

if bulk:
k = ('bulk' if prefix is None else (prefix, 'bulk'))
else:
k = R1.get_tag('SM') if prefix is None else (prefix, R1.get_tag('SM'))

if R1.is_reverse is None:
if print_reasons:
reasons['norev'] += 1
continue

if R1.is_reverse:
cut_location = R1.reference_end
else:
cut_location = R1.reference_start

cut_positions[k][
(R1.is_reverse, cut_location)
if strand_specific else
cut_location
] += 1
if print_reasons:
reasons['ok'] += 1
# for i,read in enumerate([R1,]):
# if read is None or not filter_function(read):
# continue
#
# if i==0:
# if read.is_reverse:
# cut_location = read.reference_end
# else:
# cut_location = read.reference_start
# if R1.is_reverse is not None:
# cut_positions[k][
# (R1.is_reverse, cut_location)
# if strand_specific else
# cut_location
# ] += 1
# else: # R2:
# if read.is_reverse:
# cut_location = read.reference_start
# else:
# cut_location = read.reference_end
# if R2.is_reverse is not None:
# cut_positions[k][
# (not R2.is_reverse, cut_location)
# if strand_specific else
# cut_location
# ] += 1

if print_reasons:
print(reasons)
return contig, cut_positions

def cuts_to_observation_vector(cell, cell_cuts, window_radius, n_bins, bin_size=1, take_n_samples=None,
log_distance=False):
log_distance=False, contig=None):
obs = np.zeros(n_bins, dtype=np.int64)

forward = np.array(list(cell_cuts.keys()))
Expand All @@ -233,6 +332,7 @@ def cuts_to_observation_vector(cell, cell_cuts, window_radius, n_bins, bin_size=
forward.sort()

total_tests = 0
print(f"Performing {len(forward)} tests on contig {contig}")
for position in forward:
distance_to_all_points = forward - position
in_bounds = np.abs(distance_to_all_points[(distance_to_all_points >= -window_radius) & (
Expand All @@ -247,18 +347,18 @@ def cuts_to_observation_vector(cell, cell_cuts, window_radius, n_bins, bin_size=
in_bounds = (np.floor(in_bounds / bin_size)).astype(int)
np.add.at(obs, in_bounds, 1)

return cell, obs, total_tests
return cell, obs, total_tests, contig


def _cuts_to_observation_vector(kwargs):
return cuts_to_observation_vector(**kwargs)


def analyse(bam_path,output_dir, create_plot=False, min_distance=20, max_distance=800, verbose=False, strand_specific=False):
def analyse(bam_path,output_dir, create_plot=False, min_distance=20, max_distance=800, verbose=False, strand_specific=False, bulk=False, filter_function=None,count_function=None):

if verbose:
print('Obtaining molecules per cell .. ', end='\r')
cpr = get_r1_counts_per_cell(bam_path)
cpr = get_r1_counts_per_cell(bam_path, get_r1_counts_per_cell='bulk' if bulk else None)

if verbose:
print('Molecules per cell: ')
Expand All @@ -268,15 +368,15 @@ def analyse(bam_path,output_dir, create_plot=False, min_distance=20, max_distanc
if verbose:
print('Obtaining cuts per cell .. ', end='\r')

cut_sites = get_sc_cut_dictionary(bam_path, strand_specific=strand_specific)
cut_sites = get_sc_cut_dictionary(bam_path, strand_specific=strand_specific, bulk=bulk, count_function=count_function,filter_function=filter_function)


all_counts = {}
for cell, total_molecules in cpr.most_common():
# Write from 0 to max_distance table
all_counts[cell] = DivCounter(dictionary_to_diff_vector(cut_sites,cell,0,max_distance))

cut_count_df = pd.DataFrame(all_counts).sort_index().sort_index(1).fillna(0)
cut_count_df = pd.DataFrame(all_counts).sort_index(axis=0).sort_index(axis=1).fillna(0)
cut_count_df.to_csv(f'{output_dir}/counts.csv')

if verbose:
Expand Down Expand Up @@ -417,8 +517,18 @@ def function_to_fit(xdata, period, offset, amplitude, decay, mean ):
argparser.add_argument('--legacy', action='store_true', help='Create legacy unstranded anaylsis plots and files')
argparser.add_argument('-max_distance', type=int,default=2000, help='Maximum distance in both plots and output tables')
argparser.add_argument('-t', type=int,default=None, help='Max processes')
argparser.add_argument('--bulk', action='store_true', help='All reads are derived from one sample')

argparser.add_argument('--nods', action='store_true', help='Reads do not have DS tags set, use loose setings and fragment ends')
args = argparser.parse_args()

if args.nods:
filter_func = loose_read_counts_function
count_func = _get_sc_cut_dictionary
else:
filter_func = strict_read_counts_function
count_func = _get_ds_sc_cut_dictionary

if args.regions is not None:
regions_per_contig = defaultdict(list)
with open(args.regions) as f:
Expand Down Expand Up @@ -458,17 +568,17 @@ def function_to_fit(xdata, period, offset, amplitude, decay, mean ):
print('Performing legacy analysis')
if len(args.alignmentfiles)!=1:
raise ValueError('The legacy analysis only works on a single bam file')
analyse(args.alignmentfiles[0], args.o, create_plot=True, verbose=True,strand_specific=False,max_distance=args.max_distance)
analyse(args.alignmentfiles[0], args.o, create_plot=True, verbose=True,strand_specific=False,max_distance=args.max_distance, count_function=count_func)

# Stranded analysis:
sc_cut_dict_stranded = get_sc_cut_dictionary( args.alignmentfiles,strand_specific=True,filter_function=strict_read_counts_function, regions=regions, n_threads=args.t)
sc_cut_dict_stranded = get_sc_cut_dictionary( args.alignmentfiles,strand_specific=True,filter_function=filter_func, regions=regions, n_threads=args.t, bulk=args.bulk, count_function=count_func)
distance_counter_fwd_above, distance_counter_fwd_below, distance_counter_rev_above, distance_counter_rev_below = get_stranded_pairwise_counts(sc_cut_dict_stranded)

# Write tables:
pd.DataFrame(distance_counter_fwd_above).sort_index().sort_index(1).to_csv(f'{args.o}/STRANDED_fwd_above.csv')
pd.DataFrame(distance_counter_fwd_below).sort_index().sort_index(1).to_csv(f'{args.o}/STRANDED_fwd_below.csv')
pd.DataFrame(distance_counter_rev_above).sort_index().sort_index(1).to_csv(f'{args.o}/STRANDED_rev_above.csv')
pd.DataFrame(distance_counter_rev_below).sort_index().sort_index(1).to_csv(f'{args.o}/STRANDED_rev_below.csv')
pd.DataFrame(distance_counter_fwd_above).sort_index(axis=1).sort_index(axis=0).to_csv(f'{args.o}/STRANDED_fwd_above.csv')
pd.DataFrame(distance_counter_fwd_below).sort_index(axis=1).sort_index(axis=0).to_csv(f'{args.o}/STRANDED_fwd_below.csv')
pd.DataFrame(distance_counter_rev_above).sort_index(axis=1).sort_index(axis=0).to_csv(f'{args.o}/STRANDED_rev_above.csv')
pd.DataFrame(distance_counter_rev_below).sort_index(axis=1).sort_index(axis=0).to_csv(f'{args.o}/STRANDED_rev_below.csv')

del sc_cut_dict_stranded

Expand All @@ -477,17 +587,17 @@ def function_to_fit(xdata, period, offset, amplitude, decay, mean ):

#################
# Unstranded density analysis:
print("Unstranded density analysis")
prefix_with_bam=False if len(args.alignmentfiles)==1 else True
sc_cut_dict = get_sc_cut_dictionary( args.alignmentfiles,strand_specific=False,filter_function=strict_read_counts_function, prefix_with_bam=prefix_with_bam, regions=regions, n_threads=args.t)
cpr = get_r1_counts_per_cell(args.alignmentfiles, prefix_with_bam=prefix_with_bam)
sc_cut_dict = get_sc_cut_dictionary( args.alignmentfiles,strand_specific=False,filter_function=filter_func, prefix_with_bam=prefix_with_bam, regions=regions, n_threads=args.t, count_function=count_func, bulk=args.bulk)
print("Obtaining counts per cell 1/2")
cpr = get_r1_counts_per_cell(args.alignmentfiles, prefix_with_bam=prefix_with_bam, assoc_all_to_sample='bulk' if args.bulk else None)
print("Obtaining counts per cell 2/2")
counts = pd.Series(cpr).sort_values()
print(counts)



def get_commands(one_contig=None):
def get_commands(sc_cut_dict, one_contig=None):
for contig in sc_cut_dict: # sc_cut_dict:
if '_' in contig or contig in ('chrY', 'chrM', 'chrEBV'):
if not keep_contig(contig):
continue
if one_contig is not None and contig != one_contig:
continue
Expand All @@ -506,8 +616,9 @@ def get_commands(one_contig=None):
# This is a histogram of the amount of observed fragments at distances x:
obs = defaultdict(lambda: np.zeros(n_bins, dtype=np.int64))
total_tests = Counter() # cell -> tests
print("\tcuts_to_observation_vector calculation")
with Pool(args.t) as workers:
for cell, cell_obs, n_tests in workers.imap_unordered(
for cell, cell_obs, n_tests, contig in workers.imap_unordered(
_cuts_to_observation_vector,

(
Expand All @@ -517,17 +628,20 @@ def get_commands(one_contig=None):
'log_distance': False,
'n_bins': n_bins,
'bin_size': bin_size,
'take_n_samples': None # sample_target[contig]
'take_n_samples': None, # sample_target[contig]
'contig':contig
}
for cell, cell_cuts, contig in get_commands()
for cell, cell_cuts, contig in get_commands(sc_cut_dict)
)):
print(f'\tFinished {cell} [{contig}]')
obs[cell] += cell_obs
total_tests[cell] += n_tests

p_obs = pd.DataFrame(obs) / pd.Series(total_tests)
p_obs.index = x_obs

# Means per library:
print('Exporting results')

window = 35
p_obs.to_csv(f'{args.o}/strand_unspecific_density_raw.csv')
Expand Down
Loading
Loading