Skip to content

Commit

Permalink
Merge pull request #294 from hyanwong/node-spans-not-weights
Browse files Browse the repository at this point in the history
Expose node spans not weights
  • Loading branch information
hyanwong authored Jul 24, 2023
2 parents 7282a30 + 6e816d2 commit 667a49d
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 145 deletions.
143 changes: 68 additions & 75 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_gamma_approx(self):


class TestNodeTipWeights(unittest.TestCase):
def verify_weights(self, ts):
def verify_spans(self, ts):
span_data = SpansBySamples(ts)
# Check all non-sample nodes in a tree are represented
nonsample_nodes = collections.defaultdict(float)
Expand All @@ -106,150 +106,143 @@ def verify_weights(self, ts):
assert span == pytest.approx(span_data.node_spans[id_])
for focal_node in span_data.nodes_to_date:
wt = 0
for _, weights in span_data.get_weights(focal_node).items():
for _, spans in span_data.get_spans(focal_node).items():
assert 0 <= focal_node < ts.num_nodes
wt += np.sum(weights["weight"])
assert max(weights["descendant_tips"]) <= ts.num_samples
wt += np.sum(spans["span"])
assert max(spans["descendant_tips"]) <= ts.num_samples
if not np.isnan(wt):
# Dangling nodes will have wt=nan
assert wt == pytest.approx(1.0)
assert wt == pytest.approx(span_data.node_spans[focal_node])
return span_data

def test_one_tree_n2(self):
ts = utility_functions.single_tree_ts_n2()
span_data = self.verify_weights(ts)
# with a single tree there should only be one weight
span_data = self.verify_spans(ts)
# with a single tree there should only be one span
for node in span_data.nodes_to_date:
assert len(span_data.get_weights(node)), 1
assert 2 in span_data.get_weights(2)[ts.num_samples]["descendant_tips"]
assert len(span_data.get_spans(node)), 1
assert 2 in span_data.get_spans(2)[ts.num_samples]["descendant_tips"]

def test_one_tree_n3(self):
ts = utility_functions.single_tree_ts_n3()
n = ts.num_samples
span_data = self.verify_weights(ts)
# with a single tree there should only be one weight
span_data = self.verify_spans(ts)
# with a single tree there should only be one span
for node in span_data.nodes_to_date:
assert len(span_data.get_weights(node)), 1
assert len(span_data.get_spans(node)), 1
for nd, expd_tips in [
(4, 3), # Node 4 (root) expected to have 3 descendant tips
(3, 2),
]: # Node 3 (1st internal node) expected to have 2 descendant tips
assert np.isin(span_data.get_weights(nd)[n]["descendant_tips"], expd_tips)
assert np.isin(span_data.get_spans(nd)[n]["descendant_tips"], expd_tips)

def test_one_tree_n4(self):
ts = utility_functions.single_tree_ts_n4()
n = ts.num_samples
span_data = self.verify_weights(ts)
# with a single tree there should only be one weight
span_data = self.verify_spans(ts)
# with a single tree there should only be one span
for node in span_data.nodes_to_date:
assert len(span_data.get_weights(node)), 1
assert len(span_data.get_spans(node)), 1
for nd, expd_tips in [
(6, 4), # Node 6 (root) expected to have 4 descendant tips
(5, 3), # Node 5 (1st internal node) expected to have 3 descendant tips
(4, 2),
]: # Node 4 (2nd internal node) expected to have 3 descendant tips
assert np.isin(span_data.get_weights(nd)[n]["descendant_tips"], expd_tips)
assert np.isin(span_data.get_spans(nd)[n]["descendant_tips"], expd_tips)

def test_two_trees(self):
ts = utility_functions.two_tree_ts()
n = ts.num_samples
span_data = self.verify_weights(ts)
assert span_data.lookup_weight(5, n, 3) == 1.0 # Root on R tree
assert span_data.lookup_weight(4, n, 3) == 0.2 # Root on L tree ...
span_data = self.verify_spans(ts)
assert span_data.lookup_span(5, n, 3) == 0.8 # Root on R tree
assert span_data.lookup_span(4, n, 3) == 0.2 # Root on L tree ...
# ... but internal node on R tree
assert span_data.lookup_weight(4, n, 2) == 0.8
assert span_data.lookup_weight(3, n, 2) == 1.0 # Internal nd on L tree
assert span_data.lookup_span(4, n, 2) == 0.8
assert span_data.lookup_span(3, n, 2) == 0.2 # Internal nd on L tree

def test_missing_tree(self):
ts = utility_functions.two_tree_ts().keep_intervals([(0, 0.2)], simplify=False)
keep = 0.2
ts = utility_functions.two_tree_ts().keep_intervals([(0, keep)], simplify=False)
n = ts.num_samples
# Here we have no reference in the trees to node 5
with pytest.raises(ValueError, match="nodes not in any tree"):
SpansBySamples(ts)
ts = ts.simplify()
span_data = self.verify_weights(ts)
span_data = self.verify_spans(ts)
# Root on (deleted) R tree is missing
assert 5 not in span_data.nodes_to_date
assert span_data.lookup_weight(4, n, 3) == 1.0 # Root on L tree ...
assert span_data.lookup_span(4, n, 3) == keep # Root on L tree ...
# ... but internal on (deleted) R tree
assert not np.isin(span_data.get_weights(4)[n]["descendant_tips"], 2)
assert span_data.lookup_weight(3, n, 2) == 1.0 # Internal nd on L tree
assert not np.isin(span_data.get_spans(4)[n]["descendant_tips"], 2)
assert span_data.lookup_span(3, n, 2) == keep # Internal nd on L tree

def test_tree_with_unary_nodes(self):
ts = utility_functions.single_tree_ts_with_unary()
with pytest.raises(ValueError, match="unary"):
self.verify_weights(ts)
self.verify_spans(ts)

@pytest.mark.skip("Unary node is internal then the oldest node")
def test_tree_with_unary_nodes_oldest(self):
ts = utility_functions.two_tree_ts_with_unary_n3()
n = ts.num_samples
span_data = self.verify_weights(ts)
assert span_data.lookup_weight(9, n, 4) == 0.5
assert span_data.lookup_weight(8, n, 4) == 1.0
assert span_data.lookup_weight(7, n, 1) == 0.5
assert span_data.lookup_weight(7, n, 4) == 0.5
assert span_data.lookup_weight(6, n, 2) == 0.5
assert span_data.lookup_weight(6, n, 4) == 0.5
assert span_data.lookup_weight(5, n, 2) == 0.5
assert span_data.lookup_weight(4, n, 2) == 1.0
span_data = self.verify_spans(ts)
assert span_data.lookup_span(9, n, 4) == 0.5
assert span_data.lookup_span(8, n, 4) == 1.0
assert span_data.lookup_span(7, n, 1) == 0.5
assert span_data.lookup_span(7, n, 4) == 0.5
assert span_data.lookup_span(6, n, 2) == 0.5
assert span_data.lookup_span(6, n, 4) == 0.5
assert span_data.lookup_span(5, n, 2) == 0.5
assert span_data.lookup_span(4, n, 2) == 1.0

def test_polytomy_tree(self):
ts = utility_functions.polytomy_tree_ts()
span_data = self.verify_weights(ts)
assert span_data.lookup_weight(3, ts.num_samples, 3) == 1.0
span_data = self.verify_spans(ts)
assert span_data.lookup_span(3, ts.num_samples, 3) == 1.0

def test_larger_find_node_tip_weights(self):
def test_larger_find_node_tip_spans(self):
ts = msprime.simulate(
10, recombination_rate=5, mutation_rate=5, random_seed=123
)
assert ts.num_trees > 1
self.verify_weights(ts)
self.verify_spans(ts)

def test_dangling_nodes_error(self):
ts = utility_functions.single_tree_ts_n2_dangling()
with pytest.raises(ValueError, match="dangling"):
self.verify_weights(ts)
self.verify_spans(ts)

def test_single_tree_n2_delete_intervals(self):
ts = utility_functions.single_tree_ts_n2()
deleted_interval_ts = ts.delete_intervals([[0.5, 0.6]])
delete_span = 0.1
deleted_interval_ts = ts.delete_intervals([[0.4, 0.4 + delete_span]])
n = deleted_interval_ts.num_samples
span_data = self.verify_weights(ts)
span_data_deleted = self.verify_weights(deleted_interval_ts)
assert span_data.lookup_weight(2, n, 2) == span_data_deleted.lookup_weight(
2, n, 2
)
span_data_deleted = self.verify_spans(deleted_interval_ts)
assert span_data_deleted.lookup_span(2, n, 2) == pytest.approx(1 - delete_span)

def test_single_tree_n4_delete_intervals(self):
ts = utility_functions.single_tree_ts_n4()
deleted_interval_ts = ts.delete_intervals([[0.5, 0.6]])
delete_span = 0.1
deleted_interval_ts = ts.delete_intervals([[0.5, 0.5 + delete_span]])
n = deleted_interval_ts.num_samples
span_data = self.verify_weights(ts)
span_data_deleted = self.verify_weights(deleted_interval_ts)
assert span_data.lookup_weight(4, n, 2) == span_data_deleted.lookup_weight(
4, n, 2
)
assert span_data.lookup_weight(5, n, 3) == span_data_deleted.lookup_weight(
5, n, 3
)
assert span_data.lookup_weight(6, n, 4) == span_data_deleted.lookup_weight(
6, n, 4
)
span_data_deleted = self.verify_spans(deleted_interval_ts)
assert span_data_deleted.lookup_span(4, n, 2) == pytest.approx(1 - delete_span)
assert span_data_deleted.lookup_span(5, n, 3) == pytest.approx(1 - delete_span)
assert span_data_deleted.lookup_span(6, n, 4) == pytest.approx(1 - delete_span)

def test_two_tree_ts_delete_intervals(self):
ts = utility_functions.two_tree_ts()
deleted_interval_ts = ts.delete_intervals([[0.5, 0.6]])
delete_span = 0.1
deleted_interval_ts = ts.delete_intervals([[0.6, 0.6 + delete_span]])
n = deleted_interval_ts.num_samples
span_data = self.verify_weights(ts)
span_data_deleted = self.verify_weights(deleted_interval_ts)
assert span_data.lookup_weight(3, n, 2) == span_data_deleted.lookup_weight(
3, n, 2
)
assert span_data_deleted.lookup_weight(4, n, 2)[0] == pytest.approx(0.7 / 0.9)
assert span_data_deleted.lookup_weight(4, n, 3)[0] == pytest.approx(0.2 / 0.9)
assert span_data.lookup_weight(5, n, 3) == span_data_deleted.lookup_weight(
3, n, 2
span_data = self.verify_spans(ts)
span_data_deleted = self.verify_spans(deleted_interval_ts)
assert span_data.lookup_span(3, n, 2) == span_data_deleted.lookup_span(3, n, 2)
assert span_data_deleted.lookup_span(4, n, 2)[0] == pytest.approx(0.7)
assert span_data_deleted.lookup_span(4, n, 3)[0] == pytest.approx(0.2)
assert (
span_data.lookup_span(5, n, 3) / span_data.node_spans[5]
== span_data_deleted.lookup_span(3, n, 2) / span_data.node_spans[3]
)

@pytest.mark.skip("YAN to fix")
Expand All @@ -261,7 +254,7 @@ def test_truncated_nodes(self):
truncated_ts = utility_functions.truncate_ts_samples(
ts, average_span=200, random_seed=123
)
span_data = self.verify_weights(truncated_ts)
span_data = self.verify_spans(truncated_ts)
raise NotImplementedError(str(span_data))


Expand Down Expand Up @@ -1658,7 +1651,7 @@ def test_node_metadata_simulated_tree(self):
for met in tskit.unpack_bytes(metadata, metadata_offset)
if len(met.decode()) > 0
]
assert np.array_equal(unconstrained_mn, mn_post[larger_ts.num_samples :])
assert np.allclose(unconstrained_mn, mn_post[larger_ts.num_samples :])
assert np.all(
dated_ts.tables.nodes.time[larger_ts.num_samples :]
>= mn_post[larger_ts.num_samples :]
Expand Down Expand Up @@ -1959,7 +1952,7 @@ def test_sites_time_root_mutation(self):
def test_sites_time_multiple_mutations(self):
ts = utility_functions.single_tree_ts_2mutations_n3()
sites_time = tsdate.sites_time_from_ts(ts, unconstrained=False)
assert np.array_equal(sites_time, [10])
assert np.allclose(sites_time, [10])

def test_sites_time_simulated(self):
larger_ts = msprime.simulate(
Expand All @@ -1969,11 +1962,11 @@ def test_sites_time_simulated(self):
larger_ts, mutation_rate=None, population_size=10000
)
dated = date(larger_ts, mutation_rate=None, population_size=10000)
assert np.array_equal(
assert np.allclose(
mn_post[larger_ts.tables.mutations.node],
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
)
assert np.array_equal(
assert np.allclose(
dated.tables.nodes.time[larger_ts.tables.mutations.node],
tsdate.sites_time_from_ts(dated, unconstrained=False, min_time=0),
)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"""
import logging

import numpy as np
import pytest
import utility_functions

Expand Down Expand Up @@ -62,6 +63,43 @@ def test_clear_precalc_debug(self, caplog):
priors.clear_precalculated_priors()
assert "not yet created" in caplog.text

@pytest.mark.parametrize("logwt", [True, False])
def test_mixture_expect_and_var(self, logwt):
priors = ConditionalCoalescentTimes(None)
priors.add(3)
params = {3: {"descendant_tips": [3, 2], "span": np.array([0, 200])}}
mean1, var1 = priors.mixture_expect_and_var(params, weight_by_log_span=logwt)
params = {3: {"descendant_tips": [2], "span": np.array([100])}}
mean2, var2 = priors.mixture_expect_and_var(params, weight_by_log_span=logwt)
assert mean1 == pytest.approx(1 / 3) # 1/N for a cherry
assert var1 == pytest.approx(1 / 9)
assert mean1 == mean2
assert var1 == var2

def test_mixture_expect_and_var_weight(self):
priors = ConditionalCoalescentTimes(None)
priors.add(4)
priors.add(5)
span = np.array([1, 3])
params = {
4: {"descendant_tips": [2], "span": span[0]},
5: {"descendant_tips": [2], "span": span[1]},
}
linwt = priors.mixture_expect_and_var(params, weight_by_log_span=False)
assert linwt[0] == pytest.approx(
(1 / 4 * span[0] + 1 / 5 * span[1]) / np.sum(span)
)

# use exponential version to test log weights
# The log weighting adds one to the value, so here we subtract one
exp_span = np.exp(span) - 1
params = {
4: {"descendant_tips": [2], "span": exp_span[0]},
5: {"descendant_tips": [2], "span": exp_span[1]},
}
logwt = priors.mixture_expect_and_var(params, weight_by_log_span=True)
assert np.allclose(linwt, logwt)


class TestSpansBySamples:
def test_repr(self):
Expand Down
Loading

0 comments on commit 667a49d

Please sign in to comment.