Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions pycona/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import cpmpy as cp
from cpmpy.expressions.core import Expression, Comparison, Operator
from cpmpy.expressions.variables import NDVarArray, _NumVarImpl, NegBoolView
from cpmpy.transformations.get_variables import get_variables
from sklearn.utils import class_weight
import numpy as np
import re
from cpmpy.expressions.utils import all_pairs, is_any_list


class Objectives:
"""
A class to manage different objectives for query generation, find scope, and find constraint.
Expand Down Expand Up @@ -54,6 +56,7 @@ def check_value(c):
"""
return bool(c.value())


def get_con_subset(B, Y):
"""
Get the subset of constraints whose scope is a subset of Y.
Expand All @@ -77,6 +80,7 @@ def get_kappa(B, Y):
Y = frozenset(Y)
return [c for c in B if frozenset(get_scope(c)).issubset(Y) and check_value(c) is False]


def get_lambda(B, Y):
"""
Get the subset of constraints whose scope is a subset of Y and are satisfied.
Expand All @@ -88,6 +92,7 @@ def get_lambda(B, Y):
Y = frozenset(Y)
return [c for c in B if frozenset(get_scope(c)).issubset(Y) and check_value(c) is True]


def gen_pairwise(v1, v2):
"""
Generate pairwise constraints between two variables.
Expand All @@ -98,6 +103,7 @@ def gen_pairwise(v1, v2):
"""
return [v1 == v2, v1 != v2, v1 < v2, v1 > v2]


def gen_pairwise_ineq(v1, v2):
"""
Generate pairwise inequality constraints between two variables.
Expand All @@ -108,6 +114,7 @@ def gen_pairwise_ineq(v1, v2):
"""
return [v1 != v2]


def alldiff_binary(grid):
"""
Generate all different binary constraints for a grid.
Expand All @@ -119,6 +126,7 @@ def alldiff_binary(grid):
for c in gen_pairwise_ineq(v1, v2):
yield c


def gen_scoped_cons(grid):
"""
Generate scoped constraints for a grid.
Expand All @@ -145,6 +153,7 @@ def gen_scoped_cons(grid):
for c in gen_pairwise_ineq(grid[i1, j1], grid[i2, j2]):
yield c


def gen_all_cons(grid):
"""
Generate all pairwise constraints for a grid.
Expand All @@ -156,6 +165,7 @@ def gen_all_cons(grid):
for c in gen_pairwise(v1, v2):
yield c


def get_scopes_vars(C):
"""
Get the set of variables involved in the scopes of constraints.
Expand All @@ -165,6 +175,7 @@ def get_scopes_vars(C):
"""
return set([x for scope in [get_scope(c) for c in C] for x in scope])


def get_scopes(C):
"""
Get the list of unique scopes of constraints.
Expand All @@ -174,25 +185,17 @@ def get_scopes(C):
"""
return list(set([tuple(get_scope(c)) for c in C]))


def get_scope(constraint):
"""
Get the scope (variables) of a constraint.

:param constraint: The constraint to get the scope of.
:return: List of variables in the scope of the constraint.
"""
if isinstance(constraint, _NumVarImpl):
return [constraint]
elif isinstance(constraint, Expression):
all_variables = []
for argument in constraint.args:
if isinstance(argument, _NumVarImpl):
all_variables.append(argument)
else:
all_variables.extend(get_scope(argument))
return all_variables
else:
return []
return get_variables(constraint)


def compare_scopes(scope1, scope2):

scope1 = set(scope1)
Expand All @@ -211,17 +214,19 @@ def get_constant(constraint):
:param constraint: The constraint to get the constants of.
:return: List of constants involved in the constraint.
"""

if isinstance(constraint, _NumVarImpl):
return []
elif isinstance(constraint, Expression):
elif isinstance(constraint, Expression) or is_any_list(constraint):
constants = []
for argument in constraint.args:
for argument in (constraint.args if isinstance(constraint, Expression) else constraint):
if not isinstance(argument, _NumVarImpl):
constants.extend(get_constant(argument))
return constants
else:
return [constraint]


def get_arity(constraint):
"""
Get the arity (number of variables) of a constraint.
Expand All @@ -231,6 +236,7 @@ def get_arity(constraint):
"""
return len(get_scope(constraint))


def get_min_arity(C):
"""
Get the minimum arity of a list of constraints.
Expand All @@ -242,6 +248,7 @@ def get_min_arity(C):
return min([get_arity(c) for c in C])
return 0


def get_max_arity(C):
"""
Get the maximum arity of a list of constraints.
Expand All @@ -253,6 +260,7 @@ def get_max_arity(C):
return max([get_arity(c) for c in C])
return 0


def get_relation(c, gamma):
"""
Get the relation index of a constraint in a given language.
Expand All @@ -279,6 +287,7 @@ def get_relation(c, gamma):

return -1


def replace_variables(constraint, var_mapping):
"""
Replace the variables in a constraint using a dictionary mapping previous variables to new ones.
Expand Down Expand Up @@ -316,6 +325,7 @@ def get_var_name(var):
name = var.name.replace(name[0], '')
return name


def get_var_ndims(var):
"""
Get the number of dimensions of a variable.
Expand All @@ -328,6 +338,7 @@ def get_var_ndims(var):
ndims = len(re.split(",", dims_str))
return ndims


def get_var_dims(var):
"""
Get the dimensions of a variable.
Expand All @@ -341,6 +352,7 @@ def get_var_dims(var):
dims = [int(dim) for dim in re.split(",", dims)]
return dims


def get_divisors(n):
"""
Get the divisors of a number.
Expand All @@ -354,6 +366,7 @@ def get_divisors(n):
divisors.append(i)
return divisors


def average_difference(values):
"""
Calculate the average difference between consecutive values in a list.
Expand Down Expand Up @@ -385,34 +398,26 @@ def compute_sample_weights(Y):

return sw


def get_variables_from_constraints(constraints):
"""
Get the list of variables involved in a list of constraints.

:param constraints: List of constraints.
:return: List of variables involved in the constraints.
"""
def get_variables(expr):
if isinstance(expr, _NumVarImpl):
return [expr]
elif isinstance(expr, np.bool_):
return []
elif isinstance(expr, np.int_) or isinstance(expr, int):
return []
else:
# Recursively find variables in all arguments of the expression
return [var for argument in expr.args for var in get_variables(argument)]

# Create set to hold unique variables
variable_set = set()
for constraint in constraints:
variable_set.update(get_variables(constraint))
variable_set.update(get_scope(constraint))

extract_nums = lambda s: list(map(int, s.name[s.name.index("[") + 1:s.name.index("]")].split(',')))

variable_list = sorted(variable_set, key=extract_nums)
return variable_list


def combine_sets_distinct(set1, set2):
"""
Combine two sets into a set of distinct pairs.
Expand All @@ -431,6 +436,7 @@ def combine_sets_distinct(set1, set2):
result.add(tuple(sorted((a, b))))
return result


def unravel(lst, newlist):
"""
Recursively unravel nested lists, tuples, or arrays into a flat list.
Expand All @@ -447,6 +453,7 @@ def unravel(lst, newlist):
elif isinstance(e, (list, tuple, np.flatiter, np.ndarray)):
unravel(e, newlist)


def get_combinations(lst, n):
"""
Get all combinations of a list of a given length.
Expand All @@ -461,6 +468,7 @@ def get_combinations(lst, n):
lst = newlist
return list(combinations(lst, n))


def restore_scope_values(scope, scope_values):
"""
Restore the original values of variables in a scope.
Expand Down
Loading