Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flow constraints #34

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions motile/constraints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from .constraint import Constraint
from .expression import ExpressionConstraint
from .in_out_symmetry import InOutSymmetry
from .max_children import MaxChildren
from .max_parents import MaxParents
from .min_track_length import MinTrackLength
from .pin import Pin
from .select_edge_nodes import SelectEdgeNodes

__all__ = [
"Constraint",
"ExpressionConstraint",
"InOutSymmetry",
"MaxChildren",
"MaxParents",
"MinTrackLength",
"Pin",
"SelectEdgeNodes",
]
50 changes: 50 additions & 0 deletions motile/constraints/in_out_symmetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import ilpy

from motile.constraints import Constraint
from motile.variables import EdgeSelected

if TYPE_CHECKING:
from motile.solver import Solver


class InOutSymmetry(Constraint):
r"""Ensures that all nodes, apart from the ones in the first and last
frame, have as many incoming edges as outgoing edges.

Adds the following linear constraint for nodes :math:`v` not in first or
last frame:

.. math::
\sum_{e \in \\text{in_edges}(v)} x_e = \sum{e \in \\text{out_edges}(v)} x_e
"""

def instantiate(self, solver: Solver) -> list[ilpy.Constraint]:
edge_indicators = solver.get_variables(EdgeSelected)
start, end = solver.graph.get_frames()

constraints = []
for node, attrs in solver.graph.nodes.items():
constraint = ilpy.Constraint()

if solver.graph.frame_attribute in attrs and attrs[
solver.graph.frame_attribute
] not in (
start,
end - 1, # type: ignore
):
for prev_edge in solver.graph.prev_edges[node]:
ind_e = edge_indicators[prev_edge]
constraint.set_coefficient(ind_e, 1)
for next_edge in solver.graph.next_edges[node]:
ind_e = edge_indicators[next_edge]
constraint.set_coefficient(ind_e, -1)
constraint.set_relation(ilpy.Relation.Equal)
constraint.set_value(0)

constraints.append(constraint)

return constraints
42 changes: 42 additions & 0 deletions motile/constraints/min_track_length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import ilpy

from motile.constraints import Constraint
from motile.variables import EdgeSelected, NodeAppear

if TYPE_CHECKING:
pass


class MinTrackLength(Constraint):
r"""Ensures that each appearing track consists of at least ``min_edges``
edges.

Currently only supports ``min_edges = 1``.

Args:

min_edges: The minimum number of edges per track.
"""

def __init__(self, min_edges: int) -> None:
if min_edges != 1:
raise NotImplementedError(
"Can only enforce minimum track length of 1 edge."
)
self.min_edges = min_edges

def instantiate(self, solver):
appear_indicators = solver.get_variables(NodeAppear)
edge_indicators = solver.get_variables(EdgeSelected)
for node in solver.graph.nodes:
constraint = ilpy.Constraint()
constraint.set_coefficient(appear_indicators[node], 1)
for edge in solver.graph.next_edges[node]:
constraint.set_coefficient(edge_indicators[edge], -1)
constraint.set_relation(ilpy.Relation.LessEqual)
constraint.set_value(0)
yield constraint
67 changes: 65 additions & 2 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import motile
import pytest
from motile.constraints import ExpressionConstraint, MaxChildren, MaxParents, Pin
from motile.constraints import (
ExpressionConstraint,
InOutSymmetry,
MaxChildren,
MaxParents,
MinTrackLength,
Pin,
)
from motile.costs import EdgeSelection, NodeSelection
from motile.data import arlo_graph
from motile.data import arlo_graph, toy_graph
from motile.variables import EdgeSelected
from motile.variables.node_selected import NodeSelected

Expand Down Expand Up @@ -65,3 +72,59 @@ def test_max_parents(solver: motile.Solver) -> None:
assert _selected_edges(solver) != expect, "test invalid"
solver.add_constraints(MaxParents(1))
assert _selected_edges(solver) == expect


def test_in_out_symmetry():
graph = toy_graph()
solver = motile.Solver(graph)

solver.add_costs(NodeSelection(weight=-1.0, attribute="score"))
solver.add_costs(EdgeSelection(weight=-1.0, attribute="score"))
solver.add_constraints(MaxParents(1))
solver.add_constraints(MaxChildren(1))
solver.add_constraints(InOutSymmetry())

solution = solver.solve()

node_indicators = solver.get_variables(NodeSelected)
selected_nodes = [
node for node, index in node_indicators.items() if solution[index] > 0.5
]
edge_indicators = solver.get_variables(EdgeSelected)
selected_edges = [
edge for edge, index in edge_indicators.items() if solution[index] > 0.5
]
for node in solver.graph.nodes:
assert node in selected_nodes

assert (1, 3) in selected_edges
assert (3, 6) in selected_edges
assert (0, 2) in selected_edges
assert (2, 4) in selected_edges


def test_min_track_length():
graph = toy_graph()
solver = motile.Solver(graph)

solver.add_costs(NodeSelection(weight=-1.0, attribute="score"))
solver.add_costs(EdgeSelection(weight=-1.0, attribute="score"))
solver.add_constraints(MaxParents(1))
solver.add_constraints(MaxChildren(1))
solver.add_constraints(MinTrackLength(1))

solution = solver.solve()

node_indicators = solver.get_variables(NodeSelected)
selected_nodes = [
node for node, index in node_indicators.items() if solution[index] > 0.5
]
for node in solver.graph.nodes:
if node == 5:
assert node not in selected_nodes
else:
assert node in selected_nodes


if __name__ == "__main__":
test_in_out_symmetry()