Skip to content

Commit

Permalink
Merge pull request #349 from nspope/mix-prior
Browse files Browse the repository at this point in the history
Add mixture-of-gammas global prior, updated by EM
  • Loading branch information
hyanwong authored Jan 6, 2024
2 parents 50c39af + cf58a1c commit a388028
Show file tree
Hide file tree
Showing 5 changed files with 475 additions and 36 deletions.
36 changes: 34 additions & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,17 +414,41 @@ def test_simple_sim_multi_tree(self):
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
tsdate.date(ts, mutation_rate=5, population_size=1, method="variational_gamma")

def test_nonglobal_priors(self):
def test_invalid_priors(self):
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma")
grid = priors.make_parameter_grid(population_size=1)
grid.grid_data[:] = [1.0, 0.0] # noninformative prior
with pytest.raises(ValueError, match="Non-positive shape/rate"):
tsdate.date(
ts,
mutation_rate=5,
method="variational_gamma",
priors=grid,
)

def test_custom_priors(self):
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma")
grid = priors.make_parameter_grid(population_size=1)
grid.grid_data[:] += 1.0
tsdate.date(
ts,
mutation_rate=5,
method="variational_gamma",
priors=grid,
global_prior=False,
)

def test_prior_mixture_dim(self):
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma")
grid = priors.make_parameter_grid(population_size=1)
tsdate.date(
ts,
mutation_rate=5,
method="variational_gamma",
priors=grid,
prior_mixture_dim=2,
)

def test_bad_arguments(self):
Expand All @@ -437,6 +461,14 @@ def test_bad_arguments(self):
method="variational_gamma",
max_iterations=-1,
)
with pytest.raises(ValueError, match="must be a positive integer"):
tsdate.date(
ts,
mutation_rate=5,
population_size=1,
method="variational_gamma",
prior_mixture_dim=0.1,
)

def test_match_central_moments(self):
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
Expand Down
31 changes: 31 additions & 0 deletions tsdate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,34 @@ def tsdate_cli_parser():
"but does not exactly minimize KL divergence in each EP update."
),
)
parser.add_argument(
"--max-iterations",
type=int,
help=(
"The number of iterations used in the expectation propagation "
"algorithm. Default: 20"
),
default=20,
)
parser.add_argument(
"--em-iterations",
type=int,
help=(
"The number of expectation-maximization iterations used to optimize the "
"i.i.d. mixture prior at the end of each expectation propagation step. "
"Setting to zero disables optimization. Default: 10"
),
default=10,
)
parser.add_argument(
"--prior-mixture-dim",
type=int,
help=(
"The number of components in the i.i.d. mixture prior for node "
"ages. Default: 1"
),
default=1,
)
parser.set_defaults(runner=run_date)

parser = subparsers.add_parser(
Expand Down Expand Up @@ -253,8 +281,11 @@ def run_date(args):
method=args.method,
eps=args.epsilon,
progress=args.progress,
max_iterations=args.max_iterations,
max_shape=args.max_shape,
match_central_moments=args.match_central_moments,
em_iterations=args.em_iterations,
prior_mixture_dim=args.prior_mixture_dim,
)
else:
params = dict(
Expand Down
Loading

0 comments on commit a388028

Please sign in to comment.