Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pycona/active_algorithms/gquacq.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
"""
if X is None:
X = instance.X
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)

self.env.init_state(instance, oracle, verbose, metrics)

Expand Down
3 changes: 2 additions & 1 deletion pycona/active_algorithms/growacq.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
"""
if X is None:
X = instance.X
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)

self.env.init_state(instance, oracle, verbose, metrics)

Expand Down
3 changes: 2 additions & 1 deletion pycona/active_algorithms/mquacq.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
"""
if X is None:
X = instance.X
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)

self.env.init_state(instance, oracle, verbose, metrics)

Expand Down
3 changes: 2 additions & 1 deletion pycona/active_algorithms/mquacq2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
"""
if X is None:
X = instance.X
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)

self.env.init_state(instance, oracle, verbose, metrics)

Expand Down
3 changes: 2 additions & 1 deletion pycona/active_algorithms/pquacq.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
"""
if X is None:
X = instance.X
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)

self.env.init_state(instance, oracle, verbose, metrics)

Expand Down
3 changes: 2 additions & 1 deletion pycona/active_algorithms/quacq.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
"""
if X is None:
X = instance.X
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)

self.env.init_state(instance, oracle, verbose, metrics)

Expand Down
6 changes: 5 additions & 1 deletion pycona/problem_instance/problem_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ def variables(self, vars):
"""
self._variables = vars
if vars is not None:
self.X = list(self._variables.flatten())
if isinstance(vars, NDVarArray):
self.X = list(self._variables.flatten())
else:
self.X = vars
self._variables = cp.cpm_array(vars)

@property
def X(self):
Expand Down
18 changes: 13 additions & 5 deletions pycona/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,9 @@ def get_var_name(var):
:return: The name of the variable without its indices.
"""
name = re.findall(r"\[\d+[,\d+]*\]", var.name)
name = var.name.replace(name[0], '')
return name
if name: # Check if we found any indices
return var.name.replace(name[0], '')
return var.name # Return original name if no indices found


def get_var_ndims(var):
Expand All @@ -344,9 +345,11 @@ def get_var_dims(var):
Get the dimensions of a variable.

:param var: The variable.
:return: The dimensions of the variable.
:return: The dimensions of the variable. Returns empty list if variable has no indices.
"""
dims = re.findall(r"\[\d+[,\d+]*\]", var.name)
if not dims: # If no indices found
return []
dims_str = "".join(dims)
dims = re.split(r"[\[\]]", dims_str)[1]
dims = [int(dim) for dim in re.split(",", dims)]
Expand Down Expand Up @@ -406,13 +409,18 @@ def get_variables_from_constraints(constraints):
:param constraints: List of constraints.
:return: List of variables involved in the constraints.
"""

# Create set to hold unique variables
variable_set = set()
for constraint in constraints:
variable_set.update(get_variables(constraint))

extract_nums = lambda s: list(map(int, s.name[s.name.index("[") + 1:s.name.index("]")].split(',')))
def extract_nums(s):
dims = re.findall(r"\[\d+[,\d+]*\]", s.name)
if not dims:
return [0] # Default value for variables without indices
dims_str = "".join(dims)
dims = re.split(r"[\[\]]", dims_str)[1]
return [int(dim) for dim in re.split(",", dims)]

variable_list = sorted(variable_set, key=extract_nums)
return variable_list
Expand Down