diff --git a/doc/release_notes.rst b/doc/release_notes.rst index deed60d4..27647be7 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -4,6 +4,7 @@ Release Notes Upcoming Version ---------------- +* Add documentation about `LinearExpression.where` with `drop=True`. Add `BaseExpression.variable_names` and `BaseExpression.nvar` properties. * Add the `sphinx-copybutton` to the documentation Version 0.6.1 diff --git a/examples/creating-expressions.ipynb b/examples/creating-expressions.ipynb index aafd8a09..4067018b 100644 --- a/examples/creating-expressions.ipynb +++ b/examples/creating-expressions.ipynb @@ -160,7 +160,11 @@ "cell_type": "markdown", "id": "f7578221", "metadata": {}, - "source": ".. important::\n\n\tWhen combining variables or expression with dimensions of the same name and size, the first object will determine the coordinates of the resulting expression. For example:" + "source": [ + ".. important::\n", + "\n", + "\tWhen combining variables or expression with dimensions of the same name and size, the first object will determine the coordinates of the resulting expression. For example:" + ] }, { "cell_type": "code", @@ -308,6 +312,102 @@ "(x + y).where(mask) + xr.DataArray(5, coords=[time]).where(~mask, 0)" ] }, + { + "cell_type": "markdown", + "id": "6741e69e", + "metadata": {}, + "source": [ + "Sometimes `.where` may lead to a situation where some of the variables are completely masked" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc32bdca", + "metadata": {}, + "outputs": [], + "source": [ + "mask_a = xr.DataArray(False, coords=[time])\n", + "mask_b = xr.DataArray(time > 2, coords=[time])\n", + "\n", + "z = (x.where(mask_a) + y).where(mask_b)\n", + "z" + ] + }, + { + "cell_type": "markdown", + "id": "25bf798c", + "metadata": {}, + "source": [ + "In this example you can see that many of the elements of the LinearExpression are None. If you want to remove all the None terms, you can use `.where(.., drop=True)`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72c6b51b", + "metadata": {}, + "outputs": [], + "source": [ + "z = z.where(mask_b, drop=True)\n", + "z" + ] + }, + { + "cell_type": "markdown", + "id": "1c1e0b85", + "metadata": {}, + "source": [ + "That looks nicer!
" + ] + }, + { + "cell_type": "markdown", + "id": "d8530a08", + "metadata": {}, + "source": [ + "You may notice that the variable `x` is not used at all. The expression still contains two terms (one of them is unused) but it only has one variable `y`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c577863", + "metadata": {}, + "outputs": [], + "source": [ + "z.nterm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe43d47d", + "metadata": {}, + "outputs": [], + "source": [ + "z.variable_names" + ] + }, + { + "cell_type": "markdown", + "id": "a76d40b1", + "metadata": {}, + "source": [ + "You can get rid of the unused term with `.simplify()`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc27341c", + "metadata": {}, + "outputs": [], + "source": [ + "z = z.simplify()\n", + "z.nterm" + ] + }, { "attachments": {}, "cell_type": "markdown", diff --git a/linopy/expressions.py b/linopy/expressions.py index 10e243de..a05d9206 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1071,6 +1071,38 @@ def nterm(self) -> int: """ return len(self.data._term) + @property + def nvar(self) -> int: + """ + Get the number of unique variables in the linear expression. + Note that nvar <= nterm, as variables can appear multiple times and there can be terms which are completely masked out. + """ + return len(self.variable_names) + + @property + def variable_names(self) -> set[str]: + """ + The names of the unique variables present in the expression + """ + if self.nterm == 0: + return set() + + # Collect all unique labels from the expression (excluding -1) while preserving order + all_labels = self.vars.values.ravel() + valid_labels = all_labels[all_labels != -1] + + if len(valid_labels) == 0: + return set() + + # Get unique labels while preserving first occurrence order + unique_labels, first_indices = np.unique(valid_labels, return_index=True) + ordered_labels = unique_labels[np.argsort(first_indices)] + + # Batch lookup variable names for all labels + positions = self.model.variables.get_label_position(ordered_labels) + + return {p[0] for p in positions if p[0] is not None} + @property def shape(self) -> tuple[int, ...]: """ diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index a75ace3f..174a4c38 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1313,3 +1313,53 @@ def test_simplify_partial_cancellation(x: Variable, y: Variable) -> None: assert all(simplified.coeffs.values == 3.0), ( f"Expected coefficient 3.0, got {simplified.coeffs.values}" ) + + +def test_variable_names() -> None: + m = Model() + time = pd.Index(range(3), name="time") + + a = m.add_variables(name="a", coords=[time]) + b = m.add_variables(name="b", coords=[time]) + + expr = a + b + assert expr.nterm == 2 + assert expr.variable_names == {"a", "b"} + + mask = xr.DataArray(False, coords=[time]) + expr = a + (b * 1).where(mask) + assert expr.nterm == 2 + assert expr.variable_names == {"a"} + + expr = (b * 1).where(mask) + assert expr.nterm == 1 + assert expr.variable_names == set() + + expr = LinearExpression.from_constant(model=m, constant=5) + assert expr.nterm == 0 + assert expr.variable_names == set() + + +def test_nvar_and_nterm() -> None: + m = Model() + time = pd.Index(range(3), name="time") + all_false = xr.DataArray(False, coords=[time]) + not_0 = xr.DataArray([False, True, True], coords=[time]) + not_1 = xr.DataArray([True, False, True], coords=[time]) + not_2 = xr.DataArray([True, True, False], coords=[time]) + + a = m.add_variables(name="a", coords=[time]) + b = m.add_variables(name="b", coords=[time]) + c = m.add_variables(name="c", coords=[time]) + + expr = (a.where(not_0) + b.where(not_1) + c.where(not_2)).densify_terms() + assert expr.nterm == 3 + assert expr.nvar == 3 + + expr = a + b.where(all_false) + assert expr.nterm == 2 + assert expr.nvar == 1 + + expr = expr.simplify() + assert expr.nterm == 1 + assert expr.nvar == 1 diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index fc1bb25f..6cf697ba 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -360,3 +360,11 @@ def test_power_of_three(x: Variable) -> None: x**3 with pytest.raises(TypeError): (x * x) * (x * x) + + +def test_variable_names(x: Variable, y: Variable) -> None: + expr = 2 * (x * x) + 3 * y + 1 + assert expr.variable_names == {"x", "y"} + + expr = 2 * (y * y) + 1 + assert expr.variable_names == {"y"}