diff --git a/pycona/utils.py b/pycona/utils.py index 134ceef..683fb66 100644 --- a/pycona/utils.py +++ b/pycona/utils.py @@ -4,11 +4,13 @@ import cpmpy as cp from cpmpy.expressions.core import Expression, Comparison, Operator from cpmpy.expressions.variables import NDVarArray, _NumVarImpl, NegBoolView +from cpmpy.transformations.get_variables import get_variables from sklearn.utils import class_weight import numpy as np import re from cpmpy.expressions.utils import all_pairs, is_any_list + class Objectives: """ A class to manage different objectives for query generation, find scope, and find constraint. @@ -54,6 +56,7 @@ def check_value(c): """ return bool(c.value()) + def get_con_subset(B, Y): """ Get the subset of constraints whose scope is a subset of Y. @@ -77,6 +80,7 @@ def get_kappa(B, Y): Y = frozenset(Y) return [c for c in B if frozenset(get_scope(c)).issubset(Y) and check_value(c) is False] + def get_lambda(B, Y): """ Get the subset of constraints whose scope is a subset of Y and are satisfied. @@ -88,6 +92,7 @@ def get_lambda(B, Y): Y = frozenset(Y) return [c for c in B if frozenset(get_scope(c)).issubset(Y) and check_value(c) is True] + def gen_pairwise(v1, v2): """ Generate pairwise constraints between two variables. @@ -98,6 +103,7 @@ def gen_pairwise(v1, v2): """ return [v1 == v2, v1 != v2, v1 < v2, v1 > v2] + def gen_pairwise_ineq(v1, v2): """ Generate pairwise inequality constraints between two variables. @@ -108,6 +114,7 @@ def gen_pairwise_ineq(v1, v2): """ return [v1 != v2] + def alldiff_binary(grid): """ Generate all different binary constraints for a grid. @@ -119,6 +126,7 @@ def alldiff_binary(grid): for c in gen_pairwise_ineq(v1, v2): yield c + def gen_scoped_cons(grid): """ Generate scoped constraints for a grid. @@ -145,6 +153,7 @@ def gen_scoped_cons(grid): for c in gen_pairwise_ineq(grid[i1, j1], grid[i2, j2]): yield c + def gen_all_cons(grid): """ Generate all pairwise constraints for a grid. @@ -156,6 +165,7 @@ def gen_all_cons(grid): for c in gen_pairwise(v1, v2): yield c + def get_scopes_vars(C): """ Get the set of variables involved in the scopes of constraints. @@ -165,6 +175,7 @@ def get_scopes_vars(C): """ return set([x for scope in [get_scope(c) for c in C] for x in scope]) + def get_scopes(C): """ Get the list of unique scopes of constraints. @@ -174,6 +185,7 @@ def get_scopes(C): """ return list(set([tuple(get_scope(c)) for c in C])) + def get_scope(constraint): """ Get the scope (variables) of a constraint. @@ -181,18 +193,9 @@ def get_scope(constraint): :param constraint: The constraint to get the scope of. :return: List of variables in the scope of the constraint. """ - if isinstance(constraint, _NumVarImpl): - return [constraint] - elif isinstance(constraint, Expression): - all_variables = [] - for argument in constraint.args: - if isinstance(argument, _NumVarImpl): - all_variables.append(argument) - else: - all_variables.extend(get_scope(argument)) - return all_variables - else: - return [] + return get_variables(constraint) + + def compare_scopes(scope1, scope2): scope1 = set(scope1) @@ -211,17 +214,19 @@ def get_constant(constraint): :param constraint: The constraint to get the constants of. :return: List of constants involved in the constraint. """ + if isinstance(constraint, _NumVarImpl): return [] - elif isinstance(constraint, Expression): + elif isinstance(constraint, Expression) or is_any_list(constraint): constants = [] - for argument in constraint.args: + for argument in (constraint.args if isinstance(constraint, Expression) else constraint): if not isinstance(argument, _NumVarImpl): constants.extend(get_constant(argument)) return constants else: return [constraint] + def get_arity(constraint): """ Get the arity (number of variables) of a constraint. @@ -231,6 +236,7 @@ def get_arity(constraint): """ return len(get_scope(constraint)) + def get_min_arity(C): """ Get the minimum arity of a list of constraints. @@ -242,6 +248,7 @@ def get_min_arity(C): return min([get_arity(c) for c in C]) return 0 + def get_max_arity(C): """ Get the maximum arity of a list of constraints. @@ -253,6 +260,7 @@ def get_max_arity(C): return max([get_arity(c) for c in C]) return 0 + def get_relation(c, gamma): """ Get the relation index of a constraint in a given language. @@ -279,6 +287,7 @@ def get_relation(c, gamma): return -1 + def replace_variables(constraint, var_mapping): """ Replace the variables in a constraint using a dictionary mapping previous variables to new ones. @@ -316,6 +325,7 @@ def get_var_name(var): name = var.name.replace(name[0], '') return name + def get_var_ndims(var): """ Get the number of dimensions of a variable. @@ -328,6 +338,7 @@ def get_var_ndims(var): ndims = len(re.split(",", dims_str)) return ndims + def get_var_dims(var): """ Get the dimensions of a variable. @@ -341,6 +352,7 @@ def get_var_dims(var): dims = [int(dim) for dim in re.split(",", dims)] return dims + def get_divisors(n): """ Get the divisors of a number. @@ -354,6 +366,7 @@ def get_divisors(n): divisors.append(i) return divisors + def average_difference(values): """ Calculate the average difference between consecutive values in a list. @@ -385,6 +398,7 @@ def compute_sample_weights(Y): return sw + def get_variables_from_constraints(constraints): """ Get the list of variables involved in a list of constraints. @@ -392,27 +406,18 @@ def get_variables_from_constraints(constraints): :param constraints: List of constraints. :return: List of variables involved in the constraints. """ - def get_variables(expr): - if isinstance(expr, _NumVarImpl): - return [expr] - elif isinstance(expr, np.bool_): - return [] - elif isinstance(expr, np.int_) or isinstance(expr, int): - return [] - else: - # Recursively find variables in all arguments of the expression - return [var for argument in expr.args for var in get_variables(argument)] # Create set to hold unique variables variable_set = set() for constraint in constraints: - variable_set.update(get_variables(constraint)) + variable_set.update(get_scope(constraint)) extract_nums = lambda s: list(map(int, s.name[s.name.index("[") + 1:s.name.index("]")].split(','))) variable_list = sorted(variable_set, key=extract_nums) return variable_list + def combine_sets_distinct(set1, set2): """ Combine two sets into a set of distinct pairs. @@ -431,6 +436,7 @@ def combine_sets_distinct(set1, set2): result.add(tuple(sorted((a, b)))) return result + def unravel(lst, newlist): """ Recursively unravel nested lists, tuples, or arrays into a flat list. @@ -447,6 +453,7 @@ def unravel(lst, newlist): elif isinstance(e, (list, tuple, np.flatiter, np.ndarray)): unravel(e, newlist) + def get_combinations(lst, n): """ Get all combinations of a list of a given length. @@ -461,6 +468,7 @@ def get_combinations(lst, n): lst = newlist return list(combinations(lst, n)) + def restore_scope_values(scope, scope_values): """ Restore the original values of variables in a scope.