diff --git a/pycona/find_scope/findscope2.py b/pycona/find_scope/findscope2.py index cbae9e3..0fb3739 100644 --- a/pycona/find_scope/findscope2.py +++ b/pycona/find_scope/findscope2.py @@ -1,7 +1,8 @@ from ..ca_environment.active_ca import ActiveCAEnv +from ..ca_environment.acive_ca_proba import ProbaActiveCAEnv from .findscope_core import FindScopeBase from ..utils import get_kappa, get_con_subset - +from ..find_scope.findscope_obj import split_proba, split_half class FindScope2(FindScopeBase): """ @@ -9,14 +10,19 @@ class FindScope2(FindScopeBase): Bessiere, Christian, et al., "Learning constraints through partial queries", AIJ 2023 """ - def __init__(self, ca_env: ActiveCAEnv = None, time_limit=0.2): + def __init__(self, ca_env: ActiveCAEnv = None, split_func=None, time_limit=0.2): """ Initialize the FindScope2 class. :param ca_env: The constraint acquisition environment. :param time_limit: The time limit for findscope query generation. + :param split_func: The function used to split the variables in findscope. """ - super().__init__(ca_env, time_limit) + + if split_func is None: + split_func = split_proba if isinstance(ca_env, ProbaActiveCAEnv) else split_half + + super().__init__(ca_env, time_limit, split_func=split_func) self._kappaB = [] def run(self, Y, kappa=None): @@ -72,7 +78,7 @@ def _find_scope(self, R, Y): # Create Y1, Y2 ------------------------- proba = self.ca.bias_proba if hasattr(self.ca, 'bias_proba') else [] - Y1, Y2 = self.split_func(Y=Y, R=R, kappaB=kappaBRY, proba=proba) + Y1, Y2 = self.split_func(Y=Y, R=R, kappaB=kappaBRY, P_c=proba) S1 = set() S2 = set()