Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions linopy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@

TERM_DIM = "_term"
STACKED_TERM_DIM = "_stacked_term"

# Piecewise linear constraint constants
PWL_LAMBDA_SUFFIX = "_lambda"
PWL_CONVEX_SUFFIX = "_convex"
PWL_LINK_SUFFIX = "_link"
DEFAULT_BREAKPOINT_DIM = "breakpoint"
GROUPED_TERM_DIM = "_grouped_term"
GROUP_DIM = "_group"
FACTOR_DIM = "_factor"
Expand Down
274 changes: 274 additions & 0 deletions linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@
to_path,
)
from linopy.constants import (
DEFAULT_BREAKPOINT_DIM,
GREATER_EQUAL,
HELPER_DIMS,
LESS_EQUAL,
PWL_CONVEX_SUFFIX,
PWL_LAMBDA_SUFFIX,
PWL_LINK_SUFFIX,
TERM_DIM,
ModelStatus,
TerminationCondition,
Expand Down Expand Up @@ -130,6 +134,7 @@ class Model:
"_cCounter",
"_varnameCounter",
"_connameCounter",
"_pwlCounter",
"_blocks",
# TODO: check if these should not be mutable
"_chunk",
Expand Down Expand Up @@ -180,6 +185,7 @@ def __init__(
self._cCounter: int = 0
self._varnameCounter: int = 0
self._connameCounter: int = 0
self._pwlCounter: int = 0
self._blocks: DataArray | None = None

self._chunk: T_Chunks = chunk
Expand Down Expand Up @@ -591,6 +597,274 @@ def add_sos_constraints(

variable.attrs.update(sos_type=sos_type, sos_dim=sos_dim)

def add_piecewise_constraints(
self,
expr: Variable | LinearExpression | dict[str, Variable | LinearExpression],
breakpoints: DataArray,
link_dim: str | None = None,
dim: str = DEFAULT_BREAKPOINT_DIM,
mask: DataArray | None = None,
name: str | None = None,
skip_nan_check: bool = False,
) -> Constraint:
"""
Add a piecewise linear constraint using SOS2 formulation.

This method creates a piecewise linear constraint that links one or more
variables/expressions together via a set of breakpoints. It uses the SOS2
(Special Ordered Set of type 2) formulation with lambda (interpolation)
variables.

The SOS2 formulation ensures that at most two adjacent lambda variables
can be non-zero, effectively selecting a segment of the piecewise linear
function.

Parameters
----------
expr : Variable, LinearExpression, or dict of these
The variable(s) or expression(s) to be linked by the piecewise constraint.
- If a single Variable/LinearExpression is passed, the breakpoints
directly specify the piecewise points for that expression.
- If a dict is passed, the keys must match coordinates in `link_dim`
of the breakpoints, allowing multiple expressions to be linked.
breakpoints : xr.DataArray
The breakpoint values defining the piecewise linear function.
Must have `dim` as one of its dimensions. If `expr` is a dict,
must also have `link_dim` dimension with coordinates matching the
dict keys.
link_dim : str, optional
The dimension in breakpoints that links to different expressions.
Required when `expr` is a dict. If None and `expr` is a dict,
will attempt to auto-detect from breakpoints dimensions.
dim : str, default "breakpoint"
The dimension in breakpoints that represents the breakpoint index.
This dimension's coordinates must be numeric (used as SOS2 weights).
mask : xr.DataArray, optional
Boolean mask indicating which piecewise constraints are valid.
If None, auto-detected from NaN values in breakpoints (unless
skip_nan_check is True).
name : str, optional
Base name for the generated variables and constraints.
If None, auto-generates names like "pwl0", "pwl1", etc.
skip_nan_check : bool, default False
If True, skip automatic NaN detection in breakpoints. Use this
when you know breakpoints contain no NaN values for better performance.

Returns
-------
Constraint
The convexity constraint (sum of lambda = 1). Lambda variables
and other constraints can be accessed via:
- `model.variables[f"{name}_lambda"]`
- `model.constraints[f"{name}_convex"]`
- `model.constraints[f"{name}_link"]`

Raises
------
ValueError
If expr is not a Variable, LinearExpression, or dict of these.
If breakpoints doesn't have the required dim dimension.
If link_dim cannot be auto-detected when expr is a dict.
If link_dim coordinates don't match dict keys.
If dim coordinates are not numeric.

Examples
--------
Single variable piecewise constraint:

>>> m = Model()
>>> x = m.add_variables(name="x")
>>> breakpoints = xr.DataArray([0, 10, 50, 100], dims=["bp"])
>>> _ = m.add_piecewise_constraints(x, breakpoints, dim="bp")

Using an expression:

>>> m = Model()
>>> x = m.add_variables(name="x")
>>> y = m.add_variables(name="y")
>>> breakpoints = xr.DataArray([0, 10, 50, 100], dims=["bp"])
>>> _ = m.add_piecewise_constraints(x + y, breakpoints, dim="bp")

Multiple linked variables (e.g., power-efficiency curve):

>>> m = Model()
>>> generators = ["gen1", "gen2"]
>>> power = m.add_variables(coords=[generators], name="power")
>>> efficiency = m.add_variables(coords=[generators], name="efficiency")
>>> breakpoints = xr.DataArray(
... [[0, 50, 100], [0.8, 0.95, 0.9]],
... coords={"var": ["power", "efficiency"], "bp": [0, 1, 2]},
... )
>>> _ = m.add_piecewise_constraints(
... {"power": power, "efficiency": efficiency},
... breakpoints,
... link_dim="var",
... dim="bp",
... )

Notes
-----
The piecewise linear constraint is formulated using SOS2 variables:

1. Lambda variables λ_i with bounds [0, 1] are created for each breakpoint
2. SOS2 constraint ensures at most two adjacent λ_i can be non-zero
3. Convexity constraint: Σ λ_i = 1
4. Linking constraints: expr = Σ λ_i × breakpoint_i (for each expression)
"""
# --- Input validation ---
if dim not in breakpoints.dims:
raise ValueError(
f"breakpoints must have dimension '{dim}', "
f"but only has dimensions {list(breakpoints.dims)}"
)

if not pd.api.types.is_numeric_dtype(breakpoints.coords[dim]):
raise ValueError(
f"Breakpoint dimension '{dim}' must have numeric coordinates "
f"for SOS2 weights, but got {breakpoints.coords[dim].dtype}"
)

# --- Generate names using counter ---
if name is None:
name = f"pwl{self._pwlCounter}"
self._pwlCounter += 1

lambda_name = f"{name}{PWL_LAMBDA_SUFFIX}"
convex_name = f"{name}{PWL_CONVEX_SUFFIX}"
link_name = f"{name}{PWL_LINK_SUFFIX}"

# --- Determine lambda coordinates, mask, and target expression ---
is_single = isinstance(expr, Variable | LinearExpression)
is_dict = isinstance(expr, dict)

if not is_single and not is_dict:
raise ValueError(
f"'expr' must be a Variable, LinearExpression, or dict of these, "
f"got {type(expr)}"
)

if is_single:
# Single expression case
assert isinstance(expr, Variable | LinearExpression)
target_expr = self._to_linexpr(expr)
# Build lambda coordinates from breakpoints dimensions
lambda_coords = [
pd.Index(breakpoints.coords[d].values, name=d) for d in breakpoints.dims
]
lambda_mask = self._compute_pwl_mask(mask, breakpoints, skip_nan_check)

else:
# Dict case - need to validate link_dim and build stacked expression
assert isinstance(expr, dict)
expr_dict: dict[str, Variable | LinearExpression] = expr
expr_keys = set(expr_dict.keys())

# Auto-detect or validate link_dim
link_dim = self._resolve_pwl_link_dim(link_dim, breakpoints, dim, expr_keys)

# Build lambda coordinates (exclude link_dim)
lambda_coords = [
pd.Index(breakpoints.coords[d].values, name=d)
for d in breakpoints.dims
if d != link_dim
]

# Compute mask
base_mask = self._compute_pwl_mask(mask, breakpoints, skip_nan_check)
lambda_mask = base_mask.any(dim=link_dim) if base_mask is not None else None

# Build stacked expression from dict
target_expr = self._build_stacked_expr(expr_dict, breakpoints, link_dim)

# --- Common: Create lambda, SOS2, convexity, and linking constraints ---
lambda_var = self.add_variables(
lower=0, upper=1, coords=lambda_coords, name=lambda_name, mask=lambda_mask
)

self.add_sos_constraints(lambda_var, sos_type=2, sos_dim=dim)

convex_con = self.add_constraints(
lambda_var.sum(dim=dim) == 1, name=convex_name
)

weighted_sum = (lambda_var * breakpoints).sum(dim=dim)
self.add_constraints(target_expr == weighted_sum, name=link_name)

return convex_con

def _to_linexpr(self, expr: Variable | LinearExpression) -> LinearExpression:
"""Convert Variable or LinearExpression to LinearExpression."""
if isinstance(expr, LinearExpression):
return expr
return expr.to_linexpr()

def _compute_pwl_mask(
self,
mask: DataArray | None,
breakpoints: DataArray,
skip_nan_check: bool,
) -> DataArray | None:
"""Compute mask for piecewise constraint, optionally skipping NaN check."""
if mask is not None:
return mask
if skip_nan_check:
return None
return ~breakpoints.isnull()

def _resolve_pwl_link_dim(
self,
link_dim: str | None,
breakpoints: DataArray,
dim: str,
expr_keys: set[str],
) -> str:
"""Auto-detect or validate link_dim for dict case."""
if link_dim is None:
for d in breakpoints.dims:
if d == dim:
continue
coords_set = set(str(c) for c in breakpoints.coords[d].values)
if coords_set == expr_keys:
return str(d)
raise ValueError(
"Could not auto-detect link_dim. Please specify it explicitly. "
f"Breakpoint dimensions: {list(breakpoints.dims)}, "
f"expression keys: {list(expr_keys)}"
)

if link_dim not in breakpoints.dims:
raise ValueError(
f"link_dim '{link_dim}' not found in breakpoints dimensions "
f"{list(breakpoints.dims)}"
)
coords_set = set(str(c) for c in breakpoints.coords[link_dim].values)
if coords_set != expr_keys:
raise ValueError(
f"link_dim '{link_dim}' coordinates {coords_set} "
f"don't match expression keys {expr_keys}"
)
return link_dim

def _build_stacked_expr(
self,
expr_dict: dict[str, Variable | LinearExpression],
breakpoints: DataArray,
link_dim: str,
) -> LinearExpression:
"""Build a stacked LinearExpression from a dict of Variables/Expressions."""
link_coords = list(breakpoints.coords[link_dim].values)

# Collect expression data and stack
expr_data_list = []
for k in link_coords:
e = expr_dict[str(k)]
linexpr = self._to_linexpr(e)
expr_data_list.append(linexpr.data.expand_dims({link_dim: [k]}))

# Concatenate along link_dim
stacked_data = xr.concat(expr_data_list, dim=link_dim)
return LinearExpression(stacked_data, self)

def add_constraints(
self,
lhs: VariableLike
Expand Down
Loading