From bac9a135b0b6e583ff63dd983d45fc13081d1b09 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 6 Jun 2024 14:59:30 +0100 Subject: [PATCH] Switch to ruff linting --- .pre-commit-config.yaml | 25 +--- tests/distribution_functions.py | 30 ++--- tests/test_accuracy.py | 17 ++- tests/test_approximations.py | 22 ++-- tests/test_cache.py | 1 + tests/test_cli.py | 21 ++-- tests/test_evaluation.py | 1 + tests/test_functions.py | 195 +++++++++++--------------------- tests/test_hypergeo.py | 1 + tests/test_inference.py | 21 ++-- tests/test_priors.py | 13 ++- tests/test_provenance.py | 5 +- tests/test_util.py | 5 +- tests/utility_functions.py | 8 +- tsdate/__init__.py | 22 ++-- tsdate/approx.py | 16 +-- tsdate/base.py | 19 ++-- tsdate/cache.py | 2 +- tsdate/cli.py | 2 + tsdate/core.py | 42 ++----- tsdate/demography.py | 1 + tsdate/evaluation.py | 20 ++-- tsdate/hypergeo.py | 14 +-- tsdate/prior.py | 73 +++++------- tsdate/provenance.py | 1 + tsdate/rescaling.py | 45 ++++---- tsdate/schemas.py | 1 + tsdate/util.py | 29 ++--- tsdate/variational.py | 26 ++--- 29 files changed, 272 insertions(+), 406 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 36b69046..2dbdc5af 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,23 +7,10 @@ repos: - id: mixed-line-ending - id: check-case-conflict - id: check-yaml - - repo: https://github.com/asottile/reorder_python_imports - rev: v3.10.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.5 hooks: - - id: reorder-python-imports - - repo: https://github.com/asottile/pyupgrade - rev: v3.10.1 - hooks: - - id: pyupgrade - args: [--py3-plus, --py38-plus] - - repo: https://github.com/psf/black - rev: 23.7.0 - hooks: - - id: black - language_version: python3 - - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 - hooks: - - id: flake8 - args: [--config=.flake8] - additional_dependencies: ["flake8-bugbear==23.7.10", "flake8-builtins==2.1.0"] + - id: ruff + args: [ "--fix", "--config", "ruff.toml" ] + - id: ruff-format + args: [ "--config", "ruff.toml" ] \ No newline at end of file diff --git a/tests/distribution_functions.py b/tests/distribution_functions.py index 7c02784d..6eb2d567 100644 --- a/tests/distribution_functions.py +++ b/tests/distribution_functions.py @@ -24,13 +24,13 @@ Utility functions to construct distributions used in variational inference, for testing purposes """ + import mpmath import numpy as np import scipy.integrate import scipy.special -from tsdate import approx -from tsdate import hypergeo +from tsdate import approx, hypergeo def kl_divergence(p, logq): @@ -81,9 +81,7 @@ def pr_a(a, n, k): if n == k: return pr_t_bar_a(t, 1) else: - return np.sum( - [pr_a(a, n, k) * pr_t_bar_a(t, a, n) for a in range(2, n - k + 2)] - ) + return np.sum([pr_a(a, n, k) * pr_t_bar_a(t, a, n) for a in range(2, n - k + 2)]) class TiltedGammaDiff: @@ -114,8 +112,12 @@ def _U(a, b, z): return float(val) def __init__(self, shape1, shape2, shape3, rate1, rate2, rate3, reorder=True): - assert shape1 > 0 and shape2 > 0 and shape3 > 0 - assert rate1 >= 0 and rate2 > 0 and rate3 >= 0 + assert shape1 > 0 + assert shape2 > 0 + assert shape3 > 0 + assert rate1 >= 0 + assert rate2 > 0 + assert rate3 >= 0 # for convergence of 2F1, we need rate2 > rate3. Invariant # transformations of 2F1 allow us to switch arguments, with # appropriate rescaling @@ -369,8 +371,12 @@ def _M(a, b, x): return float(val) def __init__(self, shape1, shape2, shape3, rate1, rate2, rate3): - assert shape1 > 0 and shape2 > 0 and shape3 > 0 - assert rate1 >= 0 and rate2 > 0 and rate3 >= 0 + assert shape1 > 0 + assert shape2 > 0 + assert shape3 > 0 + assert rate1 >= 0 + assert rate2 > 0 + assert rate3 >= 0 # for numeric stability of hypergeometric we need rate2 > rate1 # as this is a convolution, the order of (1) and (2) don't matter self.reparametrize = rate1 > rate2 @@ -481,11 +487,7 @@ def sufficient_statistics(self): + scipy.special.betaln(self.shape1, self.shape2) ) x = dF_dz * T / S**2 + B / S - xsq = ( - d2F_dz2 * T**2 / S**4 - + B * (B + 1) / S**2 - + 2 * dF_dz * (1 + B) * T / S**3 - ) + xsq = d2F_dz2 * T**2 / S**4 + B * (B + 1) / S**2 + 2 * dF_dz * (1 + B) * T / S**3 logx = dF_db + scipy.special.digamma(B) - np.log(S) return logconst, x, xsq, logx diff --git a/tests/test_accuracy.py b/tests/test_accuracy.py index b4888643..8e06ad3d 100644 --- a/tests/test_accuracy.py +++ b/tests/test_accuracy.py @@ -23,6 +23,7 @@ """ Test cases for tsdate accuracy. """ + import json import os @@ -40,7 +41,7 @@ class TestAccuracy: Test for some of the basic functions used in tsdate """ - @pytest.mark.makefiles + @pytest.mark.makefiles() def test_make_static_files(self, request): """ The function used to create the tree sequences for accuracy testing. @@ -75,7 +76,13 @@ def test_make_static_files(self, request): ts.dump(os.path.join(request.fspath.dirname, "data", f"{name}.trees")) @pytest.mark.parametrize( - "ts_name,min_r2_ts,min_r2_unconstrained,min_spear_ts,min_spear_unconstrained", + ( + "ts_name", + "min_r2_ts", + "min_r2_unconstrained", + "min_spear_ts", + "min_spear_unconstrained", + ), [ ("one_tree", 0.98601, 0.98601, 0.97719, 0.97719), ("few_trees", 0.98220, 0.98220, 0.97744, 0.97744), @@ -91,9 +98,7 @@ def test_basic( min_spear_unconstrained, request, ): - ts = tskit.load( - os.path.join(request.fspath.dirname, "data", ts_name + ".trees") - ) + ts = tskit.load(os.path.join(request.fspath.dirname, "data", ts_name + ".trees")) sim_ancestry_parameters = json.loads(ts.provenance(0).record)["parameters"] assert sim_ancestry_parameters["command"] == "sim_ancestry" @@ -144,7 +149,7 @@ def test_scaling(self, Ne): assert 0.9 < dts.node(dts.first().root).time / (2 * Ne) < 1.1 @pytest.mark.parametrize( - "bkwd_rate, trio_tmrca", + ("bkwd_rate", "trio_tmrca"), [ # calculated from simulations (-1.0, 0.76), (-0.9, 0.79), diff --git a/tests/test_approximations.py b/tests/test_approximations.py index fb41f843..dac8d508 100644 --- a/tests/test_approximations.py +++ b/tests/test_approximations.py @@ -24,17 +24,15 @@ """ Test cases for the gamma-variational approximations in tsdate """ + import numpy as np import pytest import scipy.integrate import scipy.special import scipy.stats -from distribution_functions import conditional_coalescent_pdf -from distribution_functions import kl_divergence +from distribution_functions import conditional_coalescent_pdf, kl_divergence -from tsdate import approx -from tsdate import hypergeo -from tsdate import prior +from tsdate import approx, hypergeo, prior # TODO: better test set? # TODO: test special case where child is fixed to age 0 @@ -273,9 +271,7 @@ def test_truncated_moments(self, pars): assert np.isclose(t_i, ck_t_i, rtol=1e-4) ck_var_t_i = ( scipy.integrate.quad( - lambda t_i: t_i**2 - * self.pdf_truncated(t_i, *pars_redux) - / ck_normconst, + lambda t_i: t_i**2 * self.pdf_truncated(t_i, *pars_redux) / ck_normconst, low, upp, epsabs=0, @@ -389,9 +385,7 @@ def test_average_gammas(self): E_x = np.mean(shape + 1) E_logx = np.mean(scipy.special.digamma(shape + 1)) assert np.isclose(E_x, (avg_shape + 1) / avg_rate) - assert np.isclose( - E_logx, scipy.special.digamma(avg_shape + 1) - np.log(avg_rate) - ) + assert np.isclose(E_logx, scipy.special.digamma(avg_shape + 1) - np.log(avg_rate)) class TestKLMinimizationFailed: @@ -409,10 +403,12 @@ def test_asymptotic_bound(self): alpha, _ = approx.approximate_gamma_kl(1, logx) alpha += 1 alpha_bound = -0.5 / logx - assert alpha == alpha_bound and alpha > 1e4 + assert alpha == alpha_bound + assert alpha > 1e4 # check that bound matches optimization result just under threshold logx = -0.000051 alpha, _ = approx.approximate_gamma_kl(1, logx) alpha += 1 alpha_bound = -0.5 / logx - assert np.abs(alpha - alpha_bound) < 1 and alpha < 1e4 + assert np.abs(alpha - alpha_bound) < 1 + assert alpha < 1e4 diff --git a/tests/test_cache.py b/tests/test_cache.py index ed6e28c3..45766902 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,6 +1,7 @@ """ Tests for the cache management code. """ + import os import pathlib import unittest diff --git a/tests/test_cli.py b/tests/test_cli.py index 16827409..2d1d5d5e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -23,6 +23,7 @@ """ Test cases for the command line interface for tsdate. """ + import json import logging from unittest import mock @@ -75,11 +76,11 @@ def test_recombination_rate(self): parser = cli.tsdate_cli_parser() params = ["-m", "1e10"] args = parser.parse_args( - ["date", self.infile, self.output] + params + ["-r", "1e-100"] + ["date", self.infile, self.output, *params, "-r", "1e-100"] ) assert args.recombination_rate == 1e-100 args = parser.parse_args( - ["date", self.infile, self.output] + params + ["--recombination-rate", "73"] + ["date", self.infile, self.output, *params, "--recombination-rate", "73"] ) assert args.recombination_rate == 73 @@ -97,24 +98,22 @@ def test_epsilon(self): def test_num_threads(self): parser = cli.tsdate_cli_parser() params = ["--method", "maximization", "--num-threads"] - args = parser.parse_args(["date", self.infile, self.output] + params + ["1"]) + args = parser.parse_args(["date", self.infile, self.output, *params, "1"]) assert args.num_threads == 1 - args = parser.parse_args(["date", self.infile, self.output] + params + ["2"]) + args = parser.parse_args(["date", self.infile, self.output, *params, "2"]) assert args.num_threads == 2 def test_probability_space(self): parser = cli.tsdate_cli_parser() params = ["--method", "inside_outside", "--probability-space"] - args = parser.parse_args( - ["date", self.infile, self.output] + params + ["linear"] - ) + args = parser.parse_args(["date", self.infile, self.output, *params, "linear"]) assert args.probability_space == "linear" args = parser.parse_args( - ["date", self.infile, self.output] + params + ["logarithmic"] + ["date", self.infile, self.output, *params, "logarithmic"] ) assert args.probability_space == "logarithmic" - @pytest.mark.parametrize("flag, log_status", logging_flags.items()) + @pytest.mark.parametrize(("flag", "log_status"), logging_flags.items()) def test_verbosity(self, flag, log_status): parser = cli.tsdate_cli_parser() args = parser.parse_args(["preprocess", self.infile, self.output, flag]) @@ -130,7 +129,7 @@ def test_method(self, method): params = ["-m", "1e-8", "--method", method] if method != "variational_gamma": params += ["-n", "10"] - args = parser.parse_args(["date", self.infile, self.output] + params) + args = parser.parse_args(["date", self.infile, self.output, *params]) assert args.method == method def test_progress(self): @@ -231,7 +230,7 @@ def test_no_output_variational_gamma(self, tmp_path, capfd): assert out == "" assert err == "" - @pytest.mark.parametrize("flag, log_status", logging_flags.items()) + @pytest.mark.parametrize(("flag", "log_status"), logging_flags.items()) def test_verbosity(self, tmp_path, caplog, flag, log_status): popsize = 10000 ts = msprime.simulate( diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 10f7d9a0..857c0717 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -24,6 +24,7 @@ """ Test tools for mapping between node sets of different tree sequences """ + from collections import defaultdict from itertools import combinations diff --git a/tests/test_functions.py b/tests/test_functions.py index 338b737b..e68bddca 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -24,6 +24,7 @@ """ Test cases for the python API for tsdate. """ + import collections import logging import unittest @@ -38,20 +39,23 @@ import tsdate from tsdate import base -from tsdate.core import DiscreteTimeMethod -from tsdate.core import InOutAlgorithms -from tsdate.core import InsideOutsideMethod -from tsdate.core import Likelihoods -from tsdate.core import LogLikelihoods +from tsdate.core import ( + DiscreteTimeMethod, + InOutAlgorithms, + InsideOutsideMethod, + Likelihoods, + LogLikelihoods, +) from tsdate.demography import PopulationSizeHistory -from tsdate.prior import ConditionalCoalescentTimes -from tsdate.prior import fill_priors -from tsdate.prior import gamma_approx -from tsdate.prior import MixturePrior -from tsdate.prior import PriorParams -from tsdate.prior import SpansBySamples -from tsdate.util import constrain_ages -from tsdate.util import nodes_time_unconstrained +from tsdate.prior import ( + ConditionalCoalescentTimes, + MixturePrior, + PriorParams, + SpansBySamples, + fill_priors, + gamma_approx, +) +from tsdate.util import constrain_ages, nodes_time_unconstrained class TestBasicFunctions: @@ -183,9 +187,7 @@ def test_polytomy_tree(self): assert span_data.lookup_span(3, ts.num_samples, 3) == 1.0 def test_larger_find_node_tip_spans(self): - ts = msprime.simulate( - 10, recombination_rate=5, mutation_rate=5, random_seed=123 - ) + ts = msprime.simulate(10, recombination_rate=5, mutation_rate=5, random_seed=123) assert ts.num_trees > 1 self.verify_spans(ts) @@ -253,9 +255,7 @@ def verify_priors(self, ts, prior_distr): def test_one_tree_n2(self): ts = utility_functions.single_tree_ts_n2() priors = self.verify_priors(ts, "gamma") - assert np.allclose( - priors[2], PriorParams(alpha=1.0, beta=1.0, mean=1.0, var=1.0) - ) + assert np.allclose(priors[2], PriorParams(alpha=1.0, beta=1.0, mean=1.0, var=1.0)) priors = self.verify_priors(ts, "lognorm") assert np.allclose( priors[2], @@ -285,26 +285,14 @@ def test_one_tree_n4(self): prior4mv = {"mean": np.nan, "var": np.nan} priors = self.verify_priors(ts, "lognorm") - assert np.allclose( - priors[2], PriorParams(alpha=np.nan, beta=np.nan, **prior2mv) - ) - assert np.allclose( - priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv) - ) - assert np.allclose( - priors[4], PriorParams(alpha=np.nan, beta=np.nan, **prior4mv) - ) + assert np.allclose(priors[2], PriorParams(alpha=np.nan, beta=np.nan, **prior2mv)) + assert np.allclose(priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv)) + assert np.allclose(priors[4], PriorParams(alpha=np.nan, beta=np.nan, **prior4mv)) priors = self.verify_priors(ts, "gamma") - assert np.allclose( - priors[2], PriorParams(alpha=np.nan, beta=np.nan, **prior2mv) - ) - assert np.allclose( - priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv) - ) - assert np.allclose( - priors[4], PriorParams(alpha=np.nan, beta=np.nan, **prior4mv) - ) + assert np.allclose(priors[2], PriorParams(alpha=np.nan, beta=np.nan, **prior2mv)) + assert np.allclose(priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv)) + assert np.allclose(priors[4], PriorParams(alpha=np.nan, beta=np.nan, **prior4mv)) @pytest.mark.skip("Fill in values instead of np.nan") def test_polytomy_tree(self): @@ -312,14 +300,10 @@ def test_polytomy_tree(self): prior3mv = {"mean": np.nan, "var": np.nan} priors = self.verify_priors(ts, "lognorm") - assert np.allclose( - priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv) - ) + assert np.allclose(priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv)) priors = self.verify_prior(ts, "gamma") - assert np.allclose( - priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv) - ) + assert np.allclose(priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv)) @pytest.mark.skip("Fill in values instead of np.nan") def test_two_tree_ts(self): @@ -328,20 +312,12 @@ def test_two_tree_ts(self): prior3mv = {"mean": np.nan, "var": np.nan} priors = self.verify_priors(ts, "lognorm") - assert np.allclose( - priors[2], PriorParams(alpha=np.nan, beta=np.nan, **prior2mv) - ) - assert np.allclose( - priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv) - ) + assert np.allclose(priors[2], PriorParams(alpha=np.nan, beta=np.nan, **prior2mv)) + assert np.allclose(priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv)) priors = self.verify_priors(ts, "gamma") - assert np.allclose( - priors[2], PriorParams(alpha=np.nan, beta=np.nan, **prior2mv) - ) - assert np.allclose( - priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv) - ) + assert np.allclose(priors[2], PriorParams(alpha=np.nan, beta=np.nan, **prior2mv)) + assert np.allclose(priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv)) @pytest.mark.skip("Fill in values instead of np.nan") def test_single_tree_ts_with_unary(self): @@ -350,12 +326,8 @@ def test_single_tree_ts_with_unary(self): prior3mv = {"mean": np.nan, "var": np.nan} priors = self.verify_priors(ts, "lognorm") - assert np.allclose( - priors[2], PriorParams(alpha=np.nan, beta=np.nan, **prior2mv) - ) - assert np.allclose( - priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv) - ) + assert np.allclose(priors[2], PriorParams(alpha=np.nan, beta=np.nan, **prior2mv)) + assert np.allclose(priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv)) priors = self.verify_priors(ts, "gamma") assert np.allclose(priors[2], PriorParams(alpha=1.0, beta=3.0, **prior2mv)) @@ -368,12 +340,8 @@ def test_two_tree_mutation_ts(self): prior3mv = {"mean": np.nan, "var": np.nan} priors = self.verify_priors(ts, "lognorm") - assert np.allclose( - priors[2], PriorParams(alpha=np.nan, beta=np.nan, **prior2mv) - ) - assert np.allclose( - priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv) - ) + assert np.allclose(priors[2], PriorParams(alpha=np.nan, beta=np.nan, **prior2mv)) + assert np.allclose(priors[3], PriorParams(alpha=np.nan, beta=np.nan, **prior3mv)) priors = self.verify_priors(ts, "gamma") assert np.allclose(priors[2], PriorParams(alpha=1.0, beta=3.0, **prior2mv)) @@ -381,7 +349,7 @@ def test_two_tree_mutation_ts(self): class TestMixturePrior: - alpha_beta = [PriorParams.field_index("alpha"), PriorParams.field_index("beta")] + alpha_beta = (PriorParams.field_index("alpha"), PriorParams.field_index("beta")) def get_mixture_prior_params(self, ts, prior_distr, **kwargs): span_data = SpansBySamples(ts, **kwargs) @@ -395,9 +363,7 @@ def test_one_tree_n2(self): mixture_priors = self.get_mixture_prior_params(ts, "gamma") assert np.allclose(mixture_priors[2, self.alpha_beta], [1.0, 1.0]) mixture_priors = self.get_mixture_prior_params(ts, "lognorm") - assert np.allclose( - mixture_priors[2, self.alpha_beta], [-0.34657359, 0.69314718] - ) + assert np.allclose(mixture_priors[2, self.alpha_beta], [-0.34657359, 0.69314718]) def test_one_tree_n3(self): ts = utility_functions.single_tree_ts_n3() @@ -405,9 +371,7 @@ def test_one_tree_n3(self): assert np.allclose(mixture_priors[3, self.alpha_beta], [1.0, 3.0]) assert np.allclose(mixture_priors[4, self.alpha_beta], [1.6, 1.2]) mixture_priors = self.get_mixture_prior_params(ts, "lognorm") - assert np.allclose( - mixture_priors[3, self.alpha_beta], [-1.44518588, 0.69314718] - ) + assert np.allclose(mixture_priors[3, self.alpha_beta], [-1.44518588, 0.69314718]) assert np.allclose(mixture_priors[4, self.alpha_beta], [0.04492816, 0.48550782]) def test_one_tree_n4(self): @@ -438,9 +402,7 @@ def test_single_tree_ts_disallow_unary(self): def test_single_tree_ts_with_unary(self, caplog): ts = utility_functions.single_tree_ts_with_unary() with caplog.at_level(logging.WARNING): - mixture_priors = self.get_mixture_prior_params( - ts, "gamma", allow_unary=True - ) + mixture_priors = self.get_mixture_prior_params(ts, "gamma", allow_unary=True) assert "tsdate may give poor results" in caplog.text # Root is a 3 tip prior assert np.allclose(mixture_priors[7, self.alpha_beta], [1.6, 1.2]) @@ -669,9 +631,7 @@ def test_precalc_lik_upper_multithread(self): lik.precalculate_mutation_likelihoods(num_threads=num_threads) for edge in ts.edges(): if not ts.node(edge.child).is_sample(): - n_internal_edges += ( - 1 # only two internal edges in this tree - ) + n_internal_edges += 1 # only two internal edges in this tree assert n_internal_edges <= 2 if edge.parent == 4 and edge.child == 3: num_muts = 2 @@ -738,7 +698,8 @@ def test_logsumexp(self): def test_zeros_logsumexp(self): with np.errstate(divide="ignore"): - lls = np.log(np.concatenate([np.zeros(100), np.random.rand(1000)])) + rng = np.random.default_rng() + lls = np.log(np.concatenate([np.zeros(100), rng.random(1000)])) assert np.allclose(LogLikelihoods.logsumexp(lls), self.naive_logsumexp(lls)) def test_logsumexp_underflow(self): @@ -804,13 +765,13 @@ def test_set_and_get(self): num_nodes = 5 grid_size = 2 fill = {} + rng = np.random.default_rng(1) for ids in ([3, 4], []): - np.random.seed(1) store = base.NodeGridValues( num_nodes, np.array(ids, dtype=np.int32), np.array(range(grid_size)) ) for i in range(num_nodes): - fill[i] = np.random.random(grid_size if i in ids else None) + fill[i] = rng.random(grid_size if i in ids else None) store[i] = fill[i] for i in range(num_nodes): assert np.all(fill[i] == store[i]) @@ -857,9 +818,7 @@ def test_clone(self): assert np.all(clone.grid_data == 0) assert np.all(clone.fixed_data == scalars) - clone = base.NodeGridValues.clone_with_new_data( - orig, np.array([[1, 2], [4, 3]]) - ) + clone = base.NodeGridValues.clone_with_new_data(orig, np.array([[1, 2], [4, 3]])) for i in range(num_nodes): if i in ids: assert np.all(clone[i] == orig[i]) @@ -1022,14 +981,11 @@ def test_polytomy_tree(self): def test_two_tree_ts(self): ts = utility_functions.two_tree_ts() - algo, priors, marg_lik = self.run_inside_algorithm( - ts, "gamma", standardize=False - ) + algo, priors, marg_lik = self.run_inside_algorithm(ts, "gamma", standardize=False) mut_rate = 0.5 # priors[3][1] * Ll_(0->3)(1.2 - 0 + eps) ** 2 node3_t1 = ( - priors[3][1] - * scipy.stats.poisson.pmf(0, (1.2 + 1e-6) * mut_rate * 0.2) ** 2 + priors[3][1] * scipy.stats.poisson.pmf(0, (1.2 + 1e-6) * mut_rate * 0.2) ** 2 ) # priors[3][2] * sum(Ll_(0->3)(2 - t + eps)) node3_t2 = ( @@ -1054,10 +1010,7 @@ def test_two_tree_ts(self): * scipy.stats.poisson.pmf(0, (2 + 1e-6) * mut_rate * 0.8) * ( (scipy.stats.poisson.pmf(0, (0.8 + 1e-6) * mut_rate * 0.2) * node3_t1) - + ( - scipy.stats.poisson.pmf(0, (1e-6 + 1e-6) * mut_rate * 0.2) - * node3_t2 - ) + + (scipy.stats.poisson.pmf(0, (1e-6 + 1e-6) * mut_rate * 0.2) * node3_t2) ) ) assert np.allclose(algo.inside[4], np.array([0, node4_t1, node4_t2])) @@ -1166,9 +1119,7 @@ def run_outside_algorithm( lls.precalculate_mutation_likelihoods() algo = InOutAlgorithms(prior_vals, lls) algo.inside_pass() - algo.outside_pass( - standardize=standardize, ignore_oldest_root=ignore_oldest_root - ) + algo.outside_pass(standardize=standardize, ignore_oldest_root=ignore_oldest_root) return algo def test_one_tree_n2(self): @@ -1326,9 +1277,7 @@ def test_gil_tree(self): priors.add(ts.num_samples, approximate=False) grid = np.array([0, 0.1, 0.2, 0.5, 1, 2, 5]) mixture_prior = priors.get_mixture_prior_params(span_data) - prior_vals = fill_priors( - mixture_prior, grid, ts, Ne, prior_distr=prior_distr - ) + prior_vals = fill_priors(mixture_prior, grid, ts, Ne, prior_distr=prior_distr) prior_vals.grid_data[0] = [0, 0.5, 0.3, 0.1, 0.05, 0.02, 0.03] prior_vals.grid_data[1] = [0, 0.05, 0.1, 0.2, 0.45, 0.1, 0.1] mut_rate = 1 @@ -1468,9 +1417,7 @@ def test_one_tree_n3(self): node_4 = lls.timepoints[np.argmax(algo.inside[4])] ll_mut = scipy.stats.poisson.pmf( 0, - (node_4 - lls.timepoints[: np.argmax(algo.inside[4]) + 1] + 1e-6) - * 1 - * 1, + (node_4 - lls.timepoints[: np.argmax(algo.inside[4]) + 1] + 1e-6) * 1 * 1, ) result = ll_mut / np.max(ll_mut) inside_val = algo.inside[3][: (np.argmax(algo.inside[4]) + 1)] @@ -1539,8 +1486,8 @@ def test_recombination_not_implemented(self): def test_Ne_and_priors(self): ts = utility_functions.single_tree_ts_n2() + priors = tsdate.build_prior_grid(ts, population_size=1) with pytest.raises(ValueError): - priors = tsdate.build_prior_grid(ts, population_size=1) tsdate.inside_outside( ts, mutation_rate=None, population_size=1, priors=priors ) @@ -1714,7 +1661,8 @@ def test_constrain_ages_backcompat(self): ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) sample_data = tsinfer.SampleData.from_tree_sequence(ts) inf_ts = tsinfer.infer(sample_data).simplify() - noise = np.random.uniform(0, 0.1, size=inf_ts.num_nodes) + rng = np.random.default_rng() + noise = rng.uniform(0, 0.1, size=inf_ts.num_nodes) nodes_time = inf_ts.nodes_time + noise eps = 1e-6 blen = nodes_time[inf_ts.edges_parent] - nodes_time[inf_ts.edges_child] @@ -1738,7 +1686,8 @@ def test_constrain_ages_leastsquare(self): ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) sample_data = tsinfer.SampleData.from_tree_sequence(ts) inf_ts = tsinfer.infer(sample_data).simplify() - noise = np.random.uniform(0, 0.5, size=inf_ts.num_nodes) + rng = np.random.default_rng() + noise = rng.uniform(0, 0.5, size=inf_ts.num_nodes) nodes_time = inf_ts.nodes_time + noise eps = 1e-6 blen = nodes_time[inf_ts.edges_parent] - nodes_time[inf_ts.edges_child] @@ -1794,12 +1743,10 @@ def test_invariant_sites(self): assert removed.num_populations == ts.num_populations assert tsdate.preprocess_ts(ts, **{"filter_sites": True}).num_sites == 0 assert ( - tsdate.preprocess_ts(ts, **{"filter_populations": True}).num_populations - == 0 + tsdate.preprocess_ts(ts, **{"filter_populations": True}).num_populations == 0 ) assert ( - tsdate.preprocess_ts(ts, **{"filter_individuals": True}).num_individuals - == 0 + tsdate.preprocess_ts(ts, **{"filter_individuals": True}).num_individuals == 0 ) def test_no_intervals(self): @@ -1926,9 +1873,7 @@ def test_sites_time_node_selection(self): for tree in dated.trees(): for site in tree.sites(): for mut in site.mutations: - parent_sites_check[site.id] = dated_nodes_time[ - tree.parent(mut.node) - ] + parent_sites_check[site.id] = dated_nodes_time[tree.parent(mut.node)] assert np.array_equal(parent_sites_check, sites_time_parent) sites_time_arithmetic = tsdate.sites_time_from_ts( @@ -2033,9 +1978,7 @@ def test_wrong_number_of_sites(self): ts = utility_functions.single_tree_ts_2mutations_n3() sites_time = tsdate.sites_time_from_ts(ts, unconstrained=False) sites_time = np.append(sites_time, [10]) - samples = tsinfer.formats.SampleData.from_tree_sequence( - ts, use_sites_time=False - ) + samples = tsinfer.formats.SampleData.from_tree_sequence(ts, use_sites_time=False) with pytest.raises(ValueError): tsdate.add_sampledata_times(samples, sites_time) @@ -2051,9 +1994,9 @@ def test_historical_samples(self): length=1e4, random_seed=12, ) - ancient_samples = np.where(ts.tables.nodes.time[:][ts.samples()] != 0)[ - 0 - ].astype("int32") + ancient_samples = np.where(ts.tables.nodes.time[:][ts.samples()] != 0)[0].astype( + "int32" + ) ancient_samples_times = ts.tables.nodes.time[ancient_samples] samples = tsinfer.formats.SampleData.from_tree_sequence(ts) @@ -2062,13 +2005,9 @@ def test_historical_samples(self): sites_time = tsdate.sites_time_from_ts(dated) # Add in the original individual times ind_dated_sd = samples.copy() - ind_dated_sd.individuals_time[ - : - ] = tsinfer.formats.SampleData.from_tree_sequence( + ind_dated_sd.individuals_time[:] = tsinfer.formats.SampleData.from_tree_sequence( ts, use_individuals_time=True, use_sites_time=True - ).individuals_time[ - : - ] + ).individuals_time[:] ind_dated_sd.finalise() dated_samples = tsdate.add_sampledata_times(ind_dated_sd, sites_time) for variant in ts.variants(samples=ancient_samples): @@ -2088,9 +2027,7 @@ def test_sampledata(self): length=1e4, random_seed=12, ) - samples = tsinfer.formats.SampleData.from_tree_sequence( - ts, use_sites_time=False - ) + samples = tsinfer.formats.SampleData.from_tree_sequence(ts, use_sites_time=False) inferred = tsinfer.infer(samples).simplify() dated = tsdate.date(inferred, mutation_rate=1e-8) sites_time = tsdate.sites_time_from_ts(dated) diff --git a/tests/test_hypergeo.py b/tests/test_hypergeo.py index fe4586e8..d4ae9d5a 100644 --- a/tests/test_hypergeo.py +++ b/tests/test_hypergeo.py @@ -23,6 +23,7 @@ """ Test cases for numba-fied hypergeometric functions """ + import mpmath import numdifftools as nd import numpy as np diff --git a/tests/test_inference.py b/tests/test_inference.py index dbe068a9..cd53105e 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -23,6 +23,7 @@ """ Test cases for the python API for tsdate. """ + import logging import msprime @@ -33,11 +34,9 @@ import utility_functions import tsdate -from tsdate.base import LIN -from tsdate.base import LOG +from tsdate.base import LIN, LOG from tsdate.demography import PopulationSizeHistory -from tsdate.evaluation import remove_edges -from tsdate.evaluation import unsupported_edges +from tsdate.evaluation import remove_edges, unsupported_edges class TestConstants: @@ -68,9 +67,7 @@ def test_no_population_size(self): def test_no_mutation(self): ts = utility_functions.two_tree_mutation_ts() with pytest.raises(ValueError, match="method requires mutation rate"): - tsdate.date( - ts, method="maximization", population_size=1, mutation_rate=None - ) + tsdate.date(ts, method="maximization", population_size=1, mutation_rate=None) with pytest.raises(ValueError, match="method requires mutation rate"): tsdate.date(ts, method="variational_gamma", mutation_rate=None) @@ -78,9 +75,7 @@ def test_not_needed_population_size(self): ts = utility_functions.two_tree_mutation_ts() prior = tsdate.build_prior_grid(ts, population_size=1, timepoints=10) with pytest.raises(ValueError, match="Cannot specify population size"): - tsdate.inside_outside( - ts, population_size=1, mutation_rate=None, priors=prior - ) + tsdate.inside_outside(ts, population_size=1, mutation_rate=None, priors=prior) def test_bad_population_size(self): ts = utility_functions.two_tree_mutation_ts() @@ -120,8 +115,8 @@ def test_variational_gamma_unary_failure(self): with pytest.raises(ValueError, match="unary"): tsdate.variational_gamma(ts, mutation_rate=1) - @pytest.mark.parametrize("probability_space", (LOG, LIN)) - @pytest.mark.parametrize("mu", (None, 1)) + @pytest.mark.parametrize("probability_space", [LOG, LIN]) + @pytest.mark.parametrize("mu", [None, 1]) def test_fails_with_recombination(self, probability_space, mu): ts = utility_functions.two_tree_mutation_ts() with pytest.raises(NotImplementedError): @@ -400,7 +395,7 @@ class TestVariational: """ @pytest.fixture(autouse=True) - def ts(self): + def ts(self): # noqa PT004 ts = msprime.sim_ancestry( samples=10, recombination_rate=1e-8, diff --git a/tests/test_priors.py b/tests/test_priors.py index 07f24f3f..7082baaa 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -22,17 +22,20 @@ """ Test cases for prior functionality used in tsdate """ + import logging import numpy as np import pytest import utility_functions -from tsdate.prior import conditional_coalescent_variance -from tsdate.prior import ConditionalCoalescentTimes -from tsdate.prior import create_timepoints -from tsdate.prior import PriorParams -from tsdate.prior import SpansBySamples +from tsdate.prior import ( + ConditionalCoalescentTimes, + PriorParams, + SpansBySamples, + conditional_coalescent_variance, + create_timepoints, +) class TestConditionalCoalescentTimes: diff --git a/tests/test_provenance.py b/tests/test_provenance.py index 1c8c6238..b2b41b52 100644 --- a/tests/test_provenance.py +++ b/tests/test_provenance.py @@ -23,6 +23,7 @@ """ Test cases for saving provenances in tsdate. """ + import json import numpy as np @@ -106,9 +107,7 @@ def test_preprocess_defaults_recorded(self): def test_preprocess_interval_recorded(self): ts = utility_functions.ts_w_data_desert(40, 60, 100) num_provenances = ts.num_provenances - preprocessed_ts = tsdate.preprocess_ts( - ts, minimum_gap=20, remove_telomeres=False - ) + preprocessed_ts = tsdate.preprocess_ts(ts, minimum_gap=20, remove_telomeres=False) assert preprocessed_ts.num_provenances == num_provenances + 1 rec = json.loads(preprocessed_ts.provenance(-1).record) assert rec["parameters"]["minimum_gap"] == 20 diff --git a/tests/test_util.py b/tests/test_util.py index 6eb626fe..a69da1cc 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -22,6 +22,7 @@ """ Test cases for tsdate utility functions """ + import json import logging @@ -151,9 +152,7 @@ def test_metadata(self): tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json() tables.nodes.packset_metadata( [ - tables.nodes.metadata_schema.validate_and_encode_row( - {"xxx": f"test{x}"} - ) + tables.nodes.metadata_schema.validate_and_encode_row({"xxx": f"test{x}"}) for x in range(ts.num_nodes) ] ) diff --git a/tests/utility_functions.py b/tests/utility_functions.py index f605583e..7f0469e1 100644 --- a/tests/utility_functions.py +++ b/tests/utility_functions.py @@ -22,6 +22,7 @@ """ A collection of utilities to edit and construct tree sequences for testing purposes """ + import io import itertools @@ -1015,16 +1016,15 @@ def truncate_ts_samples(ts, average_span, random_seed, min_span=5): of the tree. """ - np.random.seed(random_seed) + rng = np.random.default_rng(random_seed) # Make a list of (left,right) tuples giving the new limits of each sample # Keyed by sample ID. # for simplicity, we pick lengths from a poisson distribution of av 300 bp - span = np.random.poisson(average_span, ts.num_samples) + span = rng.poisson(average_span, ts.num_samples) span = np.maximum(span, min_span) span = np.minimum(span, ts.sequence_length) - start = np.random.uniform(0, ts.sequence_length - span) + start = rng.uniform(0, ts.sequence_length - span) to_slice = {id_: (a, b) for id_, a, b in zip(ts.samples(), start, start + span)} - tables = ts.dump_tables() tables.edges.clear() for e in ts.tables.edges: diff --git a/tsdate/__init__.py b/tsdate/__init__.py index bc267f5c..4760b92d 100644 --- a/tsdate/__init__.py +++ b/tsdate/__init__.py @@ -19,18 +19,22 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from .cache import * # NOQA: F401,F403 -from .core import date # NOQA: F401 -from .core import estimation_methods # NOQA: F401 -from .core import inside_outside # NOQA: F401 -from .core import maximization # NOQA: F401 -from .core import variational_gamma # NOQA: F401 +from .cache import * # noqa: F403 +from .core import ( + date, # NOQA: F401 + estimation_methods, # NOQA: F401 + inside_outside, # NOQA: F401 + maximization, # NOQA: F401 + variational_gamma, # NOQA: F401 +) from .prior import parameter_grid as build_parameter_grid # NOQA: F401 from .prior import prior_grid as build_prior_grid # NOQA: F401 from .provenance import __version__ # NOQA: F401 -from .util import add_sampledata_times # NOQA: F401 -from .util import preprocess_ts # NOQA: F401 -from .util import sites_time_from_ts # NOQA: F401 +from .util import ( + add_sampledata_times, # NOQA: F401 + preprocess_ts, # NOQA: F401 + sites_time_from_ts, # NOQA: F401 +) # Bit 20 is set in node flags when they are samples not at time zero in the sampledata # file. This should match the node flag in tsinfer. diff --git a/tsdate/approx.py b/tsdate/approx.py index cb832217..92952017 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -23,15 +23,13 @@ """ Tools for approximating combinations of Gamma variates with Gamma distributions """ -from math import exp -from math import inf -from math import lgamma -from math import log + +from math import exp, inf, lgamma, log import numba import numpy as np -from numba.types import Tuple as _tuple -from numba.types import UniTuple as _unituple +from numba.types import Tuple as _tuple # noqa N813 +from numba.types import UniTuple as _unituple # noqa N813 from . import hypergeo @@ -59,7 +57,7 @@ _b1r = numba.types.Array(_b, 1, "C", readonly=True) -class KLMinimizationFailed(Exception): +class KLMinimizationFailed(Exception): # noqa N818 pass @@ -145,9 +143,7 @@ def approximate_gamma_iqr(q1, q2, x1, x2): itt = 0 while abs(delta) > abs(alpha) * _KLMIN_RELTOL: if itt > _KLMIN_MAXITT: - raise KLMinimizationFailed( - "Maximum iterations reached in quantile matching" - ) + raise KLMinimizationFailed("Maximum iterations reached in quantile matching") y1 = hypergeo._gammainc_inv(alpha, q1) y2 = hypergeo._gammainc_inv(alpha, q2) obj = y2 / y1 - x2 / x1 diff --git a/tsdate/base.py b/tsdate/base.py index 3881c236..3129edba 100644 --- a/tsdate/base.py +++ b/tsdate/base.py @@ -23,6 +23,7 @@ """ Base classes and internal constants used by tsdate """ + import numpy as np FLOAT_DTYPE = np.float64 @@ -85,19 +86,15 @@ def __init__( self.num_nodes = num_nodes self.nonfixed_nodes = nonfixed_nodes self.num_nonfixed = len(nonfixed_nodes) - self.grid_data = np.full( - (self.num_nonfixed, grid_size), fill_value, dtype=dtype - ) - self.fixed_data = np.full( - num_nodes - self.num_nonfixed, fill_value, dtype=dtype - ) + self.grid_data = np.full((self.num_nonfixed, grid_size), fill_value, dtype=dtype) + self.fixed_data = np.full(num_nodes - self.num_nonfixed, fill_value, dtype=dtype) self.row_lookup = np.empty(num_nodes, dtype=np.int64) # non-fixed nodes get a positive value, indicating lookup in the grid_data array self.row_lookup[nonfixed_nodes] = np.arange(self.num_nonfixed) # fixed nodes get a negative value from -1, indicating lookup in the scalar array - self.row_lookup[ - np.logical_not(np.isin(np.arange(num_nodes), nonfixed_nodes)) - ] = (-np.arange(num_nodes - self.num_nonfixed) - 1) + self.row_lookup[np.logical_not(np.isin(np.arange(num_nodes), nonfixed_nodes))] = ( + -np.arange(num_nodes - self.num_nonfixed) - 1 + ) self.probability_space = LIN def force_probability_space(self, probability_space): @@ -159,9 +156,7 @@ def to_probabilities(self): in logarithmic space) """ if self.probability_space != LIN: - raise NotImplementedError( - "Can only convert to probabilities in linear space" - ) + raise NotImplementedError("Can only convert to probabilities in linear space") assert not np.any(self.grid_data < 0) self.grid_data = self.grid_data / self.grid_data.sum(axis=1)[:, np.newaxis] diff --git a/tsdate/cache.py b/tsdate/cache.py index 29247479..49b2d142 100644 --- a/tsdate/cache.py +++ b/tsdate/cache.py @@ -22,13 +22,13 @@ """ Handle cache for precalculated prior """ + import logging import os import pathlib import appdirs - logger = logging.getLogger(__name__) diff --git a/tsdate/cli.py b/tsdate/cli.py index dd715eb7..9b157468 100644 --- a/tsdate/cli.py +++ b/tsdate/cli.py @@ -23,6 +23,7 @@ """ Command line interface for tsdate. """ + import argparse import logging import sys @@ -30,6 +31,7 @@ import tskit import tsdate + from . import core logger = logging.getLogger(__name__) diff --git a/tsdate/core.py b/tsdate/core.py index f5023b47..ab0a7d8e 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -23,14 +23,14 @@ """ Infer the age of nodes conditional on a tree sequence topology. """ + import functools import itertools import logging import multiprocessing import operator import time # DEBUG -from collections import defaultdict -from collections import namedtuple +from collections import defaultdict, namedtuple import numba import numpy as np @@ -38,13 +38,7 @@ import tskit from tqdm.auto import tqdm -from . import base -from . import demography -from . import prior -from . import provenance -from . import schemas -from . import util -from . import variational +from . import base, demography, prior, provenance, schemas, util, variational FORMAT_NAME = "tsdate" DEFAULT_RESCALING_INTERVALS = 1000 @@ -80,9 +74,7 @@ def __init__( ): self.ts = ts self.timepoints = timepoints - self.fixednodes = ( - set(ts.samples()) if fixed_node_set is None else fixed_node_set - ) + self.fixednodes = set(ts.samples()) if fixed_node_set is None else fixed_node_set self.mut_rate = mutation_rate self.rec_rate = recombination_rate self.standardize = standardize @@ -409,17 +401,17 @@ class LogLikelihoods(Likelihoods): @staticmethod @numba.jit(nopython=True) def logsumexp(X): - alpha = -np.Inf + alpha = -np.inf r = 0.0 for x in X: - if x != -np.Inf: + if x != -np.inf: if x <= alpha: r += np.exp(x - alpha) else: r *= np.exp(alpha - x) r += 1.0 alpha = x - return -np.Inf if r == 0 else np.log(r) + alpha + return -np.inf if r == 0 else np.log(r) + alpha @staticmethod def _lik(muts, span, dt, mutation_rate, standardize=True): @@ -668,9 +660,7 @@ def inside_pass(self, *, standardize=True, cache_inside=False, progress=None): if edge.child in self.fixednodes: # NB: geometric scaling works exactly when all nodes fixed in graph # but is an approximation when times are unknown. - daughter_val = self.lik.scale_geometric( - spanfrac, inside[edge.child] - ) + daughter_val = self.lik.scale_geometric(spanfrac, inside[edge.child]) edge_lik = self.lik.get_fixed(daughter_val, edge) else: inside_values = inside[edge.child] @@ -703,9 +693,7 @@ def inside_pass(self, *, standardize=True, cache_inside=False, progress=None): for root, span_when_root in self.root_spans.items(): spanfrac = span_when_root / self.spans[root] root_val = self.lik.scale_geometric(spanfrac, inside[root]) - marginal_lik = self.lik.combine( - marginal_lik, self.lik.marginalize(root_val) - ) + marginal_lik = self.lik.combine(marginal_lik, self.lik.marginalize(root_val)) return marginal_lik def outside_pass( @@ -734,9 +722,7 @@ def outside_pass( if not hasattr(self, "inside"): raise RuntimeError("You have not yet run the inside algorithm") - outside = self.inside.clone_with_new_data( - grid_data=0, probability_space=base.LIN - ) + outside = self.inside.clone_with_new_data(grid_data=0, probability_space=base.LIN) for root, span_when_root in self.root_spans.items(): outside[root] = span_when_root / self.spans[root] outside.force_probability_space(self.inside.probability_space) @@ -1024,9 +1010,7 @@ def get_modified_ts(self, result, eps): mutations, mut_mean_t, mut_var_t, schemas.default_mutation_schema ) meta_timing -= time.time() - logging.info( - f"Inserted node and mutation metadata in {abs(meta_timing)} seconds" - ) + logging.info(f"Inserted node and mutation metadata in {abs(meta_timing)} seconds") tables.sort() return tables.tree_sequence() @@ -1217,9 +1201,7 @@ def mean_var(posteriors, constraints): fixed_node_set). This is a static method for ease of testing. """ - mn_post = np.full( - posteriors.shape[0], np.nan - ) # Fill with NaNs so we detect when + mn_post = np.full(posteriors.shape[0], np.nan) # Fill with NaNs so we detect when va_post = np.full(posteriors.shape[0], np.nan) # there's been an error fixed = constraints[:, 0] == constraints[:, 1] diff --git a/tsdate/demography.py b/tsdate/demography.py index 5cd893e0..6569640a 100644 --- a/tsdate/demography.py +++ b/tsdate/demography.py @@ -23,6 +23,7 @@ """ Routines and classes for manipulating demographic histories in tsdate """ + import numpy as np import scipy.stats diff --git a/tsdate/evaluation.py b/tsdate/evaluation.py index 7ed44784..8c468357 100644 --- a/tsdate/evaluation.py +++ b/tsdate/evaluation.py @@ -22,11 +22,11 @@ """ Tools for comparing node times between tree sequences with different node sets """ + import copy import json from collections import defaultdict -from itertools import groupby -from itertools import product +from itertools import groupby, product from math import isqrt import matplotlib.pyplot as plt @@ -78,7 +78,7 @@ def _propagate(self, edge, downdate=False): node = self.tree.parent(node) return nodes - def next(self): # noqa: A003 + def next(self): """ Advance to the next tree, returning the difference between trees as a dictionary of the form `node : (last_clade, next_clade)` @@ -384,9 +384,9 @@ def unsupported_edges(ts, per_interval=False): else: keep = ~edges_to_remove for p, c in zip(ts.edges_parent[keep], ts.edges_child[keep]): - edges_to_remove[ - np.logical_and(ts.edges_parent == p, ts.edges_child == c) - ] = False + edges_to_remove[np.logical_and(ts.edges_parent == p, ts.edges_child == c)] = ( + False + ) return np.where(edges_to_remove)[0] @@ -428,9 +428,7 @@ def node_coverage(ts, inferred_ts, alpha): upper[i] = scipy.stats.gamma.ppf(1 - alpha / 2, shape, scale=1 / rate) lower[i] = scipy.stats.gamma.ppf(alpha / 2, shape, scale=1 / rate) true = ts.nodes_time[true_child] - is_covered = np.logical_and( - true[:, np.newaxis] < upper, true[:, np.newaxis] > lower - ) + is_covered = np.logical_and(true[:, np.newaxis] < upper, true[:, np.newaxis] > lower) prop_covered = np.sum(is_covered, axis=0) / is_covered.shape[0] # import matplotlib.pyplot as plt # plt.axline((0,0), slope=1, linestyle="--", color="black") @@ -500,9 +498,7 @@ def mutation_coverage(ts, inferred_ts, alpha): upper[i] = scipy.stats.gamma.ppf(1 - alpha / 2, shape, scale=1 / rate) lower[i] = scipy.stats.gamma.ppf(alpha / 2, shape, scale=1 / rate) true = ts.mutations_time[true_mut] - is_covered = np.logical_and( - true[:, np.newaxis] < upper, true[:, np.newaxis] > lower - ) + is_covered = np.logical_and(true[:, np.newaxis] < upper, true[:, np.newaxis] > lower) prop_covered = np.sum(is_covered, axis=0) / is_covered.shape[0] # plt.clf() # plt.axline((0,0), slope=1, linestyle="--", color="black") diff --git a/tsdate/hypergeo.py b/tsdate/hypergeo.py index b64d9b69..a574cfa1 100644 --- a/tsdate/hypergeo.py +++ b/tsdate/hypergeo.py @@ -23,14 +23,9 @@ """ Numerically stable implementations of the Gauss hypergeometric function with numba. """ + import ctypes -from math import erf -from math import exp -from math import lgamma -from math import log -from math import pi -from math import pow -from math import sqrt +from math import erf, exp, lgamma, log, pi, pow, sqrt import numba import numpy as np @@ -40,7 +35,7 @@ _HYP2F1_MAXTERM = int(1e6) -class Invalid2F1(Exception): +class Invalid2F1(Exception): # noqa N818 pass @@ -220,7 +215,8 @@ def _hyp2f1_unity(a, b, c, x): limits don't converge. A good reference is Buhring 2003 "Partial sums of hypergeometric series of unit argument" """ - assert np.isclose(x, 1.0) and x < 1.0 + assert np.isclose(x, 1.0) + assert x < 1.0 g = c - a - b diff --git a/tsdate/prior.py b/tsdate/prior.py index 5d4e70dd..9a24207c 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -23,10 +23,10 @@ """ Routines and classes for creating priors and timeslices for use in tsdate """ + import logging import os -from collections import defaultdict -from collections import namedtuple +from collections import defaultdict, namedtuple import numba import numpy as np @@ -36,11 +36,7 @@ import tskit from tqdm.auto import tqdm -from . import base -from . import cache -from . import demography -from . import provenance -from . import util +from . import base, cache, demography, provenance, util class PriorParams(namedtuple("PriorParamsBase", "alpha, beta, mean, var")): @@ -244,22 +240,18 @@ def add(self, total_tips, approximate=None): # NB: it should be possible to vectorize this in numpy expectation = self.tau_expect(tips, total_tips) alpha, beta = self.func_approx(expectation, var) - priors[tips] = PriorParams( - alpha=alpha, beta=beta, mean=expectation, var=var - ) + priors[tips] = PriorParams(alpha=alpha, beta=beta, mean=expectation, var=var) self.prior_store[total_tips] = priors def precalculate_priors_for_approximation(self, precalc_approximation_n): n = precalc_approximation_n logging.warning( "Initialising your tsdate installation by creating a user cache of " - "conditional coalescent prior values for {} tips".format(n) + f"conditional coalescent prior values for {n} tips" ) logging.info( - "Creating prior lookup table for a total tree of n={} tips" - " in `{}`, this may take some time for large n".format( - n, self.get_precalc_cache(n) - ) + f"Creating prior lookup table for a total tree of n={n} tips" + f" in `{self.get_precalc_cache(n)}`, this may take some time for large n" ) # The first value should be zero tips, we don't want the 1 tip value prior_lookup_table = np.zeros((n, 2)) @@ -274,8 +266,8 @@ def clear_precalculated_priors(self): os.remove(self.get_precalc_cache(self.n_approx)) else: logging.debug( - "Precalculated priors in `{}` not yet created, so cannot be" - " cleared".format(self.get_precalc_cache(self.n_approx)) + f"Precalculated priors in `{self.get_precalc_cache(self.n_approx)}`" + "not yet created, so cannot be cleared" ) @staticmethod @@ -410,9 +402,7 @@ def get_mixture_prior_params(self, spans_by_samples): priors[node] = seen_mixtures[mixture_hash] else: # a large number of mixtures in this node - don't bother caching - priors[node] = self.func_approx( - *self.mixture_expect_and_var(mixture) - ) + priors[node] = self.func_approx(*self.mixture_expect_and_var(mixture)) else: # The node spans trees with multiple total tip numbers, # don't use the cache @@ -584,10 +574,10 @@ def save_to_spans(prev_tree, node, num_fixed_at_0_treenodes): else: coverage = 0 raise ValueError( - "Node {} is dangling (no descendant samples) at pos {}: " - "this node will have no weight in this region. Run " - "`simplify(keep_unary=False)` before dating this tree " - "sequence".format(node, stored_pos[node]) + f"Node {node} is dangling (no descendant samples) at pos " + f"{stored_pos[node]}: this node will have no weight in " + "this region. Run `simplify(keep_unary=False)` before dating " + "this tree sequence" ) if node in self.sample_node_set: return True @@ -607,8 +597,8 @@ def save_to_spans(prev_tree, node, num_fixed_at_0_treenodes): except ValueError: # Happens if we have hit the root assert top_node == tskit.NULL logging.debug( - "Unary node `{}` exists above highest coalescence in tree {}." - " Skipping for now".format(node, prev_tree.index) + f"Unary node `{node}` exists above highest coalescence in " + "tree {prev_tree.index}. Skipping for now" ) return None # for unary nodes, a proportion of the span is allocated @@ -810,8 +800,8 @@ def second_pass(self, trees_with_undated, n_tips_per_tree): continue else: logging.debug( - "Assigning prior to unary node {}: connected to node {} which" - " has a prior in tree {}".format(node, n, tree_id) + f"Assigning prior to unary node {node}: connected to " + f"node {n} which has a prior in tree {tree_id}" ) for n_tips, spans in self._spans[n].items(): for k, v in spans.items(): @@ -869,10 +859,9 @@ def finalize(self): if self.nodes_remain_to_date(): raise ValueError( - "When finalising node spans, found the following nodes not in any tree;" - " you must simplify your tree sequence first: {}".format( - self.nodes_remaining_to_date() - ) + "When finalising node spans, found the following nodes not in " + "any tree; you must simplify your tree sequence first:" + f"{self.nodes_remaining_to_date()}" ) for node, spans_by_total_tips in self._spans.items(): @@ -1084,9 +1073,7 @@ def __init__( "Passed tree sequence is not simplified and/or contains " "noncontemporaneous samples" ) - span_data = SpansBySamples( - contmpr_ts, progress=progress, allow_unary=allow_unary - ) + span_data = SpansBySamples(contmpr_ts, progress=progress, allow_unary=allow_unary) base_priors = ConditionalCoalescentTimes( approx_prior_size, prior_distribution, progress=progress @@ -1120,11 +1107,11 @@ def make_discretised_prior(self, population_size, timepoints=20, progress=False) timepoints = create_timepoints(self.base_priors, timepoints + 1) elif isinstance(timepoints, np.ndarray): try: - timepoints = np.sort( - timepoints.astype(base.FLOAT_DTYPE, casting="safe") - ) + timepoints = np.sort(timepoints.astype(base.FLOAT_DTYPE, casting="safe")) except TypeError: - raise TypeError("Timepoints array cannot be converted to float dtype") + raise TypeError( + "Timepoints array cannot be converted to float dtype" + ) from None if len(timepoints) < 2: raise ValueError("You must have at least 2 time points") elif np.any(timepoints < 0): @@ -1135,9 +1122,7 @@ def make_discretised_prior(self, population_size, timepoints=20, progress=False) # coalescent timescale to evaluate prior timepoints = population_size.to_coalescent_timescale(timepoints) else: - raise ValueError( - "time_slices must be an integer or a numpy array of floats" - ) + raise ValueError("time_slices must be an integer or a numpy array of floats") # Set all fixed nodes (i.e. samples) to have 0 variance priors = fill_priors( @@ -1286,9 +1271,7 @@ def parameter_grid( def has_locally_unary_nodes(ts): for tree, ediff in zip(ts.trees(), ts.edge_diffs()): - changed = { - e.parent for edges in (ediff.edges_out, ediff.edges_in) for e in edges - } + changed = {e.parent for edges in (ediff.edges_out, ediff.edges_in) for e in edges} if (tree.num_children_array[list(changed)] == 1).any(): return True return False diff --git a/tsdate/provenance.py b/tsdate/provenance.py index f2d8a6ba..da36bf0a 100644 --- a/tsdate/provenance.py +++ b/tsdate/provenance.py @@ -22,6 +22,7 @@ """ Versions of important dependencies and environment. """ + import json import platform diff --git a/tsdate/rescaling.py b/tsdate/rescaling.py index 0d6dbade..6fd793e9 100644 --- a/tsdate/rescaling.py +++ b/tsdate/rescaling.py @@ -22,25 +22,27 @@ """ Utilities for rescaling time according to a mutational clock """ -from math import inf -from math import log + +from math import inf, log import numba import numpy as np import tskit -from numba.types import UniTuple as _unituple - -from .approx import _b -from .approx import _b1r -from .approx import _f -from .approx import _f1r -from .approx import _f1w -from .approx import _f2r -from .approx import _f2w -from .approx import _i -from .approx import _i1r -from .approx import _i1w -from .approx import approximate_gamma_iqr +from numba.types import UniTuple as _unituple # noqa: N813 + +from .approx import ( + _b, + _b1r, + _f, + _f1r, + _f1w, + _f2r, + _f2w, + _i, + _i1r, + _i1w, + approximate_gamma_iqr, +) from .hypergeo import _gammainc_inv as gammainc_inv from .util import mutation_span_array # NOQA: F401 @@ -137,8 +139,10 @@ def mutational_timescale( """ assert edges_parent.size == edges_child.size == edges_weight.size - assert likelihoods.shape[0] == edges_parent.size and likelihoods.shape[1] == 2 - assert constraints.shape[0] == nodes_time.size and constraints.shape[1] == 2 + assert likelihoods.shape[0] == edges_parent.size + assert likelihoods.shape[1] == 2 + assert constraints.shape[0] == nodes_time.size + assert constraints.shape[1] == 2 assert max_intervals > 0 nodes_fixed = constraints[:, 0] == constraints[:, 1] @@ -236,7 +240,8 @@ def piecewise_scale_posterior( def rescale(x): i = np.searchsorted(original_breaks, x, "right") - 1 - assert i.min() >= 0 and i.max() < scalings.size # DEBUG + assert i.min() >= 0 + assert i.max() < scalings.size # DEBUG return rescaled_breaks[i] + scalings[i] * (x - original_breaks[i]) midpt = rescale(midpt) @@ -247,9 +252,7 @@ def rescale(x): # TODO: catch rare cases where lower/upper quantiles are nearly identical new_posteriors = np.zeros(posteriors.shape) for i in np.flatnonzero(freed): - alpha, beta = approximate_gamma_iqr( - quant_lower, quant_upper, lower[i], upper[i] - ) + alpha, beta = approximate_gamma_iqr(quant_lower, quant_upper, lower[i], upper[i]) beta = gammainc_inv(alpha + 1, 0.5) if use_median else (alpha + 1) beta /= midpt[i] # choose rate so as to keep mean or median new_posteriors[i] = alpha, beta diff --git a/tsdate/schemas.py b/tsdate/schemas.py index 22cebc6b..5dd72999 100644 --- a/tsdate/schemas.py +++ b/tsdate/schemas.py @@ -22,6 +22,7 @@ """ Metadata schemas used in tsdate, if no schema already provided """ + import tskit default_mutation_schema = tskit.MetadataSchema( diff --git a/tsdate/util.py b/tsdate/util.py index e47f397e..bd23ee97 100644 --- a/tsdate/util.py +++ b/tsdate/util.py @@ -23,24 +23,19 @@ Utility functions for tsdate. Many of these can be removed when tskit is updated to a more recent version which has the functionality built-in """ + import json import logging import numba import numpy as np import tskit -from numba.types import UniTuple as _unituple +from numba.types import UniTuple as _unituple # noqa: N813 import tsdate -from . import provenance -from .approx import _b1r -from .approx import _f -from .approx import _f1r -from .approx import _f1w -from .approx import _i -from .approx import _i1r -from .approx import _i1w +from . import provenance +from .approx import _b1r, _f, _f1r, _f1w, _i, _i1r, _i1w logger = logging.getLogger(__name__) @@ -149,7 +144,7 @@ def preprocess_ts( delete_intervals.append([0, first_site]) logger.info( "REMOVING TELOMERE: Snip topology " - "from 0 to first site at {}.".format(first_site) + f"from 0 to first site at {first_site}." ) last_site = sites[-1] + 1 sequence_length = tables.sequence_length @@ -157,9 +152,7 @@ def preprocess_ts( delete_intervals.append([last_site, sequence_length]) logger.info( "REMOVING TELOMERE: Snip topology " - "from {} to end of sequence at {}.".format( - last_site, sequence_length - ) + f"from {last_site} to end of sequence at {sequence_length}." ) gaps = sites[1:] - sites[:-1] threshold_gaps = np.where(gaps >= minimum_gap)[0] @@ -168,15 +161,13 @@ def preprocess_ts( gap_end = sites[gap + 1] - 1 if gap_end > gap_start: logger.info( - "Gap Size is {}. Snip topology " - "from {} to {}.".format(gap_end - gap_start, gap_start, gap_end) + f"Gap Size is {gap_end - gap_start}. Snip topology " + f"from {gap_start} to {gap_end}." ) delete_intervals.append([gap_start, gap_end]) delete_intervals = sorted(delete_intervals, key=lambda x: x[0]) if len(delete_intervals) > 0: - tables.delete_intervals( - delete_intervals, simplify=False, record_provenance=False - ) + tables.delete_intervals(delete_intervals, simplify=False, record_provenance=False) tables.simplify( filter_populations=filter_populations, filter_individuals=filter_individuals, @@ -228,7 +219,7 @@ def nodes_time_unconstrained(tree_sequence): except (KeyError, json.decoder.JSONDecodeError): raise ValueError( "Tree Sequence must be tsdated with the Inside-Outside Method." - ) + ) from None return nodes_time diff --git a/tsdate/variational.py b/tsdate/variational.py index f1a6b197..826de482 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -23,6 +23,7 @@ """ Expectation propagation implementation """ + import logging import time @@ -33,22 +34,13 @@ from tqdm.auto import tqdm from . import approx -from .approx import _b -from .approx import _b1r -from .approx import _f -from .approx import _f1r -from .approx import _f1w -from .approx import _f2r -from .approx import _f2w -from .approx import _f3r -from .approx import _f3w -from .approx import _i -from .approx import _i1r +from .approx import _b, _b1r, _f, _f1r, _f1w, _f2r, _f2w, _f3r, _f3w, _i, _i1r from .hypergeo import _gammainc_inv as gammainc_inv -from .rescaling import edge_sampling_weight -from .rescaling import mutational_timescale -from .rescaling import piecewise_scale_posterior - +from .rescaling import ( + edge_sampling_weight, + mutational_timescale, + piecewise_scale_posterior, +) # columns for edge_factors ROOTWARD = 0 # edge likelihood to parent @@ -391,9 +383,7 @@ def posterior_damping(x): @staticmethod @numba.njit(_f(_b1r, _f2w, _f3w, _f1w, _f, _i, _f)) - def propagate_prior( - free, posterior, factors, scale, max_shape, em_maxitt, em_reltol - ): + def propagate_prior(free, posterior, factors, scale, max_shape, em_maxitt, em_reltol): """ Update approximating factors for global prior.