diff --git a/CHANGELOG.md b/CHANGELOG.md index d80545e..1855689 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ All notable changes to this project will be documented in this file. - Add example for Line Results to documentation [#341](https://github.com/ie3-institute/pypsdm/issues/341) - Add colored Line Trace to plotting [#348](https://github.com/ie3-institute/pypsdm/issues/348) - Add colored Node Trace to plotting [#349](https://github.com/ie3-institute/pypsdm/issues/349) +- Add functionality to plot line traces without mapbox as vector graphic [#360](https://github.com/ie3-institute/pypsdm/issues/360) ### Changed - Move `NBVAL` to dev dependencies [#374](https://github.com/ie3-institute/pypsdm/issues/374) diff --git a/docs/nbs/plotting_utilities_trace_vector_graphic.ipynb b/docs/nbs/plotting_utilities_trace_vector_graphic.ipynb new file mode 100644 index 0000000..253be2a --- /dev/null +++ b/docs/nbs/plotting_utilities_trace_vector_graphic.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "initial_id", + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "# Some jupyter notebook magic to reload modules automatically when they change\n", + "# not necessary for this specific notebook but useful in general\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "# Gives you high resolution images within the notebook\n", + "%config InlineBackend.figure_format = 'retina'" + ] + }, + { + "cell_type": "markdown", + "id": "e5d8cd1a4adf4f11", + "metadata": {}, + "source": [ + "# Plotting Line traces as Vector Graphic (High Quality)\n", + "Since plotting line and node traces on a map, e.g. OpenStreetMap, using `Scattermapbox` the output will be rendered not as vector graphic.\n", + "\n", + "Setting `use_mapbox = False` allows to use `Scatter` which will output as vector graphic and thus allows to save figures in .svg or .pdf format." + ] + }, + { + "cell_type": "markdown", + "id": "3ed885bf22427a4a", + "metadata": {}, + "source": [ + "## Load Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81678c770f5da613", + "metadata": {}, + "outputs": [], + "source": [ + "from definitions import ROOT_DIR\n", + "import os\n", + "\n", + "# The PSDM specific input models can be imported from the pypsdm.models.input and\n", + "# pypsdm.models.result. The `GridWithResults` container is located in pypsdm.models.gwr\n", + "from pypsdm.models.gwr import GridWithResults\n", + "\n", + "grid_path = os.path.join(ROOT_DIR, \"tests\", \"resources\", \"simbench\", \"input\")\n", + "result_path = os.path.join(ROOT_DIR, \"tests\", \"resources\", \"simbench\", \"results\")\n", + "\n", + "# IO data models in general have a from_csv method to parse psdm files\n", + "gwr = GridWithResults.from_csv(grid_path, result_path)" + ] + }, + { + "cell_type": "markdown", + "id": "ca40b9b6b739fa82", + "metadata": {}, + "source": [ + "## Get Line Results and Calculate Utilisation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "275adf2cb67a9d37", + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_CHECK_OUTPUT\n", + "line_input_data = gwr.lines\n", + "line_utilization = gwr.lines_res.utilisation(line_input_data, side=\"a\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64515d7eab450b23", + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_CHECK_OUTPUT\n", + "import pandas as pd\n", + "\n", + "specific_time = pd.to_datetime(\"2016-01-02 12:00:00\")\n", + "# filter for timestamp\n", + "filtered_data = line_utilization.loc[[specific_time]].to_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8629a2a57260668", + "metadata": {}, + "outputs": [], + "source": [ + "from pypsdm.plots.grid import create_zoom_box, grid_plot\n", + "\n", + "# zoom_box allows to focus on certain parts of your plot\n", + "zoom_box = create_zoom_box(53.665, 11.35, 53.62, 11.38)\n", + "\n", + "# to remove the axes and lat / lon grid simply set show_axes = False or remove the parameter\n", + "fig_svg = grid_plot(\n", + " gwr.grid,\n", + " cmap_lines=\"Jet\",\n", + " cmap_line_values=filtered_data,\n", + " cbar_line_title=\"Line Utilisation\",\n", + " zoom_box=zoom_box,\n", + " show_axes=True,\n", + " use_mapbox=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b28859b912d9c22c", + "metadata": {}, + "outputs": [], + "source": [ + "fig_svg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ba2dbff4e8bd18a", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot can be saved as svg or other vector format\n", + "# fig_svg.write_image('save_as_svg.svg')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pypsdm/plots/grid.py b/pypsdm/plots/grid.py index 63c3155..3040a31 100644 --- a/pypsdm/plots/grid.py +++ b/pypsdm/plots/grid.py @@ -29,9 +29,12 @@ def grid_plot( cmap_node_values: Optional[Union[list, dict]] = None, cbar_node_title: Optional[str] = None, mapbox_style: Optional[str] = "open-street-map", + use_mapbox: bool = True, # New parameter to control mapbox vs regular scatter + zoom_box: Optional[dict] = None, # New parameter for zoom box + show_axes: bool = True, # New parameter to control axis visibility ) -> go.Figure: """ - Plots the grid on an OpenStreetMap. Supports Line and Node highlighting as well as colored map for line traces. Lines that are disconnected due to open switches will be grey. + Unified grid plotting function that supports both mapbox and SVG-friendly modes. ATTENTION: We currently consider the node_b of the switches to be the auxiliary switch node. @@ -55,6 +58,12 @@ def grid_plot( or dict mapping node IDs to values. cbar_node_title (Optional[str]): Title for the node colorbar. mapbox_style (Optional[str]): Mapbox style. Defaults to open-street-map. + use_mapbox (bool): Whether to use mapbox (True) or regular scatter plots (False). + When False, creates SVG-friendly plots. Defaults to True. + zoom_box (Optional[dict]): Dictionary with keys 'lat_min', 'lat_max', 'lon_min', 'lon_max' + to zoom to a specific bounding box. If None, auto-fits to grid data. + show_axes (bool): Whether to show axis labels, ticks, and grid (only for non-mapbox). + Defaults to True. Returns: Figure: Plotly figure. """ @@ -70,8 +79,8 @@ def grid_plot( show_line_colorbar and cmap_lines is not None and cmap_nodes is not None ) + # Add colorbar background if needed if (show_line_colorbar and cmap_lines is not None) or cmap_nodes is not None: - # Plot white half transparent rectangle as background for color bars x0_value = 0.85 if both_color_bars else 0.925 fig.add_shape( type="rect", @@ -83,6 +92,7 @@ def grid_plot( line=dict(color="rgba(255, 255, 255, 0.0)"), ) + # Process lines with colormap if cmap_lines and cmap_line_values is not None: try: value_dict, cmin, cmax = _process_colormap_values( @@ -100,79 +110,28 @@ def grid_plot( value_dict=value_dict, cbar_title=cbar_line_title, show_colorbar=show_line_colorbar, + use_mapbox=use_mapbox, ), axis=1, # type: ignore ) if show_line_colorbar: - custom_colorscale = [ - [i / 10, f"rgb({int(255 * (i / 10))},0,{int(255 * (1 - i / 10))})"] - for i in range(11) - ] - lons, lats = _get_lons_lats(grid.lines.geo_position.iloc[0]) - - # Add a separate trace for line colorbar (using a single point) - fig.add_trace( - go.Scattermapbox( - mode="markers", - lon=[lons[0]], - lat=[lats[0]], - marker=dict( - size=0.1, - opacity=0, - color="#008000", - colorscale=( - custom_colorscale - if cmap_lines == "fixed_line_rating_scale" - else cmap_lines - ), - cmin=( - cmin if not cmap_lines == "fixed_line_rating_scale" else 0.0 - ), - cmax=( - cmax if not cmap_lines == "fixed_line_rating_scale" else 1.0 - ), - colorbar=dict( - title=dict( - text=cbar_line_title or "Line Value", - font=dict( - size=12, - weight="normal", - style="normal", - color="#000000", - ), - ), - x=x0_value, - tickvals=( - [i / 10 for i in range(11)] - if cmap_lines == "fixed_line_rating_scale" - else None - ), - ticktext=( - [f"{round(i / 10.0, 2)}" for i in range(11)] - if cmap_lines == "fixed_line_rating_scale" - else None - ), - thickness=15, - len=0.85, - tickfont=dict( - size=12, - weight="normal", - style="normal", - color="#000000", - ), - ), - showscale=True, - ), - hoverinfo="skip", - showlegend=False, - ) + _add_line_colorbar( + fig, grid, cmap_lines, cmin, cmax, cbar_line_title, x0_value, use_mapbox ) else: connected_lines.data.apply( - lambda line: _add_line_trace(fig, line, is_disconnected=False, highlights=line_highlights), axis=1 # type: ignore + lambda line: _add_line_trace( + fig, + line, + is_disconnected=False, + highlights=line_highlights, + use_mapbox=use_mapbox, + ), + axis=1, ) + # Add disconnected lines disconnected_lines.data.apply( lambda line: _add_line_trace( fig, @@ -180,10 +139,12 @@ def grid_plot( is_disconnected=True, highlights=line_highlights, highlight_disconnected=highlight_disconnected, - ), # type: ignore + use_mapbox=use_mapbox, + ), axis=1, ) + # Add nodes if cmap_nodes and cmap_node_values is not None: _add_node_trace( fig, @@ -192,42 +153,18 @@ def grid_plot( cmap=cmap_nodes, cmap_node_values=cmap_node_values, cbar_node_title=cbar_node_title, + use_mapbox=use_mapbox, ) - else: - _add_node_trace(fig, grid, highlights=node_highlights) - - center_lat = grid.raw_grid.nodes.data["latitude"].mean() - center_lon = grid.raw_grid.nodes.data["longitude"].mean() + _add_node_trace(fig, grid, highlights=node_highlights, use_mapbox=use_mapbox) - # Dynamically calculate the zoom level - lat_range = ( - grid.raw_grid.nodes.data["latitude"].max() - - grid.raw_grid.nodes.data["latitude"].min() - ) - lon_range = ( - grid.raw_grid.nodes.data["longitude"].max() - - grid.raw_grid.nodes.data["longitude"].min() - ) - - zoom = 12 - max(lat_range, lon_range) - - fig.update_layout( - # mapbox = {"zoom"=10}, - showlegend=False, - mapbox_style=mapbox_style, - margin={"r": 0, "t": 0, "l": 0, "b": 0}, - mapbox=dict( - center=dict(lat=center_lat, lon=center_lon), - zoom=zoom, # Adjust the zoom level as per the calculated heuristic - style=mapbox_style, - ), - ) + # Configure layout + _configure_layout(fig, grid, use_mapbox, mapbox_style, zoom_box, show_axes) return fig -def _process_colormap_values(cmap_vals: dict, cmap) -> (dict, float, float): +def _process_colormap_values(cmap_vals: dict, cmap) -> tuple[dict, float, float]: """Process colormap values and return a dictionary with original values in case of fixed scale or one with normalized data.""" values = [] uuids = [] @@ -300,13 +237,10 @@ def _get_colormap_color(value, cmap): color_str = colorscale[index] rgb_string = color_str[1] - # Remove 'rgb(' and ')' and split by commas + + # Convert to hex rgb_values = list(map(int, rgb_string[4:-1].split(","))) - hex_string = "#%02x%02x%02x" % ( - int(rgb_values[0]), - int(rgb_values[1]), - int(rgb_values[2]), - ) + hex_string = "#%02x%02x%02x" % tuple(rgb_values) return hex_string @@ -320,8 +254,9 @@ def _add_line_trace( value_dict: Optional[dict] = None, cbar_title: Optional[str] = None, show_colorbar: bool = True, + use_mapbox: bool = True, ): - """Enhanced line trace function with colormap support.""" + """Unified line trace function supporting both mapbox and regular scatter.""" lons, lats = _get_lons_lats(line_data.geo_position) hover_text = line_data["id"] @@ -331,74 +266,86 @@ def _add_line_trace( colormap_value = None line_id = line_data.name if hasattr(line_data, "name") else line_data["id"] + + # Determine color if not is_disconnected: if cmap and value_dict and line_id in value_dict.keys(): value = value_dict[line_id] colormap_value = _get_colormap_color(value, cmap) - use_colorbar = True + hover_text += f"
{cbar_title or 'Value'}: {value:.3f}" else: colormap_value = "#008000" - use_colorbar = False # Check for highlights (overrides colormap) if isinstance(highlights, dict): for color, lines in highlights.items(): - if line_data.name in lines: # type: ignore + if line_data.name in lines: line_color = rgb_to_hex(color) highlighted = True - use_colorbar = False elif highlights is not None: if line_data.name in highlights: line_color = rgb_to_hex(RED) highlighted = True - use_colorbar = False # Handle disconnected lines if (highlight_disconnected is False) and is_disconnected: - # Highlights override the disconnected status if not highlighted: line_color = rgb_to_hex(GREY) - use_colorbar = False - if cmap and colormap_value is not None: - hover_text += f"
{cbar_title or 'Value'}: {value:.3f}" + # Use colormap color if not highlighted + if cmap and colormap_value is not None and not highlighted: + line_color = colormap_value - # Add the lines with or without colorbar - line_color_to_use = ( - colormap_value - if colormap_value is not None and use_colorbar and show_colorbar is not None - else line_color - ) - - fig.add_trace( - go.Scattermapbox( - mode="lines", - lon=lons, - lat=lats, - hoverinfo="skip", # Skip hoverinfo for the lines - line=dict(color=line_color_to_use, width=2), - showlegend=False, + # Add the line trace + if use_mapbox: + fig.add_trace( + go.Scattermapbox( + mode="lines", + lon=lons, + lat=lats, + hoverinfo="skip", + line=dict(color=line_color, width=2), + showlegend=False, + ) ) - ) - - # Create a LineString object from the line's coordinates - line = LineString(zip(lons, lats)) - - # Calculate the midpoint on the line based on distance - midpoint = line.interpolate(line.length / 2) - - # Add a transparent marker at the midpoint of the line for hover text - fig.add_trace( - go.Scattermapbox( - mode="markers", - lon=[midpoint.x], - lat=[midpoint.y], - hoverinfo="text", - hovertext=hover_text, - marker=dict(size=0, opacity=0, color=line_color), - showlegend=False, + # Add hover point at midpoint + line = LineString(zip(lons, lats)) + midpoint = line.interpolate(line.length / 2) + fig.add_trace( + go.Scattermapbox( + mode="markers", + lon=[midpoint.x], + lat=[midpoint.y], + hoverinfo="text", + hovertext=hover_text, + marker=dict(size=0, opacity=0, color=line_color), + showlegend=False, + ) + ) + else: + fig.add_trace( + go.Scatter( + mode="lines", + x=lons, + y=lats, + hoverinfo="skip", + line=dict(color=line_color, width=2), + showlegend=False, + ) + ) + # Add hover point at midpoint + midpoint_idx = len(lons) // 2 + fig.add_trace( + go.Scatter( + mode="markers", + x=[lons[midpoint_idx]], + y=[lats[midpoint_idx]], + hoverinfo="text", + hovertext=hover_text, + marker=dict(size=0, opacity=0, color=line_color), + showlegend=False, + ) ) - ) def _add_node_trace( @@ -408,41 +355,23 @@ def _add_node_trace( cmap: Optional[str] = None, cmap_node_values: Optional[dict] = None, cbar_node_title: Optional[str] = None, + use_mapbox: bool = True, ): - """ - Node trace function with colormap support. - - Args: - fig (go.Figure): The Plotly figure object. - grid (GridContainer): The grid container holding node data. - highlights (Optional): Highlights nodes. Defaults to None. - List of uuids or dict[(r, g, b), str] with colors. - cmap (Optional[str]): Name of a colormap (e.g., 'Viridis', 'Jet', etc.). - cmap_node_values (Optional[dict]): Dictionary mapping node IDs to values for colormap. - cbar_node_title (Optional[str]): Title for the colorbar. + """Unified node trace function supporting both mapbox and regular scatter.""" - Returns: - Updates the given figure object with node traces and optional colorbar. - """ - - # Hover text generation def to_hover_text_nodes(node: pd.Series): hover_text = f"ID: {node.id}
" - if cmap_node_values is not None: voltage_magnitude = cmap_node_values.get(node.name) if voltage_magnitude is not None: voltage_magnitude_str = f"{round(voltage_magnitude, 5)} pu" hover_text += f"Voltage Magnitude: {voltage_magnitude_str}
" - hover_text += ( f"Latitude: {node['latitude']:.6f}
" f"Longitude: {node['longitude']:.6f}" ) - return hover_text - # Determine colors based on either highlights or cmap def _get_node_color(node_uuid): if highlights is not None: # Handle explicit highlights first @@ -453,14 +382,13 @@ def _get_node_color(node_uuid): elif isinstance(highlights, list) and node_uuid in highlights: return rgb_to_hex(RED) # Default highlight color is red - # Handle colormap-based coloring if ( cmap is not None and cmap_node_values is not None and node_uuid in cmap_node_values.keys() ): value = cmap_node_values[node_uuid] - # Normalize values between 0-1 + cmin, cmax = 0.9, 1.1 # Fixed range for nodes normalized_value = (value - cmin) / (cmax - cmin) if cmax != cmin else 0.5 return _get_colormap_color(normalized_value, cmap) @@ -468,74 +396,291 @@ def _get_node_color(node_uuid): nodes_data = grid.raw_grid.nodes.data + # Add colorbar if needed if cmap and cmap_node_values is not None: - cmin = 0.9 - cmax = 1.1 + _add_node_colorbar(fig, nodes_data, cmap, cbar_node_title, use_mapbox) - # Create a custom colorscale for the colorbar - custom_colorscale = px.colors.get_colorscale(cmap) - # Add a separate trace for colorbar + hover_texts = nodes_data.apply( + lambda node_data: to_hover_text_nodes(node_data), axis=1 + ) + + color_list = [] + for node_uuid in nodes_data.index: + color = _get_node_color(node_uuid) + color_list.append(color) + + # Add the node trace + if use_mapbox: fig.add_trace( go.Scattermapbox( mode="markers", - lon=[nodes_data["longitude"][0]], - lat=[nodes_data["latitude"][0]], - marker=dict( - size=0.1, - opacity=0, - colorscale=custom_colorscale, - cmin=0.9, - cmax=1.1, - colorbar=dict( - title=dict( - text=cbar_node_title or "Node Value", - font=dict(size=12, color="#000000"), - ), - x=0.925, - tickvals=([0.9 + i * 2 / 100 for i in range(11)]), - ticktext=([f"{round(0.9 + i*2 / 100, 2)}" for i in range(11)]), - thickness=10, - len=0.85, - tickfont=dict( - size=12, weight="normal", style="normal", color="#000000" - ), - ), - ), + lon=nodes_data["longitude"], + lat=nodes_data["latitude"], + hovertext=hover_texts, + hoverinfo="text", + marker=dict(size=8, color=color_list), + showlegend=False, + ) + ) + else: + fig.add_trace( + go.Scatter( + mode="markers", + x=nodes_data["longitude"], + y=nodes_data["latitude"], + hovertext=hover_texts, + hoverinfo="text", + marker=dict(size=8, color=color_list), + showlegend=False, + ) + ) + + +def _add_line_colorbar( + fig, grid, cmap_lines, cmin, cmax, cbar_line_title, x0_value, use_mapbox +): + """Add line colorbar for both mapbox and regular plots.""" + custom_colorscale = [ + [i / 10, f"rgb({int(255 * (i / 10))},0,{int(255 * (1 - i / 10))})"] + for i in range(11) + ] + + colorbar_config = dict( + title=dict( + text=cbar_line_title or "Line Value", + font=dict(size=12, color="#000000"), + ), + x=x0_value, + thickness=15, + len=0.85, + tickfont=dict(size=12, color="#000000"), + ) + + # Add tick configuration for fixed scale + if cmap_lines == "fixed_line_rating_scale": + colorbar_config.update( + { + "tickvals": [i / 10 for i in range(11)], + "ticktext": [f"{round(i / 10.0, 2)}" for i in range(11)], + } + ) + + marker_config = dict( + size=0.1, + opacity=0, + color="#008000", + colorscale=( + custom_colorscale if cmap_lines == "fixed_line_rating_scale" else cmap_lines + ), + cmin=(cmin if cmap_lines != "fixed_line_rating_scale" else 0.0), + cmax=(cmax if cmap_lines != "fixed_line_rating_scale" else 1.0), + colorbar=colorbar_config, + showscale=True, + ) + + if use_mapbox: + lons, lats = _get_lons_lats(grid.lines.geo_position.iloc[0]) + fig.add_trace( + go.Scattermapbox( + mode="markers", + lon=[lons[0]], + lat=[lats[0]], + marker=marker_config, + hoverinfo="skip", + showlegend=False, + ) + ) + else: + nodes_data = grid.raw_grid.nodes.data + fig.add_trace( + go.Scatter( + mode="markers", + x=[nodes_data["longitude"].mean()], + y=[nodes_data["latitude"].mean()], + marker=marker_config, hoverinfo="skip", showlegend=False, ) ) - hover_texts = nodes_data.apply( - lambda node_data: to_hover_text_nodes(node_data), axis=1 + +def _add_node_colorbar(fig, nodes_data, cmap, cbar_node_title, use_mapbox): + """Add node colorbar for both mapbox and regular plots.""" + custom_colorscale = px.colors.get_colorscale(cmap) + + colorbar_config = dict( + title=dict( + text=cbar_node_title or "Node Value", + font=dict(size=12, color="#000000"), + ), + x=0.925, + thickness=10, + len=0.85, + tickvals=[0.9 + i * 2 / 100 for i in range(11)], + ticktext=[f"{round(0.9 + i * 2 / 100, 2)}" for i in range(11)], + tickfont=dict(size=12, color="#000000"), ) - node_colors = {} - for _, node_data in nodes_data.iterrows(): - node_colors[node_data.name] = _get_node_color(node_data.name) + marker_config = dict( + size=0.1, + opacity=0, + colorscale=custom_colorscale, + cmin=0.9, + cmax=1.1, + colorbar=colorbar_config, + ) + + if use_mapbox: + fig.add_trace( + go.Scattermapbox( + mode="markers", + lon=[nodes_data["longitude"].mean()], + lat=[nodes_data["latitude"].mean()], + marker=marker_config, + hoverinfo="skip", + showlegend=False, + ) + ) + else: + fig.add_trace( + go.Scatter( + mode="markers", + x=[nodes_data["longitude"].mean()], + y=[nodes_data["latitude"].mean()], + marker=marker_config, + hoverinfo="skip", + showlegend=False, + ) + ) - # Create a color list based on the ID column in nodes_data - color_list = [] - for node_uuid in nodes_data.index: - color = node_colors.get( - node_uuid, rgb_to_hex(BLUE) - ) # Default to blue if no color found - color_list.append(color) - fig.add_trace( - go.Scattermapbox( - mode="markers", - lon=nodes_data["longitude"], - lat=nodes_data["latitude"], - hovertext=hover_texts, - hoverinfo="text", - marker=dict(size=8, color=color_list), +def _configure_layout(fig, grid, use_mapbox, mapbox_style, zoom_box, show_axes): + """Configure the figure layout based on mapbox usage and zoom settings.""" + # Determine zoom/extent settings + if zoom_box is not None: + lat_min, lat_max = zoom_box["lat_min"], zoom_box["lat_max"] + lon_min, lon_max = zoom_box["lon_min"], zoom_box["lon_max"] + center_lat = (lat_min + lat_max) / 2 + center_lon = (lon_min + lon_max) / 2 + lat_range = lat_max - lat_min + lon_range = lon_max - lon_min + else: + lat_min = grid.raw_grid.nodes.data["latitude"].min() + lat_max = grid.raw_grid.nodes.data["latitude"].max() + lon_min = grid.raw_grid.nodes.data["longitude"].min() + lon_max = grid.raw_grid.nodes.data["longitude"].max() + center_lat = grid.raw_grid.nodes.data["latitude"].mean() + center_lon = grid.raw_grid.nodes.data["longitude"].mean() + lat_range = lat_max - lat_min + lon_range = lon_max - lon_min + + if use_mapbox: + # Mapbox layout + zoom = ( + 14.5 - max(lat_range, lon_range) + if zoom_box + else 12 - max(lat_range, lon_range) + ) + fig.update_layout( showlegend=False, + mapbox_style=mapbox_style, + margin={"r": 0, "t": 0, "l": 0, "b": 0}, + mapbox=dict( + center=dict(lat=center_lat, lon=center_lon), + zoom=zoom, + style=mapbox_style, + ), + ) + else: + # Regular scatter plot layout + padding_lat = lat_range * 0.05 if lat_range > 0 else 0.001 + padding_lon = lon_range * 0.05 if lon_range > 0 else 0.001 + + if show_axes: + xaxis_config = dict( + title="Longitude", + showgrid=True, + gridcolor="lightgray", + gridwidth=0.5, + range=[lon_min - padding_lon, lon_max + padding_lon], + showline=True, + linecolor="black", + linewidth=1, + ticks="outside", + showticklabels=True, + ) + yaxis_config = dict( + title="Latitude", + showgrid=True, + gridcolor="lightgray", + gridwidth=0.5, + scaleanchor="x", + scaleratio=1, + range=[lat_min - padding_lat, lat_max + padding_lat], + showline=True, + linecolor="black", + linewidth=1, + ticks="outside", + showticklabels=True, + ) + else: + xaxis_config = dict( + title="", + showgrid=False, + range=[lon_min - padding_lon, lon_max + padding_lon], + showline=False, + ticks="", + showticklabels=False, + zeroline=False, + ) + yaxis_config = dict( + title="", + showgrid=False, + scaleanchor="x", + scaleratio=1, + range=[lat_min - padding_lat, lat_max + padding_lat], + showline=False, + ticks="", + showticklabels=False, + zeroline=False, + ) + + fig.update_layout( + showlegend=False, + margin={"r": 0, "t": 0, "l": 0, "b": 0}, + xaxis=xaxis_config, + yaxis=yaxis_config, + plot_bgcolor="white", ) - ) def _get_lons_lats(geojson: str): """Extract longitude and latitude coordinates from GeoJSON string.""" coordinates = json.loads(geojson)["coordinates"] return list(zip(*coordinates)) # returns lons, lats + + +def create_zoom_box( + upper_left_lat: float, + upper_left_lon: float, + bottom_right_lat: float, + bottom_right_lon: float, +) -> dict: + """ + Create a zoom box dictionary from center coordinates and span. + + Args: + center_lat: Center latitude + center_lon: Center longitude + lat_span: Total latitude span (degrees) + lon_span: Total longitude span (degrees) + + Returns: + Dictionary with lat_min, lat_max, lon_min, lon_max keys + """ + return { + "lat_min": bottom_right_lat, + "lat_max": upper_left_lat, + "lon_min": upper_left_lon, + "lon_max": bottom_right_lon, + } diff --git a/tests/docs/nbs/test_notebooks.py b/tests/docs/nbs/test_notebooks.py index 4bd2bd1..12af54a 100644 --- a/tests/docs/nbs/test_notebooks.py +++ b/tests/docs/nbs/test_notebooks.py @@ -27,6 +27,7 @@ def test_notebook_only_for_errors_and_explicit_cell_checks(): ROOT_DIR + "/docs/nbs/plotting_utilities_colormap_lines.ipynb", ROOT_DIR + "/docs/nbs/plotting_utilities_colormap_nodes.ipynb", ROOT_DIR + "/docs/nbs/plotting_utilities_colormap_lines_and_nodes.ipynb", + ROOT_DIR + "/docs/nbs/plotting_utilities_trace_vector_graphic.ipynb", ] exit_code = pytest.main(args)