diff --git a/.flake8 b/.flake8
index 62afefc..dee98cc 100644
--- a/.flake8
+++ b/.flake8
@@ -1,4 +1,4 @@
[flake8]
max-line-length = 130
exclude = .git,__pycache__,build,dist
-ignore = E261
\ No newline at end of file
+ignore = E261, W504
\ No newline at end of file
diff --git a/NOTICE.md b/NOTICE.md
new file mode 100644
index 0000000..2954973
--- /dev/null
+++ b/NOTICE.md
@@ -0,0 +1,44 @@
+# Dependencies and Licenses
+
+This project `MultiModalRouter` depends on the following libraries. All licenses are permissive and compatible with MIT licensing for this project.
+
+| Package | Version | License | License Link |
+|---------|---------|---------|--------------|
+| colorama | >=0.4.6 | BSD 3-Clause | [License](https://github.com/tartley/colorama/blob/master/LICENSE) |
+| dill | >=0.4.0 | BSD | [License](https://github.com/uqfoundation/dill/blob/main/LICENSE) |
+| filelock | >=3.19.1 | MIT | [License](https://github.com/tox-dev/py-filelock/blob/main/LICENSE) |
+| fsspec | >=2025.9.0 | Apache 2.0 | [License](https://github.com/fsspec/filesystem_spec/blob/main/LICENSE) |
+| Jinja2 | >=3.1.6 | BSD-3-Clause | [License](https://github.com/pallets/jinja/blob/main/LICENSE) |
+| MarkupSafe | >=3.0.2 | BSD-3-Clause | [License](https://github.com/pallets/markupsafe/blob/main/LICENSE) |
+| mpmath | >=1.3.0 | BSD | [License](https://github.com/fredrik-johansson/mpmath/blob/master/LICENSE) |
+| networkx | >=3.5 | BSD | [License](https://github.com/networkx/networkx/blob/main/LICENSE.txt) |
+| numpy | >=2.3.3 | BSD | [License](https://github.com/numpy/numpy/blob/main/LICENSE.txt) |
+| pandas | >=2.3.2 | BSD-3-Clause | [License](https://github.com/pandas-dev/pandas/blob/main/LICENSE) |
+| parquet | >=1.3.1 | Apache 2.0 | [License](https://github.com/urschrei/parquet-python/blob/master/LICENSE) |
+| ply | >=3.11 | BSD | [License](https://github.com/dabeaz/ply/blob/master/LICENSE.txt) |
+| pyarrow | >=21.0.0 | Apache 2.0 | [License](https://github.com/apache/arrow/blob/master/LICENSE) |
+| python-dateutil | >=2.9.0.post0 | BSD | [License](https://github.com/dateutil/dateutil/blob/master/LICENSE.txt) |
+| pytz | >=2025.2 | MIT | [License](https://github.com/stub42/pytz/blob/master/LICENSE) |
+| setuptools | >=80.9.0 | MIT | [License](https://github.com/pypa/setuptools/blob/main/LICENSE) |
+| six | >=1.17.0 | MIT | [License](https://github.com/benjaminp/six/blob/master/LICENSE) |
+| sympy | >=1.14.0 | BSD | [License](https://github.com/sympy/sympy/blob/master/LICENSE) |
+| thriftpy2 | >=0.5.3 | MIT | [License](https://github.com/Thriftpy/thriftpy2/blob/master/LICENSE) |
+| tqdm | >=4.67.1 | MPL 2.0 | [License](https://github.com/tqdm/tqdm/blob/master/LICENSE) |
+| typing_extensions | >=4.15.0 | PSF | [License](https://github.com/python/typing_extensions/blob/main/LICENSE) |
+| tzdata | >=2025.2 | Public Domain | [License](https://github.com/python/tzdata) |
+
+## Optional Dependencies
+
+| Package | Version | License | License Link |
+|---------|---------|---------|--------------|
+| torch | >=2.8.0 | BSD | [License](https://github.com/pytorch/pytorch/blob/master/LICENSE) |
+| plotly | >=6.3.0 | MIT | [License](https://github.com/plotly/plotly.py/blob/master/LICENSE) |
+| pytest | >=8.0 | MIT | [License](https://github.com/pytest-dev/pytest/blob/main/LICENSE) |
+
+---
+
+### Notes
+
+1. All packages listed above are permissively licensed (MIT, BSD, Apache 2.0, or Public Domain), so they are compatible with MIT licensing for this project.
+2. If distributing this library, include this `DEPENDENCIES.md` file and your own MIT license file to give proper attribution.
+3. Optional dependencies should be listed in documentation or `pyproject.toml` extras.
diff --git a/README.md b/README.md
index a566373..076df20 100644
--- a/README.md
+++ b/README.md
@@ -46,6 +46,12 @@ The graph can be build from any data aslong as the required fields are present (

+## graph visualizations
+
+Use the build-in [visualization](./docs/visualization.md) tool to plot any `2D` or `3D` Graph.
+
+
+
## Important considerations for your usecase
Depending on your usecase and datasets some features may not be usable see solutions below
@@ -67,4 +73,6 @@ Depending on your usecase and datasets some features may not be usable see solut
[see here](./LICENSE.md)
+[dependencies](./NOTICE.md)
+
diff --git a/docs/FlightPathPlot.png b/docs/FlightPathPlot.png
new file mode 100644
index 0000000..4e54641
Binary files /dev/null and b/docs/FlightPathPlot.png differ
diff --git a/docs/examples/flightRouter/main.py b/docs/examples/flightRouter/main.py
index 2402ee8..a0a66f3 100644
--- a/docs/examples/flightRouter/main.py
+++ b/docs/examples/flightRouter/main.py
@@ -5,6 +5,7 @@
from multimodalrouter import RouteGraph
import os
+
def main():
path = os.path.dirname(os.path.abspath(__file__))
# initialize the graph
@@ -17,21 +18,21 @@ def main():
# build the graph
graph.build()
# set start and end points
- start = [60.866699,-162.272996] # Atmautluak Airport
- end = [60.872747,-162.5247] #Kasigluk Airport
+ start = [60.866699, -162.272996] # Atmautluak Airport
+ end = [60.872747, -162.5247] # Kasigluk Airport
start_hub = graph.findClosestHub(["airport"], start) # find the hubs
end_hub = graph.findClosestHub(["airport"], end)
# find the route
route = graph.find_shortest_path(
- start_hub.id,
+ start_hub.id,
end_hub.id,
- allowed_modes=["plane","car"],
+ allowed_modes=["plane", "car"],
verbose=True
- )
+ )
# print the route
print(route.flatPath if route else "No route found")
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/docs/examples/flightRouter/plot.py b/docs/examples/flightRouter/plot.py
new file mode 100644
index 0000000..d0ef7e9
--- /dev/null
+++ b/docs/examples/flightRouter/plot.py
@@ -0,0 +1,25 @@
+# dataclasses.py
+# Copyright (c) 2025 Tobias Karusseit
+# Licensed under the MIT License. See LICENSE file in the project root for full license information.
+
+
+from multimodalrouter import RouteGraph
+from multimodalrouter.graphics import GraphDisplay
+import os
+
+if __name__ == "__main__":
+ path = os.path.dirname(os.path.abspath(__file__))
+ graph = RouteGraph(
+ maxDistance=50,
+ transportModes={"airport": "fly", },
+ dataPaths={"airport": os.path.join(path, "data", "fullDataset.csv")},
+ compressed=False,
+ )
+
+ graph.build()
+ display = GraphDisplay(graph)
+ display.display(
+ displayEarth=True,
+ nodeTransform=GraphDisplay.degreesToCartesian3D,
+ edgeTransform=GraphDisplay.curvedEdges
+ )
diff --git a/docs/examples/mazePathfinder/data/createMaze.py b/docs/examples/mazePathfinder/data/createMaze.py
index 1477da4..3612a09 100644
--- a/docs/examples/mazePathfinder/data/createMaze.py
+++ b/docs/examples/mazePathfinder/data/createMaze.py
@@ -5,15 +5,17 @@
import random
import pandas as pd
+
# simple cell class for the maze
class Cell:
def __init__(self, x, y):
- self.id = f"cell-{x,y}"
+ self.id = f"cell-{x, y}"
self.x = x
self.y = y
self.visited = False
self.connected = []
+
def main():
# init a 10x10 maze
mazeHeight = 10
@@ -53,7 +55,15 @@ def main():
cellStack.pop()
# init the dataframe
- data = pd.DataFrame(columns=["source", "destination", "distance", "source_lat", "source_lng", "destination_lat", "destination_lng"])
+ data = pd.DataFrame(columns=[
+ "source",
+ "destination",
+ "distance",
+ "source_lat",
+ "source_lng",
+ "destination_lat",
+ "destination_lng"
+ ])
# add the edges to the dataframe
for cell in cells:
for neighbor in cell.connected:
@@ -61,4 +71,6 @@ def main():
# save the dataframe
data.to_csv("docs/examples/mazePathfinder/data/maze.csv", index=False)
-if __name__ == "__main__": main()
\ No newline at end of file
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/examples/mazePathfinder/main.py b/docs/examples/mazePathfinder/main.py
index d52ec9b..7b2775f 100644
--- a/docs/examples/mazePathfinder/main.py
+++ b/docs/examples/mazePathfinder/main.py
@@ -6,22 +6,26 @@
import os
import pandas as pd
+
def main():
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("matplotlib is not installed. Please install matplotlib to use this example.")
-
+
path = os.path.dirname(os.path.abspath(__file__))
# init the maze df for the plot
mazeDf = pd.read_csv(os.path.join(path, "data", "maze.csv"))
# init the plot
- plt.figure(figsize=(10,10))
+ plt.figure(figsize=(10, 10))
# draw the maze
+ # draw the maze (grid lines)
for _, row in mazeDf.iterrows():
- plt.plot([row.source_lat, row.destination_lat],
- [row.source_lng, row.destination_lng],
- "k-") # black line for edge
+ plt.plot(
+ [row.source_lng, row.destination_lng], # x = "lng" column
+ [row.source_lat, row.destination_lat], # y = "lat" column
+ "k-"
+ )
# initialize the graph
graph = RouteGraph(
@@ -35,7 +39,7 @@ def main():
graph.build()
# find the shortest route
route = graph.find_shortest_path(
- start_id="cell-(0, 0)",
+ start_id="cell-(0, 0)",
end_id="cell-(0, 9)",
allowed_modes=["walk"],
verbose=True,
@@ -49,11 +53,17 @@ def main():
if s_prev is not None:
h1 = graph.getHubById(s_prev)
h2 = graph.getHubById(s)
- plt.plot([h1.coords[0], h2.coords[0]],
- [h1.coords[1], h2.coords[1]],
- "b-")
+ # Swap coords so x=column, y=row
+ plt.plot(
+ [h1.coords[1], h2.coords[1]], # x-axis
+ [h1.coords[0], h2.coords[0]], # y-axis
+ "b-"
+ )
s_prev = s
+
# display the plot
plt.show()
-
-if __name__ == "__main__": main()
\ No newline at end of file
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/examples/mazePathfinder/plot.py b/docs/examples/mazePathfinder/plot.py
new file mode 100644
index 0000000..02a2ab0
--- /dev/null
+++ b/docs/examples/mazePathfinder/plot.py
@@ -0,0 +1,32 @@
+# dataclasses.py
+# Copyright (c) 2025 Tobias Karusseit
+# Licensed under the MIT License. See LICENSE file in the project root for full license information.
+
+
+from multimodalrouter import RouteGraph
+from multimodalrouter.graphics import GraphDisplay
+import os
+
+
+# custom transform to make lat lng to x y (-> lng lat)
+def NodeTransform(coords):
+ for coord in coords:
+ yield list((coord[0], coord[1]))
+
+
+if __name__ == "__main__":
+ path = os.path.dirname(os.path.abspath(__file__))
+ # initialize the graph
+ graph = RouteGraph(
+ maxDistance=50,
+ transportModes={"cell": "walk", },
+ dataPaths={"cell": os.path.join(path, "data", "maze.csv")},
+ compressed=False,
+ drivingEnabled=False
+ )
+
+ graph.build()
+ # init the display
+ display = GraphDisplay(graph)
+ # display the graph (uses the transform to swap lat lng to x y)
+ display.display(nodeTransform=NodeTransform)
diff --git a/docs/visualization.md b/docs/visualization.md
new file mode 100644
index 0000000..f92d8e8
--- /dev/null
+++ b/docs/visualization.md
@@ -0,0 +1,108 @@
+[HOME](../README.md)
+
+# Graph Plotting
+
+Using the build-in graph plotting tool you can [plotly](https://plotly.com/python/) plot any graph in `2D` or `3D`, while defining [transformations](#transformations) for your coordiante space or even path curvature etc.
+
+## GraphDisplay
+
+```python
+def __init__(
+ self,
+ graph: RouteGraph,
+ name: str = "Graph",
+ iconSize: int = 10
+) -> None:
+```
+
+#### args:
+ - graph: RouteDisplay = the graph instance you want to plot
+ - name: str = (not in use at the moment)
+ - iconSize: int = the size of the nodes in the plot
+
+#### example
+
+```
+gd = GraphDisplay(myGraphInstance)
+```
+
+[flight path CODE example on sphere](./examples/flightRouter/plot.py)
+
+
+### display()
+
+The display function will collect data from your Graph and create a [plotly](https://plotly.com/python/) plot from it.
+
+```python
+def display(
+ self,
+ nodeTransform=None,
+ edgeTransform=None,
+ displayEarth=False
+):
+```
+
+#### args:
+
+- nodeTransform: function = a [transformation](#transformations) function that transformes all node coordinates
+- edgeTransform: funstion = a function that [transformes](#transformations) all your edges
+- displayEarth: bool = if True -> will display a sphere that (roughly) matches earth
+
+#### example:
+
+this call will create the plot for your graph while mapping all coords onto the surface of the earth
+
+```python
+gd.display(
+ nodeTransform = gd.degreesToCartesian3D,
+ displayEarth: True
+)
+```
+
+### transformations
+
+#### base function style
+
+IF you want to implement your own transformation function note that the call must adhere to the following parameters:
+
+```python
+def customNodeTrandsform(coords: list[list[float]]):
+ return list[list[float]]
+
+def customEdgeTransform(start: list[list[float]], end: list[list[float]]):
+ return list[list[list[float]]]
+```
+
+#### args
+
+- coords: list[list[float]] = a nested list of coordinates for all nodes
+- start: list[list[float]] = a nested list of all start coordinates
+- end: list[list[float]] = a nested list of all end coordinates
+
+#### returns:
+
+- list[list[float]] = a list of all transformed node coordinates
+- list[list[list[float]]] = a list of curves whare each curve / edge can have n points defining it
+
+### build-in Node Transforms:
+
+#### degreesToCartesian3D
+
+```python
+@staticmethod
+ def degreesToCartesian3D(coords):
+```
+This function maps any valid `2D` coordinates (best if in degrees) to spherical coords on the surface of earth
+
+### build-in Edge Transformations
+
+```python
+@staticmethod
+ def curvedEdges(start, end, R=6371.0, H=0.05, n=20):
+```
+
+curves edges for coordinates on spheres (here earth) so that the edges curve along the spherical surface with a curvature that places the midpoint of the curve at $H \dot R$ above the surface. (great for displaying flights).
+
+If torch is installed this will use great-circle distance for the curves
+
+> Note if torch is not installed this will fall back to using `math` with quadratic bezier curves -> some curves may end up inside the sphere to bezier inaccuracy
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index b48bd36..98d57cc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "multimodalrouter"
-version = "0.1.4"
+version = "0.1.5"
description = "A graph-based routing library for dynamic routing."
readme = "README.md"
license = { file = "LICENSE.md" }
@@ -46,12 +46,14 @@ Repository = "https://github.com/K-T0BIAS/MultiModalRouter"
[project.optional-dependencies]
torch = ["torch>=2.8.0"]
dev = [
- "pytest>=8.0"
+ "pytest>=8.0",
+ "plotly>=6.3.0"
]
+plotly = ["plotly>=6.3.0"]
[tool.setuptools]
package-dir = {"" = "src"}
-packages = ["multimodalrouter", "multimodalrouter.graph", "multimodalrouter.router", "multimodalrouter.utils"]
+packages = ["multimodalrouter", "multimodalrouter.graph", "multimodalrouter.router", "multimodalrouter.utils", "multimodalrouter.graphics"]
[project.scripts]
multiModalRouter-build = "multimodalrouter.router.build:main"
diff --git a/src/multimodalrouter/graph/dataclasses.py b/src/multimodalrouter/graph/dataclasses.py
index c59f49b..eb1e4ac 100644
--- a/src/multimodalrouter/graph/dataclasses.py
+++ b/src/multimodalrouter/graph/dataclasses.py
@@ -50,6 +50,7 @@ def __init__(self, coords: list[float], id: str, hubType: str):
self.coords: list[float] = coords
self.id = id
self.hubType = hubType
+ # dict like {mode -> {dest_id -> EdgeMetadata}}
self.outgoing: dict[str, dict[str, EdgeMetadata]] = {}
def addOutgoing(self, mode: str, dest_id: str, metrics: EdgeMetadata):
diff --git a/src/multimodalrouter/graphics/__init__.py b/src/multimodalrouter/graphics/__init__.py
new file mode 100644
index 0000000..a2f68fe
--- /dev/null
+++ b/src/multimodalrouter/graphics/__init__.py
@@ -0,0 +1 @@
+from .graphicsWrapper import GraphDisplay # noqa: F401
diff --git a/src/multimodalrouter/graphics/graphicsWrapper.py b/src/multimodalrouter/graphics/graphicsWrapper.py
new file mode 100644
index 0000000..40c0eb8
--- /dev/null
+++ b/src/multimodalrouter/graphics/graphicsWrapper.py
@@ -0,0 +1,323 @@
+# dataclasses.py
+# Copyright (c) 2025 Tobias Karusseit
+# Licensed under the MIT License. See LICENSE file in the project root for full license information.
+
+
+from ..graph import RouteGraph
+import plotly.graph_objects as go
+
+
+class GraphDisplay():
+
+ def __init__(self, graph: RouteGraph, name: str = "Graph", iconSize: int = 10) -> None:
+ self.graph: RouteGraph = graph
+ self.name: str = name
+ self.iconSize: int = iconSize
+
+ def _toPlotlyFormat(
+ self,
+ nodeTransform=None,
+ edgeTransform=None
+ ):
+ """
+ transform the graph data into plotly format.to use the display function
+
+ args:
+ - nodeTransform: function to transform the node coordinates (default = None)
+ - edgeTransform: function to transform the edge coordinates (default = None)
+ returns:
+ - None (modifies self.nodes and self.edges)
+ """
+ self.nodes = {
+ f"{hub.hubType}-{hub.id}": {
+ "coords": hub.coords,
+ "hubType": hub.hubType,
+ "id": hub.id
+ }
+ for hub in self.graph._allHubs()
+ }
+
+ self.edges = [
+ {
+ "from": f"{hub.hubType}-{hub.id}",
+ "to": f"{self.graph.getHubById(dest).hubType}-{dest}",
+ **edge.allMetrics
+ }
+ for hub in self.graph._allHubs()
+ for _, edge in hub.outgoing.items()
+ for dest, edge in edge.items()
+ ]
+ self.dim = max(len(node.get("coords")) for node in self.nodes.values())
+
+ if nodeTransform is not None:
+ expandedCoords = [node.get("coords") + [0] * (self.dim - len(node.get("coords"))) for node in self.nodes.values()]
+ transformedCoords = nodeTransform(expandedCoords)
+ for node, coords in zip(self.nodes.values(), transformedCoords):
+ node["coords"] = coords
+
+ self.dim = max(len(node.get("coords")) for node in self.nodes.values())
+
+ if edgeTransform is not None:
+ starts = [edge["from"] for edge in self.edges]
+ startCoords = [self.nodes[start]["coords"] for start in starts]
+ ends = [edge["to"] for edge in self.edges]
+ endCoords = [self.nodes[end]["coords"] for end in ends]
+
+ transformedEdges = edgeTransform(startCoords, endCoords)
+ for edge, transformedEdge in zip(self.edges, transformedEdges):
+ edge["curve"] = transformedEdge
+
+ def display(
+ self,
+ nodeTransform=None,
+ edgeTransform=None,
+ displayEarth=False
+ ):
+ """
+ function to display any 2D or 3D RouteGraph
+
+ args:
+ - nodeTransform: function to transform the node coordinates (default = None)
+ - edgeTransform: function to transform the edge coordinates (default = None)
+ - displayEarth: whether to display the earth as a background (default = False, only in 3D)
+
+ returns:
+ - None (modifies self.nodes and self.edges opens the plot in a browser)
+
+ """
+ # transform the graph
+ self._toPlotlyFormat(nodeTransform, edgeTransform)
+ # init plotly placeholders
+ node_x, node_y, node_z, text, colors = [], [], [], [], []
+ edge_x, edge_y, edge_z, edge_text = [], [], [], []
+
+ # add all the nodes
+ for node_key, node_data in self.nodes.items():
+ x, y, *rest = node_data["coords"]
+ node_x.append(x)
+ node_y.append(y)
+ if self.dim == 3:
+ node_z.append(node_data["coords"][2])
+ text.append(f"{node_data['id']}
Type: {node_data['hubType']}")
+ colors.append(hash(node_data['hubType']) % 10)
+
+ # add all the edges
+ for edge in self.edges:
+ # check if edge has been transformed
+ if "curve" in edge:
+ curve = edge["curve"]
+ # add all the points of the edge
+ for point in curve:
+ edge_x.append(point[0])
+ edge_y.append(point[1])
+ if self.dim == 3:
+ edge_z.append(point[2])
+ edge_x.append(None)
+ edge_y.append(None)
+ # if 3d add the extra none to close the edge
+ if self.dim == 3:
+ edge_z.append(None)
+ else:
+ source = self.nodes[edge["from"]]["coords"]
+ target = self.nodes[edge["to"]]["coords"]
+
+ edge_x += [source[0], target[0], None]
+ edge_y += [source[1], target[1], None]
+
+ if self.dim == 3:
+ edge_z += [source[2], target[2], None]
+
+ # add text and hover display
+ hover = f"{edge['from']} → {edge['to']}"
+ metrics = {k: v for k, v in edge.items() if k not in ("from", "to", "curve")}
+ if metrics:
+ hover += "
" + "
".join(f"{k}: {v}" for k, v in metrics.items())
+ edge_text.append(hover)
+
+ if self.dim == 2:
+ # ceate the plot in 2d
+ node_trace = go.Scatter(
+ x=node_x,
+ y=node_y,
+ mode="markers",
+ hoverinfo="text",
+ text=text,
+ marker=dict(
+ size=self.iconSize,
+ color=colors,
+ colorscale="Viridis",
+ showscale=True
+ )
+ )
+
+ edge_trace = go.Scatter(
+ x=edge_x,
+ y=edge_y,
+ line=dict(width=2, color="#888"),
+ hoverinfo="text",
+ text=edge_text,
+ mode="lines"
+ )
+
+ elif self.dim == 3:
+ # create the plot in 3d
+ node_trace = go.Scatter3d(
+ x=node_x,
+ y=node_y,
+ z=node_z,
+ mode="markers",
+ hoverinfo="text",
+ text=text,
+ marker=dict(
+ size=self.iconSize,
+ color=colors,
+ colorscale="Viridis",
+ showscale=True
+ )
+ )
+
+ edge_trace = go.Scatter3d(
+ x=edge_x,
+ y=edge_y,
+ z=edge_z,
+ line=dict(width=1, color="#888"),
+ hoverinfo="text",
+ text=edge_text,
+ mode="lines",
+ opacity=0.6
+ )
+
+ # create the plotly figure
+ fig = go.Figure(data=[edge_trace, node_trace])
+ # render earth / sphere in 3d
+ if self.dim == 3 and displayEarth:
+ try:
+ import numpy as np
+ R = 6369.9 # sphere radius
+ u = np.linspace(0, 2 * np.pi, 50) # azimuthal angle
+ v = np.linspace(0, np.pi, 50) # polar angle
+ u, v = np.meshgrid(u, v)
+
+ # Cartesian coordinates
+ x = R * np.cos(u) * np.sin(v)
+ y = R * np.sin(u) * np.sin(v)
+ z = R * np.cos(v)
+ except ImportError:
+ raise ImportError("numpy is required to display the earth")
+
+ sphere_surface = go.Surface(
+ x=x, y=y, z=z,
+ colorscale='Blues',
+ opacity=1,
+ showscale=False,
+ hoverinfo='skip'
+ )
+
+ fig.add_trace(sphere_surface)
+
+ fig.update_layout(title="Interactive Graph", showlegend=False, hovermode="closest")
+ fig.show()
+
+ @staticmethod
+ def degreesToCartesian3D(coords):
+ try:
+ import torch
+ C = torch.tensor(coords)
+ if C.dim() == 1:
+ C = C.unsqueeze(0)
+ R = 6371.0
+ lat = torch.deg2rad(C[:, 0])
+ lng = torch.deg2rad(C[:, 1])
+ x = R * torch.cos(lat) * torch.cos(lng)
+ y = R * torch.cos(lat) * torch.sin(lng)
+ z = R * torch.sin(lat)
+ return list(torch.stack((x, y, z), dim=1).numpy())
+ except ImportError:
+ import math
+ R = 6371.0
+ output = []
+ for lat, lng in coords:
+ lat = math.radians(lat)
+ lng = math.radians(lng)
+ x = R * math.cos(lat) * math.cos(lng)
+ y = R * math.cos(lat) * math.sin(lng)
+ z = R * math.sin(lat)
+ output.append([x, y, z])
+ return output
+
+ @staticmethod
+ def curvedEdges(start, end, R=6371.0, H=0.05, n=20):
+ try:
+ # if torch and np are available calc vectorized graeter circle curves
+ import numpy as np
+ import torch
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ start_np = np.array(start, dtype=np.float32)
+ end_np = np.array(end, dtype=np.float32)
+
+ start = torch.tensor(start_np, device=device)
+ end = torch.tensor(end_np, device=device)
+ start = start.float()
+ end = end.float()
+
+ # normalize to sphere
+ start_norm = R * start / start.norm(dim=1, keepdim=True)
+ end_norm = R * end / end.norm(dim=1, keepdim=True)
+
+ # compute angle between vectors
+ dot = (start_norm * end_norm).sum(dim=1, keepdim=True) / (R**2)
+ dot = torch.clamp(dot, -1.0, 1.0)
+ theta = torch.acos(dot).unsqueeze(2) # shape: (num_edges,1,1)
+
+ # linear interpolation along great circle
+ t = torch.linspace(0, 1, n, device=device).view(1, n, 1)
+ one_minus_t = 1 - t
+ sin_theta = torch.sin(theta)
+ sin_theta[sin_theta == 0] = 1e-6
+
+ factor_start = torch.sin(one_minus_t * theta) / sin_theta
+ factor_end = torch.sin(t * theta) / sin_theta
+
+ curve = factor_start * start_norm.unsqueeze(1) + factor_end * end_norm.unsqueeze(1)
+
+ # normalize to radius
+ curve = R * curve / curve.norm(dim=2, keepdim=True)
+
+ # apply radial lift at curve center using sin weight
+ weight = torch.sin(torch.pi * t) # 0 at endpoints, 1 at center
+ curve = curve * (1 + H * weight)
+
+ return curve
+ except ImportError:
+ # fallback to calculating quadratic bezier curves with math
+ import math
+ curves_all = []
+
+ def multiply_vec(vec, factor):
+ return [factor * x for x in vec]
+
+ def add_vec(*vecs):
+ return [sum(items) for items in zip(*vecs)]
+
+ for startP, endP in zip(start, end):
+ mid = [(s + e) / 2 for s, e in zip(startP, endP)]
+ norm = math.sqrt(sum(c ** 2 for c in mid))
+ mid_proj = [R * c / norm for c in mid]
+ mid_arch = [c * (1 + H) for c in mid_proj]
+
+ curve = []
+ for i in range(n):
+ t_i = i / (n - 1)
+ one_minus_t = 1 - t_i
+ point = add_vec(
+ multiply_vec(startP, one_minus_t ** 2),
+ multiply_vec(mid_arch, 2 * one_minus_t * t_i),
+ multiply_vec(endP, t_i ** 2)
+ )
+ curve.append(point)
+
+ curves_all.append(curve)
+
+ return curves_all
diff --git a/tests/unit/test_graphics_wrapper.py b/tests/unit/test_graphics_wrapper.py
new file mode 100644
index 0000000..1c552d7
--- /dev/null
+++ b/tests/unit/test_graphics_wrapper.py
@@ -0,0 +1,84 @@
+import unittest
+from unittest.mock import MagicMock
+from multimodalrouter.graphics import GraphDisplay
+
+
+class MockHub:
+ def __init__(self, hubType, id, coords):
+ self.hubType = hubType
+ self.id = id
+ self.coords = coords
+ self.outgoing = {}
+
+
+class MockRouteGraph:
+ def __init__(self, hubs):
+ self._hubs = hubs
+
+ def _allHubs(self):
+ return self._hubs
+
+ def getHubById(self, hub_id):
+ for hub in self._hubs:
+ if hub.id == hub_id:
+ return hub
+ return None
+
+
+class TestGraphDisplay(unittest.TestCase):
+
+ def setUp(self):
+ # create mock hubs
+ self.hubs = [
+ MockHub("cell", "0", [0, 0]),
+ MockHub("cell", "1", [1, 1])
+ ]
+ # add an edge from 0 to 1
+ edge = MagicMock()
+ edge.allMetrics = {"distance": 1}
+ self.hubs[0].outgoing = {"1": {"1": edge}}
+ self.graph = MockRouteGraph(self.hubs)
+
+ self.display = GraphDisplay(self.graph)
+
+ def test_node_transform(self):
+ # transform: add 1 to all coordinates
+ def nodeTransform(coords):
+ return [[x + 1, y + 1] for x, y, *rest in coords]
+
+ self.display._toPlotlyFormat(nodeTransform=nodeTransform)
+ self.assertEqual(self.display.nodes["cell-0"]["coords"][:2], [1, 1])
+ self.assertEqual(self.display.nodes["cell-1"]["coords"][:2], [2, 2])
+
+ def test_edge_transform(self):
+ # transform edges into straight lines
+ def edgeTransform(starts, ends):
+ return [[[s[0], s[1]], [e[0], e[1]]] for s, e in zip(starts, ends)]
+
+ self.display._toPlotlyFormat(edgeTransform=edgeTransform)
+ self.assertIn("curve", self.display.edges[0])
+ self.assertEqual(self.display.edges[0]["curve"][0], [0, 0])
+ self.assertEqual(self.display.edges[0]["curve"][-1], [1, 1])
+
+ def test_degreesToCartesian3D(self):
+ coords = [[0, 0], [90, 0], [0, 90]]
+ result = GraphDisplay.degreesToCartesian3D(coords)
+ self.assertEqual(len(result), 3)
+ # first point should be on x-axis
+ self.assertAlmostEqual(result[0][0], 6371.0, places=0)
+ # second point should be at north pole
+ self.assertAlmostEqual(result[1][2], 6371.0, places=0)
+
+ def test_curvedEdges(self):
+ start = [[1, 0, 0]]
+ end = [[0, 1, 0]]
+ curve = GraphDisplay.curvedEdges(start, end, R=1.0, H=0.0, n=5)
+ self.assertEqual(len(curve), 1)
+ self.assertEqual(len(curve[0]), 5) # n points
+ # first and last points match start/end
+ self.assertAlmostEqual(curve[0][0][0], 1.0, places=5)
+ self.assertAlmostEqual(curve[0][-1][1], 1.0, places=5)
+
+
+if __name__ == "__main__":
+ unittest.main()