Skip to content

Commit

Permalink
Mask entries in simulated SFS out if they are negative and return None
Browse files Browse the repository at this point in the history
when all values are masked during evaluation of ll
  • Loading branch information
noscode committed Sep 13, 2023
1 parent 9e6a666 commit df5e65b
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 1 deletion.
2 changes: 2 additions & 0 deletions gadma/code_generator/dadi_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def _print_dadi_simulation():
ret_str = f"func_ex = dadi.Numerics.make_extrap_log_func"\
f"({FUNCTION_NAME})\n"
ret_str += "model = func_ex(p0, ns, pts)\n"
# check if some entries are negative - mask them out
ret_str += f"model.mask[np.where(model < 0)] = True\n"
return ret_str


Expand Down
2 changes: 2 additions & 0 deletions gadma/code_generator/moments_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def _print_moments_load_data(data_holder):

def _print_moments_simulation():
ret_str = f"model = {FUNCTION_NAME}(p0, ns)\n"
# check if some entries are negative - mask them out
ret_str += f"model.mask[np.where(model < 0)] = True\n"
return ret_str


Expand Down
3 changes: 3 additions & 0 deletions gadma/engines/dadi_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..models import CustomDemographicModel, Epoch, Split
from ..utils import DynamicVariable
from .. import SFSDataHolder, dadi_available
import numpy as np


class DadiEngine(DadiOrMomentsEngine):
Expand Down Expand Up @@ -162,6 +163,8 @@ def simulate(self, values, ns, sequence_length, population_labels, pts):
model = func_ex(values, ns, pts)
if population_labels is not None:
model.pop_ids = population_labels
# check if some entries are negative - mask them out
model.mask[np.where(model < 0)] = True
# TODO: Nref
return model

Expand Down
6 changes: 6 additions & 0 deletions gadma/engines/dadi_moments_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,12 @@ def evaluate(self, values, grid_sizes):
raise ValueError(f"{self.id} engine could not process constrains "
"on demographic model parameters (bounds of time "
"splits) in not-multinom mode.")
# Check if masks intersection gives any values to work with
if np.all(np.logical_or(self.data.mask, model_sfs.mask)):
key = self._get_key(values, grid_sizes)
self.saved_add_info[key] = None
return None

if not self.multinom:
theta0_inv = self.get_N_ancestral_from_theta(1)
if theta0_inv is None:
Expand Down
2 changes: 2 additions & 0 deletions gadma/engines/moments_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ def simulate(self, values, ns, sequence_length, population_labels,
model = self._inner_func(values, ns, dt_fac)
if population_labels is not None:
model.pop_ids = population_labels
# check if some entries are negative - mask them out
model.mask[np.where(model < 0)] = True
return model

def get_N_ancestral(self, values, dt_fac=default_dt_fac):
Expand Down
1 change: 0 additions & 1 deletion gadma/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,4 +629,3 @@ def check_variables(self, variables):
print(var.domain)
assert (isinstance(var, DiscreteVariable)
or np.all(var.domain != np.array([-np.inf, np.inf])))

0 comments on commit df5e65b

Please sign in to comment.