diff --git a/.gitignore b/.gitignore index 9319a34..4d09dd5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,21 +1,113 @@ +# Project data artifacts *.geojson !tests/data/*.geojson *.psd -.ipynb_checkpoints + +# Cached or auto-generated files .DS_Store +.ipynb_checkpoints **/__pycache__/ .pytest_cache/ -.vscode/ -dev/ +.ruff_cache/ +.cache/ +htmlcov/ +cover/ +.coverage +.coverage* +coverage.xml +nosetests.xml +# Notebooks notebooks/** !notebooks/city2graph_demo.ipynb - docs/source/examples/** !docs/source/examples/**/*.ipynb -conda/conda-bld/ -docs/build/ +# Build and packaging outputs +.Python +build/ +dist/ +develop-eggs/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST -.coverage -.coverage* +# Python bytecode / compiled extensions +*.py[cod] +*$py.class +*.so + +# Virtual environments +.env +.envrc +.venv +venv/ +env.bak/ +venv.bak/ +.pdm-python +.pdm-build/ +.pixi + +# Dependency managers (lockfiles optional) +#uv.lock +#Pipfile.lock +#poetry.lock +#poetry.toml +#pdm.lock +#pdm.toml +#pixi.lock + +# Tooling +.pypirc +.pybuilder/ +target/ +.tox/ +.nox/ +.hypothesis/ +.pyre/ +.pytype/ +.mypy_cache/ +.dmypy.json +dmypy.json +cython_debug/ + +# Logs and installer leftovers +pip-log.txt +pip-delete-this-directory.txt + +# Misc project files +local_settings.py +*.log +db.sqlite3 +db.sqlite3-journal +instance/ +.webassets-cache +.scrapy +.spyderproject +.spyproject +.ropeproject +__pypackages__/ + +# Editors & IDEs +.vscode/ +.idea/ +.abstra/ +.cursorignore +.cursorindexingignore + +# Application-specific +dev/ +docs/_build/ +docs/build/ +conda/conda-bld/ diff --git a/city2graph/data.py b/city2graph/data.py index 3132118..6cf63a5 100644 --- a/city2graph/data.py +++ b/city2graph/data.py @@ -2,24 +2,32 @@ Data Loading and Processing Module. This module provides comprehensive functionality for loading and processing geospatial -data from various sources, with specialized support for Overture Maps data. It handles -data validation, coordinate reference system management, and geometric processing +data from various sources, with specialized support for Overture Maps data and the Transportation Networks data project. +It handles data validation, coordinate reference system management, and geometric processing operations commonly needed for urban network analysis. """ # Standard library imports import json +import os +import re import subprocess +import time +from dataclasses import asdict, dataclass from pathlib import Path +from typing import Dict, List, Optional +import warnings # Third-party imports import geopandas as gpd import pandas as pd +from pyparsing import lru_cache +import requests from pyproj import CRS -from shapely.geometry import LineString -from shapely.geometry import MultiLineString -from shapely.geometry import Polygon +from requests.adapters import HTTPAdapter +from shapely.geometry import LineString, MultiLineString, Polygon from shapely.ops import substring +from urllib3.util.retry import Retry # Public API definition __all__ = ["load_overture_data", "process_overture_segments"] @@ -765,3 +773,529 @@ def _calculate_passable_intervals( passable_intervals.append((current, 1.0)) return passable_intervals + + +@dataclass +class TransportationNetworkData: + """Container for transportation network data files.""" + network: Optional[pd.DataFrame] + trips: Optional[pd.DataFrame] + nodes: Optional[pd.DataFrame] + flow: Optional[pd.DataFrame] + + +def _requests_session() -> requests.Session: + """ + Create a requests session with default headers and retry strategy. + + Returns + ------- + requests.Session: Configured requests session. + """ + s = requests.Session() + retries = Retry( + total=5, + backoff_factor=0.5, + status_forcelist=(429, 500, 502, 503, 504), + allowed_methods=frozenset(["GET"]), + raise_on_status=False, + ) + s.mount("https://", HTTPAdapter(max_retries=retries)) + headers = {"User-Agent": "tntp-loader/1.0"} + s.headers.update(headers) + return s + + +def load_transportation_networks_data( + network_name: str, + output_dir: Optional[str | Path] = None, + save_to_file: bool = False, + load_network: bool = True, + load_trips: bool = True, + load_nodes: bool = True, + load_flow: bool = True, + download_if_missing: bool = True, + best_effort: bool = True, +) -> Dict[str, Optional[pd.DataFrame]]: + """ + Load transportation network data from the Transportation Networks repository + (bstabler/TransportationNetworks). + + Parameters + ---------- + network_name : str + e.g. 'SiouxFalls', 'Anaheim', 'Chicago-Sketch', ... + Will be validated against available networks. + output_dir : Optional[str|Path] + Directory to read/write local copies of tntp files when save_to_file=True. + save_to_file : bool + Save fetched files to output_dir. If False, never writes to disk. + load_* : bool + Toggle individual file types. + download_if_missing : bool + If False, only use local files; otherwise hit GitHub when needed. + best_effort : bool + If True, missing/failed parts return None; if False, raise on any requested part failure. + + Returns + ------- + dict[str, Optional[pd.DataFrame]] + Keys: 'network', 'trips', 'nodes', 'flow'. Values are DataFrames or None. + + Raises + ------ + ValueError, FileNotFoundError, requests.HTTPError + """ + session = _requests_session() + + # Validate target network against a filtered list + available = _get_available_transportation_networks(session) + if network_name not in available: + raise ValueError( + f"Network '{network_name}' not found. " + f"Try one of: {', '.join(sorted(available)[:25])}..." + ) + + if save_to_file and output_dir is None: + raise ValueError("output_dir must be specified if save_to_file=True") + + if save_to_file and output_dir is not None: + Path(output_dir).mkdir(parents=True, exist_ok=True) + + data = _get_transportation_networks_data( + session=session, + network_name=network_name, + output_dir=Path(output_dir) if output_dir else None, + save_to_file=save_to_file, + load_network=load_network, + load_trips=load_trips, + load_nodes=load_nodes, + load_flow=load_flow, + download_if_missing=download_if_missing, + best_effort=best_effort, + ) + + # keep the dict return but with correct Optional typing + return asdict(data) + + +@lru_cache(maxsize=1) +def _get_available_transportation_networks(session: requests.Session) -> List[str]: + """ + Query the Transportation Networks GitHub repository for available networks. + + Parameters + ---------- + session : requests.Session + HTTP session for making requests. + + Returns + ------- + List[str] + List of available transportation network names. + """ + base_api = "https://api.github.com/repos/bstabler/TransportationNetworks/contents" + # Store to a local file to avoid repeated API hits and update only when older than 1 hour + cache_file = Path(__file__).parent / ".cache" / "transportation_networks.json" + cache_file.parent.mkdir(parents=True, exist_ok=True) + if cache_file.exists() and (time.time() - cache_file.stat().st_mtime) < 3600: + with open(cache_file, "r", encoding="utf-8") as f: + items = json.load(f) + else: + r = session.get(base_api, timeout=15) + r.raise_for_status() + items = r.json() + dirs = [i["name"] for i in items if i.get("type") == "dir" and not i["name"].startswith(('.', '_'))] + + return sorted(set(dirs)) + + +def _get_transportation_networks_data( + session: requests.Session, + network_name: str, + output_dir: Optional[Path], + save_to_file: bool, + load_network: bool, + load_trips: bool, + load_nodes: bool, + load_flow: bool, + download_if_missing: bool, + best_effort: bool, +) -> TransportationNetworkData: + """ + Retrieve and parse transportation network data from local files or remote repository. + Loads/downloads transportation network data files from the Transportation Networks GitHub repository + and parses them into pandas DataFrames. Supports loading different types of network data + (network topology, trips, nodes, flow) with flexible error handling and caching options. + + Parameters + ---------- + session (requests.Session): + HTTP session for making requests to remote repository. + network_name (str): + Name of the transportation network (e.g., 'SiouxFalls'). + output_dir (Optional[Path]): + Local directory to save/load files. If None, no local caching. + save_to_file (bool): + Whether to save downloaded files to local directory. + load_network (bool): + Whether to load network topology data (.tntp file). + load_trips (bool): + Whether to load trips data (.tntp file). + load_nodes (bool): + Whether to load nodes data (.tntp file). + load_flow (bool): + Whether to load flow data (.tntp file). + download_if_missing (bool): + Whether to download from remote if file not found locally. + best_effort (bool): + If True, continue processing even if some files fail to load. + If False, raise exception on first error. + + Returns + ------- + TransportationNetworkData: Object containing the loaded DataFrames for each data type. + DataFrames will be None for data types not requested or failed to load. + + Raises + ------ + FileNotFoundError: When a required file is missing locally and download_if_missing=False, + or when best_effort=False. + requests.HTTPError: When remote download fails and best_effort=False. + Exception: When file parsing fails and best_effort=False. + """ + base_url = "https://raw.githubusercontent.com/bstabler/TransportationNetworks/master" + file_map = { + 'network': f"{network_name}_net.tntp", + 'trips': f"{network_name}_trips.tntp", + 'nodes': f"{network_name}_node.tntp", + 'flow': f"{network_name}_flow.tntp", + } + wanted = { + 'network': load_network, + 'trips': load_trips, + 'nodes': load_nodes, + 'flow': load_flow, + } + + out: Dict[str, Optional[pd.DataFrame]] = dict(network=None, trips=None, nodes=None, flow=None) + errors: Dict[str, str] = {} + + for kind, filename in file_map.items(): + if not wanted[kind]: + continue + + path = (output_dir / filename) if output_dir else None + lines: Optional[List[str]] = None + + if path and path.exists(): + with open(path, "r", encoding="utf-8", errors="ignore") as f: + lines = f.readlines() + + if lines is None and download_if_missing: + url = f"{base_url}/{network_name}/{filename}" + resp = session.get(url, timeout=20) + if resp.status_code == 404: + out[kind] = None + continue + try: + resp.raise_for_status() + except requests.HTTPError as e: + if best_effort: + errors[kind] = f"HTTP {resp.status_code} for {filename}" + out[kind] = None + continue + raise + lines = resp.text.splitlines() + + if save_to_file and path: + path.write_text(resp.text, encoding="utf-8") + + if lines is None: + msg = f"Missing {filename} locally and download_if_missing=False" + if best_effort: + errors[kind] = msg + out[kind] = None + continue + raise FileNotFoundError(msg) + + # Parse + try: + df = _parse_tntp_from_lines(lines, kind) + out[kind] = df + except Exception as e: + if best_effort: + errors[kind] = f"Parse error for {filename}: {e}" + out[kind] = None + else: + raise + + for k, v in errors.items(): + warnings.warn(f"Warning! Encountered error while loading '{k}': {v}", UserWarning) + + return TransportationNetworkData(**out) + + +def _strip_metadata(lines: List[str]) -> List[str]: + idx = None + for i, line in enumerate(lines): + if line.strip().startswith(""): + idx = i + break + data = lines[idx + 1 :] if idx is not None else lines + # Remove ~-comments and empty lines + cleaned = [] + for raw in data: + line = raw.split('~', 1)[0].strip() + if line: + cleaned.append(line) + return cleaned + + +def _parse_tntp_from_lines(data_lines: List[str], data_type: str) -> pd.DataFrame: + """ + Parse TNTP (Transportation Networks for Traffic Assignment Problems) data from lines. + This function processes different types of TNTP data files by parsing the provided + lines and returning structured DataFrames with appropriate data types and columns. + + Parameters + ---------- + data_lines (List[str]): + List of strings containing the raw data lines from a TNTP file. + data_type (str): + Type of data to parse. Must be one of: + - 'network': Network/link data with node connections and link properties + - 'trips': Trip matrix data with origin-destination demand + - 'flow': Flow matrix data with origin-destination flows + - 'nodes': Node coordinate data + + Returns + ------- + pd.DataFrame: + Parsed and structured DataFrame with appropriate columns and data types: + - For 'network': Contains columns like init_node, term_node, capacity, length, + free_flow_time, b, power, speed_limit, toll, link_type + - For 'trips': Contains origin, destination, and demand columns + - For 'flow': Contains origin, destination, volume, cost columns + - For 'nodes': Contains node, x, y coordinate columns + + Raises + ------ + ValueError: + If data_type is not one of the recognized types. + """ + data_lines = _strip_metadata(data_lines) + + if data_type == 'network': + df = _parse_network_file(data_lines) + numeric_cols = ["capacity", "length", "free_flow_time", "b", "power", "speed_limit", "toll"] + for col in numeric_cols: + if col not in df: + df[col] = pd.NA + df[col] = pd.to_numeric(df[col], errors="coerce") + + if "link_type" not in df: + df["link_type"] = pd.NA + return df + + if data_type == 'trips': + df = _parse_trips_file(data_lines) + if not df.empty: + df = df.groupby(["origin", "destination"], as_index=False)["demand"].sum() + df["origin"] = df["origin"].astype("int64") + df["destination"] = df["destination"].astype("int64") + df["demand"] = pd.to_numeric(df["demand"], errors="coerce") + return pd.DataFrame(df) + + if data_type == 'flow': + df = _parse_flow_file(data_lines) + if not df.empty: + df["origin"] = df["origin"].astype("int64") + df["destination"] = df["destination"].astype("int64") + df["volume"] = pd.to_numeric(df["volume"], errors="coerce") + df["cost"] = pd.to_numeric(df["cost"], errors="coerce") + return df + + if data_type == 'nodes': + df = _parse_nodes_file(data_lines) + if not df.empty: + df["node"] = df["node"].astype("int64") + df["x"] = pd.to_numeric(df["x"], errors="coerce") + df["y"] = pd.to_numeric(df["y"], errors="coerce") + return df + + raise ValueError(f"Unrecognized data_type '{data_type}'.") + + +def _parse_network_file(data_lines: List[str]) -> pd.DataFrame: + """ + Parse network data from text lines into a structured DataFrame. + + This function provides a defensive parser that handles various delimiters (semicolons, + commas, or whitespace) and automatically skips header rows. It extracts network link + information with mandatory node identifiers and optional link attributes. + + Parameters + ---------- + data_lines (List[str]): + List of text lines containing network data. Each line should represent a + network link with space/comma/semicolon-separated values. + + Returns + ------- + pd.DataFrame: + DataFrame with network links containing the following columns: + - init_node (int): Initial/source node identifier + - term_node (int): Terminal/destination node identifier + - capacity (float, optional): Link capacity + - length (float, optional): Link length + - free_flow_time (float, optional): Free flow travel time + - b (float, optional): BPR function parameter + - power (float, optional): BPR function exponent + - speed_limit (float, optional): Speed limit on the link + - toll (float, optional): Toll cost for the link + - link_type (int, optional): Link type classification + """ + records = [] + for line in data_lines: + # Normalize delimiters to single space + line = re.sub(r"[;,]", " ", line) + parts = [p for p in line.split() if p] + # Skip obvious header rows (non-numeric in first 2 fields) + if len(parts) >= 2 and not (parts[0].lstrip("-").isdigit() and parts[1].lstrip("-").isdigit()): + continue + if len(parts) >= 2: + rec = { + "init_node": int(parts[0]), + "term_node": int(parts[1]), + } + # Optional fields by position if present + opt_names = [ + "capacity", "length", "free_flow_time", "b", "power", + "speed_limit", "toll", "link_type" + ] + for i, name in enumerate(opt_names, start=2): + if len(parts) > i: + val = parts[i] + rec[name] = float(val) if name != "link_type" else int(val) + records.append(rec) + return pd.DataFrame.from_records(records) + + +def _parse_flow_file(data_lines: List[str]) -> pd.DataFrame: + """ + This function processes lines of text data containing flow information and converts + them into a structured pandas DataFrame. Each line should contain at least 4 values: + origin node ID, destination node ID, flow volume, and flow cost. + + Parameters + ---------- + data_lines (List[str]): List of strings where each string represents a line + of flow data. Lines can use tabs, commas, or semicolons as separators, + and should contain at least 4 numeric values per line. + + Returns + ------- + pd.DataFrame: + A DataFrame with columns ['origin', 'destination', 'volume', 'cost']. + - origin (int): Source node identifier + - destination (int): Target node identifier + - volume (float): Flow volume between origin and destination + - cost (float): Cost associated with the flow + """ + records: List[dict[str, float | int]] = [] + for line in data_lines: + normalized = re.sub(r"[\t,;]", " ", line).strip() + if not normalized: + continue + parts = [token for token in normalized.split() if token] + if len(parts) < 4: + continue + if not (parts[0].lstrip("-").isdigit() and parts[1].lstrip("-").isdigit()): + continue + origin, destination, volume, cost = parts[:4] + records.append( + { + "origin": int(origin), + "destination": int(destination), + "volume": float(volume), + "cost": float(cost), + } + ) + return pd.DataFrame.from_records(records, columns=["origin", "destination", "volume", "cost"]) + + +def _parse_trips_file(data_lines: List[str]) -> pd.DataFrame: + """ + Parse a trips/flow file format into a pandas DataFrame. + The function expects a specific file format where: + - Origin lines follow the pattern "Origin " (case-insensitive) + - Destination-demand pairs follow the pattern ": " + - Multiple destination-demand pairs can appear on the same line + + Parameters + ---------- + data_lines (List[str]): List of strings representing lines from a trips file + + Returns + ------- + pd.DataFrame: DataFrame with columns ['origin', 'destination', 'demand'] + containing the parsed trip data + + Raises + ------ + ValueError: + If no 'Origin' sections are found in the input data + """ + origin_pat = re.compile(r'^\s*Origin\s+(\d+)\s*$', flags=re.IGNORECASE) + pair_pat = re.compile(r'(\d+)\s*:\s*([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)') + + rows = [] + current_origin = None + saw_origin = False + + for raw in data_lines: + line = raw + m = origin_pat.match(line) + if m: + current_origin = int(m.group(1)) + saw_origin = True + continue + if current_origin is None: + # Skip preamble until first Origin appears + continue + for d_str, v_str in pair_pat.findall(line): + rows.append( + {"origin": current_origin, "destination": int(d_str), "demand": float(v_str)} + ) + + if not rows and not saw_origin: + raise ValueError("No 'Origin' sections found in trips/flow file.") + + return pd.DataFrame(rows, columns=["origin", "destination", "demand"]) + + +def _parse_nodes_file(data_lines: List[str]) -> pd.DataFrame: + """ + Parse a nodes file format into a pandas DataFrame. + + Parameters + ---------- + data_lines (List[str]): List of strings representing lines from a nodes file + + Returns + ------- + pd.DataFrame: DataFrame with columns ['node', 'x', 'y'] + """ + records = [] + for line in data_lines: + # Normalize delimiters + line = re.sub(r"[;,]", " ", line) + parts = [p for p in line.split() if p] + if len(parts) < 3: + continue + # First token should be an int node id + if not parts[0].lstrip("-").isdigit(): + continue + records.append({"node": int(parts[0]), "x": parts[1], "y": parts[2]}) + return pd.DataFrame.from_records(records) diff --git a/tests/test_data.py b/tests/test_data.py index dc77604..e7c140d 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -2,11 +2,16 @@ import importlib.util import subprocess +from pathlib import Path +from typing import Optional from unittest.mock import Mock from unittest.mock import patch import geopandas as gpd +import pandas as pd import pytest +import requests +from requests.adapters import HTTPAdapter from shapely.geometry import LineString from shapely.geometry import Point from shapely.geometry import Polygon @@ -23,6 +28,17 @@ WGS84_CRS = data_module.WGS84_CRS load_overture_data = data_module.load_overture_data process_overture_segments = data_module.process_overture_segments +_requests_session = data_module._requests_session +load_transportation_networks_data = data_module.load_transportation_networks_data +_get_available_transportation_networks = data_module._get_available_transportation_networks +_get_transportation_networks_data = data_module._get_transportation_networks_data +_strip_metadata = data_module._strip_metadata +_parse_tntp_from_lines = data_module._parse_tntp_from_lines +_parse_network_file = data_module._parse_network_file +_parse_flow_file = data_module._parse_flow_file +_parse_trips_file = data_module._parse_trips_file +_parse_nodes_file = data_module._parse_nodes_file +TransportationNetworkData = data_module.TransportationNetworkData # Tests for constants and basic functionality @@ -808,3 +824,345 @@ def test_process_overture_segments_with_short_linestring() -> None: # Should process without errors assert len(result) == len(segments_gdf) + + +def test_requests_session_default_headers(monkeypatch: pytest.MonkeyPatch) -> None: + """Session should include retry adapter and custom user agent.""" + session = _requests_session() + try: + assert session.headers["User-Agent"] == "tntp-loader/1.0" + assert "Authorization" not in session.headers + adapter = session.adapters["https://"] + assert isinstance(adapter, HTTPAdapter) + assert adapter.max_retries.total == 5 + finally: + session.close() + +def test_load_transportation_networks_data_happy_path( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Verify helper wiring and return conversion for transportation networks loader.""" + mock_session = Mock() + network_df = pd.DataFrame({"capacity": [100]}) + with ( + patch.object(data_module, "_requests_session", return_value=mock_session) as mock_session_factory, + patch.object(data_module, "_get_available_transportation_networks", return_value=["TestNet"]) as mock_available, + patch.object(data_module, "_get_transportation_networks_data") as mock_get_data, + ): + mock_get_data.return_value = TransportationNetworkData( + network=network_df, + trips=None, + nodes=None, + flow=None, + ) + + result = load_transportation_networks_data( + "TestNet", + load_trips=False, + load_nodes=False, + load_flow=False, + ) + + mock_session_factory.assert_called_once() + mock_available.assert_called_once_with(mock_session) + mock_get_data.assert_called_once_with( + session=mock_session, + network_name="TestNet", + output_dir=None, + save_to_file=False, + load_network=True, + load_trips=False, + load_nodes=False, + load_flow=False, + download_if_missing=True, + best_effort=True, + ) + + assert set(result.keys()) == {"network", "trips", "nodes", "flow"} + pd.testing.assert_frame_equal(result["network"], network_df) + assert result["trips"] is None + + +def test_load_transportation_networks_data_unknown_network() -> None: + """Unknown network names should raise early.""" + mock_session = Mock() + with ( + patch.object(data_module, "_requests_session", return_value=mock_session), + patch.object(data_module, "_get_available_transportation_networks", return_value=["OtherNet"]), + ): + with pytest.raises(ValueError, match="Network 'MissingNet' not found"): + load_transportation_networks_data("MissingNet") + + +def test_load_transportation_networks_data_requires_output_dir() -> None: + """Saving to disk without output directory is invalid.""" + mock_session = Mock() + with ( + patch.object(data_module, "_requests_session", return_value=mock_session), + patch.object(data_module, "_get_available_transportation_networks", return_value=["TestNet"]), + ): + with pytest.raises(ValueError, match="output_dir must be specified"): + load_transportation_networks_data("TestNet", save_to_file=True, output_dir=None) + + +class _FakeResponse: + """Minimal response object for session.get mocks.""" + + def __init__( + self, + *, + status_code: int = 200, + json_data: Optional[list[dict[str, str]]] = None, + text: str = "", + ok: Optional[bool] = None, + ) -> None: + self.status_code = status_code + self._json = json_data or [] + self.text = text + self.ok = ok if ok is not None else status_code < 400 + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise requests.HTTPError(f"status {self.status_code}") + + def json(self) -> list[dict[str, str]]: + return self._json + + +def test_get_available_transportation_networks_filters_valid_dirs() -> None: + """Only directories with *_net.tntp files should be returned.""" + session = Mock() + base_api = "https://api.github.com/repos/bstabler/TransportationNetworks/contents" + + def _get(url: str, timeout: int) -> _FakeResponse: + if url == base_api: + return _FakeResponse( + json_data=[ + {"name": "NetA", "type": "dir"}, + {"name": "_ignore", "type": "dir"}, + {"name": "README.md", "type": "file"}, + {"name": "NetB", "type": "dir"}, + ] + ) + if "NetA_net.tntp" in url: + return _FakeResponse(status_code=200) + if "NetB_net.tntp" in url: + return _FakeResponse(status_code=404) + if url == f"{base_api}/NetB": + return _FakeResponse( + json_data=[{"name": "alt_net.tntp", "type": "file"}], + ok=True, + ) + raise AssertionError(f"Unexpected URL {url}") + + session.get.side_effect = _get + + result = _get_available_transportation_networks(session) + assert result == ["NetA", "NetB"] + + +def test_get_transportation_networks_data_prefers_local_files( + tmp_path: Path, +) -> None: + """Existing local files should be used without hitting the network.""" + session = Mock() + sample_path = tmp_path / "TestNet_net.tntp" + sample_path.write_text("1 2 3", encoding="utf-8") + df = pd.DataFrame({"value": [1]}) + with patch.object(data_module, "_parse_tntp_from_lines", return_value=df) as mock_parser: + data = _get_transportation_networks_data( + session=session, + network_name="TestNet", + output_dir=tmp_path, + save_to_file=False, + load_network=True, + load_trips=False, + load_nodes=False, + load_flow=False, + download_if_missing=True, + best_effort=True, + ) + + session.get.assert_not_called() + mock_parser.assert_called_once() + assert isinstance(data.network, pd.DataFrame) + assert data.trips is None + + +@patch("city2graph.data._parse_tntp_from_lines") +def test_get_transportation_networks_data_handles_404( + mock_parser: Mock, +) -> None: + """404 remote responses should yield None without parsing.""" + session = Mock() + response = _FakeResponse(status_code=404) + session.get.return_value = response + + data = _get_transportation_networks_data( + session=session, + network_name="MissingNet", + output_dir=None, + save_to_file=False, + load_network=True, + load_trips=False, + load_nodes=False, + load_flow=False, + download_if_missing=True, + best_effort=True, + ) + + mock_parser.assert_not_called() + assert data.network is None + + +def test_strip_metadata_removes_preamble_and_comments() -> None: + """Metadata markers and comment lines are removed.""" + lines = [ + "this is metadata", + "", + "1, 2, 3 ~ inline comment", + " ", + "~ full comment", + "4;5;6", + ] + cleaned = _strip_metadata(lines) + assert cleaned == ["1, 2, 3", "4;5;6"] + + +def test_parse_tntp_from_lines_network() -> None: + """Network parser should coerce numeric columns and add defaults.""" + lines = [ + "header", + "", + "~ comment", + "1,2,1000,1.5,12,0.15,60,30,0,3", + ] + + df = _parse_tntp_from_lines(lines, "network") + + assert {"init_node", "term_node", "capacity", "length", "link_type"}.issubset(df.columns) + assert df.at[0, "init_node"] == 1 + assert df.at[0, "capacity"] == pytest.approx(1000) + assert df.at[0, "link_type"] == 3 + + +def test_parse_tntp_from_lines_trips_aggregates_duplicates() -> None: + """Trip matrices should aggregate duplicate OD pairs.""" + lines = [ + "", + "Origin 1", + "2 : 10.0 3 : 5", + "2 : 2", + "Origin 2", + "1 : 7", + ] + + df = _parse_tntp_from_lines(lines, "trips") + df = df.sort_values(["origin", "destination"]).reset_index(drop=True) + expected = pd.DataFrame( + { + "origin": [1, 1, 2], + "destination": [2, 3, 1], + "demand": [12.0, 5.0, 7.0], + } + ) + pd.testing.assert_frame_equal(df, expected) + + +def test_parse_tntp_from_lines_flow_table() -> None: + """Flow parser should capture volume and cost columns.""" + lines = [ + "", + "From To Volume Cost", + "1 2 10 1.5", + "3 4 5.5 2.25", + ] + + df = _parse_tntp_from_lines(lines, "flow") + expected = pd.DataFrame( + { + "origin": [1, 3], + "destination": [2, 4], + "volume": [10.0, 5.5], + "cost": [1.5, 2.25], + } + ) + pd.testing.assert_frame_equal(df, expected) + + +def test_parse_tntp_from_lines_nodes_casts_types() -> None: + """Node parser should coerce coordinates to numeric types.""" + lines = [ + "", + "node,x,y", + "1,10.5,20.1", + "2 30.2 40.3", + ] + + df = _parse_tntp_from_lines(lines, "nodes") + assert df.at[0, "node"] == 1 + assert df.at[1, "x"] == pytest.approx(30.2) + + +def test_parse_tntp_from_lines_unknown_type() -> None: + """Unsupported data types should raise ValueError.""" + with pytest.raises(ValueError): + _parse_tntp_from_lines(["", "foo"], "unknown") + + +def test_parse_network_file_skips_headers() -> None: + """Header rows and non-numeric prefixes should be ignored.""" + lines = [ + "Init,Term,Capacity", + "1, 2 , 1000 ; 1.5", + "3 4 2000 2.0 10", + ] + df = _parse_network_file(lines) + assert len(df) == 2 + assert set(df.columns) >= {"init_node", "term_node"} + assert df.at[0, "term_node"] == 2 + + +def test_parse_flow_file_skips_headers() -> None: + """Flow parser should ignore header rows and coerce numeric values.""" + lines = [ + "From To Volume Cost", + "1 2 10 1.5", + "3 4 5.5 2.25", + ] + df = _parse_flow_file(lines) + assert len(df) == 2 + assert set(df.columns) == {"origin", "destination", "volume", "cost"} + assert df.at[1, "volume"] == pytest.approx(5.5) + + +def test_parse_trips_file_parses_pairs() -> None: + """Matrix files should yield origin-destination-demand rows.""" + lines = [ + "Origin 1", + "2 : 10.0 3 : 5", + "Origin 2", + "1 : 7", + ] + df = _parse_trips_file(lines) + assert len(df) == 3 + assert {"origin", "destination", "demand"} == set(df.columns) + + +def test_parse_trips_file_requires_origin() -> None: + """Matrix parser should complain when no origin section is present.""" + with pytest.raises(ValueError): + _parse_trips_file(["2 : 1.0"]) + + +def test_parse_nodes_file_skips_invalid_rows() -> None: + """Nodes parser should ignore malformed lines.""" + lines = [ + "header", + "1, 10, 20", + "two, 30, 40", + "3 50 60", + ] + df = _parse_nodes_file(lines) + assert len(df) == 2 + assert df.at[1, "node"] == 3