Skip to content

Commit

Permalink
Rename normalisation
Browse files Browse the repository at this point in the history
  • Loading branch information
nspope committed May 10, 2024
1 parent 47a4048 commit a9e1f83
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 99 deletions.
6 changes: 3 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_verbosity(self, tmp_path, caplog, flag, log_status):
)
def test_no_progress(self, method, tmp_path, capfd):
input_ts = msprime.simulate(4, random_seed=123)
params = f"-m 0.1 --method {method} --normalisation-intervals 0"
params = f"-m 0.1 --method {method} --rescaling-intervals 0"
self.run_tsdate_cli(tmp_path, input_ts, f"{self.popsize} {params}")
(out, err) = capfd.readouterr()
assert out == ""
Expand All @@ -257,7 +257,7 @@ def test_no_progress(self, method, tmp_path, capfd):

def test_progress(self, tmp_path, capfd):
input_ts = msprime.simulate(4, random_seed=123)
params = "--method inside_outside --progress --normalisation-intervals 0"
params = "--method inside_outside --progress --rescaling-intervals 0"
self.run_tsdate_cli(tmp_path, input_ts, f"{self.popsize} {params}")
(out, err) = capfd.readouterr()
assert out == ""
Expand All @@ -277,7 +277,7 @@ def test_progress(self, tmp_path, capfd):
def test_iterative_progress(self, tmp_path, capfd):
input_ts = msprime.simulate(4, random_seed=123)
params = "--method variational_gamma --mutation-rate 1e-8 "
params += "--progress --normalisation-intervals 0"
params += "--progress --rescaling-intervals 0"
self.run_tsdate_cli(tmp_path, input_ts, f"{self.popsize} {params}")
(out, err) = capfd.readouterr()
assert out == ""
Expand Down
1 change: 0 additions & 1 deletion tsdate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from .core import inside_outside # NOQA: F401
from .core import maximization # NOQA: F401
from .core import variational_gamma # NOQA: F401
from .normalisation import normalise_tree_sequence as normalise # NOQA: F401
from .prior import parameter_grid as build_parameter_grid # NOQA: F401
from .prior import prior_grid as build_prior_grid # NOQA: F401
from .provenance import __version__ # NOQA: F401
Expand Down
4 changes: 2 additions & 2 deletions tsdate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def tsdate_cli_parser():
default=1000,
)
parser.add_argument(
"--normalisation-intervals",
"--rescaling-intervals",
type=float,
help=(
"The number of time intervals within which to estimate a time "
Expand Down Expand Up @@ -265,7 +265,7 @@ def run_date(args):
progress=args.progress,
max_iterations=args.max_iterations,
max_shape=args.max_shape,
normalisation_intervals=args.normalisation_intervals,
rescaling_intervals=args.rescaling_intervals,
)
else:
params = dict(
Expand Down
16 changes: 8 additions & 8 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ def run(
max_iterations,
max_shape,
match_central_moments,
normalisation_intervals,
rescaling_intervals,
match_segregating_sites,
regularise_roots,
):
Expand All @@ -1251,9 +1251,9 @@ def run(
ep_maxitt=max_iterations,
max_shape=max_shape,
min_kl=min_kl,
norm_intervals=normalisation_intervals,
rescale_intervals=rescaling_intervals,
regularise=regularise_roots,
norm_segsites=match_segregating_sites,
rescale_segsites=match_segregating_sites,
progress=self.pbar,
)

Expand Down Expand Up @@ -1476,7 +1476,7 @@ def variational_gamma(
eps=None,
max_iterations=None,
max_shape=None,
normalisation_intervals=None,
rescaling_intervals=None,
match_central_moments=None, # undocumented
match_segregating_sites=None, # undocumented
regularise_roots=None, # undocumented
Expand Down Expand Up @@ -1505,7 +1505,7 @@ def variational_gamma(
:param float max_shape: The maximum value for the shape parameter in the variational
posteriors. This is equivalent to the maximum precision (inverse variance) on a
logarithmic scale. Default: None, treated as 1000.
:param float normalisation_intervals: For normalisation, the number of time
:param float rescaling_intervals: For time rescaling, the number of time
intervals within which to estimate a rescaling parameter. Default None,
treated as 1000.
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
Expand Down Expand Up @@ -1537,8 +1537,8 @@ def variational_gamma(
max_iterations = 10
if max_shape is None:
max_shape = 1000
if normalisation_intervals is None:
normalisation_intervals = 1000
if rescaling_intervals is None:
rescaling_intervals = 1000
if match_central_moments is None:
match_central_moments = True
if match_segregating_sites is None:
Expand All @@ -1552,7 +1552,7 @@ def variational_gamma(
max_iterations=max_iterations,
max_shape=max_shape,
match_central_moments=match_central_moments,
normalisation_intervals=normalisation_intervals,
rescaling_intervals=rescaling_intervals,
match_segregating_sites=match_segregating_sites,
regularise_roots=regularise_roots,
)
Expand Down
141 changes: 71 additions & 70 deletions tsdate/normalisation.py → tsdate/rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from .approx import _i1w
from .approx import approximate_gamma_iqr
from .hypergeo import _gammainc_inv as gammainc_inv
from .util import mutation_span_array
from .util import mutation_span_array # NOQA: F401


@numba.njit(_i1w(_f1r, _i))
Expand Down Expand Up @@ -344,72 +344,73 @@ def edge_sampling_weight(
return edges_leaves


def normalise_tree_sequence(
ts, mutation_rate, *, normalisation_intervals=1000, match_segregating_sites=False
):
"""
Adjust the time scaling of a tree sequence so that expected mutational area
matches the expected number of mutations on a path from leaf to root, where
the expectation is taken over all paths and bases in the sequence.
:param tskit.TreeSequence ts: the tree sequence to normalise
:param float mutation_rate: the per-base mutation rate
:param int normalisation_intervals: the number of time intervals for which
to estimate a separate time rescaling parameter
:param bool match_segregating_sites: if True, match the total number of
mutations rather than the average number of differences from the ancestral
state
"""
if match_segregating_sites:
edge_weights = np.ones(ts.num_edges)
else:
has_parent = np.full(ts.num_nodes, False)
has_child = np.full(ts.num_nodes, False)
has_parent[ts.edges_child] = True
has_child[ts.edges_parent] = True
is_leaf = np.logical_and(~has_child, has_parent)
edge_weights = edge_sampling_weight(
is_leaf,
ts.edges_parent,
ts.edges_child,
ts.edges_left,
ts.edges_right,
ts.indexes_edge_insertion_order,
ts.indexes_edge_removal_order,
)
# estimate time rescaling parameter within intervals
samples = list(ts.samples())
if not np.all(ts.nodes_time[samples] == 0.0):
raise ValueError("Normalisation not implemented for ancient samples")
constraints = np.zeros((ts.num_nodes, 2))
constraints[:, 1] = np.inf
constraints[samples, :] = ts.nodes_time[samples, np.newaxis]
mutations_span, mutations_edge = mutation_span_array(ts)
mutations_span[:, 1] *= mutation_rate
original_breaks, rescaled_breaks = mutational_timescale(
ts.nodes_time,
mutations_span,
constraints,
ts.edges_parent,
ts.edges_child,
edge_weights,
normalisation_intervals,
)
# rescale node time
assert np.all(np.diff(rescaled_breaks) > 0)
assert np.all(np.diff(original_breaks) > 0)
scalings = np.append(np.diff(rescaled_breaks) / np.diff(original_breaks), 0)
idx = np.searchsorted(original_breaks, ts.nodes_time, "right") - 1
nodes_time = rescaled_breaks[idx] + scalings[idx] * (
ts.nodes_time - original_breaks[idx]
)
# calculate mutation time
mutations_parent = ts.edges_parent[mutations_edge]
mutations_child = ts.edges_child[mutations_edge]
mutations_time = (nodes_time[mutations_parent] + nodes_time[mutations_child]) / 2
above_root = mutations_edge == tskit.NULL
mutations_time[above_root] = nodes_time[mutations_child[above_root]]
tables = ts.dump_tables()
tables.nodes.time = nodes_time
tables.mutations.time = mutations_time
return tables.tree_sequence()
# TODO: standalone API for rescaling
# def rescale_tree_sequence(
# ts, mutation_rate, *, rescaling_intervals=1000, match_segregating_sites=False
# ):
# """
# Adjust the time scaling of a tree sequence so that expected mutational area
# matches the expected number of mutations on a path from leaf to root, where
# the expectation is taken over all paths and bases in the sequence.
#
# :param tskit.TreeSequence ts: the tree sequence to rescale
# :param float mutation_rate: the per-base mutation rate
# :param int rescaling_intervals: the number of time intervals for which
# to estimate a separate time rescaling parameter
# :param bool match_segregating_sites: if True, match the total number of
# mutations rather than the average number of differences from the ancestral
# state
# """
# if match_segregating_sites:
# edge_weights = np.ones(ts.num_edges)
# else:
# has_parent = np.full(ts.num_nodes, False)
# has_child = np.full(ts.num_nodes, False)
# has_parent[ts.edges_child] = True
# has_child[ts.edges_parent] = True
# is_leaf = np.logical_and(~has_child, has_parent)
# edge_weights = edge_sampling_weight(
# is_leaf,
# ts.edges_parent,
# ts.edges_child,
# ts.edges_left,
# ts.edges_right,
# ts.indexes_edge_insertion_order,
# ts.indexes_edge_removal_order,
# )
# # estimate time rescaling parameter within intervals
# samples = list(ts.samples())
# if not np.all(ts.nodes_time[samples] == 0.0):
# raise ValueError("Normalisation not implemented for ancient samples")
# constraints = np.zeros((ts.num_nodes, 2))
# constraints[:, 1] = np.inf
# constraints[samples, :] = ts.nodes_time[samples, np.newaxis]
# mutations_span, mutations_edge = mutation_span_array(ts)
# mutations_span[:, 1] *= mutation_rate
# original_breaks, rescaled_breaks = mutational_timescale(
# ts.nodes_time,
# mutations_span,
# constraints,
# ts.edges_parent,
# ts.edges_child,
# edge_weights,
# rescaling_intervals,
# )
# # rescale node time
# assert np.all(np.diff(rescaled_breaks) > 0)
# assert np.all(np.diff(original_breaks) > 0)
# scalings = np.append(np.diff(rescaled_breaks) / np.diff(original_breaks), 0)
# idx = np.searchsorted(original_breaks, ts.nodes_time, "right") - 1
# nodes_time = rescaled_breaks[idx] + scalings[idx] * (
# ts.nodes_time - original_breaks[idx]
# )
# # calculate mutation time
# mutations_parent = ts.edges_parent[mutations_edge]
# mutations_child = ts.edges_child[mutations_edge]
# mutations_time = (nodes_time[mutations_parent] + nodes_time[mutations_child]) / 2
# above_root = mutations_edge == tskit.NULL
# mutations_time[above_root] = nodes_time[mutations_child[above_root]]
# tables = ts.dump_tables()
# tables.nodes.time = nodes_time
# tables.mutations.time = mutations_time
# return tables.tree_sequence()
32 changes: 17 additions & 15 deletions tsdate/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@
from .approx import _i
from .approx import _i1r
from .hypergeo import _gammainc_inv as gammainc_inv
from .normalisation import edge_sampling_weight
from .normalisation import mutational_timescale
from .normalisation import piecewise_scale_posterior
from .rescaling import edge_sampling_weight
from .rescaling import mutational_timescale
from .rescaling import piecewise_scale_posterior


# columns for edge_factors
Expand Down Expand Up @@ -606,17 +606,17 @@ def iterate(

return np.nan # TODO: placeholder for marginal likelihood

def normalise(
def rescale(
self,
*,
norm_intervals=1000,
norm_segsites=False,
rescale_intervals=1000,
rescale_segsites=False,
use_median=False,
quantile_width=0.5,
):
"""Normalise posteriors so that empirical mutation rate is constant"""
edge_weights = (
np.ones(self.edge_weights.size) if norm_segsites else self.edge_weights
np.ones(self.edge_weights.size) if rescale_segsites else self.edge_weights
)
nodes_time = self._point_estimate(self.posterior, self.constraints, use_median)
original_breaks, rescaled_breaks = mutational_timescale(
Expand All @@ -626,7 +626,7 @@ def normalise(
self.parents,
self.children,
edge_weights,
norm_intervals,
rescale_intervals,
)
self.posterior[:] = piecewise_scale_posterior(
self.posterior,
Expand All @@ -650,8 +650,8 @@ def run(
max_shape=1000,
min_step=0.1,
min_kl=False,
norm_intervals=1000,
norm_segsites=False,
rescale_intervals=1000,
rescale_segsites=False,
regularise=True,
progress=None,
):
Expand Down Expand Up @@ -691,8 +691,10 @@ def run(
logging.info(f"Skipped {skipped_muts} mutations with invalid posteriors")
logging.info(f"Calculated mutation posteriors in {abs(muts_timing)} seconds")

if norm_intervals > 0:
norm_timing = time.time()
self.normalise(norm_intervals=norm_intervals, norm_segsites=norm_segsites)
norm_timing -= time.time()
logging.info(f"Timescale normalised in {abs(norm_timing)} seconds")
if rescale_intervals > 0:
rescale_timing = time.time()
self.rescale(
rescale_intervals=rescale_intervals, rescale_segsites=rescale_segsites
)
rescale_timing -= time.time()
logging.info(f"Timescale rescaled in {abs(rescale_timing)} seconds")

0 comments on commit a9e1f83

Please sign in to comment.