diff --git a/doc/release_notes.rst b/doc/release_notes.rst index ba6b3bd5..9c2aa99e 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -4,6 +4,8 @@ Release Notes Upcoming Version ---------------- +* When writing out an LP file, large variables and constraints are now chunked to avoid memory issues. This is especially useful for large models with constraints with many terms. The chunk size can be set with the `slice_size` argument in the `solve` function. + Version 0.3.15 -------------- diff --git a/linopy/common.py b/linopy/common.py index c9532a62..4a80114c 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -5,12 +5,14 @@ This module contains commonly used functions. """ +from __future__ import annotations + import operator import os -from collections.abc import Hashable, Iterable, Mapping, Sequence +from collections.abc import Generator, Hashable, Iterable, Mapping, Sequence from functools import reduce, wraps from pathlib import Path -from typing import Any, Callable, Union, overload +from typing import TYPE_CHECKING, Any, Callable, overload from warnings import warn import numpy as np @@ -30,6 +32,11 @@ sign_replace_dict, ) +if TYPE_CHECKING: + from linopy.constraints import Constraint + from linopy.expressions import LinearExpression + from linopy.variables import Variable + def maybe_replace_sign(sign: str) -> str: """ @@ -86,7 +93,7 @@ def format_string_as_variable_name(name: Hashable): return str(name).replace(" ", "_").replace("-", "_") -def get_from_iterable(lst: Union[str, Iterable[Hashable], None], index: int): +def get_from_iterable(lst: str | Iterable[Hashable] | None, index: int): """ Returns the element at the specified index of the list, or None if the index is out of bounds. @@ -99,9 +106,9 @@ def get_from_iterable(lst: Union[str, Iterable[Hashable], None], index: int): def pandas_to_dataarray( - arr: Union[pd.DataFrame, pd.Series], - coords: Union[Sequence[Union[Sequence, pd.Index, DataArray]], Mapping, None] = None, - dims: Union[Iterable[Hashable], None] = None, + arr: pd.DataFrame | pd.Series, + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, + dims: Iterable[Hashable] | None = None, **kwargs, ) -> DataArray: """ @@ -156,8 +163,8 @@ def pandas_to_dataarray( def numpy_to_dataarray( arr: np.ndarray, - coords: Union[Sequence[Union[Sequence, pd.Index, DataArray]], Mapping, None] = None, - dims: Union[str, Iterable[Hashable], None] = None, + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, + dims: str | Iterable[Hashable] | None = None, **kwargs, ) -> DataArray: """ @@ -195,8 +202,8 @@ def numpy_to_dataarray( def as_dataarray( arr, - coords: Union[Sequence[Union[Sequence, pd.Index, DataArray]], Mapping, None] = None, - dims: Union[str, Iterable[Hashable], None] = None, + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, + dims: str | Iterable[Hashable] | None = None, **kwargs, ) -> DataArray: """ @@ -246,7 +253,7 @@ def as_dataarray( # TODO: rename to to_pandas_dataframe -def to_dataframe(ds: Dataset, mask_func: Union[Callable, None] = None): +def to_dataframe(ds: Dataset, mask_func: Callable | None = None): """ Convert an xarray Dataset to a pandas DataFrame. @@ -467,6 +474,65 @@ def fill_missing_coords(ds, fill_helper_dims: bool = False): return ds +def iterate_slices( + ds: Dataset | Variable | LinearExpression | Constraint, + slice_size: int | None = 10_000, + slice_dims: list | None = None, +) -> Generator[Dataset | Variable | LinearExpression | Constraint, None, None]: + """ + Generate slices of an xarray Dataset or DataArray with a specified soft maximum size. + + The slicing is performed on the largest dimension of the input object. + If the maximum size is larger than the total size of the object, the function yields + the original object. + + Parameters + ---------- + ds : xarray.Dataset or xarray.DataArray + The input xarray Dataset or DataArray to be sliced. + slice_size : int + The maximum number of elements in each slice. If the maximum size is too small to accommodate any slice, + the function splits the largest dimension. + slice_dims : list, optional + The dimensions to slice along. If None, all dimensions in `coord_dims` are used if + `coord_dims` is an attribute of the input object. Otherwise, all dimensions are used. + + Yields + ------ + xarray.Dataset or xarray.DataArray + A slice of the input Dataset or DataArray. + + """ + if slice_dims is None: + slice_dims = list(getattr(ds, "coord_dims", ds.dims)) + + # Calculate the total number of elements in the dataset + size = np.prod([ds.sizes[dim] for dim in ds.dims], dtype=int) + + if slice_size is None or size <= slice_size: + yield ds + return + + # number of slices + n_slices = max(size // slice_size, 1) + + # leading dimension (the dimension with the largest size) + leading_dim = max(ds.sizes, key=ds.sizes.get) # type: ignore + size_of_leading_dim = ds.sizes[leading_dim] + + if size_of_leading_dim < n_slices: + n_slices = size_of_leading_dim + + chunk_size = ds.sizes[leading_dim] // n_slices + + # Iterate over the Cartesian product of slice indices + for i in range(n_slices): + start = i * chunk_size + end = start + chunk_size + slice_dict = {leading_dim: slice(start, end)} + yield ds.isel(slice_dict) + + def _remap(array, mapping): return mapping[array.ravel()].reshape(array.shape) @@ -484,7 +550,7 @@ def replace_by_map(ds, mapping): ) -def to_path(path: Union[str, Path, None]) -> Union[Path, None]: +def to_path(path: str | Path | None) -> Path | None: """ Convert a string to a Path object. """ @@ -526,7 +592,7 @@ def generate_indices_for_printout(dim_sizes, max_lines): yield tuple(np.unravel_index(i, dim_sizes)) -def align_lines_by_delimiter(lines: list[str], delimiter: Union[str, list[str]]): +def align_lines_by_delimiter(lines: list[str], delimiter: str | list[str]): # Determine the maximum position of the delimiter if isinstance(delimiter, str): delimiter = [delimiter] @@ -548,17 +614,18 @@ def align_lines_by_delimiter(lines: list[str], delimiter: Union[str, list[str]]) def get_label_position( - obj, values: Union[int, np.ndarray] -) -> Union[ - Union[tuple[str, dict], tuple[None, None]], - list[Union[tuple[str, dict], tuple[None, None]]], - list[list[Union[tuple[str, dict], tuple[None, None]]]], -]: + obj, values: int | np.ndarray +) -> ( + tuple[str, dict] + | tuple[None, None] + | list[tuple[str, dict] | tuple[None, None]] + | list[list[tuple[str, dict] | tuple[None, None]]] +): """ Get tuple of name and coordinate for variable labels. """ - def find_single(value: int) -> Union[tuple[str, dict], tuple[None, None]]: + def find_single(value: int) -> tuple[str, dict] | tuple[None, None]: if value == -1: return None, None for name, val in obj.items(): diff --git a/linopy/constraints.py b/linopy/constraints.py index fabe74d2..458fa419 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -40,6 +40,7 @@ has_optimized_model, infer_schema_polars, is_constant, + iterate_slices, maybe_replace_signs, print_coord, print_single_constraint, @@ -658,6 +659,8 @@ def to_polars(self): stack = conwrap(Dataset.stack) + iterate_slices = iterate_slices + @dataclass(repr=False) class Constraints: diff --git a/linopy/expressions.py b/linopy/expressions.py index 49a7240d..08af65a8 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -52,6 +52,7 @@ get_index_map, group_terms_polars, has_optimized_model, + iterate_slices, print_single_expression, to_dataframe, to_polars, @@ -1457,6 +1458,8 @@ def to_polars(self) -> pl.DataFrame: stack = exprwrap(Dataset.stack) + iterate_slices = iterate_slices + class QuadraticExpression(LinearExpression): """ diff --git a/linopy/io.py b/linopy/io.py index 93120a97..509a8644 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -127,7 +127,11 @@ def objective_to_file( def constraints_to_file( - m: Model, f: TextIOWrapper, log: bool = False, batch_size: int = 50000 + m: Model, + f: TextIOWrapper, + log: bool = False, + batch_size: int = 50_000, + slice_size: int = 100_000, ) -> None: if not len(m.constraints): return @@ -143,54 +147,60 @@ def constraints_to_file( batch = [] for name in names: - df = m.constraints[name].flat - - labels = df.labels.values - vars = df.vars.values - coeffs = df.coeffs.values - rhs = df.rhs.values - sign = df.sign.values - - len_df = len(df) # compute length once - if not len_df: - continue - - # write out the start to enable a fast loop afterwards - idx = 0 - label = labels[idx] - coeff = coeffs[idx] - var = vars[idx] - batch.append(f"c{label}:\n{coeff:+.12g} x{var}\n") - prev_label = label - prev_sign = sign[idx] - prev_rhs = rhs[idx] - - for idx in range(1, len_df): + con = m.constraints[name] + for con_slice in con.iterate_slices(slice_size): + df = con_slice.flat + + labels = df.labels.values + vars = df.vars.values + coeffs = df.coeffs.values + rhs = df.rhs.values + sign = df.sign.values + + len_df = len(df) # compute length once + if not len_df: + continue + + # write out the start to enable a fast loop afterwards + idx = 0 label = labels[idx] coeff = coeffs[idx] var = vars[idx] + batch.append(f"c{label}:\n{coeff:+.12g} x{var}\n") + prev_label = label + prev_sign = sign[idx] + prev_rhs = rhs[idx] - if label != prev_label: - batch.append( - f"{prev_sign} {prev_rhs:+.12g}\n\nc{label}:\n{coeff:+.12g} x{var}\n" - ) - prev_sign = sign[idx] - prev_rhs = rhs[idx] - else: - batch.append(f"{coeff:+.12g} x{var}\n") + for idx in range(1, len_df): + label = labels[idx] + coeff = coeffs[idx] + var = vars[idx] - batch = handle_batch(batch, f, batch_size) + if label != prev_label: + batch.append( + f"{prev_sign} {prev_rhs:+.12g}\n\nc{label}:\n{coeff:+.12g} x{var}\n" + ) + prev_sign = sign[idx] + prev_rhs = rhs[idx] + else: + batch.append(f"{coeff:+.12g} x{var}\n") - prev_label = label + batch = handle_batch(batch, f, batch_size) + + prev_label = label - batch.append(f"{prev_sign} {prev_rhs:+.12g}\n") + batch.append(f"{prev_sign} {prev_rhs:+.12g}\n") if batch: # write the remaining lines f.writelines(batch) def bounds_to_file( - m: Model, f: TextIOWrapper, log: bool = False, batch_size: int = 10000 + m: Model, + f: TextIOWrapper, + log: bool = False, + batch_size: int = 10000, + slice_size: int = 100_000, ) -> None: """ Write out variables of a model to a lp file. @@ -209,25 +219,31 @@ def bounds_to_file( batch = [] # to store batch of lines for name in names: - df = m.variables[name].flat + var = m.variables[name] + for var_slice in var.iterate_slices(slice_size): + df = var_slice.flat - labels = df.labels.values - lowers = df.lower.values - uppers = df.upper.values + labels = df.labels.values + lowers = df.lower.values + uppers = df.upper.values - for idx in range(len(df)): - label = labels[idx] - lower = lowers[idx] - upper = uppers[idx] - batch.append(f"{lower:+.12g} <= x{label} <= {upper:+.12g}\n") - batch = handle_batch(batch, f, batch_size) + for idx in range(len(df)): + label = labels[idx] + lower = lowers[idx] + upper = uppers[idx] + batch.append(f"{lower:+.12g} <= x{label} <= {upper:+.12g}\n") + batch = handle_batch(batch, f, batch_size) if batch: # write the remaining lines f.writelines(batch) def binaries_to_file( - m: Model, f: TextIOWrapper, log: bool = False, batch_size: int = 1000 + m: Model, + f: TextIOWrapper, + log: bool = False, + batch_size: int = 1000, + slice_size: int = 100_000, ) -> None: """ Write out binaries of a model to a lp file. @@ -246,11 +262,13 @@ def binaries_to_file( batch = [] # to store batch of lines for name in names: - df = m.variables[name].flat + var = m.variables[name] + for var_slice in var.iterate_slices(slice_size): + df = var_slice.flat - for label in df.labels.values: - batch.append(f"x{label}\n") - batch = handle_batch(batch, f, batch_size) + for label in df.labels.values: + batch.append(f"x{label}\n") + batch = handle_batch(batch, f, batch_size) if batch: # write the remaining lines f.writelines(batch) @@ -261,6 +279,7 @@ def integers_to_file( f: TextIOWrapper, log: bool = False, batch_size: int = 1000, + slice_size: int = 100_000, integer_label: str = "general", ) -> None: """ @@ -280,17 +299,19 @@ def integers_to_file( batch = [] # to store batch of lines for name in names: - df = m.variables[name].flat + var = m.variables[name] + for var_slice in var.iterate_slices(slice_size): + df = var_slice.flat - for label in df.labels.values: - batch.append(f"x{label}\n") - batch = handle_batch(batch, f, batch_size) + for label in df.labels.values: + batch.append(f"x{label}\n") + batch = handle_batch(batch, f, batch_size) if batch: # write the remaining lines f.writelines(batch) -def to_lp_file(m, fn, integer_label): +def to_lp_file(m: Model, fn: Path, integer_label: str, slice_size: int = 10_000_000): log = m._xCounter > 10_000 batch_size = 5000 @@ -302,11 +323,18 @@ def to_lp_file(m, fn, integer_label): raise ValueError("File not found.") objective_to_file(m, f, log=log) - constraints_to_file(m, f=f, log=log, batch_size=batch_size) - bounds_to_file(m, f=f, log=log, batch_size=batch_size) - binaries_to_file(m, f=f, log=log, batch_size=batch_size) + constraints_to_file( + m, f=f, log=log, batch_size=batch_size, slice_size=slice_size + ) + bounds_to_file(m, f=f, log=log, batch_size=batch_size, slice_size=slice_size) + binaries_to_file(m, f=f, log=log, batch_size=batch_size, slice_size=slice_size) integers_to_file( - m, integer_label=integer_label, f=f, log=log, batch_size=batch_size + m, + integer_label=integer_label, + f=f, + log=log, + batch_size=batch_size, + slice_size=slice_size, ) f.write("end\n") @@ -371,7 +399,7 @@ def objective_to_file_polars(m, f, log=False): objective_write_quadratic_terms_polars(f, quads) -def bounds_to_file_polars(m, f, log=False): +def bounds_to_file_polars(m, f, log=False, slice_size=2_000_000): """ Write out variables of a model to a lp file. """ @@ -388,26 +416,28 @@ def bounds_to_file_polars(m, f, log=False): ) for name in names: - df = m.variables[name].to_polars() - - columns = [ - pl.when(pl.col("lower") >= 0).then(pl.lit("+")).otherwise(pl.lit("")), - pl.col("lower").cast(pl.String), - pl.lit(" <= x"), - pl.col("labels").cast(pl.String), - pl.lit(" <= "), - pl.when(pl.col("upper") >= 0).then(pl.lit("+")).otherwise(pl.lit("")), - pl.col("upper").cast(pl.String), - ] - - kwargs = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + var = m.variables[name] + for var_slice in var.iterate_slices(slice_size): + df = var_slice.to_polars() + + columns = [ + pl.when(pl.col("lower") >= 0).then(pl.lit("+")).otherwise(pl.lit("")), + pl.col("lower").cast(pl.String), + pl.lit(" <= x"), + pl.col("labels").cast(pl.String), + pl.lit(" <= "), + pl.when(pl.col("upper") >= 0).then(pl.lit("+")).otherwise(pl.lit("")), + pl.col("upper").cast(pl.String), + ] + + kwargs = dict( + separator=" ", null_value="", quote_style="never", include_header=False + ) + formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) + formatted.write_csv(f, **kwargs) -def binaries_to_file_polars(m, f, log=False): +def binaries_to_file_polars(m, f, log=False, slice_size=2_000_000): """ Write out binaries of a model to a lp file. """ @@ -424,21 +454,25 @@ def binaries_to_file_polars(m, f, log=False): ) for name in names: - df = m.variables[name].to_polars() + var = m.variables[name] + for var_slice in var.iterate_slices(slice_size): + df = var_slice.to_polars() - columns = [ - pl.lit("x"), - pl.col("labels").cast(pl.String), - ] + columns = [ + pl.lit("x"), + pl.col("labels").cast(pl.String), + ] - kwargs = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + kwargs = dict( + separator=" ", null_value="", quote_style="never", include_header=False + ) + formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) + formatted.write_csv(f, **kwargs) -def integers_to_file_polars(m, f, log=False, integer_label="general"): +def integers_to_file_polars( + m, f, log=False, integer_label="general", slice_size=2_000_000 +): """ Write out integers of a model to a lp file. """ @@ -455,21 +489,23 @@ def integers_to_file_polars(m, f, log=False, integer_label="general"): ) for name in names: - df = m.variables[name].to_polars() + var = m.variables[name] + for var_slice in var.iterate_slices(slice_size): + df = var_slice.to_polars() - columns = [ - pl.lit("x"), - pl.col("labels").cast(pl.String), - ] + columns = [ + pl.lit("x"), + pl.col("labels").cast(pl.String), + ] - kwargs = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + kwargs = dict( + separator=" ", null_value="", quote_style="never", include_header=False + ) + formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) + formatted.write_csv(f, **kwargs) -def constraints_to_file_polars(m, f, log=False, lazy=False): +def constraints_to_file_polars(m, f, log=False, lazy=False, slice_size=2_000_000): if not len(m.constraints): return @@ -485,53 +521,57 @@ def constraints_to_file_polars(m, f, log=False, lazy=False): # to make this even faster, we can use polars expression # https://docs.pola.rs/user-guide/expressions/plugins/#output-data-types for name in names: - df = m.constraints[name].to_polars() - - # df = df.lazy() - # filter out repeated label values - df = df.with_columns( - pl.when(pl.col("labels").is_first_distinct()) - .then(pl.col("labels")) - .otherwise(pl.lit(None)) - .alias("labels") - ) - - columns = [ - pl.when(pl.col("labels").is_not_null()).then(pl.lit("c")).alias("c"), - pl.col("labels").cast(pl.String), - pl.when(pl.col("labels").is_not_null()).then(pl.lit(":\n")).alias(":"), - pl.when(pl.col("coeffs") >= 0).then(pl.lit("+")), - pl.col("coeffs").cast(pl.String), - pl.when(pl.col("vars").is_not_null()).then(pl.lit(" x")).alias("x"), - pl.col("vars").cast(pl.String), - "sign", - pl.lit(" "), - pl.col("rhs").cast(pl.String), - ] + con = m.constraints[name] + for con_slice in con.iterate_slices(slice_size): + df = con_slice.to_polars() + + # df = df.lazy() + # filter out repeated label values + df = df.with_columns( + pl.when(pl.col("labels").is_first_distinct()) + .then(pl.col("labels")) + .otherwise(pl.lit(None)) + .alias("labels") + ) - kwargs = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + columns = [ + pl.when(pl.col("labels").is_not_null()).then(pl.lit("c")).alias("c"), + pl.col("labels").cast(pl.String), + pl.when(pl.col("labels").is_not_null()).then(pl.lit(":\n")).alias(":"), + pl.when(pl.col("coeffs") >= 0).then(pl.lit("+")), + pl.col("coeffs").cast(pl.String), + pl.when(pl.col("vars").is_not_null()).then(pl.lit(" x")).alias("x"), + pl.col("vars").cast(pl.String), + "sign", + pl.lit(" "), + pl.col("rhs").cast(pl.String), + ] + + kwargs = dict( + separator=" ", null_value="", quote_style="never", include_header=False + ) + formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) + formatted.write_csv(f, **kwargs) - # in the future, we could use lazy dataframes when they support appending - # tp existent files - # formatted = df.lazy().select(pl.concat_str(columns, ignore_nulls=True)) - # formatted.sink_csv(f, **kwargs) + # in the future, we could use lazy dataframes when they support appending + # tp existent files + # formatted = df.lazy().select(pl.concat_str(columns, ignore_nulls=True)) + # formatted.sink_csv(f, **kwargs) -def to_lp_file_polars(m, fn, integer_label="general"): +def to_lp_file_polars(m, fn, integer_label="general", slice_size=2_000_000): log = m._xCounter > 10_000 with open(fn, mode="wb") as f: start = time.time() objective_to_file_polars(m, f, log=log) - constraints_to_file_polars(m, f=f, log=log) - bounds_to_file_polars(m, f=f, log=log) - binaries_to_file_polars(m, f=f, log=log) - integers_to_file_polars(m, integer_label=integer_label, f=f, log=log) + constraints_to_file_polars(m, f=f, log=log, slice_size=slice_size) + bounds_to_file_polars(m, f=f, log=log, slice_size=slice_size) + binaries_to_file_polars(m, f=f, log=log, slice_size=slice_size) + integers_to_file_polars( + m, integer_label=integer_label, f=f, log=log, slice_size=slice_size + ) f.write(b"end\n") logger.info(f" Writing time: {round(time.time()-start, 2)}s") @@ -539,15 +579,18 @@ def to_lp_file_polars(m, fn, integer_label="general"): def to_file( m: Model, - fn: Path | None, + fn: Path | str | None, io_api: str | None = None, integer_label: str = "general", + slice_size: int = 2_000_000, ) -> Path: """ Write out a model to a lp or mps file. """ if fn is None: fn = Path(m.get_problem_file()) + if isinstance(fn, str): + fn = Path(fn) if fn.exists(): fn.unlink() @@ -555,9 +598,9 @@ def to_file( io_api = fn.suffix[1:] if io_api == "lp": - to_lp_file(m, fn, integer_label) + to_lp_file(m, fn, integer_label, slice_size=slice_size) elif io_api == "lp-polars": - to_lp_file_polars(m, fn, integer_label) + to_lp_file_polars(m, fn, integer_label, slice_size=slice_size) elif io_api == "mps": if "highs" not in solvers.available_solvers: diff --git a/linopy/model.py b/linopy/model.py index 728c61bb..d1f34015 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -953,6 +953,7 @@ def solve( keep_files: bool = False, env: None = None, sanitize_zeros: bool = True, + slice_size: int = 2_000_000, remote: None = None, **solver_options, ) -> tuple[str, str]: @@ -1002,6 +1003,10 @@ def solve( Whether to set terms with zero coefficient as missing. This will remove unneeded overhead in the lp file writing. The default is True. + slice_size : int, optional + Size of the slice to use for writing the lp file. The slice size + is used to split large variables and constraints into smaller + chunks to avoid memory issues. The default is 2_000_000. remote : linopy.remote.RemoteHandler Remote handler to use for solving model on a server. Note that when solving on a rSee diff --git a/linopy/variables.py b/linopy/variables.py index fe0a214d..e6584141 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -42,6 +42,7 @@ get_label_position, has_optimized_model, is_constant, + iterate_slices, print_coord, print_single_variable, save_join, @@ -1075,6 +1076,8 @@ def equals(self, other: Variable) -> bool: stack = varwrap(Dataset.stack) + iterate_slices = iterate_slices + class AtIndexer: __slots__ = ("object",) diff --git a/test/test_common.py b/test/test_common.py index c5937c06..90f5df89 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -11,7 +11,12 @@ import xarray as xr from xarray import DataArray -from linopy.common import as_dataarray, assign_multiindex_safe, best_int +from linopy.common import ( + as_dataarray, + assign_multiindex_safe, + best_int, + iterate_slices, +) def test_as_dataarray_with_series_dims_default(): @@ -453,3 +458,74 @@ def test_assign_multiindex_safe(): assert "value" in result assert result["humidity"].equals(data) assert result["pressure"].equals(data) + + +def test_iterate_slices_basic(): + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=20)) + assert len(slices) == 5 + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_with_exclude_dims(): + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=20, slice_dims=["x"])) + assert len(slices) == 5 + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_large_max_size(): + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=200)) + assert len(slices) == 1 + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_small_max_size(): + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=8, slice_dims=[])) + assert len(slices) == 10 + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_slice_size_none(): + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=None)) + assert len(slices) == 1 + for s in slices: + assert ds.equals(s) + + +def test_iterate_slices_no_slice_dims(): + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=50, slice_dims=[])) + assert len(slices) == 2 + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) diff --git a/test/test_constraint.py b/test/test_constraint.py index 307d4453..0d0080fa 100644 --- a/test/test_constraint.py +++ b/test/test_constraint.py @@ -380,6 +380,12 @@ def test_constraint_flat(c): assert isinstance(c.flat, pd.DataFrame) +def test_iterate_slices(c): + for i in c.iterate_slices(slice_size=2): + assert isinstance(i, Constraint) + assert c.coord_dims == i.coord_dims + + def test_constraint_to_polars(c): assert isinstance(c.to_polars(), pl.DataFrame) diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 05b9b280..277e6443 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -526,6 +526,14 @@ def test_linear_expression_flat(v): assert (df.coeffs == coeff).all() +def test_iterate_slices(x, y): + expr = x + 10 * y + for s in expr.iterate_slices(slice_size=2): + assert isinstance(s, LinearExpression) + assert s.nterm == expr.nterm + assert s.coord_dims == expr.coord_dims + + def test_linear_expression_to_polars(v): coeff = np.arange(1, 21) # use non-zero coefficients expr = coeff * v diff --git a/test/test_optimization.py b/test/test_optimization.py index 2a76efb0..c5e040c8 100644 --- a/test/test_optimization.py +++ b/test/test_optimization.py @@ -400,6 +400,15 @@ def test_default_settings_chunked(model_chunked, solver, io_api): assert np.isclose(model_chunked.objective.value, 3.3) +@pytest.mark.parametrize("solver,io_api", params) +def test_default_settings_small_slices(model, solver, io_api): + assert model.objective.sense == "min" + status, condition = model.solve(solver, io_api=io_api, slice_size=2) + assert status == "ok" + assert np.isclose(model.objective.value, 3.3) + assert model.solver_name == solver + + @pytest.mark.parametrize("solver,io_api", params) def test_solver_options(model, solver, io_api): time_limit_option = { diff --git a/test/test_variable.py b/test/test_variable.py index bf718558..321bc3d5 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -286,3 +286,10 @@ def test_variable_sanitize(x): x = x.sanitize() assert isinstance(x, linopy.variables.Variable) assert x.labels[9] == -1 + + +def test_variable_iterate_slices(x): + slices = x.iterate_slices(slice_size=2) + for s in slices: + assert isinstance(s, linopy.variables.Variable) + assert s.size <= 2