diff --git a/tests/test_functions.py b/tests/test_functions.py index 0841a1a6..a3c37948 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -92,7 +92,7 @@ def test_gamma_approx(self): class TestNodeTipWeights(unittest.TestCase): - def verify_weights(self, ts): + def verify_spans(self, ts): span_data = SpansBySamples(ts) # Check all non-sample nodes in a tree are represented nonsample_nodes = collections.defaultdict(float) @@ -106,150 +106,143 @@ def verify_weights(self, ts): assert span == pytest.approx(span_data.node_spans[id_]) for focal_node in span_data.nodes_to_date: wt = 0 - for _, weights in span_data.get_weights(focal_node).items(): + for _, spans in span_data.get_spans(focal_node).items(): assert 0 <= focal_node < ts.num_nodes - wt += np.sum(weights["weight"]) - assert max(weights["descendant_tips"]) <= ts.num_samples + wt += np.sum(spans["span"]) + assert max(spans["descendant_tips"]) <= ts.num_samples if not np.isnan(wt): # Dangling nodes will have wt=nan - assert wt == pytest.approx(1.0) + assert wt == pytest.approx(span_data.node_spans[focal_node]) return span_data def test_one_tree_n2(self): ts = utility_functions.single_tree_ts_n2() - span_data = self.verify_weights(ts) - # with a single tree there should only be one weight + span_data = self.verify_spans(ts) + # with a single tree there should only be one span for node in span_data.nodes_to_date: - assert len(span_data.get_weights(node)), 1 - assert 2 in span_data.get_weights(2)[ts.num_samples]["descendant_tips"] + assert len(span_data.get_spans(node)), 1 + assert 2 in span_data.get_spans(2)[ts.num_samples]["descendant_tips"] def test_one_tree_n3(self): ts = utility_functions.single_tree_ts_n3() n = ts.num_samples - span_data = self.verify_weights(ts) - # with a single tree there should only be one weight + span_data = self.verify_spans(ts) + # with a single tree there should only be one span for node in span_data.nodes_to_date: - assert len(span_data.get_weights(node)), 1 + assert len(span_data.get_spans(node)), 1 for nd, expd_tips in [ (4, 3), # Node 4 (root) expected to have 3 descendant tips (3, 2), ]: # Node 3 (1st internal node) expected to have 2 descendant tips - assert np.isin(span_data.get_weights(nd)[n]["descendant_tips"], expd_tips) + assert np.isin(span_data.get_spans(nd)[n]["descendant_tips"], expd_tips) def test_one_tree_n4(self): ts = utility_functions.single_tree_ts_n4() n = ts.num_samples - span_data = self.verify_weights(ts) - # with a single tree there should only be one weight + span_data = self.verify_spans(ts) + # with a single tree there should only be one span for node in span_data.nodes_to_date: - assert len(span_data.get_weights(node)), 1 + assert len(span_data.get_spans(node)), 1 for nd, expd_tips in [ (6, 4), # Node 6 (root) expected to have 4 descendant tips (5, 3), # Node 5 (1st internal node) expected to have 3 descendant tips (4, 2), ]: # Node 4 (2nd internal node) expected to have 3 descendant tips - assert np.isin(span_data.get_weights(nd)[n]["descendant_tips"], expd_tips) + assert np.isin(span_data.get_spans(nd)[n]["descendant_tips"], expd_tips) def test_two_trees(self): ts = utility_functions.two_tree_ts() n = ts.num_samples - span_data = self.verify_weights(ts) - assert span_data.lookup_weight(5, n, 3) == 1.0 # Root on R tree - assert span_data.lookup_weight(4, n, 3) == 0.2 # Root on L tree ... + span_data = self.verify_spans(ts) + assert span_data.lookup_span(5, n, 3) == 0.8 # Root on R tree + assert span_data.lookup_span(4, n, 3) == 0.2 # Root on L tree ... # ... but internal node on R tree - assert span_data.lookup_weight(4, n, 2) == 0.8 - assert span_data.lookup_weight(3, n, 2) == 1.0 # Internal nd on L tree + assert span_data.lookup_span(4, n, 2) == 0.8 + assert span_data.lookup_span(3, n, 2) == 0.2 # Internal nd on L tree def test_missing_tree(self): - ts = utility_functions.two_tree_ts().keep_intervals([(0, 0.2)], simplify=False) + keep = 0.2 + ts = utility_functions.two_tree_ts().keep_intervals([(0, keep)], simplify=False) n = ts.num_samples # Here we have no reference in the trees to node 5 with pytest.raises(ValueError, match="nodes not in any tree"): SpansBySamples(ts) ts = ts.simplify() - span_data = self.verify_weights(ts) + span_data = self.verify_spans(ts) # Root on (deleted) R tree is missing assert 5 not in span_data.nodes_to_date - assert span_data.lookup_weight(4, n, 3) == 1.0 # Root on L tree ... + assert span_data.lookup_span(4, n, 3) == keep # Root on L tree ... # ... but internal on (deleted) R tree - assert not np.isin(span_data.get_weights(4)[n]["descendant_tips"], 2) - assert span_data.lookup_weight(3, n, 2) == 1.0 # Internal nd on L tree + assert not np.isin(span_data.get_spans(4)[n]["descendant_tips"], 2) + assert span_data.lookup_span(3, n, 2) == keep # Internal nd on L tree def test_tree_with_unary_nodes(self): ts = utility_functions.single_tree_ts_with_unary() with pytest.raises(ValueError, match="unary"): - self.verify_weights(ts) + self.verify_spans(ts) @pytest.mark.skip("Unary node is internal then the oldest node") def test_tree_with_unary_nodes_oldest(self): ts = utility_functions.two_tree_ts_with_unary_n3() n = ts.num_samples - span_data = self.verify_weights(ts) - assert span_data.lookup_weight(9, n, 4) == 0.5 - assert span_data.lookup_weight(8, n, 4) == 1.0 - assert span_data.lookup_weight(7, n, 1) == 0.5 - assert span_data.lookup_weight(7, n, 4) == 0.5 - assert span_data.lookup_weight(6, n, 2) == 0.5 - assert span_data.lookup_weight(6, n, 4) == 0.5 - assert span_data.lookup_weight(5, n, 2) == 0.5 - assert span_data.lookup_weight(4, n, 2) == 1.0 + span_data = self.verify_spans(ts) + assert span_data.lookup_span(9, n, 4) == 0.5 + assert span_data.lookup_span(8, n, 4) == 1.0 + assert span_data.lookup_span(7, n, 1) == 0.5 + assert span_data.lookup_span(7, n, 4) == 0.5 + assert span_data.lookup_span(6, n, 2) == 0.5 + assert span_data.lookup_span(6, n, 4) == 0.5 + assert span_data.lookup_span(5, n, 2) == 0.5 + assert span_data.lookup_span(4, n, 2) == 1.0 def test_polytomy_tree(self): ts = utility_functions.polytomy_tree_ts() - span_data = self.verify_weights(ts) - assert span_data.lookup_weight(3, ts.num_samples, 3) == 1.0 + span_data = self.verify_spans(ts) + assert span_data.lookup_span(3, ts.num_samples, 3) == 1.0 - def test_larger_find_node_tip_weights(self): + def test_larger_find_node_tip_spans(self): ts = msprime.simulate( 10, recombination_rate=5, mutation_rate=5, random_seed=123 ) assert ts.num_trees > 1 - self.verify_weights(ts) + self.verify_spans(ts) def test_dangling_nodes_error(self): ts = utility_functions.single_tree_ts_n2_dangling() with pytest.raises(ValueError, match="dangling"): - self.verify_weights(ts) + self.verify_spans(ts) def test_single_tree_n2_delete_intervals(self): ts = utility_functions.single_tree_ts_n2() - deleted_interval_ts = ts.delete_intervals([[0.5, 0.6]]) + delete_span = 0.1 + deleted_interval_ts = ts.delete_intervals([[0.4, 0.4 + delete_span]]) n = deleted_interval_ts.num_samples - span_data = self.verify_weights(ts) - span_data_deleted = self.verify_weights(deleted_interval_ts) - assert span_data.lookup_weight(2, n, 2) == span_data_deleted.lookup_weight( - 2, n, 2 - ) + span_data_deleted = self.verify_spans(deleted_interval_ts) + assert span_data_deleted.lookup_span(2, n, 2) == pytest.approx(1 - delete_span) def test_single_tree_n4_delete_intervals(self): ts = utility_functions.single_tree_ts_n4() - deleted_interval_ts = ts.delete_intervals([[0.5, 0.6]]) + delete_span = 0.1 + deleted_interval_ts = ts.delete_intervals([[0.5, 0.5 + delete_span]]) n = deleted_interval_ts.num_samples - span_data = self.verify_weights(ts) - span_data_deleted = self.verify_weights(deleted_interval_ts) - assert span_data.lookup_weight(4, n, 2) == span_data_deleted.lookup_weight( - 4, n, 2 - ) - assert span_data.lookup_weight(5, n, 3) == span_data_deleted.lookup_weight( - 5, n, 3 - ) - assert span_data.lookup_weight(6, n, 4) == span_data_deleted.lookup_weight( - 6, n, 4 - ) + span_data_deleted = self.verify_spans(deleted_interval_ts) + assert span_data_deleted.lookup_span(4, n, 2) == pytest.approx(1 - delete_span) + assert span_data_deleted.lookup_span(5, n, 3) == pytest.approx(1 - delete_span) + assert span_data_deleted.lookup_span(6, n, 4) == pytest.approx(1 - delete_span) def test_two_tree_ts_delete_intervals(self): ts = utility_functions.two_tree_ts() - deleted_interval_ts = ts.delete_intervals([[0.5, 0.6]]) + delete_span = 0.1 + deleted_interval_ts = ts.delete_intervals([[0.6, 0.6 + delete_span]]) n = deleted_interval_ts.num_samples - span_data = self.verify_weights(ts) - span_data_deleted = self.verify_weights(deleted_interval_ts) - assert span_data.lookup_weight(3, n, 2) == span_data_deleted.lookup_weight( - 3, n, 2 - ) - assert span_data_deleted.lookup_weight(4, n, 2)[0] == pytest.approx(0.7 / 0.9) - assert span_data_deleted.lookup_weight(4, n, 3)[0] == pytest.approx(0.2 / 0.9) - assert span_data.lookup_weight(5, n, 3) == span_data_deleted.lookup_weight( - 3, n, 2 + span_data = self.verify_spans(ts) + span_data_deleted = self.verify_spans(deleted_interval_ts) + assert span_data.lookup_span(3, n, 2) == span_data_deleted.lookup_span(3, n, 2) + assert span_data_deleted.lookup_span(4, n, 2)[0] == pytest.approx(0.7) + assert span_data_deleted.lookup_span(4, n, 3)[0] == pytest.approx(0.2) + assert ( + span_data.lookup_span(5, n, 3) / span_data.node_spans[5] + == span_data_deleted.lookup_span(3, n, 2) / span_data.node_spans[3] ) @pytest.mark.skip("YAN to fix") @@ -261,7 +254,7 @@ def test_truncated_nodes(self): truncated_ts = utility_functions.truncate_ts_samples( ts, average_span=200, random_seed=123 ) - span_data = self.verify_weights(truncated_ts) + span_data = self.verify_spans(truncated_ts) raise NotImplementedError(str(span_data)) @@ -1658,7 +1651,7 @@ def test_node_metadata_simulated_tree(self): for met in tskit.unpack_bytes(metadata, metadata_offset) if len(met.decode()) > 0 ] - assert np.array_equal(unconstrained_mn, mn_post[larger_ts.num_samples :]) + assert np.allclose(unconstrained_mn, mn_post[larger_ts.num_samples :]) assert np.all( dated_ts.tables.nodes.time[larger_ts.num_samples :] >= mn_post[larger_ts.num_samples :] @@ -1959,7 +1952,7 @@ def test_sites_time_root_mutation(self): def test_sites_time_multiple_mutations(self): ts = utility_functions.single_tree_ts_2mutations_n3() sites_time = tsdate.sites_time_from_ts(ts, unconstrained=False) - assert np.array_equal(sites_time, [10]) + assert np.allclose(sites_time, [10]) def test_sites_time_simulated(self): larger_ts = msprime.simulate( @@ -1969,11 +1962,11 @@ def test_sites_time_simulated(self): larger_ts, mutation_rate=None, population_size=10000 ) dated = date(larger_ts, mutation_rate=None, population_size=10000) - assert np.array_equal( + assert np.allclose( mn_post[larger_ts.tables.mutations.node], tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0), ) - assert np.array_equal( + assert np.allclose( dated.tables.nodes.time[larger_ts.tables.mutations.node], tsdate.sites_time_from_ts(dated, unconstrained=False, min_time=0), ) diff --git a/tests/test_priors.py b/tests/test_priors.py index cd038c1a..318a92d1 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -24,6 +24,7 @@ """ import logging +import numpy as np import pytest import utility_functions @@ -62,6 +63,43 @@ def test_clear_precalc_debug(self, caplog): priors.clear_precalculated_priors() assert "not yet created" in caplog.text + @pytest.mark.parametrize("logwt", [True, False]) + def test_mixture_expect_and_var(self, logwt): + priors = ConditionalCoalescentTimes(None) + priors.add(3) + params = {3: {"descendant_tips": [3, 2], "span": np.array([0, 200])}} + mean1, var1 = priors.mixture_expect_and_var(params, weight_by_log_span=logwt) + params = {3: {"descendant_tips": [2], "span": np.array([100])}} + mean2, var2 = priors.mixture_expect_and_var(params, weight_by_log_span=logwt) + assert mean1 == pytest.approx(1 / 3) # 1/N for a cherry + assert var1 == pytest.approx(1 / 9) + assert mean1 == mean2 + assert var1 == var2 + + def test_mixture_expect_and_var_weight(self): + priors = ConditionalCoalescentTimes(None) + priors.add(4) + priors.add(5) + span = np.array([1, 3]) + params = { + 4: {"descendant_tips": [2], "span": span[0]}, + 5: {"descendant_tips": [2], "span": span[1]}, + } + linwt = priors.mixture_expect_and_var(params, weight_by_log_span=False) + assert linwt[0] == pytest.approx( + (1 / 4 * span[0] + 1 / 5 * span[1]) / np.sum(span) + ) + + # use exponential version to test log weights + # The log weighting adds one to the value, so here we subtract one + exp_span = np.exp(span) - 1 + params = { + 4: {"descendant_tips": [2], "span": exp_span[0]}, + 5: {"descendant_tips": [2], "span": exp_span[1]}, + } + logwt = priors.mixture_expect_and_var(params, weight_by_log_span=True) + assert np.allclose(linwt, logwt) + class TestSpansBySamples: def test_repr(self): diff --git a/tsdate/prior.py b/tsdate/prior.py index 47df9c42..30370e67 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -85,6 +85,8 @@ def __init__( self.n_approx = precalc_approximation_n self.prior_store = {} self.progress = progress + self.mean_column = PriorParams.field_index("mean") + self.var_column = PriorParams.field_index("var") if precalc_approximation_n: # Create lookup table based on a large n that can be used for n > ~50 @@ -300,16 +302,53 @@ def tau_var_exact(self, total_tips, all_tips): ) ] + def mixture_expect_and_var(self, mixture, weight_by_log_span=False): + """ + Return the expectation and variance of a coalescent mixture + mixture is a dict of numpy recarrays of the form + {N: {descendant_tips: [N_tips], span: [N_spans]}} + + weight_by_log_span is a boolean that determines whether the + weights are taken as the log of the span of the node + (plus one to avoid log(0). Testing indicates that + this gives a slightly better fit to the observed values + under the coalescent with recombination. + + Note, however, that both the expected mean and (especially) + the expected variance are substantially affected by the + *total length* of the span rather than just the relative + weights, which is not taken into account here + + """ + expectation = 0 + first = secnd = 0 + weight_sum = 0 + for N, tip_dict in mixture.items(): + # assert 1 not in tip_dict.descendant_tips + mean_time = self[N][tip_dict["descendant_tips"], self.mean_column] + var_time = self[N][tip_dict["descendant_tips"], self.var_column] + # Add one to avoid log(0) + w = np.log(tip_dict["span"] + 1) if weight_by_log_span else tip_dict["span"] + # Mixture expectation + expectation += np.sum(mean_time * w) + # Mixture variance + first += np.sum(var_time * w) + secnd += np.sum(mean_time**2 * w) + weight_sum += np.sum(w) + mean = expectation / weight_sum + var = (first + secnd) / weight_sum - (mean**2) + return mean, var + def get_mixture_prior_params(self, spans_by_samples): """ - Given an object that can be queried for tip weights for a node, - and a set of conditional coalescent priors for different + Given an object that can be queried for spans by num descendant tips + for a node, and a set of conditional coalescent priors for different numbers of sample tips under a node, return distribution parameters (shape and scale) that best fit the distribution for that node. :param .SpansBySamples spans_by_samples: An instance of the :class:`SpansBySamples` class that can be used to obtain - weights for each. + spans for each node to use as weights. :return: A numpy array whose rows corresponds to the node id in ``spans_by_samples.nodes_to_date`` and whose columns are the parameter columns in PriorParams (i.e. not including the mean and variance) @@ -318,28 +357,10 @@ def get_mixture_prior_params(self, spans_by_samples): :rtype: numpy.ndarray """ - mean_column = PriorParams.field_index("mean") - var_column = PriorParams.field_index("var") param_cols = np.array( [i for i, f in enumerate(PriorParams._fields) if f not in ("mean", "var")] ) - def mixture_expect_and_var(mixture): - expectation = 0 - first = secnd = 0 - for N, tip_dict in mixture.items(): - # assert 1 not in tip_dict.descendant_tips - mean = self[N][tip_dict["descendant_tips"], mean_column] - var = self[N][tip_dict["descendant_tips"], var_column] - # Mixture expectation - expectation += np.sum(mean * tip_dict["weight"]) - # Mixture variance - first += np.sum(var * tip_dict["weight"]) - secnd += np.sum(mean**2 * tip_dict["weight"]) - mean = expectation - var = first + secnd - (expectation**2) - return mean, var - seen_mixtures = {} # allocate space for params for all nodes, even though we only use nodes_to_date num_nodes, num_params = spans_by_samples.ts.num_nodes, len(param_cols) @@ -350,32 +371,34 @@ def mixture_expect_and_var(mixture): disable=not self.progress, desc="Find Mixture Priors", ): - mixture = spans_by_samples.get_weights(node) + mixture = spans_by_samples.get_spans(node) if len(mixture) == 1: # The norm: this node spans trees that all have the same set of samples - total_tips, weight_arr = next(iter(mixture.items())) - if weight_arr.shape[0] == 1: - d_tips = weight_arr["descendant_tips"][0] + total_tips, span_arr = next(iter(mixture.items())) + if span_arr.shape[0] == 1: + d_tips = span_arr["descendant_tips"][0] # This node is not a mixture - can use the standard coalescent prior priors[node] = self[total_tips][d_tips, param_cols] - elif weight_arr.shape[0] <= 5: + elif span_arr.shape[0] <= 5: # Making mixture priors is a little expensive. We can help by caching # in those cases where we have only a few mixtures # (arbitrarily set here as <= 5 mixtures) - mixture_hash = (total_tips, weight_arr.tobytes()) + mixture_hash = (total_tips, span_arr.tobytes()) if mixture_hash not in seen_mixtures: priors[node] = seen_mixtures[mixture_hash] = self.func_approx( - *mixture_expect_and_var(mixture) + *self.mixture_expect_and_var(mixture) ) else: priors[node] = seen_mixtures[mixture_hash] else: # a large number of mixtures in this node - don't bother caching - priors[node] = self.func_approx(*mixture_expect_and_var(mixture)) + priors[node] = self.func_approx( + *self.mixture_expect_and_var(mixture) + ) else: # The node spans trees with multiple total tip numbers, # don't use the cache - priors[node] = self.func_approx(*mixture_expect_and_var(mixture)) + priors[node] = self.func_approx(*self.mixture_expect_and_var(mixture)) # Check that references to the tskit.NULL'th node return NaNs, as we will later # be indexing into the prior array using a node mapping which could have NULLs assert np.all(np.isnan(priors[tskit.NULL, :])) @@ -387,9 +410,9 @@ class SpansBySamples: A class to efficiently calculate the genomic spans covered by each non-sample node, broken down by the number of samples that descend directly from that node. This is used to calculate the conditional - coalescent prior. The main method is :meth:`get_weights`, which - returns the spans for a node, normalized by the total span that that - node covers in the tree sequence. + coalescent prior. The main method is :meth:`get_spans`, which + returns the spans for each node, broken down by the number of + samples under different regions. .. note:: This assumes that all edges connect to the same tree - i.e. there is only a single topology present at each point in the @@ -398,7 +421,7 @@ class SpansBySamples: "missing data" nodes. :ivar tree_sequence: A reference to the tree sequence that was used to - generate the spans and weights + generate the spans :vartype tree_sequence: tskit.TreeSequence :ivar total_fixed_at_0_counts: A numpy array of unique numbers which list, in no particular order, the various sample counts among the trees @@ -417,7 +440,7 @@ class SpansBySamples: :ivar nodes_to_date: An numpy array containing all the node ids in the tree sequence that we wish to date. These are usually all the non-sample nodes, and also provide the node numbers that are valid parameters for the - :meth:`weights` method. + :meth:`get_spans` method. :vartype nodes_to_date: numpy.ndarray (dtype=np.uint32) """ @@ -466,10 +489,10 @@ def __repr__(self): ret = [] for n in range(self.ts.num_nodes): items = [] - for tot_tips, weights in self.get_weights(n).items(): + for tot_tips, spans in self.get_spans(n).items(): items.append( "[{}] / {} ".format( - ", ".join([f"{a}: {b}" for a, b in weights]), tot_tips + ", ".join([f"{a}: {b}" for a, b in spans]), tot_tips ) ) ret.append(f"Node {n: >3}: " + "{" + ", ".join(items) + "}") @@ -563,7 +586,9 @@ def save_to_spans(prev_tree, node, num_fixed_at_0_treenodes): " Skipping for now".format(node, prev_tree.index) ) return None - # Weights are exponential fractions: if no unary nodes above, we have + # for unary nodes, a proportion of the span is allocated + # according to the coalescent node above and the coalescent + # node below. If there are extra unary nodes above or below # weight = 1/2 from parent. If one unary node above, 1/4 from parent, etc wt = 2 ** (unary_nodes_above + 1) # 1/wt from abpve iwt = wt / (wt - 1.0) # 1/iwt from below @@ -763,8 +788,8 @@ def second_pass(self, trees_with_undated, n_tips_per_tree): "Assigning prior to unary node {}: connected to node {} which" " has a prior in tree {}".format(node, n, tree_id) ) - for n_tips, weights in self._spans[n].items(): - for k, v in weights.items(): + for n_tips, spans in self._spans[n].items(): + for k, v in spans.items(): if k <= 0: raise ValueError(f"Node {n} has no fixed descendants") local_weight = v / self.node_spans[n] @@ -806,13 +831,13 @@ def finalize(self): """ normalize the spans in self._spans by the values in self.node_spans, and overwrite the results (as we don't need them any more), providing a - shortcut to by setting normalized_node_span_data. Also provide the + shortcut to by setting node_span_data. Also provide the nodes_to_date value. """ - assert not hasattr(self, "normalized_node_span_data"), "Already finalized" - weight_dtype = np.dtype( + assert not hasattr(self, "node_span_data"), "Already finalized" + spans_dtype = np.dtype( { - "names": ("descendant_tips", "weight"), + "names": ("descendant_tips", "span"), "formats": (np.uint64, base.FLOAT_DTYPE), } ) @@ -825,55 +850,54 @@ def finalize(self): ) ) - for node, weights_by_total_tips in self._spans.items(): + for node, spans_by_total_tips in self._spans.items(): self._spans[node] = {} # Overwrite, so we don't leave the old data around - for num_samples, weights in sorted(weights_by_total_tips.items()): - wt = np.array([(k, v) for k, v in weights.items()], dtype=weight_dtype) - with np.errstate(invalid="ignore"): - # Allow self.node_spans[node]=0 -> nan - wt["weight"] /= self.node_spans[node] + for num_samples, spans in sorted(spans_by_total_tips.items()): + wt = np.array([(k, v) for k, v in spans.items()], dtype=spans_dtype) self._spans[node][num_samples] = wt # Assign into the instance, for further reference - self.normalized_node_span_data = self._spans + self.node_span_data = self._spans self.nodes_to_date = np.array(list(self._spans.keys()), dtype=np.uint64) - def get_weights(self, node): + def get_spans(self, node): """ - Access the main calculated results from this class, returning weights - for a node contained within a dict of dicts. Weights for each node - (i.e. normalized genomic spans) sum to one, and are used to construct + Access the main calculated results from this class, returning spans + for a node contained within a dict of dicts. Spans for each node + are divided into regions with different numbers of sample descendants, + and sum to the total span over which that node is present + in trees along the tree sequence. They are used to construct the mixed conditional coalescent prior. For each coalescent node, the - returned weights are categorised firstly by the total number of sample + returned spans are categorised firstly by the total number of sample nodes (or "tips") ( :math:`T` ) in the tree(s) covered by this node, then by the number of descendant samples, :math:`k`. In other words, - ``weights(u)[T][k]`` gives the fraction of the genome over which node + ``spans(u)[T][k]`` gives the fraction of the genome over which node ``u`` is present in a tree of ``T`` total samples with exactly ``k`` samples descending from the node. Although ``k`` may take any value from 2 up to ``T``, the values are likely to be very sparse, and many values of both ``T`` and ``k`` are likely to be missing from the - returned weights. For example, if there are no trees in which the node + returned spans. For example, if there are no trees in which the node ``u`` has exactly 2 descendant samples, then none of the inner dictionaries returned by this method will have a key of 2. - Non-coalescent (unary) nodes are treated differently. A unary node - returns a 50:50 mix of the coalescent node above and the coalescent - node below it. + Non-coalescent (unary) regions of nodes are treated differently. A unary + region of a node returns a 50:50 mix of the coalescent node above and + the coalescent node below it. - :param int node: The node for which we want weights. + :param int node: The node for which we want spans. :return: A dictionary, whose keys ( :math:`n_t` ) are the total number of samples in the trees in a tree sequence, and whose values are - themselves a dictionary where key :math:`k` gives the weight (genomic - span, normalized by the total span over which the node exists) for - :math:`k` descendant samples, as a floating point number. For any node, - the normalization means that all the weights should sum to one. + themselves a dictionary where key :math:`k` gives the genomic + span for :math:`k` descendant samples, as a floating point number. + For any node ``u``, the normalization means that all the spans should + sum to ``self.node_spans[u]``. :rtype: dict(int, numpy.ndarray)' """ - return self.normalized_node_span_data[node] + return self.node_span_data[node] - def lookup_weight(self, node, total_tips, descendant_tips): + def lookup_span(self, node, total_tips, descendant_tips): # Only used for testing - which = self.get_weights(node)[total_tips]["descendant_tips"] == descendant_tips - return self.get_weights(node)[total_tips]["weight"][which] + which = self.get_spans(node)[total_tips]["descendant_tips"] == descendant_tips + return self.get_spans(node)[total_tips]["span"][which] def create_timepoints(base_priors, n_points=21):