diff --git a/CHANGELOG.md b/CHANGELOG.md
index d80545e..1855689 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -13,6 +13,7 @@ All notable changes to this project will be documented in this file.
- Add example for Line Results to documentation [#341](https://github.com/ie3-institute/pypsdm/issues/341)
- Add colored Line Trace to plotting [#348](https://github.com/ie3-institute/pypsdm/issues/348)
- Add colored Node Trace to plotting [#349](https://github.com/ie3-institute/pypsdm/issues/349)
+- Add functionality to plot line traces without mapbox as vector graphic [#360](https://github.com/ie3-institute/pypsdm/issues/360)
### Changed
- Move `NBVAL` to dev dependencies [#374](https://github.com/ie3-institute/pypsdm/issues/374)
diff --git a/docs/nbs/plotting_utilities_trace_vector_graphic.ipynb b/docs/nbs/plotting_utilities_trace_vector_graphic.ipynb
new file mode 100644
index 0000000..253be2a
--- /dev/null
+++ b/docs/nbs/plotting_utilities_trace_vector_graphic.ipynb
@@ -0,0 +1,161 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "initial_id",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# NBVAL_SKIP\n",
+ "# Some jupyter notebook magic to reload modules automatically when they change\n",
+ "# not necessary for this specific notebook but useful in general\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "# Gives you high resolution images within the notebook\n",
+ "%config InlineBackend.figure_format = 'retina'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e5d8cd1a4adf4f11",
+ "metadata": {},
+ "source": [
+ "# Plotting Line traces as Vector Graphic (High Quality)\n",
+ "Since plotting line and node traces on a map, e.g. OpenStreetMap, using `Scattermapbox` the output will be rendered not as vector graphic.\n",
+ "\n",
+ "Setting `use_mapbox = False` allows to use `Scatter` which will output as vector graphic and thus allows to save figures in .svg or .pdf format."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3ed885bf22427a4a",
+ "metadata": {},
+ "source": [
+ "## Load Data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "81678c770f5da613",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from definitions import ROOT_DIR\n",
+ "import os\n",
+ "\n",
+ "# The PSDM specific input models can be imported from the pypsdm.models.input and\n",
+ "# pypsdm.models.result. The `GridWithResults` container is located in pypsdm.models.gwr\n",
+ "from pypsdm.models.gwr import GridWithResults\n",
+ "\n",
+ "grid_path = os.path.join(ROOT_DIR, \"tests\", \"resources\", \"simbench\", \"input\")\n",
+ "result_path = os.path.join(ROOT_DIR, \"tests\", \"resources\", \"simbench\", \"results\")\n",
+ "\n",
+ "# IO data models in general have a from_csv method to parse psdm files\n",
+ "gwr = GridWithResults.from_csv(grid_path, result_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ca40b9b6b739fa82",
+ "metadata": {},
+ "source": [
+ "## Get Line Results and Calculate Utilisation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "275adf2cb67a9d37",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# NBVAL_CHECK_OUTPUT\n",
+ "line_input_data = gwr.lines\n",
+ "line_utilization = gwr.lines_res.utilisation(line_input_data, side=\"a\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "64515d7eab450b23",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# NBVAL_CHECK_OUTPUT\n",
+ "import pandas as pd\n",
+ "\n",
+ "specific_time = pd.to_datetime(\"2016-01-02 12:00:00\")\n",
+ "# filter for timestamp\n",
+ "filtered_data = line_utilization.loc[[specific_time]].to_dict()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d8629a2a57260668",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pypsdm.plots.grid import create_zoom_box, grid_plot\n",
+ "\n",
+ "# zoom_box allows to focus on certain parts of your plot\n",
+ "zoom_box = create_zoom_box(53.665, 11.35, 53.62, 11.38)\n",
+ "\n",
+ "# to remove the axes and lat / lon grid simply set show_axes = False or remove the parameter\n",
+ "fig_svg = grid_plot(\n",
+ " gwr.grid,\n",
+ " cmap_lines=\"Jet\",\n",
+ " cmap_line_values=filtered_data,\n",
+ " cbar_line_title=\"Line Utilisation\",\n",
+ " zoom_box=zoom_box,\n",
+ " show_axes=True,\n",
+ " use_mapbox=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b28859b912d9c22c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig_svg"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2ba2dbff4e8bd18a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Plot can be saved as svg or other vector format\n",
+ "# fig_svg.write_image('save_as_svg.svg')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/pypsdm/plots/grid.py b/pypsdm/plots/grid.py
index 63c3155..3040a31 100644
--- a/pypsdm/plots/grid.py
+++ b/pypsdm/plots/grid.py
@@ -29,9 +29,12 @@ def grid_plot(
cmap_node_values: Optional[Union[list, dict]] = None,
cbar_node_title: Optional[str] = None,
mapbox_style: Optional[str] = "open-street-map",
+ use_mapbox: bool = True, # New parameter to control mapbox vs regular scatter
+ zoom_box: Optional[dict] = None, # New parameter for zoom box
+ show_axes: bool = True, # New parameter to control axis visibility
) -> go.Figure:
"""
- Plots the grid on an OpenStreetMap. Supports Line and Node highlighting as well as colored map for line traces. Lines that are disconnected due to open switches will be grey.
+ Unified grid plotting function that supports both mapbox and SVG-friendly modes.
ATTENTION:
We currently consider the node_b of the switches to be the auxiliary switch node.
@@ -55,6 +58,12 @@ def grid_plot(
or dict mapping node IDs to values.
cbar_node_title (Optional[str]): Title for the node colorbar.
mapbox_style (Optional[str]): Mapbox style. Defaults to open-street-map.
+ use_mapbox (bool): Whether to use mapbox (True) or regular scatter plots (False).
+ When False, creates SVG-friendly plots. Defaults to True.
+ zoom_box (Optional[dict]): Dictionary with keys 'lat_min', 'lat_max', 'lon_min', 'lon_max'
+ to zoom to a specific bounding box. If None, auto-fits to grid data.
+ show_axes (bool): Whether to show axis labels, ticks, and grid (only for non-mapbox).
+ Defaults to True.
Returns:
Figure: Plotly figure.
"""
@@ -70,8 +79,8 @@ def grid_plot(
show_line_colorbar and cmap_lines is not None and cmap_nodes is not None
)
+ # Add colorbar background if needed
if (show_line_colorbar and cmap_lines is not None) or cmap_nodes is not None:
- # Plot white half transparent rectangle as background for color bars
x0_value = 0.85 if both_color_bars else 0.925
fig.add_shape(
type="rect",
@@ -83,6 +92,7 @@ def grid_plot(
line=dict(color="rgba(255, 255, 255, 0.0)"),
)
+ # Process lines with colormap
if cmap_lines and cmap_line_values is not None:
try:
value_dict, cmin, cmax = _process_colormap_values(
@@ -100,79 +110,28 @@ def grid_plot(
value_dict=value_dict,
cbar_title=cbar_line_title,
show_colorbar=show_line_colorbar,
+ use_mapbox=use_mapbox,
),
axis=1, # type: ignore
)
if show_line_colorbar:
- custom_colorscale = [
- [i / 10, f"rgb({int(255 * (i / 10))},0,{int(255 * (1 - i / 10))})"]
- for i in range(11)
- ]
- lons, lats = _get_lons_lats(grid.lines.geo_position.iloc[0])
-
- # Add a separate trace for line colorbar (using a single point)
- fig.add_trace(
- go.Scattermapbox(
- mode="markers",
- lon=[lons[0]],
- lat=[lats[0]],
- marker=dict(
- size=0.1,
- opacity=0,
- color="#008000",
- colorscale=(
- custom_colorscale
- if cmap_lines == "fixed_line_rating_scale"
- else cmap_lines
- ),
- cmin=(
- cmin if not cmap_lines == "fixed_line_rating_scale" else 0.0
- ),
- cmax=(
- cmax if not cmap_lines == "fixed_line_rating_scale" else 1.0
- ),
- colorbar=dict(
- title=dict(
- text=cbar_line_title or "Line Value",
- font=dict(
- size=12,
- weight="normal",
- style="normal",
- color="#000000",
- ),
- ),
- x=x0_value,
- tickvals=(
- [i / 10 for i in range(11)]
- if cmap_lines == "fixed_line_rating_scale"
- else None
- ),
- ticktext=(
- [f"{round(i / 10.0, 2)}" for i in range(11)]
- if cmap_lines == "fixed_line_rating_scale"
- else None
- ),
- thickness=15,
- len=0.85,
- tickfont=dict(
- size=12,
- weight="normal",
- style="normal",
- color="#000000",
- ),
- ),
- showscale=True,
- ),
- hoverinfo="skip",
- showlegend=False,
- )
+ _add_line_colorbar(
+ fig, grid, cmap_lines, cmin, cmax, cbar_line_title, x0_value, use_mapbox
)
else:
connected_lines.data.apply(
- lambda line: _add_line_trace(fig, line, is_disconnected=False, highlights=line_highlights), axis=1 # type: ignore
+ lambda line: _add_line_trace(
+ fig,
+ line,
+ is_disconnected=False,
+ highlights=line_highlights,
+ use_mapbox=use_mapbox,
+ ),
+ axis=1,
)
+ # Add disconnected lines
disconnected_lines.data.apply(
lambda line: _add_line_trace(
fig,
@@ -180,10 +139,12 @@ def grid_plot(
is_disconnected=True,
highlights=line_highlights,
highlight_disconnected=highlight_disconnected,
- ), # type: ignore
+ use_mapbox=use_mapbox,
+ ),
axis=1,
)
+ # Add nodes
if cmap_nodes and cmap_node_values is not None:
_add_node_trace(
fig,
@@ -192,42 +153,18 @@ def grid_plot(
cmap=cmap_nodes,
cmap_node_values=cmap_node_values,
cbar_node_title=cbar_node_title,
+ use_mapbox=use_mapbox,
)
-
else:
- _add_node_trace(fig, grid, highlights=node_highlights)
-
- center_lat = grid.raw_grid.nodes.data["latitude"].mean()
- center_lon = grid.raw_grid.nodes.data["longitude"].mean()
+ _add_node_trace(fig, grid, highlights=node_highlights, use_mapbox=use_mapbox)
- # Dynamically calculate the zoom level
- lat_range = (
- grid.raw_grid.nodes.data["latitude"].max()
- - grid.raw_grid.nodes.data["latitude"].min()
- )
- lon_range = (
- grid.raw_grid.nodes.data["longitude"].max()
- - grid.raw_grid.nodes.data["longitude"].min()
- )
-
- zoom = 12 - max(lat_range, lon_range)
-
- fig.update_layout(
- # mapbox = {"zoom"=10},
- showlegend=False,
- mapbox_style=mapbox_style,
- margin={"r": 0, "t": 0, "l": 0, "b": 0},
- mapbox=dict(
- center=dict(lat=center_lat, lon=center_lon),
- zoom=zoom, # Adjust the zoom level as per the calculated heuristic
- style=mapbox_style,
- ),
- )
+ # Configure layout
+ _configure_layout(fig, grid, use_mapbox, mapbox_style, zoom_box, show_axes)
return fig
-def _process_colormap_values(cmap_vals: dict, cmap) -> (dict, float, float):
+def _process_colormap_values(cmap_vals: dict, cmap) -> tuple[dict, float, float]:
"""Process colormap values and return a dictionary with original values in case of fixed scale or one with normalized data."""
values = []
uuids = []
@@ -300,13 +237,10 @@ def _get_colormap_color(value, cmap):
color_str = colorscale[index]
rgb_string = color_str[1]
- # Remove 'rgb(' and ')' and split by commas
+
+ # Convert to hex
rgb_values = list(map(int, rgb_string[4:-1].split(",")))
- hex_string = "#%02x%02x%02x" % (
- int(rgb_values[0]),
- int(rgb_values[1]),
- int(rgb_values[2]),
- )
+ hex_string = "#%02x%02x%02x" % tuple(rgb_values)
return hex_string
@@ -320,8 +254,9 @@ def _add_line_trace(
value_dict: Optional[dict] = None,
cbar_title: Optional[str] = None,
show_colorbar: bool = True,
+ use_mapbox: bool = True,
):
- """Enhanced line trace function with colormap support."""
+ """Unified line trace function supporting both mapbox and regular scatter."""
lons, lats = _get_lons_lats(line_data.geo_position)
hover_text = line_data["id"]
@@ -331,74 +266,86 @@ def _add_line_trace(
colormap_value = None
line_id = line_data.name if hasattr(line_data, "name") else line_data["id"]
+
+ # Determine color
if not is_disconnected:
if cmap and value_dict and line_id in value_dict.keys():
value = value_dict[line_id]
colormap_value = _get_colormap_color(value, cmap)
- use_colorbar = True
+ hover_text += f"
{cbar_title or 'Value'}: {value:.3f}"
else:
colormap_value = "#008000"
- use_colorbar = False
# Check for highlights (overrides colormap)
if isinstance(highlights, dict):
for color, lines in highlights.items():
- if line_data.name in lines: # type: ignore
+ if line_data.name in lines:
line_color = rgb_to_hex(color)
highlighted = True
- use_colorbar = False
elif highlights is not None:
if line_data.name in highlights:
line_color = rgb_to_hex(RED)
highlighted = True
- use_colorbar = False
# Handle disconnected lines
if (highlight_disconnected is False) and is_disconnected:
- # Highlights override the disconnected status
if not highlighted:
line_color = rgb_to_hex(GREY)
- use_colorbar = False
- if cmap and colormap_value is not None:
- hover_text += f"
{cbar_title or 'Value'}: {value:.3f}"
+ # Use colormap color if not highlighted
+ if cmap and colormap_value is not None and not highlighted:
+ line_color = colormap_value
- # Add the lines with or without colorbar
- line_color_to_use = (
- colormap_value
- if colormap_value is not None and use_colorbar and show_colorbar is not None
- else line_color
- )
-
- fig.add_trace(
- go.Scattermapbox(
- mode="lines",
- lon=lons,
- lat=lats,
- hoverinfo="skip", # Skip hoverinfo for the lines
- line=dict(color=line_color_to_use, width=2),
- showlegend=False,
+ # Add the line trace
+ if use_mapbox:
+ fig.add_trace(
+ go.Scattermapbox(
+ mode="lines",
+ lon=lons,
+ lat=lats,
+ hoverinfo="skip",
+ line=dict(color=line_color, width=2),
+ showlegend=False,
+ )
)
- )
-
- # Create a LineString object from the line's coordinates
- line = LineString(zip(lons, lats))
-
- # Calculate the midpoint on the line based on distance
- midpoint = line.interpolate(line.length / 2)
-
- # Add a transparent marker at the midpoint of the line for hover text
- fig.add_trace(
- go.Scattermapbox(
- mode="markers",
- lon=[midpoint.x],
- lat=[midpoint.y],
- hoverinfo="text",
- hovertext=hover_text,
- marker=dict(size=0, opacity=0, color=line_color),
- showlegend=False,
+ # Add hover point at midpoint
+ line = LineString(zip(lons, lats))
+ midpoint = line.interpolate(line.length / 2)
+ fig.add_trace(
+ go.Scattermapbox(
+ mode="markers",
+ lon=[midpoint.x],
+ lat=[midpoint.y],
+ hoverinfo="text",
+ hovertext=hover_text,
+ marker=dict(size=0, opacity=0, color=line_color),
+ showlegend=False,
+ )
+ )
+ else:
+ fig.add_trace(
+ go.Scatter(
+ mode="lines",
+ x=lons,
+ y=lats,
+ hoverinfo="skip",
+ line=dict(color=line_color, width=2),
+ showlegend=False,
+ )
+ )
+ # Add hover point at midpoint
+ midpoint_idx = len(lons) // 2
+ fig.add_trace(
+ go.Scatter(
+ mode="markers",
+ x=[lons[midpoint_idx]],
+ y=[lats[midpoint_idx]],
+ hoverinfo="text",
+ hovertext=hover_text,
+ marker=dict(size=0, opacity=0, color=line_color),
+ showlegend=False,
+ )
)
- )
def _add_node_trace(
@@ -408,41 +355,23 @@ def _add_node_trace(
cmap: Optional[str] = None,
cmap_node_values: Optional[dict] = None,
cbar_node_title: Optional[str] = None,
+ use_mapbox: bool = True,
):
- """
- Node trace function with colormap support.
-
- Args:
- fig (go.Figure): The Plotly figure object.
- grid (GridContainer): The grid container holding node data.
- highlights (Optional): Highlights nodes. Defaults to None.
- List of uuids or dict[(r, g, b), str] with colors.
- cmap (Optional[str]): Name of a colormap (e.g., 'Viridis', 'Jet', etc.).
- cmap_node_values (Optional[dict]): Dictionary mapping node IDs to values for colormap.
- cbar_node_title (Optional[str]): Title for the colorbar.
+ """Unified node trace function supporting both mapbox and regular scatter."""
- Returns:
- Updates the given figure object with node traces and optional colorbar.
- """
-
- # Hover text generation
def to_hover_text_nodes(node: pd.Series):
hover_text = f"ID: {node.id}
"
-
if cmap_node_values is not None:
voltage_magnitude = cmap_node_values.get(node.name)
if voltage_magnitude is not None:
voltage_magnitude_str = f"{round(voltage_magnitude, 5)} pu"
hover_text += f"Voltage Magnitude: {voltage_magnitude_str}
"
-
hover_text += (
f"Latitude: {node['latitude']:.6f}
"
f"Longitude: {node['longitude']:.6f}"
)
-
return hover_text
- # Determine colors based on either highlights or cmap
def _get_node_color(node_uuid):
if highlights is not None:
# Handle explicit highlights first
@@ -453,14 +382,13 @@ def _get_node_color(node_uuid):
elif isinstance(highlights, list) and node_uuid in highlights:
return rgb_to_hex(RED) # Default highlight color is red
- # Handle colormap-based coloring
if (
cmap is not None
and cmap_node_values is not None
and node_uuid in cmap_node_values.keys()
):
value = cmap_node_values[node_uuid]
- # Normalize values between 0-1
+ cmin, cmax = 0.9, 1.1 # Fixed range for nodes
normalized_value = (value - cmin) / (cmax - cmin) if cmax != cmin else 0.5
return _get_colormap_color(normalized_value, cmap)
@@ -468,74 +396,291 @@ def _get_node_color(node_uuid):
nodes_data = grid.raw_grid.nodes.data
+ # Add colorbar if needed
if cmap and cmap_node_values is not None:
- cmin = 0.9
- cmax = 1.1
+ _add_node_colorbar(fig, nodes_data, cmap, cbar_node_title, use_mapbox)
- # Create a custom colorscale for the colorbar
- custom_colorscale = px.colors.get_colorscale(cmap)
- # Add a separate trace for colorbar
+ hover_texts = nodes_data.apply(
+ lambda node_data: to_hover_text_nodes(node_data), axis=1
+ )
+
+ color_list = []
+ for node_uuid in nodes_data.index:
+ color = _get_node_color(node_uuid)
+ color_list.append(color)
+
+ # Add the node trace
+ if use_mapbox:
fig.add_trace(
go.Scattermapbox(
mode="markers",
- lon=[nodes_data["longitude"][0]],
- lat=[nodes_data["latitude"][0]],
- marker=dict(
- size=0.1,
- opacity=0,
- colorscale=custom_colorscale,
- cmin=0.9,
- cmax=1.1,
- colorbar=dict(
- title=dict(
- text=cbar_node_title or "Node Value",
- font=dict(size=12, color="#000000"),
- ),
- x=0.925,
- tickvals=([0.9 + i * 2 / 100 for i in range(11)]),
- ticktext=([f"{round(0.9 + i*2 / 100, 2)}" for i in range(11)]),
- thickness=10,
- len=0.85,
- tickfont=dict(
- size=12, weight="normal", style="normal", color="#000000"
- ),
- ),
- ),
+ lon=nodes_data["longitude"],
+ lat=nodes_data["latitude"],
+ hovertext=hover_texts,
+ hoverinfo="text",
+ marker=dict(size=8, color=color_list),
+ showlegend=False,
+ )
+ )
+ else:
+ fig.add_trace(
+ go.Scatter(
+ mode="markers",
+ x=nodes_data["longitude"],
+ y=nodes_data["latitude"],
+ hovertext=hover_texts,
+ hoverinfo="text",
+ marker=dict(size=8, color=color_list),
+ showlegend=False,
+ )
+ )
+
+
+def _add_line_colorbar(
+ fig, grid, cmap_lines, cmin, cmax, cbar_line_title, x0_value, use_mapbox
+):
+ """Add line colorbar for both mapbox and regular plots."""
+ custom_colorscale = [
+ [i / 10, f"rgb({int(255 * (i / 10))},0,{int(255 * (1 - i / 10))})"]
+ for i in range(11)
+ ]
+
+ colorbar_config = dict(
+ title=dict(
+ text=cbar_line_title or "Line Value",
+ font=dict(size=12, color="#000000"),
+ ),
+ x=x0_value,
+ thickness=15,
+ len=0.85,
+ tickfont=dict(size=12, color="#000000"),
+ )
+
+ # Add tick configuration for fixed scale
+ if cmap_lines == "fixed_line_rating_scale":
+ colorbar_config.update(
+ {
+ "tickvals": [i / 10 for i in range(11)],
+ "ticktext": [f"{round(i / 10.0, 2)}" for i in range(11)],
+ }
+ )
+
+ marker_config = dict(
+ size=0.1,
+ opacity=0,
+ color="#008000",
+ colorscale=(
+ custom_colorscale if cmap_lines == "fixed_line_rating_scale" else cmap_lines
+ ),
+ cmin=(cmin if cmap_lines != "fixed_line_rating_scale" else 0.0),
+ cmax=(cmax if cmap_lines != "fixed_line_rating_scale" else 1.0),
+ colorbar=colorbar_config,
+ showscale=True,
+ )
+
+ if use_mapbox:
+ lons, lats = _get_lons_lats(grid.lines.geo_position.iloc[0])
+ fig.add_trace(
+ go.Scattermapbox(
+ mode="markers",
+ lon=[lons[0]],
+ lat=[lats[0]],
+ marker=marker_config,
+ hoverinfo="skip",
+ showlegend=False,
+ )
+ )
+ else:
+ nodes_data = grid.raw_grid.nodes.data
+ fig.add_trace(
+ go.Scatter(
+ mode="markers",
+ x=[nodes_data["longitude"].mean()],
+ y=[nodes_data["latitude"].mean()],
+ marker=marker_config,
hoverinfo="skip",
showlegend=False,
)
)
- hover_texts = nodes_data.apply(
- lambda node_data: to_hover_text_nodes(node_data), axis=1
+
+def _add_node_colorbar(fig, nodes_data, cmap, cbar_node_title, use_mapbox):
+ """Add node colorbar for both mapbox and regular plots."""
+ custom_colorscale = px.colors.get_colorscale(cmap)
+
+ colorbar_config = dict(
+ title=dict(
+ text=cbar_node_title or "Node Value",
+ font=dict(size=12, color="#000000"),
+ ),
+ x=0.925,
+ thickness=10,
+ len=0.85,
+ tickvals=[0.9 + i * 2 / 100 for i in range(11)],
+ ticktext=[f"{round(0.9 + i * 2 / 100, 2)}" for i in range(11)],
+ tickfont=dict(size=12, color="#000000"),
)
- node_colors = {}
- for _, node_data in nodes_data.iterrows():
- node_colors[node_data.name] = _get_node_color(node_data.name)
+ marker_config = dict(
+ size=0.1,
+ opacity=0,
+ colorscale=custom_colorscale,
+ cmin=0.9,
+ cmax=1.1,
+ colorbar=colorbar_config,
+ )
+
+ if use_mapbox:
+ fig.add_trace(
+ go.Scattermapbox(
+ mode="markers",
+ lon=[nodes_data["longitude"].mean()],
+ lat=[nodes_data["latitude"].mean()],
+ marker=marker_config,
+ hoverinfo="skip",
+ showlegend=False,
+ )
+ )
+ else:
+ fig.add_trace(
+ go.Scatter(
+ mode="markers",
+ x=[nodes_data["longitude"].mean()],
+ y=[nodes_data["latitude"].mean()],
+ marker=marker_config,
+ hoverinfo="skip",
+ showlegend=False,
+ )
+ )
- # Create a color list based on the ID column in nodes_data
- color_list = []
- for node_uuid in nodes_data.index:
- color = node_colors.get(
- node_uuid, rgb_to_hex(BLUE)
- ) # Default to blue if no color found
- color_list.append(color)
- fig.add_trace(
- go.Scattermapbox(
- mode="markers",
- lon=nodes_data["longitude"],
- lat=nodes_data["latitude"],
- hovertext=hover_texts,
- hoverinfo="text",
- marker=dict(size=8, color=color_list),
+def _configure_layout(fig, grid, use_mapbox, mapbox_style, zoom_box, show_axes):
+ """Configure the figure layout based on mapbox usage and zoom settings."""
+ # Determine zoom/extent settings
+ if zoom_box is not None:
+ lat_min, lat_max = zoom_box["lat_min"], zoom_box["lat_max"]
+ lon_min, lon_max = zoom_box["lon_min"], zoom_box["lon_max"]
+ center_lat = (lat_min + lat_max) / 2
+ center_lon = (lon_min + lon_max) / 2
+ lat_range = lat_max - lat_min
+ lon_range = lon_max - lon_min
+ else:
+ lat_min = grid.raw_grid.nodes.data["latitude"].min()
+ lat_max = grid.raw_grid.nodes.data["latitude"].max()
+ lon_min = grid.raw_grid.nodes.data["longitude"].min()
+ lon_max = grid.raw_grid.nodes.data["longitude"].max()
+ center_lat = grid.raw_grid.nodes.data["latitude"].mean()
+ center_lon = grid.raw_grid.nodes.data["longitude"].mean()
+ lat_range = lat_max - lat_min
+ lon_range = lon_max - lon_min
+
+ if use_mapbox:
+ # Mapbox layout
+ zoom = (
+ 14.5 - max(lat_range, lon_range)
+ if zoom_box
+ else 12 - max(lat_range, lon_range)
+ )
+ fig.update_layout(
showlegend=False,
+ mapbox_style=mapbox_style,
+ margin={"r": 0, "t": 0, "l": 0, "b": 0},
+ mapbox=dict(
+ center=dict(lat=center_lat, lon=center_lon),
+ zoom=zoom,
+ style=mapbox_style,
+ ),
+ )
+ else:
+ # Regular scatter plot layout
+ padding_lat = lat_range * 0.05 if lat_range > 0 else 0.001
+ padding_lon = lon_range * 0.05 if lon_range > 0 else 0.001
+
+ if show_axes:
+ xaxis_config = dict(
+ title="Longitude",
+ showgrid=True,
+ gridcolor="lightgray",
+ gridwidth=0.5,
+ range=[lon_min - padding_lon, lon_max + padding_lon],
+ showline=True,
+ linecolor="black",
+ linewidth=1,
+ ticks="outside",
+ showticklabels=True,
+ )
+ yaxis_config = dict(
+ title="Latitude",
+ showgrid=True,
+ gridcolor="lightgray",
+ gridwidth=0.5,
+ scaleanchor="x",
+ scaleratio=1,
+ range=[lat_min - padding_lat, lat_max + padding_lat],
+ showline=True,
+ linecolor="black",
+ linewidth=1,
+ ticks="outside",
+ showticklabels=True,
+ )
+ else:
+ xaxis_config = dict(
+ title="",
+ showgrid=False,
+ range=[lon_min - padding_lon, lon_max + padding_lon],
+ showline=False,
+ ticks="",
+ showticklabels=False,
+ zeroline=False,
+ )
+ yaxis_config = dict(
+ title="",
+ showgrid=False,
+ scaleanchor="x",
+ scaleratio=1,
+ range=[lat_min - padding_lat, lat_max + padding_lat],
+ showline=False,
+ ticks="",
+ showticklabels=False,
+ zeroline=False,
+ )
+
+ fig.update_layout(
+ showlegend=False,
+ margin={"r": 0, "t": 0, "l": 0, "b": 0},
+ xaxis=xaxis_config,
+ yaxis=yaxis_config,
+ plot_bgcolor="white",
)
- )
def _get_lons_lats(geojson: str):
"""Extract longitude and latitude coordinates from GeoJSON string."""
coordinates = json.loads(geojson)["coordinates"]
return list(zip(*coordinates)) # returns lons, lats
+
+
+def create_zoom_box(
+ upper_left_lat: float,
+ upper_left_lon: float,
+ bottom_right_lat: float,
+ bottom_right_lon: float,
+) -> dict:
+ """
+ Create a zoom box dictionary from center coordinates and span.
+
+ Args:
+ center_lat: Center latitude
+ center_lon: Center longitude
+ lat_span: Total latitude span (degrees)
+ lon_span: Total longitude span (degrees)
+
+ Returns:
+ Dictionary with lat_min, lat_max, lon_min, lon_max keys
+ """
+ return {
+ "lat_min": bottom_right_lat,
+ "lat_max": upper_left_lat,
+ "lon_min": upper_left_lon,
+ "lon_max": bottom_right_lon,
+ }
diff --git a/tests/docs/nbs/test_notebooks.py b/tests/docs/nbs/test_notebooks.py
index 4bd2bd1..12af54a 100644
--- a/tests/docs/nbs/test_notebooks.py
+++ b/tests/docs/nbs/test_notebooks.py
@@ -27,6 +27,7 @@ def test_notebook_only_for_errors_and_explicit_cell_checks():
ROOT_DIR + "/docs/nbs/plotting_utilities_colormap_lines.ipynb",
ROOT_DIR + "/docs/nbs/plotting_utilities_colormap_nodes.ipynb",
ROOT_DIR + "/docs/nbs/plotting_utilities_colormap_lines_and_nodes.ipynb",
+ ROOT_DIR + "/docs/nbs/plotting_utilities_trace_vector_graphic.ipynb",
]
exit_code = pytest.main(args)