diff --git a/.flake8 b/.flake8 index 62afefc..dee98cc 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] max-line-length = 130 exclude = .git,__pycache__,build,dist -ignore = E261 \ No newline at end of file +ignore = E261, W504 \ No newline at end of file diff --git a/NOTICE.md b/NOTICE.md new file mode 100644 index 0000000..2954973 --- /dev/null +++ b/NOTICE.md @@ -0,0 +1,44 @@ +# Dependencies and Licenses + +This project `MultiModalRouter` depends on the following libraries. All licenses are permissive and compatible with MIT licensing for this project. + +| Package | Version | License | License Link | +|---------|---------|---------|--------------| +| colorama | >=0.4.6 | BSD 3-Clause | [License](https://github.com/tartley/colorama/blob/master/LICENSE) | +| dill | >=0.4.0 | BSD | [License](https://github.com/uqfoundation/dill/blob/main/LICENSE) | +| filelock | >=3.19.1 | MIT | [License](https://github.com/tox-dev/py-filelock/blob/main/LICENSE) | +| fsspec | >=2025.9.0 | Apache 2.0 | [License](https://github.com/fsspec/filesystem_spec/blob/main/LICENSE) | +| Jinja2 | >=3.1.6 | BSD-3-Clause | [License](https://github.com/pallets/jinja/blob/main/LICENSE) | +| MarkupSafe | >=3.0.2 | BSD-3-Clause | [License](https://github.com/pallets/markupsafe/blob/main/LICENSE) | +| mpmath | >=1.3.0 | BSD | [License](https://github.com/fredrik-johansson/mpmath/blob/master/LICENSE) | +| networkx | >=3.5 | BSD | [License](https://github.com/networkx/networkx/blob/main/LICENSE.txt) | +| numpy | >=2.3.3 | BSD | [License](https://github.com/numpy/numpy/blob/main/LICENSE.txt) | +| pandas | >=2.3.2 | BSD-3-Clause | [License](https://github.com/pandas-dev/pandas/blob/main/LICENSE) | +| parquet | >=1.3.1 | Apache 2.0 | [License](https://github.com/urschrei/parquet-python/blob/master/LICENSE) | +| ply | >=3.11 | BSD | [License](https://github.com/dabeaz/ply/blob/master/LICENSE.txt) | +| pyarrow | >=21.0.0 | Apache 2.0 | [License](https://github.com/apache/arrow/blob/master/LICENSE) | +| python-dateutil | >=2.9.0.post0 | BSD | [License](https://github.com/dateutil/dateutil/blob/master/LICENSE.txt) | +| pytz | >=2025.2 | MIT | [License](https://github.com/stub42/pytz/blob/master/LICENSE) | +| setuptools | >=80.9.0 | MIT | [License](https://github.com/pypa/setuptools/blob/main/LICENSE) | +| six | >=1.17.0 | MIT | [License](https://github.com/benjaminp/six/blob/master/LICENSE) | +| sympy | >=1.14.0 | BSD | [License](https://github.com/sympy/sympy/blob/master/LICENSE) | +| thriftpy2 | >=0.5.3 | MIT | [License](https://github.com/Thriftpy/thriftpy2/blob/master/LICENSE) | +| tqdm | >=4.67.1 | MPL 2.0 | [License](https://github.com/tqdm/tqdm/blob/master/LICENSE) | +| typing_extensions | >=4.15.0 | PSF | [License](https://github.com/python/typing_extensions/blob/main/LICENSE) | +| tzdata | >=2025.2 | Public Domain | [License](https://github.com/python/tzdata) | + +## Optional Dependencies + +| Package | Version | License | License Link | +|---------|---------|---------|--------------| +| torch | >=2.8.0 | BSD | [License](https://github.com/pytorch/pytorch/blob/master/LICENSE) | +| plotly | >=6.3.0 | MIT | [License](https://github.com/plotly/plotly.py/blob/master/LICENSE) | +| pytest | >=8.0 | MIT | [License](https://github.com/pytest-dev/pytest/blob/main/LICENSE) | + +--- + +### Notes + +1. All packages listed above are permissively licensed (MIT, BSD, Apache 2.0, or Public Domain), so they are compatible with MIT licensing for this project. +2. If distributing this library, include this `DEPENDENCIES.md` file and your own MIT license file to give proper attribution. +3. Optional dependencies should be listed in documentation or `pyproject.toml` extras. diff --git a/README.md b/README.md index a566373..076df20 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,12 @@ The graph can be build from any data aslong as the required fields are present ( ![example from the maze solver](./docs/solvedMaze1.png) +## graph visualizations + +Use the build-in [visualization](./docs/visualization.md) tool to plot any `2D` or `3D` Graph. + +![example plot of flight paths](./docs/FlightPathPlot.png) + ## Important considerations for your usecase Depending on your usecase and datasets some features may not be usable see solutions below @@ -67,4 +73,6 @@ Depending on your usecase and datasets some features may not be usable see solut [see here](./LICENSE.md) +[dependencies](./NOTICE.md) + diff --git a/docs/FlightPathPlot.png b/docs/FlightPathPlot.png new file mode 100644 index 0000000..4e54641 Binary files /dev/null and b/docs/FlightPathPlot.png differ diff --git a/docs/examples/flightRouter/main.py b/docs/examples/flightRouter/main.py index 2402ee8..a0a66f3 100644 --- a/docs/examples/flightRouter/main.py +++ b/docs/examples/flightRouter/main.py @@ -5,6 +5,7 @@ from multimodalrouter import RouteGraph import os + def main(): path = os.path.dirname(os.path.abspath(__file__)) # initialize the graph @@ -17,21 +18,21 @@ def main(): # build the graph graph.build() # set start and end points - start = [60.866699,-162.272996] # Atmautluak Airport - end = [60.872747,-162.5247] #Kasigluk Airport + start = [60.866699, -162.272996] # Atmautluak Airport + end = [60.872747, -162.5247] # Kasigluk Airport start_hub = graph.findClosestHub(["airport"], start) # find the hubs end_hub = graph.findClosestHub(["airport"], end) # find the route route = graph.find_shortest_path( - start_hub.id, + start_hub.id, end_hub.id, - allowed_modes=["plane","car"], + allowed_modes=["plane", "car"], verbose=True - ) + ) # print the route print(route.flatPath if route else "No route found") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/docs/examples/flightRouter/plot.py b/docs/examples/flightRouter/plot.py new file mode 100644 index 0000000..d0ef7e9 --- /dev/null +++ b/docs/examples/flightRouter/plot.py @@ -0,0 +1,25 @@ +# dataclasses.py +# Copyright (c) 2025 Tobias Karusseit +# Licensed under the MIT License. See LICENSE file in the project root for full license information. + + +from multimodalrouter import RouteGraph +from multimodalrouter.graphics import GraphDisplay +import os + +if __name__ == "__main__": + path = os.path.dirname(os.path.abspath(__file__)) + graph = RouteGraph( + maxDistance=50, + transportModes={"airport": "fly", }, + dataPaths={"airport": os.path.join(path, "data", "fullDataset.csv")}, + compressed=False, + ) + + graph.build() + display = GraphDisplay(graph) + display.display( + displayEarth=True, + nodeTransform=GraphDisplay.degreesToCartesian3D, + edgeTransform=GraphDisplay.curvedEdges + ) diff --git a/docs/examples/mazePathfinder/data/createMaze.py b/docs/examples/mazePathfinder/data/createMaze.py index 1477da4..3612a09 100644 --- a/docs/examples/mazePathfinder/data/createMaze.py +++ b/docs/examples/mazePathfinder/data/createMaze.py @@ -5,15 +5,17 @@ import random import pandas as pd + # simple cell class for the maze class Cell: def __init__(self, x, y): - self.id = f"cell-{x,y}" + self.id = f"cell-{x, y}" self.x = x self.y = y self.visited = False self.connected = [] + def main(): # init a 10x10 maze mazeHeight = 10 @@ -53,7 +55,15 @@ def main(): cellStack.pop() # init the dataframe - data = pd.DataFrame(columns=["source", "destination", "distance", "source_lat", "source_lng", "destination_lat", "destination_lng"]) + data = pd.DataFrame(columns=[ + "source", + "destination", + "distance", + "source_lat", + "source_lng", + "destination_lat", + "destination_lng" + ]) # add the edges to the dataframe for cell in cells: for neighbor in cell.connected: @@ -61,4 +71,6 @@ def main(): # save the dataframe data.to_csv("docs/examples/mazePathfinder/data/maze.csv", index=False) -if __name__ == "__main__": main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/docs/examples/mazePathfinder/main.py b/docs/examples/mazePathfinder/main.py index d52ec9b..7b2775f 100644 --- a/docs/examples/mazePathfinder/main.py +++ b/docs/examples/mazePathfinder/main.py @@ -6,22 +6,26 @@ import os import pandas as pd + def main(): try: import matplotlib.pyplot as plt except ImportError: raise ImportError("matplotlib is not installed. Please install matplotlib to use this example.") - + path = os.path.dirname(os.path.abspath(__file__)) # init the maze df for the plot mazeDf = pd.read_csv(os.path.join(path, "data", "maze.csv")) # init the plot - plt.figure(figsize=(10,10)) + plt.figure(figsize=(10, 10)) # draw the maze + # draw the maze (grid lines) for _, row in mazeDf.iterrows(): - plt.plot([row.source_lat, row.destination_lat], - [row.source_lng, row.destination_lng], - "k-") # black line for edge + plt.plot( + [row.source_lng, row.destination_lng], # x = "lng" column + [row.source_lat, row.destination_lat], # y = "lat" column + "k-" + ) # initialize the graph graph = RouteGraph( @@ -35,7 +39,7 @@ def main(): graph.build() # find the shortest route route = graph.find_shortest_path( - start_id="cell-(0, 0)", + start_id="cell-(0, 0)", end_id="cell-(0, 9)", allowed_modes=["walk"], verbose=True, @@ -49,11 +53,17 @@ def main(): if s_prev is not None: h1 = graph.getHubById(s_prev) h2 = graph.getHubById(s) - plt.plot([h1.coords[0], h2.coords[0]], - [h1.coords[1], h2.coords[1]], - "b-") + # Swap coords so x=column, y=row + plt.plot( + [h1.coords[1], h2.coords[1]], # x-axis + [h1.coords[0], h2.coords[0]], # y-axis + "b-" + ) s_prev = s + # display the plot plt.show() - -if __name__ == "__main__": main() \ No newline at end of file + + +if __name__ == "__main__": + main() diff --git a/docs/examples/mazePathfinder/plot.py b/docs/examples/mazePathfinder/plot.py new file mode 100644 index 0000000..02a2ab0 --- /dev/null +++ b/docs/examples/mazePathfinder/plot.py @@ -0,0 +1,32 @@ +# dataclasses.py +# Copyright (c) 2025 Tobias Karusseit +# Licensed under the MIT License. See LICENSE file in the project root for full license information. + + +from multimodalrouter import RouteGraph +from multimodalrouter.graphics import GraphDisplay +import os + + +# custom transform to make lat lng to x y (-> lng lat) +def NodeTransform(coords): + for coord in coords: + yield list((coord[0], coord[1])) + + +if __name__ == "__main__": + path = os.path.dirname(os.path.abspath(__file__)) + # initialize the graph + graph = RouteGraph( + maxDistance=50, + transportModes={"cell": "walk", }, + dataPaths={"cell": os.path.join(path, "data", "maze.csv")}, + compressed=False, + drivingEnabled=False + ) + + graph.build() + # init the display + display = GraphDisplay(graph) + # display the graph (uses the transform to swap lat lng to x y) + display.display(nodeTransform=NodeTransform) diff --git a/docs/visualization.md b/docs/visualization.md new file mode 100644 index 0000000..f92d8e8 --- /dev/null +++ b/docs/visualization.md @@ -0,0 +1,108 @@ +[HOME](../README.md) + +# Graph Plotting + +Using the build-in graph plotting tool you can [plotly](https://plotly.com/python/) plot any graph in `2D` or `3D`, while defining [transformations](#transformations) for your coordiante space or even path curvature etc. + +## GraphDisplay + +```python +def __init__( + self, + graph: RouteGraph, + name: str = "Graph", + iconSize: int = 10 +) -> None: +``` + +#### args: + - graph: RouteDisplay = the graph instance you want to plot + - name: str = (not in use at the moment) + - iconSize: int = the size of the nodes in the plot + +#### example + +``` +gd = GraphDisplay(myGraphInstance) +``` + +[flight path CODE example on sphere](./examples/flightRouter/plot.py) + + +### display() + +The display function will collect data from your Graph and create a [plotly](https://plotly.com/python/) plot from it. + +```python +def display( + self, + nodeTransform=None, + edgeTransform=None, + displayEarth=False +): +``` + +#### args: + +- nodeTransform: function = a [transformation](#transformations) function that transformes all node coordinates +- edgeTransform: funstion = a function that [transformes](#transformations) all your edges +- displayEarth: bool = if True -> will display a sphere that (roughly) matches earth + +#### example: + +this call will create the plot for your graph while mapping all coords onto the surface of the earth + +```python +gd.display( + nodeTransform = gd.degreesToCartesian3D, + displayEarth: True +) +``` + +### transformations + +#### base function style + +IF you want to implement your own transformation function note that the call must adhere to the following parameters: + +```python +def customNodeTrandsform(coords: list[list[float]]): + return list[list[float]] + +def customEdgeTransform(start: list[list[float]], end: list[list[float]]): + return list[list[list[float]]] +``` + +#### args + +- coords: list[list[float]] = a nested list of coordinates for all nodes +- start: list[list[float]] = a nested list of all start coordinates +- end: list[list[float]] = a nested list of all end coordinates + +#### returns: + +- list[list[float]] = a list of all transformed node coordinates +- list[list[list[float]]] = a list of curves whare each curve / edge can have n points defining it + +### build-in Node Transforms: + +#### degreesToCartesian3D + +```python +@staticmethod + def degreesToCartesian3D(coords): +``` +This function maps any valid `2D` coordinates (best if in degrees) to spherical coords on the surface of earth + +### build-in Edge Transformations + +```python +@staticmethod + def curvedEdges(start, end, R=6371.0, H=0.05, n=20): +``` + +curves edges for coordinates on spheres (here earth) so that the edges curve along the spherical surface with a curvature that places the midpoint of the curve at $H \dot R$ above the surface. (great for displaying flights). + +If torch is installed this will use great-circle distance for the curves + +> Note if torch is not installed this will fall back to using `math` with quadratic bezier curves -> some curves may end up inside the sphere to bezier inaccuracy \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b48bd36..98d57cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "multimodalrouter" -version = "0.1.4" +version = "0.1.5" description = "A graph-based routing library for dynamic routing." readme = "README.md" license = { file = "LICENSE.md" } @@ -46,12 +46,14 @@ Repository = "https://github.com/K-T0BIAS/MultiModalRouter" [project.optional-dependencies] torch = ["torch>=2.8.0"] dev = [ - "pytest>=8.0" + "pytest>=8.0", + "plotly>=6.3.0" ] +plotly = ["plotly>=6.3.0"] [tool.setuptools] package-dir = {"" = "src"} -packages = ["multimodalrouter", "multimodalrouter.graph", "multimodalrouter.router", "multimodalrouter.utils"] +packages = ["multimodalrouter", "multimodalrouter.graph", "multimodalrouter.router", "multimodalrouter.utils", "multimodalrouter.graphics"] [project.scripts] multiModalRouter-build = "multimodalrouter.router.build:main" diff --git a/src/multimodalrouter/graph/dataclasses.py b/src/multimodalrouter/graph/dataclasses.py index c59f49b..eb1e4ac 100644 --- a/src/multimodalrouter/graph/dataclasses.py +++ b/src/multimodalrouter/graph/dataclasses.py @@ -50,6 +50,7 @@ def __init__(self, coords: list[float], id: str, hubType: str): self.coords: list[float] = coords self.id = id self.hubType = hubType + # dict like {mode -> {dest_id -> EdgeMetadata}} self.outgoing: dict[str, dict[str, EdgeMetadata]] = {} def addOutgoing(self, mode: str, dest_id: str, metrics: EdgeMetadata): diff --git a/src/multimodalrouter/graphics/__init__.py b/src/multimodalrouter/graphics/__init__.py new file mode 100644 index 0000000..a2f68fe --- /dev/null +++ b/src/multimodalrouter/graphics/__init__.py @@ -0,0 +1 @@ +from .graphicsWrapper import GraphDisplay # noqa: F401 diff --git a/src/multimodalrouter/graphics/graphicsWrapper.py b/src/multimodalrouter/graphics/graphicsWrapper.py new file mode 100644 index 0000000..40c0eb8 --- /dev/null +++ b/src/multimodalrouter/graphics/graphicsWrapper.py @@ -0,0 +1,323 @@ +# dataclasses.py +# Copyright (c) 2025 Tobias Karusseit +# Licensed under the MIT License. See LICENSE file in the project root for full license information. + + +from ..graph import RouteGraph +import plotly.graph_objects as go + + +class GraphDisplay(): + + def __init__(self, graph: RouteGraph, name: str = "Graph", iconSize: int = 10) -> None: + self.graph: RouteGraph = graph + self.name: str = name + self.iconSize: int = iconSize + + def _toPlotlyFormat( + self, + nodeTransform=None, + edgeTransform=None + ): + """ + transform the graph data into plotly format.to use the display function + + args: + - nodeTransform: function to transform the node coordinates (default = None) + - edgeTransform: function to transform the edge coordinates (default = None) + returns: + - None (modifies self.nodes and self.edges) + """ + self.nodes = { + f"{hub.hubType}-{hub.id}": { + "coords": hub.coords, + "hubType": hub.hubType, + "id": hub.id + } + for hub in self.graph._allHubs() + } + + self.edges = [ + { + "from": f"{hub.hubType}-{hub.id}", + "to": f"{self.graph.getHubById(dest).hubType}-{dest}", + **edge.allMetrics + } + for hub in self.graph._allHubs() + for _, edge in hub.outgoing.items() + for dest, edge in edge.items() + ] + self.dim = max(len(node.get("coords")) for node in self.nodes.values()) + + if nodeTransform is not None: + expandedCoords = [node.get("coords") + [0] * (self.dim - len(node.get("coords"))) for node in self.nodes.values()] + transformedCoords = nodeTransform(expandedCoords) + for node, coords in zip(self.nodes.values(), transformedCoords): + node["coords"] = coords + + self.dim = max(len(node.get("coords")) for node in self.nodes.values()) + + if edgeTransform is not None: + starts = [edge["from"] for edge in self.edges] + startCoords = [self.nodes[start]["coords"] for start in starts] + ends = [edge["to"] for edge in self.edges] + endCoords = [self.nodes[end]["coords"] for end in ends] + + transformedEdges = edgeTransform(startCoords, endCoords) + for edge, transformedEdge in zip(self.edges, transformedEdges): + edge["curve"] = transformedEdge + + def display( + self, + nodeTransform=None, + edgeTransform=None, + displayEarth=False + ): + """ + function to display any 2D or 3D RouteGraph + + args: + - nodeTransform: function to transform the node coordinates (default = None) + - edgeTransform: function to transform the edge coordinates (default = None) + - displayEarth: whether to display the earth as a background (default = False, only in 3D) + + returns: + - None (modifies self.nodes and self.edges opens the plot in a browser) + + """ + # transform the graph + self._toPlotlyFormat(nodeTransform, edgeTransform) + # init plotly placeholders + node_x, node_y, node_z, text, colors = [], [], [], [], [] + edge_x, edge_y, edge_z, edge_text = [], [], [], [] + + # add all the nodes + for node_key, node_data in self.nodes.items(): + x, y, *rest = node_data["coords"] + node_x.append(x) + node_y.append(y) + if self.dim == 3: + node_z.append(node_data["coords"][2]) + text.append(f"{node_data['id']}
Type: {node_data['hubType']}") + colors.append(hash(node_data['hubType']) % 10) + + # add all the edges + for edge in self.edges: + # check if edge has been transformed + if "curve" in edge: + curve = edge["curve"] + # add all the points of the edge + for point in curve: + edge_x.append(point[0]) + edge_y.append(point[1]) + if self.dim == 3: + edge_z.append(point[2]) + edge_x.append(None) + edge_y.append(None) + # if 3d add the extra none to close the edge + if self.dim == 3: + edge_z.append(None) + else: + source = self.nodes[edge["from"]]["coords"] + target = self.nodes[edge["to"]]["coords"] + + edge_x += [source[0], target[0], None] + edge_y += [source[1], target[1], None] + + if self.dim == 3: + edge_z += [source[2], target[2], None] + + # add text and hover display + hover = f"{edge['from']} → {edge['to']}" + metrics = {k: v for k, v in edge.items() if k not in ("from", "to", "curve")} + if metrics: + hover += "
" + "
".join(f"{k}: {v}" for k, v in metrics.items()) + edge_text.append(hover) + + if self.dim == 2: + # ceate the plot in 2d + node_trace = go.Scatter( + x=node_x, + y=node_y, + mode="markers", + hoverinfo="text", + text=text, + marker=dict( + size=self.iconSize, + color=colors, + colorscale="Viridis", + showscale=True + ) + ) + + edge_trace = go.Scatter( + x=edge_x, + y=edge_y, + line=dict(width=2, color="#888"), + hoverinfo="text", + text=edge_text, + mode="lines" + ) + + elif self.dim == 3: + # create the plot in 3d + node_trace = go.Scatter3d( + x=node_x, + y=node_y, + z=node_z, + mode="markers", + hoverinfo="text", + text=text, + marker=dict( + size=self.iconSize, + color=colors, + colorscale="Viridis", + showscale=True + ) + ) + + edge_trace = go.Scatter3d( + x=edge_x, + y=edge_y, + z=edge_z, + line=dict(width=1, color="#888"), + hoverinfo="text", + text=edge_text, + mode="lines", + opacity=0.6 + ) + + # create the plotly figure + fig = go.Figure(data=[edge_trace, node_trace]) + # render earth / sphere in 3d + if self.dim == 3 and displayEarth: + try: + import numpy as np + R = 6369.9 # sphere radius + u = np.linspace(0, 2 * np.pi, 50) # azimuthal angle + v = np.linspace(0, np.pi, 50) # polar angle + u, v = np.meshgrid(u, v) + + # Cartesian coordinates + x = R * np.cos(u) * np.sin(v) + y = R * np.sin(u) * np.sin(v) + z = R * np.cos(v) + except ImportError: + raise ImportError("numpy is required to display the earth") + + sphere_surface = go.Surface( + x=x, y=y, z=z, + colorscale='Blues', + opacity=1, + showscale=False, + hoverinfo='skip' + ) + + fig.add_trace(sphere_surface) + + fig.update_layout(title="Interactive Graph", showlegend=False, hovermode="closest") + fig.show() + + @staticmethod + def degreesToCartesian3D(coords): + try: + import torch + C = torch.tensor(coords) + if C.dim() == 1: + C = C.unsqueeze(0) + R = 6371.0 + lat = torch.deg2rad(C[:, 0]) + lng = torch.deg2rad(C[:, 1]) + x = R * torch.cos(lat) * torch.cos(lng) + y = R * torch.cos(lat) * torch.sin(lng) + z = R * torch.sin(lat) + return list(torch.stack((x, y, z), dim=1).numpy()) + except ImportError: + import math + R = 6371.0 + output = [] + for lat, lng in coords: + lat = math.radians(lat) + lng = math.radians(lng) + x = R * math.cos(lat) * math.cos(lng) + y = R * math.cos(lat) * math.sin(lng) + z = R * math.sin(lat) + output.append([x, y, z]) + return output + + @staticmethod + def curvedEdges(start, end, R=6371.0, H=0.05, n=20): + try: + # if torch and np are available calc vectorized graeter circle curves + import numpy as np + import torch + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + start_np = np.array(start, dtype=np.float32) + end_np = np.array(end, dtype=np.float32) + + start = torch.tensor(start_np, device=device) + end = torch.tensor(end_np, device=device) + start = start.float() + end = end.float() + + # normalize to sphere + start_norm = R * start / start.norm(dim=1, keepdim=True) + end_norm = R * end / end.norm(dim=1, keepdim=True) + + # compute angle between vectors + dot = (start_norm * end_norm).sum(dim=1, keepdim=True) / (R**2) + dot = torch.clamp(dot, -1.0, 1.0) + theta = torch.acos(dot).unsqueeze(2) # shape: (num_edges,1,1) + + # linear interpolation along great circle + t = torch.linspace(0, 1, n, device=device).view(1, n, 1) + one_minus_t = 1 - t + sin_theta = torch.sin(theta) + sin_theta[sin_theta == 0] = 1e-6 + + factor_start = torch.sin(one_minus_t * theta) / sin_theta + factor_end = torch.sin(t * theta) / sin_theta + + curve = factor_start * start_norm.unsqueeze(1) + factor_end * end_norm.unsqueeze(1) + + # normalize to radius + curve = R * curve / curve.norm(dim=2, keepdim=True) + + # apply radial lift at curve center using sin weight + weight = torch.sin(torch.pi * t) # 0 at endpoints, 1 at center + curve = curve * (1 + H * weight) + + return curve + except ImportError: + # fallback to calculating quadratic bezier curves with math + import math + curves_all = [] + + def multiply_vec(vec, factor): + return [factor * x for x in vec] + + def add_vec(*vecs): + return [sum(items) for items in zip(*vecs)] + + for startP, endP in zip(start, end): + mid = [(s + e) / 2 for s, e in zip(startP, endP)] + norm = math.sqrt(sum(c ** 2 for c in mid)) + mid_proj = [R * c / norm for c in mid] + mid_arch = [c * (1 + H) for c in mid_proj] + + curve = [] + for i in range(n): + t_i = i / (n - 1) + one_minus_t = 1 - t_i + point = add_vec( + multiply_vec(startP, one_minus_t ** 2), + multiply_vec(mid_arch, 2 * one_minus_t * t_i), + multiply_vec(endP, t_i ** 2) + ) + curve.append(point) + + curves_all.append(curve) + + return curves_all diff --git a/tests/unit/test_graphics_wrapper.py b/tests/unit/test_graphics_wrapper.py new file mode 100644 index 0000000..1c552d7 --- /dev/null +++ b/tests/unit/test_graphics_wrapper.py @@ -0,0 +1,84 @@ +import unittest +from unittest.mock import MagicMock +from multimodalrouter.graphics import GraphDisplay + + +class MockHub: + def __init__(self, hubType, id, coords): + self.hubType = hubType + self.id = id + self.coords = coords + self.outgoing = {} + + +class MockRouteGraph: + def __init__(self, hubs): + self._hubs = hubs + + def _allHubs(self): + return self._hubs + + def getHubById(self, hub_id): + for hub in self._hubs: + if hub.id == hub_id: + return hub + return None + + +class TestGraphDisplay(unittest.TestCase): + + def setUp(self): + # create mock hubs + self.hubs = [ + MockHub("cell", "0", [0, 0]), + MockHub("cell", "1", [1, 1]) + ] + # add an edge from 0 to 1 + edge = MagicMock() + edge.allMetrics = {"distance": 1} + self.hubs[0].outgoing = {"1": {"1": edge}} + self.graph = MockRouteGraph(self.hubs) + + self.display = GraphDisplay(self.graph) + + def test_node_transform(self): + # transform: add 1 to all coordinates + def nodeTransform(coords): + return [[x + 1, y + 1] for x, y, *rest in coords] + + self.display._toPlotlyFormat(nodeTransform=nodeTransform) + self.assertEqual(self.display.nodes["cell-0"]["coords"][:2], [1, 1]) + self.assertEqual(self.display.nodes["cell-1"]["coords"][:2], [2, 2]) + + def test_edge_transform(self): + # transform edges into straight lines + def edgeTransform(starts, ends): + return [[[s[0], s[1]], [e[0], e[1]]] for s, e in zip(starts, ends)] + + self.display._toPlotlyFormat(edgeTransform=edgeTransform) + self.assertIn("curve", self.display.edges[0]) + self.assertEqual(self.display.edges[0]["curve"][0], [0, 0]) + self.assertEqual(self.display.edges[0]["curve"][-1], [1, 1]) + + def test_degreesToCartesian3D(self): + coords = [[0, 0], [90, 0], [0, 90]] + result = GraphDisplay.degreesToCartesian3D(coords) + self.assertEqual(len(result), 3) + # first point should be on x-axis + self.assertAlmostEqual(result[0][0], 6371.0, places=0) + # second point should be at north pole + self.assertAlmostEqual(result[1][2], 6371.0, places=0) + + def test_curvedEdges(self): + start = [[1, 0, 0]] + end = [[0, 1, 0]] + curve = GraphDisplay.curvedEdges(start, end, R=1.0, H=0.0, n=5) + self.assertEqual(len(curve), 1) + self.assertEqual(len(curve[0]), 5) # n points + # first and last points match start/end + self.assertAlmostEqual(curve[0][0][0], 1.0, places=5) + self.assertAlmostEqual(curve[0][-1][1], 1.0, places=5) + + +if __name__ == "__main__": + unittest.main()