From 5ab4d6241a33fdb7189faae35e508f5411e93742 Mon Sep 17 00:00:00 2001 From: Alyssa Travitz Date: Thu, 18 Sep 2025 11:44:00 -0700 Subject: [PATCH 1/6] move to konnektor --- environment.yml | 3 +- .../test_atommapping_network_plotting.py | 2 +- openfe/tests/utils/test_network_plotting.py | 2 +- openfe/utils/__init__.py | 1 - openfe/utils/atommapping_network_plotting.py | 174 ------- openfe/utils/custom_typing.py | 19 - openfe/utils/network_plotting.py | 427 ------------------ openfecli/commands/view_ligand_network.py | 2 +- .../commands/test_ligand_network_viewer.py | 2 +- 9 files changed, 6 insertions(+), 626 deletions(-) delete mode 100644 openfe/utils/atommapping_network_plotting.py delete mode 100644 openfe/utils/custom_typing.py delete mode 100644 openfe/utils/network_plotting.py diff --git a/environment.yml b/environment.yml index acd85e538..6d41d5fd4 100644 --- a/environment.yml +++ b/environment.yml @@ -7,7 +7,7 @@ dependencies: - coverage - duecredit<0.10 - kartograf>=1.2.0 - - konnektor~=0.2.0 +# - konnektor~=0.2.0 - lomap2>=3.2.1 - networkx - numpy @@ -52,3 +52,4 @@ dependencies: - threadpoolctl - pip: - git+https://github.com/OpenFreeEnergy/gufe@main + - git+https://github.com/OpenFreeEnergy/konnektor@move_network_plotting_from_openfe diff --git a/openfe/tests/utils/test_atommapping_network_plotting.py b/openfe/tests/utils/test_atommapping_network_plotting.py index 8049c545d..a001527fa 100644 --- a/openfe/tests/utils/test_atommapping_network_plotting.py +++ b/openfe/tests/utils/test_atommapping_network_plotting.py @@ -6,7 +6,7 @@ import matplotlib.figure import importlib.resources -from openfe.utils.atommapping_network_plotting import ( +from konnektor.visualization.atommapping_network_plotting import ( AtomMappingNetworkDrawing, plot_atommapping_network, LigandNode, ) diff --git a/openfe/tests/utils/test_network_plotting.py b/openfe/tests/utils/test_network_plotting.py index c99ff1e73..9c798eb8a 100644 --- a/openfe/tests/utils/test_network_plotting.py +++ b/openfe/tests/utils/test_network_plotting.py @@ -4,7 +4,7 @@ from matplotlib import pyplot as plt import networkx as nx -from openfe.utils.network_plotting import ( +from konnektor.visualization.network_plotting import ( Node, Edge, EventHandler, GraphDrawing ) diff --git a/openfe/utils/__init__.py b/openfe/utils/__init__.py index 0d276a97e..fc489005d 100644 --- a/openfe/utils/__init__.py +++ b/openfe/utils/__init__.py @@ -1,7 +1,6 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from . import custom_typing from .optional_imports import requires_package from .remove_oechem import without_oechem_backend from .system_probe import log_system_probe diff --git a/openfe/utils/atommapping_network_plotting.py b/openfe/utils/atommapping_network_plotting.py deleted file mode 100644 index 8e119b432..000000000 --- a/openfe/utils/atommapping_network_plotting.py +++ /dev/null @@ -1,174 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -import io -import matplotlib -from rdkit import Chem -from typing import Dict, Tuple - -from openfe.utils.network_plotting import GraphDrawing, Node, Edge -from gufe.visualization.mapping_visualization import ( - draw_one_molecule_mapping, -) -from openfe.utils.custom_typing import MPL_MouseEvent -from openfe import SmallMoleculeComponent, LigandNetwork - - -class AtomMappingEdge(Edge): - """Edge to draw AtomMapping from a LigandNetwork. - - The ``select`` and ``unselect`` methods are implemented here to force - the mapped molecule to be drawn/disappear. - - Parameters - ---------- - node_artist1, node_artist2 : :class:`.Node` - GraphDrawing nodes for this edge - data : Dict - Data dictionary for this edge. Must have key ``object``, which maps - to an :class:`.AtomMapping`. - """ - def __init__(self, node_artist1: Node, node_artist2: Node, data: Dict): - super().__init__(node_artist1, node_artist2, data) - self.left_image = None - self.right_image = None - - def _draw_mapped_molecule( - self, - extent: Tuple[float, float, float, float], - molA: SmallMoleculeComponent, - molB: SmallMoleculeComponent, - molA_to_molB: Dict[int, int] - ): - # create the image in a format matplotlib can handle - d2d = Chem.Draw.rdMolDraw2D.MolDraw2DCairo(300, 300, 300, 300) - d2d.drawOptions().setBackgroundColour((1, 1, 1, 0.7)) - # TODO: use a custom draw2d object; figure size from transforms - img_bytes = draw_one_molecule_mapping(molA_to_molB, - molA.to_rdkit(), - molB.to_rdkit(), - d2d=d2d) - img_filelike = io.BytesIO(img_bytes) # imread needs filelike - img_data = matplotlib.pyplot.imread(img_filelike) - - ax = self.artist.axes - x0, x1, y0, y1 = extent - - # version A: using AxesImage - im = matplotlib.image.AxesImage(ax, extent=extent, zorder=10) - - # version B: using BboxImage - # keep this commented code around for later performance checks - # bounds = (x0, y0, x1 - x0, y1 - y0) - # bounds = (0.2, 0.2, 0.3, 0.3) - # bbox0 = matplotlib.transforms.Bbox.from_bounds(*bounds) - # bbox = matplotlib.transforms.TransformedBbox(bbox0, ax.transAxes) - # im = matplotlib.image.BboxImage(bbox) - - # set image data and register - im.set_data(img_data) - ax.add_artist(im) - return im - - def _get_image_extents(self): - # figure out the extent for left and right - x0, x1 = self.artist.axes.get_xlim() - dx = x1 - x0 - left_x0, left_x1 = 0.05 * dx + x0, 0.45 * dx + x0 - right_x0, right_x1 = 0.55 * dx + x0, 0.95 * dx + x0 - y0, y1 = self.artist.axes.get_ylim() - dy = y1 - y0 - y_bottom, y_top = 0.5 * dx + y0, 0.9 * dx + y0 - - left_extent = (left_x0, left_x1, y_bottom, y_top) - right_extent = (right_x0, right_x1, y_bottom, y_top) - return left_extent, right_extent - - def select(self, event, graph): - super().select(event, graph) - mapping = self.data['object'] - - # figure out which node is to the left and which to the right - xs = [node.xy[0] for node in self.node_artists] - if xs[0] <= xs[1]: - left = mapping.componentA - right = mapping.componentB - left_to_right = mapping.componentA_to_componentB - right_to_left = mapping.componentB_to_componentA - else: - left = mapping.componentB - right = mapping.componentA - left_to_right = mapping.componentB_to_componentA - right_to_left = mapping.componentA_to_componentB - - left_extent, right_extent = self._get_image_extents() - - self.left_image = self._draw_mapped_molecule(left_extent, - left, - right, - left_to_right) - self.right_image = self._draw_mapped_molecule(right_extent, - right, - left, - right_to_left) - graph.fig.canvas.draw() - - def unselect(self): - super().unselect() - for artist in [self.left_image, self.right_image]: - if artist is not None: - artist.remove() - - self.left_image = None - self.right_image = None - - -class LigandNode(Node): - def _make_artist(self, x, y, dx, dy): - artist = matplotlib.text.Text(x, y, self.node.name, color='blue', - backgroundcolor='white') - return artist - - def register_artist(self, ax): - ax.add_artist(self.artist) - - @property - def extent(self): - txt = self.artist - ext = txt.axes.transData.inverted().transform(txt.get_window_extent()) - [[xmin, ymin], [xmax, ymax]] = ext - return xmin, xmax, ymin, ymax - - @property - def xy(self): - return self.artist.get_position() - - -class AtomMappingNetworkDrawing(GraphDrawing): - """ - Class for drawing atom mappings from a provided ligang network. - - Parameters - ---------- - graph : nx.MultiDiGraph - NetworkX representation of the LigandNetwork - positions : Optional[Dict[SmallMoleculeComponent, Tuple[float, float]]] - mapping of node to position - """ - NodeCls = LigandNode - EdgeCls = AtomMappingEdge - - -def plot_atommapping_network(network: LigandNetwork): - """Convenience method for plotting the atom mapping network - - Parameters - ---------- - network : :class:`.Network` - the network to plot - - Returns - ------- - :class:`matplotlib.figure.Figure` : - the matplotlib figure containing the iteractive visualization - """ - return AtomMappingNetworkDrawing(network.graph).fig diff --git a/openfe/utils/custom_typing.py b/openfe/utils/custom_typing.py deleted file mode 100644 index 32440731a..000000000 --- a/openfe/utils/custom_typing.py +++ /dev/null @@ -1,19 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe - -from typing import TypeVar -from rdkit import Chem -import matplotlib.axes -import matplotlib.backend_bases - -try: - from typing import TypeAlias -except ImportError: - from typing_extensions import TypeAlias - -RDKitMol: TypeAlias = Chem.rdchem.Mol - -OEMol = TypeVar('OEMol') -MPL_FigureCanvasBase: TypeAlias = matplotlib.backend_bases.FigureCanvasBase -MPL_MouseEvent: TypeAlias = matplotlib.backend_bases.MouseEvent -MPL_Axes: TypeAlias = matplotlib.axes.Axes diff --git a/openfe/utils/network_plotting.py b/openfe/utils/network_plotting.py deleted file mode 100644 index ad767dc73..000000000 --- a/openfe/utils/network_plotting.py +++ /dev/null @@ -1,427 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe - -""" -Generic tools for plotting networks. Interfaces NetworkX and matplotlib. - -Create subclasses of ``Node``, ``Edge``, and ``GraphDrawing`` to customize -behavior how the graph is visualized or what happens on interactive events. -""" - -from __future__ import annotations - -import itertools -import networkx as nx -from matplotlib import pyplot as plt -from matplotlib.patches import Rectangle -from matplotlib.lines import Line2D - -from typing import Optional, Any, Union, cast -from openfe.utils.custom_typing import ( - MPL_MouseEvent, MPL_FigureCanvasBase, MPL_Axes, TypeAlias -) - -ClickLocation: TypeAlias = tuple[tuple[float, float], tuple[Any, Any]] - - -class Node: - """Node in the GraphDrawing network. - - This connects a node in the NetworkX graph to the matplotlib artist. - This is the only object that should directly use the matplotlib artist - for this node. This acts as an adapter class, allowing different artists - to be used, as well as enabling different functionalities. - """ - # TODO: someday it might be good to separate the artist adapter from the - # functionality on select, etc. - draggable = True - pickable = False - lock = None # lock used while dragging; only one Node dragged at a time - - def __init__(self, node, x: float, y: float, dx=0.1, dy=0.1): - self.node = node - self.dx = dx - self.dy = dx - self.artist = self._make_artist(x, y, dx, dy) - self.picked = False - self.press: Optional[ClickLocation] = None - - def _make_artist(self, x, y, dx, dy): - return Rectangle((x, y), dx, dy, color='blue') - - def register_artist(self, ax: MPL_Axes): - """Register this node's artist with the matplotlib Axes""" - ax.add_patch(self.artist) - - @property - def extent(self) -> tuple[float, float, float, float]: - """extent of this node in matplotlib data coordinates""" - bounds = self.artist.get_bbox().bounds - return (bounds[0], bounds[0] + bounds[2], - bounds[1], bounds[1] + bounds[3]) - - @property - def xy(self) -> tuple[float, float]: - """lower left (matplotlib data coordinates) position of this node""" - return self.artist.xy - - def select(self, event: MPL_MouseEvent, graph: GraphDrawing): # -no-cov- - """Set this node to its state when it is selected (clicked on)""" - return - - def unselect(self): - """Reset this node to its standard, unselected visualization""" - self.artist.set(color='blue') - - def edge_select(self, edge: Edge): - """Change node visualization when one of its edges is selected""" - self.artist.set(color='red') - - def update_location(self, x: float, y: float): - """Update the location of the underlying artist""" - self.artist.set(x=x, y=y) - - # note: much the stuff below is based on the "Draggable rectangle" - # exercise at: - # https://matplotlib.org/stable/users/explain/event_handling.html#draggable-rectangle-exercise - def contains(self, event: MPL_MouseEvent) -> bool: - """Report whether this object contains the given event""" - return self.artist.contains(event)[0] - - def on_mousedown(self, event: MPL_MouseEvent, graph: GraphDrawing): - """Handle mousedown event (button_press_event)""" - # these early returns probably won't be called in practice, since - # the event handler should only call this method when those - # conditions are met; still, defensive programming! - if event.inaxes != self.artist.axes: - return - - if not self.contains(event): - return - - # record the original click location; lock that we're the only - # object being dragged - self.press = self.xy, (event.xdata, event.ydata) - Node.lock = self - # TODO: blitting - - def on_drag(self, event: MPL_MouseEvent, graph: GraphDrawing): - """Handle dragging this node""" - if event.inaxes != self.artist.axes or Node.lock is not self: - return - - if self.press: - (x0, y0), (xpress, ypress) = self.press - else: - # this should be impossible in practice, but mypy needed the - # explicit check so it didn't unpack None - raise RuntimeError("Can't drag until mouse down!") - - dx = event.xdata - xpress - dy = event.ydata - ypress - self.update_location(x0 + dx, y0 + dy) - - # TODO: this might be cached on mousedown - edges = graph.edges_for_node(self.node) - for edge in edges: - edge.update_locations() - - # TODO: blitting - self.artist.figure.canvas.draw() - - def on_mouseup(self, event: MPL_MouseEvent, graph: GraphDrawing): - """Handle mouseup event (button_release_event)""" - self.press = None - Node.lock = None - # TODO: blitting - self.artist.figure.canvas.draw() - - -class Edge: - """Edge in the GraphDrawing network. - - This connects an edge in the NetworkX graph to the matplotlib artist. In - addition to the edge data, this needs to know the two GraphDrawing - ``Node`` instances associated with this edge. - - Parameters - ---------- - node_artist1, node_artist2 : :class:`.Node` - GraphDrawing nodes for this edge - data : Dict - data dictionary for this edge - """ - pickable = True - - def __init__(self, node_artist1: Node, node_artist2: Node, data: dict): - self.data = data - self.node_artists = [node_artist1, node_artist2] - self.artist = self._make_artist(node_artist1, node_artist2, data) - self.picked = False - - def _make_artist(self, node_artist1: Node, node_artist2: Node, - data: dict) -> Any: - xs, ys = self._edge_xs_ys(node_artist1, node_artist2) - return Line2D(xs, ys, color='black', picker=True, zorder=-1) - - def register_artist(self, ax: MPL_Axes): - """Register this edge's artist with the matplotlib Axes""" - ax.add_line(self.artist) - - def contains(self, event: MPL_MouseEvent) -> bool: - """Report whether this object contains the given event""" - return self.artist.contains(event)[0] - - @staticmethod - def _edge_xs_ys(node1: Node, node2: Node): - def get_midpoint(node): - x0, x1, y0, y1 = node.extent - return (0.5 * (x0 + x1), 0.5 * (y0 + y1)) - - midpt1 = get_midpoint(node1) - midpt2 = get_midpoint(node2) - - xs, ys = list(zip(*[midpt1, midpt2])) - return xs, ys - - def on_mousedown(self, event: MPL_MouseEvent, graph: GraphDrawing): - """Handle mousedown event (button_press_event)""" - return # -no-cov- - - def on_drag(self, event: MPL_MouseEvent, graph: GraphDrawing): - """Handle drag event""" - return # -no-cov- - - def on_mouseup(self, event: MPL_MouseEvent, graph: GraphDrawing): - """Handle mouseup event (button_release_event)""" - return # -no-cov- - - def unselect(self): - """Reset this edge to its standard, unselected visualization""" - self.artist.set(color='black') - for node_artist in self.node_artists: - node_artist.unselect() - self.picked = False - - def select(self, event: MPL_MouseEvent, graph: GraphDrawing): - """Mark this edge as selected, update visualization""" - self.artist.set(color='red') - for artist in self.node_artists: - artist.edge_select(self) - self.picked = True - return True - - def update_locations(self): - """Update the location of this edge based on node locations""" - xs, ys = self._edge_xs_ys(*self.node_artists) - self.artist.set(xdata=xs, ydata=ys) - - -class EventHandler: - """Pass event information to nodes/edges. - - This is the single place where we connect to the matplotlib event - system. This object receives matplotlib events and delegates to the - appropriate node or edge. - - Parameters - ---------- - graph : GraphDrawing - the graph drawing that we're handling events for - - Attributes - ---------- - active : Optional[Union[Node, Edge]] - Object activated by a mousedown event, or None if either no object - activated by mousedown or if mouse is not currently pressed. This is - primarily used to handle drag events. - selected : Optional[Union[Node, Edge]] - Object selected by a mouse click (after mouse is up), or None if no - object has been selected in the graph. - click_location : Optional[tuple[Optional[float], Optional[float]]] - Cached location of the mousedown event, or None if mouse is up - connections : List[int] - list of IDs for connections to matplotlib canvas - """ - def __init__(self, graph: GraphDrawing): - self.graph = graph - self.active: Optional[Union[Node, Edge]] = None - self.selected: Optional[Union[Node, Edge]] = None - self.click_location: Optional[tuple[Optional[float], Optional[float]]] = None - self.connections: list[int] = [] - - def connect(self, canvas: MPL_FigureCanvasBase): - """Connect our methods to events in the matplotlib canvas""" - self.connections.extend([ - canvas.mpl_connect('button_press_event', self.on_mousedown), # type: ignore - canvas.mpl_connect('motion_notify_event', self.on_drag), # type: ignore - canvas.mpl_connect('button_release_event', self.on_mouseup), # type: ignore - ]) - - def disconnect(self, canvas: MPL_FigureCanvasBase): - """Disconnect all connections to the canvas.""" - for cid in self.connections: - canvas.mpl_disconnect(cid) - self.connections = [] - - def _get_event_container(self, event: MPL_MouseEvent): - """Identify which object should process an event. - - Note that we prefer nodes to edges: If you click somewhere that - could be a node or an edge, it is interpreted as clicking on the - node. - """ - containers = itertools.chain(self.graph.nodes.values(), - self.graph.edges.values()) - for container in containers: - if container.contains(event): - break - else: - container = None - - return container - - def on_mousedown(self, event: MPL_MouseEvent): - """Handle mousedown event (button_press_event)""" - self.click_location = event.xdata, event.ydata - container = self._get_event_container(event) - if container is None: - return - - # cast because mypy can't tell that we did early return if None - self.active = cast(Union[Node, Edge], container) - self.active.on_mousedown(event, self.graph) - - def on_drag(self, event: MPL_MouseEvent): - """Handle dragging""" - if not self.active or event.inaxes != self.active.artist.axes: - return - - self.active.on_drag(event, self.graph) - - def on_mouseup(self, event: MPL_MouseEvent): - """Handle mouseup event (button_release_event)""" - if self.click_location == (event.xdata, event.ydata): - # mouse hasn't moved; call it a click - # first unselect whatever was previously selected - if self.selected: - self.selected.unselect() - - # if it is a click and the active object contains it, select it; - # otherwise unset selection - if self.active and self.active.contains(event): - self.active.select(event, self.graph) - self.selected = self.active - else: - self.selected = None - - if self.active: - self.active.on_mouseup(event, self.graph) - - self.active = None - self.click_location = None - self.graph.draw() - - -class GraphDrawing: - """ - Base class for drawing networks with matplotlib. - - Connects to the matplotlib figure and to the underlying NetworkX graph. - - Typical use will require a subclass with custom values of ``NodeCls`` - and ``EdgeCls`` to handle the specific visualization. - - Parameters - ---------- - graph : nx.MultiDiGraph - NetworkX graph with information in nodes and edges to be drawn - positions : Optional[Dict[Any, Tuple[float, float]]] - mapping of node to position - """ - NodeCls = Node - EdgeCls = Edge - - def __init__(self, graph: nx.Graph, positions=None, ax=None): - # TODO: use scale to scale up the positions? - self.event_handler = EventHandler(self) - self.graph = graph - self.nodes: dict[Node, Any] = {} - self.edges: dict[tuple[Node, Node], Any] = {} - - if positions is None: - positions = nx.nx_agraph.graphviz_layout(self.graph, prog='neato') - - was_interactive = plt.isinteractive() - plt.ioff() - if ax is None: - self.fig, self.ax = plt.subplots(figsize=(8, 8)) - else: - self.fig, self.ax = ax.figure, ax - - for node, pos in positions.items(): - self._register_node(node, pos) - - self.fig.canvas.draw() # required to get renderer - for edge in graph.edges(data=True): - self._register_edge(edge) - - self.reset_bounds() - self.ax.set_aspect(1) - self.ax.set_xticks([]) - self.ax.set_yticks([]) - if was_interactive: - plt.ion() # -no-cov- - - self.event_handler.connect(self.fig.canvas) - - def _ipython_display_(self): # -no-cov- - return self.fig - - def edges_for_node(self, node: Node) -> list[Edge]: - """List of edges for the given node""" - edges = (list(self.graph.in_edges(node)) - + list(self.graph.out_edges(node))) - return [self.edges[edge] for edge in edges] - - def _get_nodes_extent(self): - """Find the extent of all nodes (used in setting bounds)""" - min_xs, max_xs, min_ys, max_ys = zip(*( - node.extent for node in self.nodes.values() - )) - return min(min_xs), max(max_xs), min(min_ys), max(max_ys) - - def reset_bounds(self): - """Set the bounds of the matplotlib Axes to include all nodes""" - # I feel like the following should be a better approach, but it - # doesn't seem to work - # renderer = self.fig.canvas.get_renderer() - # bbox = self.ax.get_tightbbox(renderer) - # trans = self.ax.transData.inverted() - # [[min_x, min_y], [max_x, max_y]] = trans.transform(bbox) - min_x, max_x, min_y, max_y = self._get_nodes_extent() - pad_x = (max_x - min_x) * 0.05 - pad_y = (max_y - min_y) * 0.05 - self.ax.set_xlim(min_x - pad_x, max_x + pad_x) - self.ax.set_ylim(min_y - pad_y, max_y + pad_y) - - def draw(self): - """Draw the current canvas""" - self.fig.canvas.draw() - self.fig.canvas.flush_events() - - def _register_node(self, node: Any, position: tuple[float, float]): - """Create and register ``Node`` from NetworkX node and position""" - if node in self.nodes: - raise RuntimeError("node provided multiple times") - - draw_node = self.NodeCls(node, *position) - self.nodes[node] = draw_node - draw_node.register_artist(self.ax) - - def _register_edge(self, edge: tuple[Node, Node, dict]): - """Create and register ``Edge`` from NetworkX edge information""" - node1, node2, data = edge - draw_edge = self.EdgeCls(self.nodes[node1], self.nodes[node2], data) - self.edges[(node1, node2)] = draw_edge - draw_edge.register_artist(self.ax) diff --git a/openfecli/commands/view_ligand_network.py b/openfecli/commands/view_ligand_network.py index 5e5450e73..8fd9b4150 100644 --- a/openfecli/commands/view_ligand_network.py +++ b/openfecli/commands/view_ligand_network.py @@ -18,7 +18,7 @@ def view_ligand_network(ligand_network: os.PathLike): e.g. ``openfe view-ligand-network network_setup/ligand_network.graphml`` """ - from openfe.utils.atommapping_network_plotting import ( + from konnektor.visualization.atommapping_network_plotting import ( plot_atommapping_network ) from openfe.setup import LigandNetwork diff --git a/openfecli/tests/commands/test_ligand_network_viewer.py b/openfecli/tests/commands/test_ligand_network_viewer.py index 525447122..7c97fb469 100644 --- a/openfecli/tests/commands/test_ligand_network_viewer.py +++ b/openfecli/tests/commands/test_ligand_network_viewer.py @@ -15,7 +15,7 @@ def test_view_ligand_network(): backend = matplotlib.get_backend() matplotlib.use("ps") - loc = "openfe.utils.atommapping_network_plotting.matplotlib.use" + loc = "konnektor.visualization.atommapping_network_plotting.matplotlib.use" with runner.isolated_filesystem(): with mock.patch(loc, mock.Mock()): result = runner.invoke(view_ligand_network, [str(ref)]) From 5c0586e5348110b407a771de91b269321c62ef1c Mon Sep 17 00:00:00 2001 From: Alyssa Travitz Date: Thu, 18 Sep 2025 11:45:51 -0700 Subject: [PATCH 2/6] move network plotting tests to konnektor --- .../test_atommapping_network_plotting.py | 188 ------ openfe/tests/utils/test_network_plotting.py | 560 ------------------ 2 files changed, 748 deletions(-) delete mode 100644 openfe/tests/utils/test_atommapping_network_plotting.py delete mode 100644 openfe/tests/utils/test_network_plotting.py diff --git a/openfe/tests/utils/test_atommapping_network_plotting.py b/openfe/tests/utils/test_atommapping_network_plotting.py deleted file mode 100644 index a001527fa..000000000 --- a/openfe/tests/utils/test_atommapping_network_plotting.py +++ /dev/null @@ -1,188 +0,0 @@ -import inspect -import pytest -from unittest import mock -from matplotlib import pyplot as plt -import matplotlib -import matplotlib.figure -import importlib.resources - -from konnektor.visualization.atommapping_network_plotting import ( - AtomMappingNetworkDrawing, plot_atommapping_network, - LigandNode, -) - -from openfe.tests.utils.test_network_plotting import mock_event - - -def bound_args(func, args, kwargs): - """Return a dictionary mapping parameter name to value. - - Parameters - ---------- - func : Callable - this must be inspectable; mocks will require a spec - args : List - args list - kwargs : Dict - kwargs Dict - - Returns - ------- - Dict[str, Any] : - mapping of string name of function parameter to the value it would - be bound to - """ - sig = inspect.Signature.from_callable(func) - bound = sig.bind(*args, **kwargs) - return bound.arguments - - -@pytest.fixture -def network_drawing(simple_network): - nx_graph = simple_network.network.graph - node_dict = {node.smiles: node for node in nx_graph.nodes} - positions = { - node_dict["CC"]: (0.0, 0.0), - node_dict["CO"]: (0.5, 0.0), - node_dict["CCO"]: (0.25, 0.25) - } - graph = AtomMappingNetworkDrawing(nx_graph, positions) - graph.ax.set_xlim(0, 1) - graph.ax.set_ylim(0, 1) - yield graph - plt.close(graph.fig) - - -@pytest.fixture -def default_edge(network_drawing): - node_dict = {node.smiles: node for node in network_drawing.graph.nodes} - yield network_drawing.edges[node_dict["CC"], node_dict["CO"]] - - -@pytest.fixture -def default_node(network_drawing): - node_dict = {node.smiles: node for node in network_drawing.graph.nodes} - yield LigandNode(node_dict["CC"], 0.5, 0.5, 0.1, 0.1) - - - -class TestAtomMappingEdge: - def test_draw_mapped_molecule(self, default_edge): - assert len(default_edge.artist.axes.images) == 0 - im = default_edge._draw_mapped_molecule( - (0.05, 0.45, 0.5, 0.9), - default_edge.node_artists[0].node, - default_edge.node_artists[1].node, - {0: 0} - ) - # maybe add something about im itself? not sure what to test here - assert len(default_edge.artist.axes.images) == 1 - assert default_edge.artist.axes.images[0] == im - - def test_get_image_extents(self, default_edge): - left_extent, right_extent = default_edge._get_image_extents() - assert left_extent == (0.05, 0.45, 0.5, 0.9) - assert right_extent == (0.55, 0.95, 0.5, 0.9) - - def test_select(self, default_edge, network_drawing): - assert not default_edge.picked - assert len(default_edge.artist.axes.images) == 0 - - event = mock_event('mouseup', 0.25, 0.0, network_drawing.fig) - default_edge.select(event, network_drawing) - - assert default_edge.picked - assert len(default_edge.artist.axes.images) == 2 - - @pytest.mark.parametrize('edge_str,left_right,molA_to_molB', [ - (("CCO", "CC"), ("CC", "CCO"), {0: 0, 1: 1}), - (("CC", "CO"), ("CC", "CO"), {0: 0}), - (("CCO", "CO"), ("CCO", "CO"), {0: 0, 2: 1}), - ]) - def test_select_mock_drawing(self, edge_str, left_right, molA_to_molB, - network_drawing): - # this tests that we call _draw_mapped_molecule with the correct - # kwargs -- in particular, it ensures that we get the left and right - # molecules correctly - node_dict = {node.smiles: node - for node in network_drawing.graph.nodes} - edge_tuple = tuple(node_dict[node] for node in edge_str) - edge = network_drawing.edges[edge_tuple] - left, right = [network_drawing.nodes[node_dict[node]] - for node in left_right] - # ensure that we have them labelled correctly - assert left.xy[0] < right.xy[0] - func = edge._draw_mapped_molecule # save for bound_args - edge._draw_mapped_molecule = mock.Mock() - - event = mock_event('mouseup', 0.25, 0.0, network_drawing.fig) - edge.select(event, network_drawing) - - arg_dicts = [ - bound_args(func, call.args, call.kwargs) - for call in edge._draw_mapped_molecule.mock_calls - ] - expected_left = { - 'extent': (0.05, 0.45, 0.5, 0.9), - 'molA': left.node, - 'molB': right.node, - 'molA_to_molB': molA_to_molB, - } - expected_right = { - 'extent': (0.55, 0.95, 0.5, 0.9), - 'molA': right.node, - 'molB': left.node, - 'molA_to_molB': {v: k for k, v in molA_to_molB.items()}, - } - assert len(arg_dicts) == 2 - assert expected_left in arg_dicts - assert expected_right in arg_dicts - - def test_unselect(self, default_edge, network_drawing): - # start by selecting; hard to be sure we mocked all the side effects - # of select - event = mock_event('mouseup', 0.25, 0.0, network_drawing.fig) - default_edge.select(event, network_drawing) - assert default_edge.picked - assert len(default_edge.artist.axes.images) == 2 - assert default_edge.right_image is not None - assert default_edge.left_image is not None - - default_edge.unselect() - - assert not default_edge.picked - assert len(default_edge.artist.axes.images) == 0 - assert default_edge.right_image is None - assert default_edge.left_image is None - - -class TestLigandNode: - def setup_method(self): - self.fig, self.ax = plt.subplots() - - def teardown_method(self): - plt.close(self.fig) - - def test_register_artist(self, default_node): - assert len(self.ax.texts) == 0 - default_node.register_artist(self.ax) - assert len(self.ax.texts) == 1 - assert self.ax.texts[0] == default_node.artist - - def test_extent(self, default_node): - default_node.register_artist(self.ax) - xmin, xmax, ymin, ymax = default_node.extent - assert xmin == pytest.approx(0.5) - assert ymin == pytest.approx(0.5) - # can't do anything about upper bounds - - def test_xy(self, default_node): - # default_node.register_artist(self.ax) - x, y = default_node.xy - assert x == pytest.approx(0.5) - assert y == pytest.approx(0.5) - - -def test_plot_atommapping_network(simple_network): - fig = plot_atommapping_network(simple_network.network) - assert isinstance(fig, matplotlib.figure.Figure) diff --git a/openfe/tests/utils/test_network_plotting.py b/openfe/tests/utils/test_network_plotting.py deleted file mode 100644 index 9c798eb8a..000000000 --- a/openfe/tests/utils/test_network_plotting.py +++ /dev/null @@ -1,560 +0,0 @@ -import pytest -from unittest import mock -from numpy import testing as npt - -from matplotlib import pyplot as plt -import networkx as nx -from konnektor.visualization.network_plotting import ( - Node, Edge, EventHandler, GraphDrawing -) - -from matplotlib.backend_bases import MouseEvent, MouseButton - - -def _get_fig_ax(fig): - if fig is None: - fig, _ = plt.subplots() - - if len(fig.axes) != 1: # -no-cov- - raise RuntimeError("Error in test setup: figure must have exactly " - "one Axes object associated") - - return fig, fig.axes[0] - - -def mock_event(event_name, xdata, ydata, fig=None): - fig, ax = _get_fig_ax(fig) - name = { - 'mousedown': 'button_press_event', - 'mouseup': 'button_release_event', - 'drag': 'motion_notify_event', - }[event_name] - - matplotlib_buttons = { - 'mousedown': MouseButton.LEFT, - 'mouseup': MouseButton.LEFT, - 'drag': MouseButton.LEFT, - } - button = matplotlib_buttons.get(event_name, None) - x, y = ax.transData.transform((xdata, ydata)) - return MouseEvent(name, fig.canvas, x, y, button) - - -def make_mock_graph(fig=None): - fig, ax = _get_fig_ax(fig) - - def make_mock_node(node, x, y): - return mock.Mock(node=node, x=x, y=y) - - def make_mock_edge(node1, node2, data): - return mock.Mock(node_artists=[node1, node2], data=data) - - node_A = make_mock_node("A", 0.0, 0.0) - node_B = make_mock_node("B", 0.5, 0.0) - node_C = make_mock_node("C", 0.5, 0.5) - node_D = make_mock_node("D", 0.0, 0.5) - edge_AB = make_mock_edge(node_A, node_B, {'data': "AB"}) - edge_BC = make_mock_edge(node_B, node_C, {'data': "BC"}) - edge_BD = make_mock_edge(node_B, node_D, {'data': "BD"}) - - mock_graph = mock.Mock( - nodes={node.node: node for node in [node_A, node_B, node_C, node_D]}, - edges={tuple(edge.node_artists): edge - for edge in [edge_AB, edge_BC, edge_BD]}, - ) - return mock_graph - - -class TestNode: - def setup_method(self): - self.node = Node("B", 0.5, 0.0) - self.fig, self.ax = plt.subplots() - self.node.register_artist(self.ax) - - def teardown_method(self): - plt.close(self.fig) - - def test_register_artist(self): - node = Node("B", 0.6, 0.0) - fig, ax = plt.subplots() - assert len(ax.patches) == 0 - node.register_artist(ax) - assert len(ax.patches) == 1 - assert node.artist == ax.patches[0] - plt.close(fig) - - def test_extent(self): - assert self.node.extent == (0.5, 0.6, 0.0, 0.1) - - def test_xy(self): - assert self.node.xy == (0.5, 0.0) - - def test_unselect(self): - # initially blue; turn it red; unselect should switch it back - assert self.node.artist.get_facecolor() == (0.0, 0.0, 1.0, 1.0) - self.node.artist.set(color="red") - assert self.node.artist.get_facecolor() != (0.0, 0.0, 1.0, 1.0) - self.node.unselect() - assert self.node.artist.get_facecolor() == (0.0, 0.0, 1.0, 1.0) - - def test_edge_select(self): - # initially blue; edge_select should turn it red - assert self.node.artist.get_facecolor() == (0.0, 0.0, 1.0, 1.0) - edge = mock.Mock() # unused in this method - self.node.edge_select(edge) - assert self.node.artist.get_facecolor() == (1.0, 0.0, 0.0, 1.0) - - def test_update_location(self): - assert self.node.artist.xy == (0.5, 0.0) - self.node.update_location(0.7, 0.5) - assert self.node.artist.xy == (0.7, 0.5) - assert self.node.xy == (0.7, 0.5) - - @pytest.mark.parametrize('point,expected', [ - ((0.55, 0.05), True), - ((0.5, 0.5), False), - ((-10, -10), False), - ]) - def test_contains(self, point, expected): - event = mock_event('drag', *point, fig=self.fig) - assert self.node.contains(event) == expected - - def test_on_mousedown_in_rect(self): - event = mock_event('mousedown', 0.55, 0.05, self.fig) - drawing_graph = make_mock_graph(self.fig) - assert Node.lock is None - assert self.node.press is None - - self.node.on_mousedown(event, drawing_graph) - assert Node.lock == self.node - assert self.node.press is not None - Node.lock = None - - def test_on_mousedown_in_axes(self): - event = mock_event('mousedown', 0.25, 0.25, self.fig) - drawing_graph = make_mock_graph(self.fig) - - assert Node.lock is None - assert self.node.press is None - self.node.on_mousedown(event, drawing_graph) - assert Node.lock is None - assert self.node.press is None - - def test_on_mousedown_out_axes(self): - node = Node("B", 0.5, 0.6) - event = mock_event('mousedown', 0.55, 0.05, self.fig) - drawing_graph = make_mock_graph(self.fig) - - fig2, ax2 = plt.subplots() - node.register_artist(ax2) - - assert Node.lock is None - assert node.press is None - node.on_mousedown(event, drawing_graph) - assert Node.lock is None - assert node.press is None - plt.close(fig2) - - def test_on_drag(self): - event = mock_event('drag', 0.7, 0.7, self.fig) - # this test some integration, so we need more than a mock - drawing_graph = GraphDrawing( - nx.MultiDiGraph(([("A", "B"), ("B", "C"), ("B", "D")])), - positions={"A": (0.0, 0.0), "B": (0.5, 0.0), - "C": (0.5, 0.5), "D": (0.0, 0.5)} - ) - # set up things that should happen on mousedown - Node.lock = self.node - self.node.press = (0.5, 0.0), (0.55, 0.05) - - self.node.on_drag(event, drawing_graph) - - npt.assert_allclose(self.node.xy, (0.65, 0.65)) - - # undo the lock; normally handled by mouseup - Node.lock = None - - def test_on_drag_do_nothing(self): - event = mock_event('drag', 0.7, 0.7, self.fig) - drawing_graph = make_mock_graph(self.fig) - - # don't set lock -- early exit - original = self.node.xy - self.node.on_drag(event, drawing_graph) - assert self.node.xy == original - - def test_on_drag_no_mousedown(self): - event = mock_event('drag', 0.7, 0.7, self.fig) - drawing_graph = make_mock_graph(self.fig) - Node.lock = self.node - - with pytest.raises(RuntimeError, match="drag until mouse down"): - self.node.on_drag(event, drawing_graph) - - Node.lock = None - - def test_on_mouseup(self): - event = mock_event('drag', 0.7, 0.7, self.fig) - drawing_graph = make_mock_graph(self.fig) - Node.lock = self.node - self.node.press = (0.5, 0.0), (0.55, 0.05) - - self.node.on_mouseup(event, drawing_graph) - assert Node.lock is None - assert self.node.press is None - - def test_blitting(self): - pytest.skip("Blitting hasn't been implemented yet") - - -class TestEdge: - def setup_method(self): - self.nodes = [Node("A", 0.0, 0.0), Node("B", 0.5, 0.0)] - self.data = {"data": "values"} - self.edge = Edge(*self.nodes, self.data) - self.fig, self.ax = plt.subplots() - self.ax.set_xlim(-1, 1) - self.ax.set_ylim(-1, 1) - self.edge.register_artist(self.ax) - - def teardown_method(self): - plt.close(self.fig) - - def test_register_artist(self): - fig, ax = plt.subplots() - edge = Edge(*self.nodes, self.data) - assert len(ax.get_lines()) == 0 - edge.register_artist(ax) - assert len(ax.get_lines()) == 1 - assert ax.get_lines()[0] == edge.artist - plt.close(fig) - - @pytest.mark.parametrize('point,expected', [ - ((0.25, 0.05), True), - ((0.6, 0.1), False), - ]) - def test_contains(self, point, expected): - event = mock_event('drag', *point, fig=self.fig) - assert self.edge.contains(event) == expected - - def test_edge_xs_ys(self): - npt.assert_allclose(self.edge._edge_xs_ys(*self.nodes), - ((0.05, 0.55), (0.05, 0.05))) - - def _get_colors(self): - colors = {node: node.artist.get_facecolor() - for node in self.nodes} - colors[self.edge] = self.edge.artist.get_color() - return colors - - def test_unselect(self): - original = self._get_colors() - - for node in self.nodes: - node.artist.set(color='red') - - self.edge.artist.set(color='red') - - # ensure that we have changed from the original values - changed = self._get_colors() - for key in original: - assert changed[key] != original[key] - - self.edge.unselect() - after = self._get_colors() - assert after == original - - def test_select(self): - event = mock_event('mouseup', 0.25, 0.05, self.fig) - drawing_graph = make_mock_graph(self.fig) - original = self._get_colors() - self.edge.select(event, drawing_graph) - changed = self._get_colors() - - for key in self.nodes: - assert changed[key] != original[key] - assert changed[key] == (1.0, 0.0, 0.0, 1.0) # red - - assert changed[self.edge] == "red" # mpl doesn't convert to RGBA?! - # it might be better in the future to pass that through some MPL - # func that converts color string to RGBA; the fact that MPL keeps - # color name in line2d seems like an implementation detail - - def test_update_locations(self): - for node in self.nodes: - x, y = node.xy - node.update_location(x + 0.2, y + 0.2) - - self.edge.update_locations() - npt.assert_allclose(self.edge.artist.get_xdata(), [0.25, 0.75]) - npt.assert_allclose(self.edge.artist.get_ydata(), [0.25, 0.25]) - - -class TestEventHandler: - def setup_method(self): - self.fig, self.ax = plt.subplots() - self.event_handler = EventHandler(graph=make_mock_graph(self.fig)) - graph = self.event_handler.graph - node = graph.nodes["C"] - edge = graph.edges[graph.nodes["B"], graph.nodes["C"]] - self.setup_contains = { - "node": (node, [node]), - "edge": (edge, [edge]), - "node+edge": (node, [node, edge]), - "miss": (None, []), - } - - def teardown_method(self): - plt.close(self.fig) - - def _mock_for_connections(self): - self.event_handler.on_mousedown = mock.Mock() - self.event_handler.on_mouseup = mock.Mock() - self.event_handler.on_drag = mock.Mock() - - @pytest.mark.parametrize('event_type', ['mousedown', 'mouseup', 'drag']) - def test_connect(self, event_type): - self._mock_for_connections() - event = mock_event(event_type, 0.2, 0.2, self.fig) - - methods = { - 'mousedown': self.event_handler.on_mousedown, - 'mouseup': self.event_handler.on_mouseup, - 'drag': self.event_handler.on_drag, - } - should_call = methods[event_type] - should_not_call = set(methods.values()) - {should_call} - assert len(self.event_handler.connections) == 0 - - self.event_handler.connect(self.fig.canvas) - assert len(self.event_handler.connections) == 3 - - # check that the event is processed - self.fig.canvas.callbacks.process(event.name, event) - should_call.assert_called_once() - for method in should_not_call: - assert not method.called - - @pytest.mark.parametrize('event_type', ['mousedown', 'mouseup', 'drag']) - def test_disconnect(self, event_type): - self._mock_for_connections() - fig, _ = plt.subplots() - event = mock_event(event_type, 0.2, 0.2, fig) - - self.event_handler.connect(fig.canvas) # not quite full isolation - assert len(self.event_handler.connections) == 3 - - self.event_handler.disconnect(fig.canvas) - assert len(self.event_handler.connections) == 0 - methods = [self.event_handler.on_mousedown, - self.event_handler.on_mousedown, - self.event_handler.on_drag] - - fig.canvas.callbacks.process(event.name, event) - for method in methods: - assert not method.called - - plt.close(fig) - - def _mock_contains(self, mock_objs): - graph = self.event_handler.graph - objs = list(graph.nodes.values()) + list(graph.edges.values()) - for obj in objs: - if obj in mock_objs: - obj.contains = mock.Mock(return_value=True) - else: - obj.contains = mock.Mock(return_value=False) - - @pytest.mark.parametrize('hit', ['node', 'edge', 'node+edge', 'miss']) - def test_get_event_container_select_node(self, hit): - expected, contains_event = self.setup_contains[hit] - expected_count = { - "node": 3, # nodes A, B, C - "edge": 6, # nodes A, B, C, D; edges AB, BC - "node+edge": 3, # nodes A, B, C - "miss": 7, # nodes A, B, C, D; edges AB BC, BD - }[hit] - self._mock_contains(contains_event) - event = mock.Mock() - found = self.event_handler._get_event_container(event) - assert found is expected - for container in contains_event: - if container is not expected: - assert not container.called - - graph = self.event_handler.graph - all_objs = list(graph.nodes.values()) + list(graph.edges.values()) - contains_count = sum(obj.contains.called for obj in all_objs) - assert contains_count == expected_count - - @pytest.mark.parametrize('hit', ['node', 'edge', 'node+edge', 'miss']) - def test_on_mousedown(self, hit): - expected, contains_event = self.setup_contains[hit] - self._mock_contains(contains_event) - event = mock_event('mousedown', 0.5, 0.5) - - assert self.event_handler.click_location is None - assert self.event_handler.active is None - self.event_handler.on_mousedown(event) - npt.assert_allclose(self.event_handler.click_location, (0.5, 0.5)) - assert self.event_handler.active is expected - if expected is not None: - expected.on_mousedown.assert_called_once() - - plt.close(event.canvas.figure) - - @pytest.mark.parametrize('is_active', [True, False]) - def test_on_drag(self, is_active): - node = self.event_handler.graph.nodes["C"] - node.artist.axes = self.ax - event = mock_event('drag', 0.25, 0.25, self.fig) - if is_active: - self.event_handler.active = node - - self.event_handler.on_drag(event) - - if is_active: - node.on_drag.assert_called_once() - else: - assert not node.on_drag.called - - @pytest.mark.parametrize('has_selected', [True, False]) - def test_on_mouseup_click_select(self, has_selected): - # start: mouse hasn't moved, and something is active - graph = self.event_handler.graph - edge = graph.edges[graph.nodes["B"], graph.nodes["C"]] - if has_selected: - old_selected = graph.edges[graph.nodes["A"], graph.nodes["B"]] - self.event_handler.selected = old_selected - - self._mock_contains([edge]) - event = mock_event('mouseup', 0.25, 0.25) - self.event_handler.click_location = (event.xdata, event.ydata) - self.event_handler.active = edge - - # this should select the active object - self.event_handler.on_mouseup(event) - - if has_selected: - old_selected.unselect.assert_called_once() - - edge.select.assert_called_once() - edge.on_mouseup.assert_called_once() - assert self.event_handler.selected is edge - assert self.event_handler.active is None - assert self.event_handler.click_location is None - graph.draw.assert_called_once() - - plt.close(event.canvas.figure) - - @pytest.mark.parametrize('has_selected', [True, False]) - def test_on_mouseup_click_not_select(self, has_selected): - # start: mouse hasn't moved, nothing is active - graph = self.event_handler.graph - if has_selected: - old_selected = graph.edges[graph.nodes["A"], graph.nodes["B"]] - self.event_handler.selected = old_selected - - event = mock_event('mouseup', 0.25, 0.25) - self.event_handler.click_location = (event.xdata, event.ydata) - - self.event_handler.on_mouseup(event) - - if has_selected: - old_selected.unselect.assert_called_once() - - assert self.event_handler.selected is None - assert self.event_handler.active is None - assert self.event_handler.click_location is None - graph.draw.assert_called_once() - plt.close(event.canvas.figure) - - @pytest.mark.parametrize('has_selected', [True, False]) - def test_on_mouseup_drag(self, has_selected): - # start: mouse has moved, something is active - graph = self.event_handler.graph - edge = graph.edges[graph.nodes["B"], graph.nodes["C"]] - if has_selected: - old_selected = graph.edges[graph.nodes["A"], graph.nodes["B"]] - self.event_handler.selected = old_selected - - event = mock_event('mouseup', 0.25, 0.25) - self.event_handler.click_location = (0.5, 0.5) - self.event_handler.active = edge - - self.event_handler.on_mouseup(event) - - if has_selected: - assert not old_selected.unselect.called - - assert not edge.selected.called - edge.on_mouseup.assert_called_once() - assert self.event_handler.active is None - assert self.event_handler.click_location is None - graph.draw.assert_called_once() - plt.close(event.canvas.figure) - - -class TestGraphDrawing: - def setup_method(self): - self.nx_graph = nx.MultiDiGraph() - self.nx_graph.add_edges_from([ - ("A", "B", {'data': "AB"}), - ("B", "C", {'data': "BC"}), - ("B", "D", {'data': "BD"}), - ]) - self.positions = { - "A": (0.0, 0.0), "B": (0.5, 0.0), "C": (0.5, 0.5), - "D": (-0.1, 0.6) - } - self.graph = GraphDrawing(self.nx_graph, positions=self.positions) - - def test_init(self): - # this also tests _register_node, _register_edge - assert len(self.graph.nodes) == 4 - assert len(self.graph.edges) == 3 - assert len(self.graph.fig.axes) == 1 - assert self.graph.fig.axes[0] is self.graph.ax - assert len(self.graph.ax.patches) == 4 - assert len(self.graph.ax.lines) == 3 - - def test_init_custom_ax(self): - fig, ax = plt.subplots() - graph = GraphDrawing(self.nx_graph, positions=self.positions, - ax=ax) - assert graph.fig is fig - assert graph.ax is ax - plt.close(fig) - - def test_register_node_error(self): - with pytest.raises(RuntimeError, match="multiple times"): - self.graph._register_node( - node=list(self.nx_graph.nodes)[0], - position=(0, 0) - ) - - @pytest.mark.parametrize('node,edges', [ - ("A", [("A", "B")]), - ("B", [("A", "B"), ("B", "C"), ("B", "D")]), - ("C", [("B", "C")]), - ]) - def test_edges_for_node(self, node, edges): - expected_edges = {self.graph.edges[n1, n2] for n1, n2 in edges} - assert set(self.graph.edges_for_node(node)) == expected_edges - - def test_get_nodes_extent(self): - assert self.graph._get_nodes_extent() == (-0.1, 0.6, 0.0, 0.7) - - def test_reset_bounds(self): - old_xlim = self.graph.ax.get_xlim() - old_ylim = self.graph.ax.get_ylim() - self.graph.ax.set_xlim(old_xlim[0] * 2, old_xlim[1] * 2) - self.graph.ax.set_ylim(old_ylim[0] * 2, old_ylim[1] * 2) - self.graph.reset_bounds() - assert self.graph.ax.get_xlim() == old_xlim - assert self.graph.ax.get_ylim() == old_ylim - - def test_draw(self): - # just a smoke test; there's really nothing that we can test here - # other that integration - self.graph.draw() From 4126d96fdf151009eee256209bb071fbf404db47 Mon Sep 17 00:00:00 2001 From: Alyssa Travitz Date: Thu, 18 Sep 2025 11:54:06 -0700 Subject: [PATCH 3/6] remove unused fixture --- openfe/tests/utils/conftest.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/openfe/tests/utils/conftest.py b/openfe/tests/utils/conftest.py index 0addb7ca3..91a4a0659 100644 --- a/openfe/tests/utils/conftest.py +++ b/openfe/tests/utils/conftest.py @@ -10,13 +10,6 @@ from ..conftest import mol_from_smiles -class _NetworkTestContainer(NamedTuple): - """Container to facilitate network testing""" - network: LigandNetwork - nodes: Iterable[SmallMoleculeComponent] - edges: Iterable[LigandAtomMapping] - n_nodes: int - n_edges: int @pytest.fixture @@ -35,20 +28,6 @@ def std_edges(mols): edge13 = LigandAtomMapping(mol1, mol3, {0: 0, 2: 1}) return edge12, edge23, edge13 - -@pytest.fixture -def simple_network(mols, std_edges): - """Network with no edges duplicated and all nodes in edges""" - network = LigandNetwork(std_edges) - return _NetworkTestContainer( - network=network, - nodes=mols, - edges=std_edges, - n_nodes=3, - n_edges=3, - ) - - @pytest.fixture(scope='session') def benzene_transforms(): # a dict of Molecules for benzene transformations From d454bf03be109a7c82757cae836d29770bf464f3 Mon Sep 17 00:00:00 2001 From: Alyssa Travitz Date: Thu, 18 Sep 2025 12:00:11 -0700 Subject: [PATCH 4/6] removed unused code --- openfe/tests/utils/conftest.py | 40 ---------------------------------- 1 file changed, 40 deletions(-) delete mode 100644 openfe/tests/utils/conftest.py diff --git a/openfe/tests/utils/conftest.py b/openfe/tests/utils/conftest.py deleted file mode 100644 index 91a4a0659..000000000 --- a/openfe/tests/utils/conftest.py +++ /dev/null @@ -1,40 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -import pytest -from rdkit import Chem -from importlib import resources - -from openfe import SmallMoleculeComponent, LigandAtomMapping, LigandNetwork -from typing import Iterable, NamedTuple - -from ..conftest import mol_from_smiles - - - - -@pytest.fixture -def mols(): - mol1 = SmallMoleculeComponent(mol_from_smiles("CCO")) - mol2 = SmallMoleculeComponent(mol_from_smiles("CC")) - mol3 = SmallMoleculeComponent(mol_from_smiles("CO")) - return mol1, mol2, mol3 - - -@pytest.fixture -def std_edges(mols): - mol1, mol2, mol3 = mols - edge12 = LigandAtomMapping(mol1, mol2, {0: 0, 1: 1}) - edge23 = LigandAtomMapping(mol2, mol3, {0: 0}) - edge13 = LigandAtomMapping(mol1, mol3, {0: 0, 2: 1}) - return edge12, edge23, edge13 - -@pytest.fixture(scope='session') -def benzene_transforms(): - # a dict of Molecules for benzene transformations - mols = {} - with resources.as_file(resources.files('openfe.tests.data')) as d: - fn = str(d / 'benzene_modifications.sdf') - supplier = Chem.SDMolSupplier(fn, removeHs=False) - for mol in supplier: - mols[mol.GetProp('_Name')] = SmallMoleculeComponent(mol) - return mols From da06d338834c952971def07a9541d10cbd43f5ca Mon Sep 17 00:00:00 2001 From: Alyssa Travitz Date: Thu, 18 Sep 2025 12:34:30 -0700 Subject: [PATCH 5/6] add TODO --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 6d41d5fd4..477751bf7 100644 --- a/environment.yml +++ b/environment.yml @@ -29,7 +29,7 @@ dependencies: - pooch - py3dmol - pydantic >=1.10.17 # practically, this almost always means pydantic v2 because of our other dependencies. - - pygraphviz + - pygraphviz # TODO: make this an optional dependency? - pytest - pytest-xdist - pytest-cov From 8e9ebb1eb343e655b2c43fc0ba786a727c2e820d Mon Sep 17 00:00:00 2001 From: Alyssa Travitz Date: Thu, 18 Sep 2025 13:33:03 -0700 Subject: [PATCH 6/6] remove typing-extensions dep --- environment.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/environment.yml b/environment.yml index 477751bf7..2c7625554 100644 --- a/environment.yml +++ b/environment.yml @@ -39,7 +39,6 @@ dependencies: - rdkit - rich - tqdm - - typing-extensions - zstandard # Issue #443 - pymbar>4.0