Skip to content

Commit

Permalink
Merge pull request #382 from hyanwong/split-disjoint
Browse files Browse the repository at this point in the history
Save flags and metadata when splitting disjoint nodes
  • Loading branch information
hyanwong authored Jun 4, 2024
2 parents e69e26b + 13d0269 commit a18f33c
Show file tree
Hide file tree
Showing 6 changed files with 490 additions and 144 deletions.
16 changes: 15 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,23 @@

**Bugfixes**

- In variational gamma, Rescale messages at end of each iteration to avoid numerical
- Variational gamma uses a rescaling approach which helps considerably if e.g.
population sizes vary over time

- Variational gamma does not use mutational area of branches, but average path
length, which reduces bias in tree sequences containing polytomies

- In variational gamma, rescale messages at end of each iteration to avoid numerical
instability.

**Breaking changes**

- Variational gamma uses an improper (flat) prior, and therefore
no longer needs `population_size` specifying.

- The standalone `preprocess_ts` function also applies the `split_disjoint_nodes`
method, which creates extra nodes but improves dating accuracy.

## [0.1.6] - 2024-01-07

**Breaking changes**
Expand Down
79 changes: 0 additions & 79 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
from tsdate.prior import SpansBySamples
from tsdate.util import constrain_ages
from tsdate.util import nodes_time_unconstrained
from tsdate.util import split_disjoint_nodes


class TestBasicFunctions:
Expand Down Expand Up @@ -2244,81 +2243,3 @@ def test_bad_arguments(self):
demography.to_natural_timescale(time)
with pytest.raises(ValueError, match="a numpy array"):
demography.to_coalescent_timescale(time)


class TestNodeSplitting:
"""
Test that node splitting routines have the desired outcome
"""

@staticmethod
def has_disjoint_nodes(ts):
"""
Brute force check for disjoint nodes, by pulling out edge intervals for
each node; taking the union of intervals; checking that a single
interval remains
"""

def merge_intervals(intervals):
intervals = sorted(intervals, key=lambda x: x[0])
result = []
(start_candidate, stop_candidate) = intervals[0]
for start, stop in intervals[1:]:
if start <= stop_candidate:
stop_candidate = max(stop, stop_candidate)
else:
result.append((start_candidate, stop_candidate))
(start_candidate, stop_candidate) = (start, stop)
result.append((start_candidate, stop_candidate))
return result

intervals_by_node = {i: [] for i in range(ts.num_nodes)}
for e in ts.edges():
intervals_by_node[e.parent].append([e.left, e.right])
intervals_by_node[e.child].append([e.left, e.right])

for n in range(ts.num_nodes):
intr = merge_intervals(intervals_by_node[n])
if len(intr) != 1:
return True

return False

@staticmethod
def childset_changes_with_root(ts):
"""
If root nodes are split whenever their children change, the next root
should have the same child set if it has the same ID
"""
last_childset = frozenset()
last_root = tskit.NULL
for t in ts.trees():
if t.num_edges == 0:
last_childset = frozenset()
last_root = tskit.NULL
else:
if t.num_roots > 1:
return False
childset = frozenset(list(t.children(t.root)))
if t.root == last_root and childset != last_childset:
return False
last_childset = childset
last_root = t.root
return True

def test_split_disjoint_nodes(self):
ts = msprime.sim_ancestry(
10,
population_size=1e4,
recombination_rate=1e-8,
sequence_length=1e6,
random_seed=1,
)
ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1)
sample_data = tsinfer.SampleData.from_tree_sequence(ts)
inferred_ts = tsinfer.infer(sample_data).simplify()
split_ts = split_disjoint_nodes(inferred_ts)
assert self.has_disjoint_nodes(inferred_ts)
assert not self.has_disjoint_nodes(split_ts)
assert split_ts.num_edges == inferred_ts.num_edges
assert split_ts.num_nodes > inferred_ts.num_nodes
Loading

0 comments on commit a18f33c

Please sign in to comment.