Skip to content

Commit

Permalink
Switch to ruff linting
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Jun 6, 2024
1 parent 575e194 commit bac9a13
Show file tree
Hide file tree
Showing 29 changed files with 272 additions and 406 deletions.
25 changes: 6 additions & 19 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,10 @@ repos:
- id: mixed-line-ending
- id: check-case-conflict
- id: check-yaml
- repo: https://github.com/asottile/reorder_python_imports
rev: v3.10.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.5
hooks:
- id: reorder-python-imports
- repo: https://github.com/asottile/pyupgrade
rev: v3.10.1
hooks:
- id: pyupgrade
args: [--py3-plus, --py38-plus]
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
hooks:
- id: flake8
args: [--config=.flake8]
additional_dependencies: ["flake8-bugbear==23.7.10", "flake8-builtins==2.1.0"]
- id: ruff
args: [ "--fix", "--config", "ruff.toml" ]
- id: ruff-format
args: [ "--config", "ruff.toml" ]
30 changes: 16 additions & 14 deletions tests/distribution_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
Utility functions to construct distributions used in variational inference,
for testing purposes
"""

import mpmath
import numpy as np
import scipy.integrate
import scipy.special

from tsdate import approx
from tsdate import hypergeo
from tsdate import approx, hypergeo


def kl_divergence(p, logq):
Expand Down Expand Up @@ -81,9 +81,7 @@ def pr_a(a, n, k):
if n == k:
return pr_t_bar_a(t, 1)
else:
return np.sum(
[pr_a(a, n, k) * pr_t_bar_a(t, a, n) for a in range(2, n - k + 2)]
)
return np.sum([pr_a(a, n, k) * pr_t_bar_a(t, a, n) for a in range(2, n - k + 2)])


class TiltedGammaDiff:
Expand Down Expand Up @@ -114,8 +112,12 @@ def _U(a, b, z):
return float(val)

def __init__(self, shape1, shape2, shape3, rate1, rate2, rate3, reorder=True):
assert shape1 > 0 and shape2 > 0 and shape3 > 0
assert rate1 >= 0 and rate2 > 0 and rate3 >= 0
assert shape1 > 0
assert shape2 > 0
assert shape3 > 0
assert rate1 >= 0
assert rate2 > 0
assert rate3 >= 0
# for convergence of 2F1, we need rate2 > rate3. Invariant
# transformations of 2F1 allow us to switch arguments, with
# appropriate rescaling
Expand Down Expand Up @@ -369,8 +371,12 @@ def _M(a, b, x):
return float(val)

def __init__(self, shape1, shape2, shape3, rate1, rate2, rate3):
assert shape1 > 0 and shape2 > 0 and shape3 > 0
assert rate1 >= 0 and rate2 > 0 and rate3 >= 0
assert shape1 > 0
assert shape2 > 0
assert shape3 > 0
assert rate1 >= 0
assert rate2 > 0
assert rate3 >= 0
# for numeric stability of hypergeometric we need rate2 > rate1
# as this is a convolution, the order of (1) and (2) don't matter
self.reparametrize = rate1 > rate2
Expand Down Expand Up @@ -481,11 +487,7 @@ def sufficient_statistics(self):
+ scipy.special.betaln(self.shape1, self.shape2)
)
x = dF_dz * T / S**2 + B / S
xsq = (
d2F_dz2 * T**2 / S**4
+ B * (B + 1) / S**2
+ 2 * dF_dz * (1 + B) * T / S**3
)
xsq = d2F_dz2 * T**2 / S**4 + B * (B + 1) / S**2 + 2 * dF_dz * (1 + B) * T / S**3
logx = dF_db + scipy.special.digamma(B) - np.log(S)
return logconst, x, xsq, logx

Expand Down
17 changes: 11 additions & 6 deletions tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""
Test cases for tsdate accuracy.
"""

import json
import os

Expand All @@ -40,7 +41,7 @@ class TestAccuracy:
Test for some of the basic functions used in tsdate
"""

@pytest.mark.makefiles
@pytest.mark.makefiles()
def test_make_static_files(self, request):
"""
The function used to create the tree sequences for accuracy testing.
Expand Down Expand Up @@ -75,7 +76,13 @@ def test_make_static_files(self, request):
ts.dump(os.path.join(request.fspath.dirname, "data", f"{name}.trees"))

@pytest.mark.parametrize(
"ts_name,min_r2_ts,min_r2_unconstrained,min_spear_ts,min_spear_unconstrained",
(
"ts_name",
"min_r2_ts",
"min_r2_unconstrained",
"min_spear_ts",
"min_spear_unconstrained",
),
[
("one_tree", 0.98601, 0.98601, 0.97719, 0.97719),
("few_trees", 0.98220, 0.98220, 0.97744, 0.97744),
Expand All @@ -91,9 +98,7 @@ def test_basic(
min_spear_unconstrained,
request,
):
ts = tskit.load(
os.path.join(request.fspath.dirname, "data", ts_name + ".trees")
)
ts = tskit.load(os.path.join(request.fspath.dirname, "data", ts_name + ".trees"))

sim_ancestry_parameters = json.loads(ts.provenance(0).record)["parameters"]
assert sim_ancestry_parameters["command"] == "sim_ancestry"
Expand Down Expand Up @@ -144,7 +149,7 @@ def test_scaling(self, Ne):
assert 0.9 < dts.node(dts.first().root).time / (2 * Ne) < 1.1

@pytest.mark.parametrize(
"bkwd_rate, trio_tmrca",
("bkwd_rate", "trio_tmrca"),
[ # calculated from simulations
(-1.0, 0.76),
(-0.9, 0.79),
Expand Down
22 changes: 9 additions & 13 deletions tests/test_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,15 @@
"""
Test cases for the gamma-variational approximations in tsdate
"""

import numpy as np
import pytest
import scipy.integrate
import scipy.special
import scipy.stats
from distribution_functions import conditional_coalescent_pdf
from distribution_functions import kl_divergence
from distribution_functions import conditional_coalescent_pdf, kl_divergence

from tsdate import approx
from tsdate import hypergeo
from tsdate import prior
from tsdate import approx, hypergeo, prior

# TODO: better test set?
# TODO: test special case where child is fixed to age 0
Expand Down Expand Up @@ -273,9 +271,7 @@ def test_truncated_moments(self, pars):
assert np.isclose(t_i, ck_t_i, rtol=1e-4)
ck_var_t_i = (
scipy.integrate.quad(
lambda t_i: t_i**2
* self.pdf_truncated(t_i, *pars_redux)
/ ck_normconst,
lambda t_i: t_i**2 * self.pdf_truncated(t_i, *pars_redux) / ck_normconst,
low,
upp,
epsabs=0,
Expand Down Expand Up @@ -389,9 +385,7 @@ def test_average_gammas(self):
E_x = np.mean(shape + 1)
E_logx = np.mean(scipy.special.digamma(shape + 1))
assert np.isclose(E_x, (avg_shape + 1) / avg_rate)
assert np.isclose(
E_logx, scipy.special.digamma(avg_shape + 1) - np.log(avg_rate)
)
assert np.isclose(E_logx, scipy.special.digamma(avg_shape + 1) - np.log(avg_rate))


class TestKLMinimizationFailed:
Expand All @@ -409,10 +403,12 @@ def test_asymptotic_bound(self):
alpha, _ = approx.approximate_gamma_kl(1, logx)
alpha += 1
alpha_bound = -0.5 / logx
assert alpha == alpha_bound and alpha > 1e4
assert alpha == alpha_bound
assert alpha > 1e4
# check that bound matches optimization result just under threshold
logx = -0.000051
alpha, _ = approx.approximate_gamma_kl(1, logx)
alpha += 1
alpha_bound = -0.5 / logx
assert np.abs(alpha - alpha_bound) < 1 and alpha < 1e4
assert np.abs(alpha - alpha_bound) < 1
assert alpha < 1e4
1 change: 1 addition & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests for the cache management code.
"""

import os
import pathlib
import unittest
Expand Down
21 changes: 10 additions & 11 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""
Test cases for the command line interface for tsdate.
"""

import json
import logging
from unittest import mock
Expand Down Expand Up @@ -75,11 +76,11 @@ def test_recombination_rate(self):
parser = cli.tsdate_cli_parser()
params = ["-m", "1e10"]
args = parser.parse_args(
["date", self.infile, self.output] + params + ["-r", "1e-100"]
["date", self.infile, self.output, *params, "-r", "1e-100"]
)
assert args.recombination_rate == 1e-100
args = parser.parse_args(
["date", self.infile, self.output] + params + ["--recombination-rate", "73"]
["date", self.infile, self.output, *params, "--recombination-rate", "73"]
)
assert args.recombination_rate == 73

Expand All @@ -97,24 +98,22 @@ def test_epsilon(self):
def test_num_threads(self):
parser = cli.tsdate_cli_parser()
params = ["--method", "maximization", "--num-threads"]
args = parser.parse_args(["date", self.infile, self.output] + params + ["1"])
args = parser.parse_args(["date", self.infile, self.output, *params, "1"])
assert args.num_threads == 1
args = parser.parse_args(["date", self.infile, self.output] + params + ["2"])
args = parser.parse_args(["date", self.infile, self.output, *params, "2"])
assert args.num_threads == 2

def test_probability_space(self):
parser = cli.tsdate_cli_parser()
params = ["--method", "inside_outside", "--probability-space"]
args = parser.parse_args(
["date", self.infile, self.output] + params + ["linear"]
)
args = parser.parse_args(["date", self.infile, self.output, *params, "linear"])
assert args.probability_space == "linear"
args = parser.parse_args(
["date", self.infile, self.output] + params + ["logarithmic"]
["date", self.infile, self.output, *params, "logarithmic"]
)
assert args.probability_space == "logarithmic"

@pytest.mark.parametrize("flag, log_status", logging_flags.items())
@pytest.mark.parametrize(("flag", "log_status"), logging_flags.items())
def test_verbosity(self, flag, log_status):
parser = cli.tsdate_cli_parser()
args = parser.parse_args(["preprocess", self.infile, self.output, flag])
Expand All @@ -130,7 +129,7 @@ def test_method(self, method):
params = ["-m", "1e-8", "--method", method]
if method != "variational_gamma":
params += ["-n", "10"]
args = parser.parse_args(["date", self.infile, self.output] + params)
args = parser.parse_args(["date", self.infile, self.output, *params])
assert args.method == method

def test_progress(self):
Expand Down Expand Up @@ -231,7 +230,7 @@ def test_no_output_variational_gamma(self, tmp_path, capfd):
assert out == ""
assert err == ""

@pytest.mark.parametrize("flag, log_status", logging_flags.items())
@pytest.mark.parametrize(("flag", "log_status"), logging_flags.items())
def test_verbosity(self, tmp_path, caplog, flag, log_status):
popsize = 10000
ts = msprime.simulate(
Expand Down
1 change: 1 addition & 0 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"""
Test tools for mapping between node sets of different tree sequences
"""

from collections import defaultdict
from itertools import combinations

Expand Down
Loading

0 comments on commit bac9a13

Please sign in to comment.