Skip to content

Commit

Permalink
more consistent storage of data_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
falexwolf committed Jul 31, 2017
1 parent 64a6ae0 commit 8c00b4b
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 43 deletions.
31 changes: 17 additions & 14 deletions scanpy/data_structs/data_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from .. import utils


N_DCS = 15 # default number of diffusion components


def add_or_update_graph_in_adata(
adata,
n_neighbors=30,
Expand All @@ -37,8 +40,8 @@ def add_or_update_graph_in_adata(
n_jobs=n_jobs)
if graph.fresh_compute:
graph.update_diffmap()
adata.add['distance'] = graph.Dsq
adata.add['Ktilde'] = graph.Ktilde
adata.add['data_graph_distance_local'] = graph.Dsq
adata.add['data_graph_norm_weights'] = graph.Ktilde
adata.smp['X_diffmap'] = graph.rbasis[:, 1:]
adata.smp['X_diffmap0'] = graph.rbasis[:, 0]
adata.add['diffmap_evals'] = graph.evals[1:]
Expand All @@ -62,10 +65,10 @@ def no_recompute_of_graph_necessary(
and (adata.smp['X_diffmap'].shape[1] >= n_dcs-1
if n_dcs is not None else True)
# make sure that it's sparse
and (issparse(adata.add['Ktilde']) == knn
and (issparse(adata.add['data_graph_norm_weights']) == knn
if knn is not None else True)
# make sure n_neighbors matches
and n_neighbors == adata.add['distance'][0].nonzero()[0].size + 1)
and n_neighbors == adata.add['data_graph_distance_local'][0].nonzero()[0].size + 1)


def get_neighbors(X, Y, k):
Expand Down Expand Up @@ -204,15 +207,15 @@ def __init__(self,
knn=True,
n_jobs=None,
n_pcs=50,
n_dcs=15,
n_dcs=N_DCS,
recompute_pca=False,
recompute_distances=False,
recompute_graph=False,
flavor='haghverdi16'):
self.sym = True # we do not allow asymetric cases
self.flavor = flavor # this is to experiment around
self.n_pcs = n_pcs
self.n_dcs = n_dcs
self.n_dcs = n_dcs if n_dcs is not None else N_DCS
self.init_iroot_and_X(adata, recompute_pca, n_pcs)
# use the graph in adata
if no_recompute_of_graph_necessary(
Expand All @@ -224,11 +227,11 @@ def __init__(self,
knn=knn,
n_dcs=n_dcs):
self.fresh_compute = False
self.knn = issparse(adata.add['Ktilde'])
self.Ktilde = adata.add['Ktilde']
self.Dsq = adata.add['distance']
self.knn = issparse(adata.add['data_graph_norm_weights'])
self.Ktilde = adata.add['data_graph_norm_weights']
self.Dsq = adata.add['data_graph_distance_local']
if self.knn:
self.k = adata.add['distance'][0].nonzero()[0].size + 1
self.k = adata.add['data_graph_distance_local'][0].nonzero()[0].size + 1
else:
self.k = None # currently do not store this, is unknown
# for output of spectrum
Expand Down Expand Up @@ -261,13 +264,13 @@ def __init__(self,
self.init_iroot_and_X(adata, recompute_pca, n_pcs)
if False: # TODO
# in case we already computed distance relations
if not recompute_distances and 'distance' in adata.add:
n_neighbors = adata.add['distance'][0].nonzero()[0].size + 1
if (knn and issparse(adata.add['distance'])
if not recompute_distances and 'data_graph_distance_local' in adata.add:
n_neighbors = adata.add['data_graph_distance_local'][0].nonzero()[0].size + 1
if (knn and issparse(adata.add['data_graph_distance_local'])
and n_neighbors == self.k):
logg.info(' using stored distances with `n_neighbors={}`'
.format(self.k))
self.Dsq = adata.add['distance']
self.Dsq = adata.add['data_graph_distance_local']

def init_iroot_directly(self, adata):
self.iroot = None
Expand Down
46 changes: 30 additions & 16 deletions scanpy/plotting/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def diffmap(
writekey = 'diffmap'
if isinstance(components, list): components = ','.join([str(comp) for comp in components])
writekey += '_components' + components.replace(',', '')
if sett.savefigs or (save is not None): savefig(writekey) # TODO: cleaner
if sett.savefigs or (save is not None): utils.savefig(writekey) # TODO: cleaner
show = sett.autoshow if show is None else show
if not sett.savefigs and show: pl.show()
return axs
Expand Down Expand Up @@ -422,6 +422,7 @@ def aga(
title=None,
left_margin=0.05,
layout_graph=None,
minimal_realized_attachedness=None,
attachedness_type='relative',
show=None,
save=None):
Expand All @@ -442,9 +443,10 @@ def aga(
show=False)
axs[1].set_frame_on(False)
aga_graph(adata, root=root, fontsize=fontsize, ax=axs[1],
layout=layout_graph,
attachedness_type=attachedness_type,
show=False)
layout=layout_graph,
attachedness_type=attachedness_type,
minimal_realized_attachedness=minimal_realized_attachedness,
show=False)
utils.savefig_or_show('aga', show=show, save=save)


Expand Down Expand Up @@ -586,6 +588,7 @@ def aga_graph(
add_noise_to_node_positions=None,
left_margin=0.01,
attachedness_type='relative',
minimal_realized_attachedness=None,
force_labels_to_front=False,
show=None,
save=None,
Expand Down Expand Up @@ -620,6 +623,8 @@ def aga_graph(
raise ValueError('`colors` and `groups` lists need to have the same length.')
if title is None or isinstance(title, str): title = [title for name in groups]
if ax is None:
# 3.72 is the default figure_width obtained in utils.scatter_base
# for a single panel when rcParams['figure.figsize'][0] = 4
figure_width = rcParams['figure.figsize'][0] * len(colors)
top = 0.93
fig, axs = pl.subplots(ncols=len(colors),
Expand All @@ -641,6 +646,7 @@ def aga_graph(
node_size_power=node_size_power,
edge_width=edge_width,
attachedness_type=attachedness_type,
minimal_realized_attachedness=minimal_realized_attachedness,
ext=ext,
ax=axs[icolor],
title=title[icolor],
Expand All @@ -666,6 +672,7 @@ def _aga_graph_single(
ax=None,
layout=None,
add_noise_to_node_positions=None,
minimal_realized_attachedness=None,
attachedness_type=False,
draw_edge_labels=False,
force_labels_to_front=False):
Expand Down Expand Up @@ -733,6 +740,8 @@ def _aga_graph_single(
fig = pl.figure()
ax = pl.axes([0.08, 0.08, 0.9, 0.9], frameon=False)
# edge widths
from ..tools import aga
minimal_realized_attachedness = aga.MINIMAL_REALIZED_ATTACHEDNESS if minimal_realized_attachedness is None else minimal_realized_attachedness
base_edge_width = edge_width * 1.5*rcParams['lines.linewidth']
if 'aga_attachedness' in adata.add:
if attachedness_type == 'relative':
Expand All @@ -747,7 +756,11 @@ def _aga_graph_single(
nx_g = nx.Graph(adata.add['aga_adjacency'])
else:
nx_g = nx.Graph(adata.add['aga_adjacency_absolute'])
widths = [base_edge_width*x[-1]['weight'] for x in nx_g.edges(data=True)]
if minimal_realized_attachedness == aga.MINIMAL_REALIZED_ATTACHEDNESS:
widths = [base_edge_width*x[-1]['weight'] for x in nx_g.edges(data=True)]
else:
widths = [base_edge_width*(x[-1]['weight'] if x[-1]['weight'] != aga.MINIMAL_REALIZED_ATTACHEDNESS else minimal_realized_attachedness)
for x in nx_g.edges(data=True)]
nx.draw_networkx_edges(nx_g, pos, ax=ax, width=widths, edge_color='black')
else:
nx.draw_networkx_edges(nx_g, pos, ax=ax, width=widths, edge_color='black')
Expand Down Expand Up @@ -863,7 +876,7 @@ def moving_average(a, n=n_avg):
ret[n:] = ret[n:] - ret[:-n]
return ret[n - 1:] / n

ax = pl.gca()
ax = pl.gca() if ax is None else ax
from matplotlib import transforms
trans = transforms.blended_transform_factory(
ax.transData, ax.transAxes)
Expand All @@ -885,7 +898,7 @@ def moving_average(a, n=n_avg):
x = moving_average(x)
if ikey == 0: x_tick_locs = len(x)/old_len_x * np.array(x_tick_locs)
if not as_heatmap:
pl.plot(x[xlim[0]:xlim[1]], label=key)
ax.plot(x[xlim[0]:xlim[1]], label=key)
else:
X.append(x)
if ikey == 0:
Expand All @@ -900,23 +913,24 @@ def moving_average(a, n=n_avg):
else:
x_tick_labels.append(label)
if as_heatmap:
pl.imshow(np.array(X), aspect='auto', interpolation='nearest',
cmap=color_map)
pl.yticks(range(len(X)), keys, fontsize=ytick_fontsize)
ax = pl.gca()
img = ax.imshow(np.array(X), aspect='auto', interpolation='nearest',
cmap=color_map)
ax.set_yticks(range(len(X)))
ax.set_yticklabels(keys, fontsize=ytick_fontsize)
ax.set_frame_on(False)
pl.colorbar()
pl.colorbar(img, ax=ax)
left_margin = 0.2 if left_margin is None else left_margin
pl.subplots_adjust(left=left_margin)
else:
left_margin = 0.4 if left_margin is None else left_margin
pl.legend(frameon=False, loc='center left',
bbox_to_anchor=(-left_margin, 0.5),
fontsize=legend_fontsize)
pl.xticks(x_tick_locs, x_tick_labels)
pl.xlabel(adata.add['aga_groups_original'] if ('aga_groups_original' in adata.add
and adata.add['aga_groups_original'] != 'louvain_groups')
else 'aga groups')
ax.set_xticks(x_tick_locs)
ax.set_xticklabels(x_tick_labels)
ax.set_xlabel(adata.add['aga_groups_original'] if ('aga_groups_original' in adata.add
and adata.add['aga_groups_original'] != 'louvain_groups')
else 'aga groups')
if show_left_y_ticks:
utils.pimp_axis(pl.gca().get_yaxis())
pl.ylabel('as indicated on legend')
Expand Down
2 changes: 1 addition & 1 deletion scanpy/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def timeseries_subplot(X,
pl.legend(frameon=False)


def timeseries_as_heatmap(X, var_names=None, highlightsX=None, color_map='viridis'):
def timeseries_as_heatmap(X, var_names=None, highlightsX=None, color_map=None):
"""Plot timeseries as heatmap.
Parameters
Expand Down
10 changes: 7 additions & 3 deletions scanpy/tools/aga.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from .louvain import louvain
from ..plotting import utils as pl_utils


MINIMAL_REALIZED_ATTACHEDNESS = 0.05


def aga(adata,
node_groups='louvain',
n_nodes=None,
Expand Down Expand Up @@ -158,8 +162,8 @@ def aga(adata,
adata.smp['X_diffmap'] = aga.rbasis[:, 1:]
adata.smp['X_diffmap0'] = aga.rbasis[:, 0]
adata.add['diffmap_evals'] = aga.evals[1:]
adata.add['distance'] = aga.Dsq
adata.add['Ktilde'] = aga.Ktilde
adata.add['data_graph_distance_local'] = aga.Dsq
adata.add['data_graph_norm_weights'] = aga.Ktilde
if aga.iroot is not None:
aga.set_pseudotime() # pseudotimes are random walk distances from root point
adata.add['iroot'] = aga.iroot # update iroot, might have changed when subsampling, for example
Expand Down Expand Up @@ -427,7 +431,7 @@ def detect_splits(self):
norm = np.sqrt(np.multiply.outer(self.segs_sizes, self.segs_sizes))
self.segs_attachedness_absolute /= norm

minimal_realized_attachedness = 0.1
minimal_realized_attachedness = MINIMAL_REALIZED_ATTACHEDNESS
self.segs_adjacency = sp.sparse.lil_matrix((len(segs), len(segs)), dtype=float)
self.segs_adjacency_absolute = sp.sparse.lil_matrix((len(segs), len(segs)), dtype=float)
for i, neighbors in enumerate(segs_adjacency):
Expand Down
4 changes: 2 additions & 2 deletions scanpy/tools/diffmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def diffmap(adata, n_comps=15, n_neighbors=30, knn=True, n_pcs=50, sigma=0, n_jo
n_dcs=n_comps, n_jobs=n_jobs, recompute_graph=True,
flavor=flavor)
dmap.update_diffmap()
adata.add['distance'] = dmap.Dsq
adata.add['Ktilde'] = dmap.Ktilde
adata.add['data_graph_distance_local'] = dmap.Dsq
adata.add['data_graph_norm_weights'] = dmap.Ktilde
adata.smp['X_diffmap'] = dmap.rbasis[:, 1:]
adata.smp['X_diffmap0'] = dmap.rbasis[:, 0]
adata.add['diffmap_evals'] = dmap.evals[1:]
Expand Down
4 changes: 2 additions & 2 deletions scanpy/tools/dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def dpt(adata, n_branchings=0, n_neighbors=30, knn=True, n_pcs=50, n_dcs=10,
adata.smp['X_diffmap'] = dpt.rbasis[:, 1:]
adata.smp['X_diffmap0'] = dpt.rbasis[:, 0]
adata.add['diffmap_evals'] = dpt.evals[1:]
if knn: adata.add['distance'] = dpt.Dsq
if knn: adata.add['Ktilde'] = dpt.Ktilde
adata.add['data_graph_distance_local'] = dpt.Dsq
adata.add['data_graph_norm_weights'] = dpt.Ktilde
if n_branchings > 1: logg.info(' this uses a hierarchical implementation')
# compute DPT distance matrix, which we refer to as 'Ddiff'
if dpt.iroot is not None:
Expand Down
2 changes: 1 addition & 1 deletion scanpy/tools/draw_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def draw_graph(adata,
recompute_distances=recompute_distances,
recompute_graph=recompute_graph,
n_jobs=n_jobs)
adjacency = adata.add['Ktilde']
adjacency = adata.add['data_graph_norm_weights']
g = utils.get_igraph_from_adjacency(adjacency)
if layout in {'fr', 'drl', 'kk', 'grid_fr'}:
np.random.seed(random_state)
Expand Down
13 changes: 9 additions & 4 deletions scanpy/tools/louvain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ def louvain(adata,
n_neighbors : int, optional (default: 30)
Number of neighbors to use for construction of knn graph.
resolution : float or None, optional
For the default flavor, you provide a resolution, which defaults to 1.0.
For the default flavor ('vtraag'), you can provide a resolution (higher
resolution means finding more and smaller clusters), which defaults to
1.0.
flavor : {'vtraag', 'igraph'}
Choose between to packages for computing the clustering. 'vtraag' is
much more powerful.
copy : bool (default: False)
References
Expand All @@ -49,12 +54,12 @@ def louvain(adata,
adata,
n_neighbors=n_neighbors,
n_pcs=n_pcs,
n_dcs=n_dcs,
recompute_pca=recompute_pca,
recompute_distances=recompute_distances,
recompute_graph=recompute_graph,
n_dcs=n_dcs,
n_jobs=n_jobs)
adjacency = adata.add['Ktilde']
adjacency = adata.add['data_graph_norm_weights']
if flavor in {'vtraag', 'igraph'}:
if flavor == 'igraph' and resolution is not None:
logg.warn('`resolution` parameter has no effect for flavor "igraph"')
Expand Down Expand Up @@ -90,7 +95,7 @@ def louvain(adata,
# this is deprecated
import networkx as nx
import community
g = nx.Graph(adata.add['distance'])
g = nx.Graph(adata.add['data_graph_distance_local'])
partition = community.best_partition(g)
groups = np.zeros(len(partition), dtype=int)
for k, v in partition.items(): groups[k] = v
Expand Down

0 comments on commit 8c00b4b

Please sign in to comment.