diff --git a/graphistry/Plottable.py b/graphistry/Plottable.py index 18ed6cc67a..5613b7bee8 100644 --- a/graphistry/Plottable.py +++ b/graphistry/Plottable.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union, Protocol, overload +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, Protocol, overload from typing_extensions import Literal, runtime_checkable import pandas as pd @@ -101,6 +101,7 @@ class Plottable(Protocol): _node_features : Optional[pd.DataFrame] _node_features_raw: Optional[pd.DataFrame] _node_target : Optional[pd.DataFrame] + _node_target_encoder : Optional[Any] _node_target_raw : Optional[pd.DataFrame] _edge_embedding : Optional[pd.DataFrame] @@ -163,6 +164,9 @@ class Plottable(Protocol): _partition_offsets: Optional[Dict[str, Dict[int, float]]] # from gib + def to_file(self, path: str, format: Optional[str] = None) -> Tuple[Any, Any]: + ... + def reset_caches(self) -> None: ... diff --git a/graphistry/PlotterBase.py b/graphistry/PlotterBase.py index 6b4f6f2ac3..76b5d37a06 100644 --- a/graphistry/PlotterBase.py +++ b/graphistry/PlotterBase.py @@ -1823,6 +1823,26 @@ def graph(self, ig: Any) -> Plottable: return res + def to_file(self, path, format=None): + """Save this Plottable graph to disk as a bundle. + + Requires pydantic >= 2.0: ``pip install 'graphistry[serialization]'`` + + :param path: Destination path (directory, or .zip file if format="zip") + :type path: str + :param format: None for directory (default), "zip" for zip archive + :type format: Optional[str] + :returns: Tuple of (self, BundleWriteReport) + + **Example** + :: + + g2, report = g.to_file('/tmp/my_graph') + g2, report = g.to_file('/tmp/my_graph.zip', format='zip') + """ + from graphistry.io.plottable_bundle import to_file as _to_file + return _to_file(self, path, format=format) + def settings(self, height=None, url_params={}, render=None): """Specify iframe height and add URL parameter dictionary. diff --git a/graphistry/__init__.py b/graphistry/__init__.py index 954713b346..46f6d22dce 100644 --- a/graphistry/__init__.py +++ b/graphistry/__init__.py @@ -58,7 +58,8 @@ PyGraphistry, GraphistryClient, from_igraph, - from_cugraph + from_cugraph, + from_file ) from graphistry.compute import ( diff --git a/graphistry/io/__init__.py b/graphistry/io/__init__.py index d560f10caf..36807ad24b 100644 --- a/graphistry/io/__init__.py +++ b/graphistry/io/__init__.py @@ -16,5 +16,8 @@ 'serialize_node_bindings', 'serialize_edge_bindings', 'serialize_node_encodings', - 'serialize_edge_encodings' + 'serialize_edge_encodings', ] + +# Note: to_file and from_file imported lazily via graphistry.io.plottable_bundle +# to avoid requiring pydantic at import time. diff --git a/graphistry/io/bundle.py b/graphistry/io/bundle.py new file mode 100644 index 0000000000..2d00b7eba8 --- /dev/null +++ b/graphistry/io/bundle.py @@ -0,0 +1,234 @@ +""" +Generic bundle engine for serializing/deserializing data bundles. + +Low-level, Plottable-agnostic. Handles file I/O, SHA256 integrity, +parquet read/write, manifest management, and zip/dir format support. +""" +import hashlib +import json +import os +import shutil +import tempfile +import zipfile +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import pandas as pd + + +def _require_pydantic() -> Any: + """Import and return pydantic module, raising clear error if missing.""" + try: + import pydantic + if int(pydantic.VERSION.split('.')[0]) < 2: + raise ImportError( + "graphistry serialization requires pydantic >= 2.0. " + "Install with: pip install 'graphistry[serialization]'" + ) + return pydantic + except ImportError: + raise ImportError( + "graphistry serialization requires pydantic >= 2.0. " + "Install with: pip install 'graphistry[serialization]'" + ) + + +def sha256_file(path: str) -> str: + """Compute SHA256 hex digest of a file.""" + h = hashlib.sha256() + with open(path, 'rb') as f: + for chunk in iter(lambda: f.read(65536), b''): + h.update(chunk) + return h.hexdigest() + + +def sha256_bytes(data: bytes) -> str: + """Compute SHA256 hex digest of bytes.""" + return hashlib.sha256(data).hexdigest() + + +@dataclass +class BundleWriteReport: + """Report from a bundle write operation.""" + warnings: List[str] = field(default_factory=list) + artifacts_written: List[str] = field(default_factory=list) + artifacts_skipped: List[str] = field(default_factory=list) + + def __repr__(self) -> str: + return ( + f"BundleWriteReport(written={len(self.artifacts_written)}, " + f"skipped={len(self.artifacts_skipped)}, " + f"warnings={len(self.warnings)})" + ) + + +@dataclass +class BundleReadReport: + """Report from a bundle read operation.""" + warnings: List[str] = field(default_factory=list) + artifacts_loaded: List[str] = field(default_factory=list) + artifacts_skipped: List[str] = field(default_factory=list) + integrity_ok: bool = True + remote_state_skipped: bool = False + + def __repr__(self) -> str: + return ( + f"BundleReadReport(loaded={len(self.artifacts_loaded)}, " + f"skipped={len(self.artifacts_skipped)}, " + f"integrity_ok={self.integrity_ok}, " + f"warnings={len(self.warnings)})" + ) + + +def _df_to_pandas(df: Any) -> pd.DataFrame: + """Convert a DataFrame to pandas if it's a cuDF DataFrame.""" + if hasattr(df, 'to_pandas'): + try: + return df.to_pandas() + except Exception: + pass + return df + + +def write_df_parquet( + df: Any, + name: str, + bundle_dir: str, + report: BundleWriteReport, +) -> Optional[Dict[str, str]]: + """Write a DataFrame as parquet to bundle_dir/data/{name}.parquet. + + Returns artifact dict {kind, path, sha256} or None on failure. + """ + if df is None: + return None + + data_dir = os.path.join(bundle_dir, 'data') + os.makedirs(data_dir, exist_ok=True) + rel_path = os.path.join('data', f'{name}.parquet') + abs_path = os.path.join(bundle_dir, rel_path) + + try: + pdf = _df_to_pandas(df) + if not isinstance(pdf, pd.DataFrame): + report.warnings.append(f"{name}: not a DataFrame, skipping") + report.artifacts_skipped.append(name) + return None + pdf.to_parquet(abs_path) + except Exception as e: + report.warnings.append(f"{name}: failed to write parquet: {e}") + report.artifacts_skipped.append(name) + return None + + sha = sha256_file(abs_path) + report.artifacts_written.append(name) + return { + 'kind': 'parquet', + 'path': rel_path, + 'sha256': sha, + } + + +def read_df_parquet( + rel_path: str, + bundle_dir: str, + expected_sha: Optional[str], + report: BundleReadReport, +) -> Optional[pd.DataFrame]: + """Read a parquet file from bundle_dir and verify SHA256. + + Returns DataFrame or None on failure. + """ + abs_path = os.path.join(bundle_dir, rel_path) + if not os.path.exists(abs_path): + report.warnings.append(f"File not found: {rel_path}") + return None + + if expected_sha is not None: + actual_sha = sha256_file(abs_path) + if actual_sha != expected_sha: + report.warnings.append( + f"SHA256 mismatch for {rel_path}: " + f"expected {expected_sha[:16]}..., got {actual_sha[:16]}..." + ) + report.integrity_ok = False + + try: + return pd.read_parquet(abs_path) + except Exception as e: + report.warnings.append(f"Failed to read parquet {rel_path}: {e}") + return None + + +def write_manifest(manifest: Dict[str, Any], bundle_dir: str) -> None: + """Write manifest.json to bundle_dir.""" + path = os.path.join(bundle_dir, 'manifest.json') + with open(path, 'w', encoding='utf-8') as f: + json.dump(manifest, f, indent=2, ensure_ascii=False, default=str) + + +def read_manifest(bundle_dir: str) -> Dict[str, Any]: + """Read manifest.json from bundle_dir.""" + path = os.path.join(bundle_dir, 'manifest.json') + with open(path, 'r', encoding='utf-8') as f: + return json.load(f) + + +def dir_to_zip(src_dir: str, zip_path: str) -> None: + """Create a zip archive from a bundle directory.""" + with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf: + for root, _dirs, files in os.walk(src_dir): + for fname in files: + abs_path = os.path.join(root, fname) + rel_path = os.path.relpath(abs_path, src_dir) + zf.write(abs_path, rel_path) + + +def zip_to_dir(zip_path: str, dest_dir: str) -> None: + """Extract a zip archive to dest_dir with zip-slip protection.""" + abs_dest = os.path.realpath(dest_dir) + with zipfile.ZipFile(zip_path, 'r') as zf: + for member in zf.namelist(): + member_path = os.path.realpath(os.path.join(dest_dir, member)) + if not member_path.startswith(abs_dest + os.sep) and member_path != abs_dest: + raise ValueError( + f"Zip-slip detected: {member} would extract outside {dest_dir}" + ) + zf.extractall(dest_dir) + + +def detect_format(path: str) -> str: + """Detect whether path is a directory bundle or a zip archive. + + Returns "dir" or "zip". + Raises FileNotFoundError if path doesn't exist. + """ + if not os.path.exists(path): + raise FileNotFoundError(f"Bundle path does not exist: {path}") + if os.path.isdir(path): + return "dir" + if zipfile.is_zipfile(path): + return "zip" + raise ValueError(f"Path is neither a directory nor a zip file: {path}") + + +def prepare_bundle_dir(path: str, fmt: Optional[str]) -> str: + """Create and return the working bundle directory. + + If fmt is "zip", creates a temp directory that will later be zipped. + Otherwise creates the directory at path directly. + """ + if fmt == "zip": + return tempfile.mkdtemp(prefix="graphistry_bundle_") + else: + os.makedirs(path, exist_ok=True) + return path + + +def finalize_bundle(bundle_dir: str, path: str, fmt: Optional[str]) -> None: + """Finalize the bundle: zip if needed, clean up temp dir.""" + if fmt == "zip": + try: + dir_to_zip(bundle_dir, path) + finally: + shutil.rmtree(bundle_dir, ignore_errors=True) diff --git a/graphistry/io/plottable_bundle.py b/graphistry/io/plottable_bundle.py new file mode 100644 index 0000000000..62adf5c834 --- /dev/null +++ b/graphistry/io/plottable_bundle.py @@ -0,0 +1,499 @@ +""" +Plottable bundle adapter for serializing/deserializing Plottable objects. + +Domain-aware layer that maps Plottable fields to bundle artifacts. +Provides to_file() and from_file() with field group constants and +tripwire-testable canonical field lists. +""" +import copy +import json +import os +import platform +import warnings +from collections import UserDict +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING + +import pandas as pd + +from graphistry.io.bundle import ( + BundleReadReport, + BundleWriteReport, + _require_pydantic, + detect_format, + finalize_bundle, + prepare_bundle_dir, + read_df_parquet, + read_manifest, + sha256_file, + write_df_parquet, + write_manifest, + zip_to_dir, +) + +if TYPE_CHECKING: + from graphistry.Plottable import Plottable + +# --------------------------------------------------------------------------- +# Field group constants — canonical lists for tripwire tests +# --------------------------------------------------------------------------- + +TIER1_DF_FIELDS: List[str] = [ + "_edges", + "_nodes", +] + +TIER1_BINDING_FIELDS: List[str] = [ + "_source", + "_destination", + "_node", + "_edge", + "_edge_title", + "_edge_label", + "_edge_color", + "_edge_source_color", + "_edge_destination_color", + "_edge_size", + "_edge_weight", + "_edge_icon", + "_edge_opacity", + "_point_title", + "_point_label", + "_point_color", + "_point_size", + "_point_weight", + "_point_icon", + "_point_opacity", + "_point_x", + "_point_y", + "_point_longitude", + "_point_latitude", +] + +TIER1_DISPLAY_FIELDS: List[str] = [ + "_height", + "_render", + "_url_params", + "_name", + "_description", + "_style", + "_complex_encodings", +] + +TIER1_REMOTE_FIELDS: List[str] = [ + "_dataset_id", + "_url", + "_nodes_file_id", + "_edges_file_id", + "_privacy", +] + +TIER2_DF_FIELDS: List[str] = [ + "_node_embedding", + "_node_features", + "_node_features_raw", + "_node_target", + "_node_target_raw", + "_edge_embedding", + "_edge_features", + "_edge_features_raw", + "_edge_target", + "_edge_target_raw", + "_weighted_edges_df", + "_weighted_edges_df_from_nodes", + "_weighted_edges_df_from_edges", + "_xy", +] + +TIER2_JSON_ALGO_FIELDS: List[str] = [ + "_umap_engine", + "_umap_params", + "_umap_fit_kwargs", + "_umap_transform_kwargs", + "_n_components", + "_metric", + "_n_neighbors", + "_min_dist", + "_spread", + "_local_connectivity", + "_repulsion_strength", + "_negative_sample_rate", + "_suffix", + "_dbscan_engine", + "_dbscan_params", + "_collapse_node_col", + "_collapse_src_col", + "_collapse_dst_col", +] + +TIER2_JSON_KG_FIELDS: List[str] = [ + "_relation", + "_use_feat", + "_triplets", + "_kg_embed_dim", +] + +TIER2_JSON_LAYOUT_FIELDS: List[str] = [ + "_partition_offsets", +] + +TIER2_JSON_INDEX_FIELDS: List[str] = [ + "_entity_to_index", + "_index_to_entity", +] + +# Objects that can't be serialized as JSON/parquet in v1 +TIER3_FIELDS: List[str] = [ + "_umap", + "_node_encoder", + "_node_target_encoder", + "_edge_encoder", + "_weighted_adjacency", + "_weighted_adjacency_nodes", + "_weighted_adjacency_edges", + "_adjacency", + "_dbscan_nodes", + "_dbscan_edges", +] + +# Never serialized: session/auth/driver/DGL state +NEVER_FIELDS: List[str] = [ + "session", + "_pygraphistry", + "_bolt_driver", + "_tigergraph", + "DGL_graph", + "_dgl_graph", +] + +ALL_KNOWN_FIELDS: List[str] = sorted(set( + TIER1_DF_FIELDS + + TIER1_BINDING_FIELDS + + TIER1_DISPLAY_FIELDS + + TIER1_REMOTE_FIELDS + + TIER2_DF_FIELDS + + TIER2_JSON_ALGO_FIELDS + + TIER2_JSON_KG_FIELDS + + TIER2_JSON_LAYOUT_FIELDS + + TIER2_JSON_INDEX_FIELDS + + TIER3_FIELDS + + NEVER_FIELDS +)) + + +SCHEMA_VERSION = "1.0" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _safe_json_val(val: Any) -> Any: + """Convert values that aren't directly JSON-serializable.""" + if val is None: + return None + if isinstance(val, UserDict): + return dict(val) + if isinstance(val, dict): + return {k: _safe_json_val(v) for k, v in val.items()} + if isinstance(val, (list, tuple)): + return [_safe_json_val(v) for v in val] + if isinstance(val, (str, int, float, bool)): + return val + # Fall back to str for unserializable types + try: + json.dumps(val) + return val + except (TypeError, ValueError): + return str(val) + + +def _get_field(g: 'Plottable', field_name: str) -> Any: + """Safely get a field from a Plottable, returning None if missing.""" + try: + return getattr(g, field_name) + except AttributeError: + return None + + +def _collect_json_fields( + g: 'Plottable', field_names: List[str] +) -> Dict[str, Any]: + """Collect fields as a JSON-safe dict, omitting None values.""" + out: Dict[str, Any] = {} + for name in field_names: + val = _get_field(g, name) + if val is not None: + out[name] = _safe_json_val(val) + return out + + +# --------------------------------------------------------------------------- +# to_file +# --------------------------------------------------------------------------- + +def to_file( + g: 'Plottable', + path: str, + format: Optional[str] = None, +) -> Tuple['Plottable', BundleWriteReport]: + """Save a Plottable graph to disk as a bundle (directory or zip). + + :param g: Plottable object to serialize + :param path: Destination path (directory or .zip file) + :param format: None for directory (default), "zip" for zip archive + :return: Tuple of (original Plottable, BundleWriteReport) + :raises RuntimeError: If edges DataFrame is missing + :raises ImportError: If pydantic >= 2.0 is not installed + """ + _require_pydantic() + report = BundleWriteReport() + + if _get_field(g, '_edges') is None: + raise RuntimeError( + "Cannot save bundle: edges DataFrame is required. " + "Set edges with g.edges(df, 'src', 'dst') first." + ) + + bundle_dir = prepare_bundle_dir(path, format) + + try: + artifacts: Dict[str, Dict[str, str]] = {} + files: Dict[str, str] = {} + + # --- Tier 1 DataFrames --- + edges_art = write_df_parquet(g._edges, '_edges', bundle_dir, report) + if edges_art is None: + raise RuntimeError("Failed to write edges parquet — aborting bundle save") + artifacts['_edges'] = edges_art + files[edges_art['path']] = edges_art['sha256'] + + nodes_art = write_df_parquet( + _get_field(g, '_nodes'), '_nodes', bundle_dir, report + ) + if nodes_art is not None: + artifacts['_nodes'] = nodes_art + files[nodes_art['path']] = nodes_art['sha256'] + + # --- Tier 2 DataFrames --- + for field_name in TIER2_DF_FIELDS: + val = _get_field(g, field_name) + if val is not None and isinstance(val, pd.DataFrame): + art = write_df_parquet(val, field_name, bundle_dir, report) + if art is not None: + artifacts[field_name] = art + files[art['path']] = art['sha256'] + elif val is not None: + report.warnings.append( + f"{field_name}: not a DataFrame ({type(val).__name__}), skipping" + ) + report.artifacts_skipped.append(field_name) + + # --- Build manifest --- + from graphistry.io.metadata import serialize_plottable_metadata + try: + from graphistry._version import get_versions + graphistry_version = get_versions()["version"] + except Exception: + graphistry_version = "unknown" + + manifest: Dict[str, Any] = { + "schema_version": SCHEMA_VERSION, + "created_at": datetime.now(timezone.utc).isoformat(), + "python_version": platform.python_version(), + "graphistry_version": graphistry_version, + "plottable_metadata": serialize_plottable_metadata(g), + "settings": { + "height": _get_field(g, '_height'), + "render": _get_field(g, '_render'), + "url_params": _safe_json_val(_get_field(g, '_url_params')), + }, + "remote": { + "dataset_id": _get_field(g, '_dataset_id'), + "url": _get_field(g, '_url'), + "nodes_file_id": _get_field(g, '_nodes_file_id'), + "edges_file_id": _get_field(g, '_edges_file_id'), + "privacy": _safe_json_val(_get_field(g, '_privacy')), + }, + "algorithm_config": _collect_json_fields(g, TIER2_JSON_ALGO_FIELDS), + "kg_config": _collect_json_fields(g, TIER2_JSON_KG_FIELDS), + "layout": _collect_json_fields(g, TIER2_JSON_LAYOUT_FIELDS), + "graph_indices": _collect_json_fields(g, TIER2_JSON_INDEX_FIELDS), + "artifacts": artifacts, + "files": files, + } + + write_manifest(manifest, bundle_dir) + finalize_bundle(bundle_dir, path, format) + + except Exception: + # Clean up temp dir on failure if zipping + if format == "zip" and bundle_dir != path: + import shutil + shutil.rmtree(bundle_dir, ignore_errors=True) + raise + + return (g, report) + + +# --------------------------------------------------------------------------- +# from_file +# --------------------------------------------------------------------------- + +def from_file( + path: str, + restore_remote: bool = False, +) -> Tuple['Plottable', BundleReadReport]: + """Load a Plottable graph from a bundle on disk. + + :param path: Path to bundle directory or .zip file + :param restore_remote: If True, restore remote server state (dataset_id, url, etc.) + :return: Tuple of (Plottable, BundleReadReport) + :raises ImportError: If pydantic >= 2.0 is not installed + :raises FileNotFoundError: If path doesn't exist + """ + _require_pydantic() + report = BundleReadReport() + + fmt = detect_format(path) + bundle_dir = path + tmp_dir = None + + try: + if fmt == "zip": + import tempfile + tmp_dir = tempfile.mkdtemp(prefix="graphistry_bundle_read_") + zip_to_dir(path, tmp_dir) + bundle_dir = tmp_dir + + manifest = read_manifest(bundle_dir) + artifacts = manifest.get("artifacts", {}) + file_shas = manifest.get("files", {}) + + # --- Verify file integrity --- + for rel_path, expected_sha in file_shas.items(): + abs_path = os.path.join(bundle_dir, rel_path) + if not os.path.exists(abs_path): + report.warnings.append(f"Missing file: {rel_path}") + report.integrity_ok = False + continue + actual_sha = sha256_file(abs_path) + if actual_sha != expected_sha: + report.warnings.append( + f"SHA256 mismatch: {rel_path} " + f"(expected {expected_sha[:16]}..., got {actual_sha[:16]}...)" + ) + report.integrity_ok = False + + # --- Hydration sequence --- + import graphistry + from graphistry.io.metadata import deserialize_plottable_metadata + + g: 'Plottable' = graphistry.bind() # type: ignore[assignment] + + # Load edges + if '_edges' in artifacts: + art = artifacts['_edges'] + edges_df = read_df_parquet( + art['path'], bundle_dir, art.get('sha256'), report + ) + if edges_df is not None: + report.artifacts_loaded.append('_edges') + g = g.edges(edges_df) + else: + report.warnings.append("No edges artifact in manifest") + + # Load nodes + if '_nodes' in artifacts: + art = artifacts['_nodes'] + nodes_df = read_df_parquet( + art['path'], bundle_dir, art.get('sha256'), report + ) + if nodes_df is not None: + report.artifacts_loaded.append('_nodes') + g = g.nodes(nodes_df) + + # Apply plottable metadata (bindings, encodings, name, desc, style) + plottable_metadata = manifest.get("plottable_metadata", {}) + if plottable_metadata: + g = deserialize_plottable_metadata(plottable_metadata, g) # type: ignore[assignment] + + # Create final copy for direct field assignment + result = copy.copy(g) + + # Restore settings directly (avoid .settings() merge semantics) + settings = manifest.get("settings", {}) + if "height" in settings and settings["height"] is not None: + result._height = settings["height"] + if "render" in settings and settings["render"] is not None: + result._render = settings["render"] + if "url_params" in settings and settings["url_params"] is not None: + result._url_params = settings["url_params"] + + # --- Load Tier 2 DF artifacts --- + for field_name in TIER2_DF_FIELDS: + if field_name in artifacts: + art = artifacts[field_name] + df = read_df_parquet( + art['path'], bundle_dir, art.get('sha256'), report + ) + if df is not None: + setattr(result, field_name, df) + report.artifacts_loaded.append(field_name) + else: + report.artifacts_skipped.append(field_name) + + # --- Restore Tier 2 JSON fields --- + algo_config = manifest.get("algorithm_config", {}) + for field_name in TIER2_JSON_ALGO_FIELDS: + if field_name in algo_config: + setattr(result, field_name, algo_config[field_name]) + + kg_config = manifest.get("kg_config", {}) + for field_name in TIER2_JSON_KG_FIELDS: + if field_name in kg_config: + setattr(result, field_name, kg_config[field_name]) + + layout = manifest.get("layout", {}) + for field_name in TIER2_JSON_LAYOUT_FIELDS: + if field_name in layout: + setattr(result, field_name, layout[field_name]) + + graph_indices = manifest.get("graph_indices", {}) + for field_name in TIER2_JSON_INDEX_FIELDS: + if field_name in graph_indices: + setattr(result, field_name, graph_indices[field_name]) + + # --- Remote state --- + remote = manifest.get("remote", {}) + has_remote = any( + remote.get(k) is not None + for k in ["dataset_id", "url", "nodes_file_id", "edges_file_id"] + ) + if has_remote: + if restore_remote: + if remote.get("dataset_id") is not None: + result._dataset_id = remote["dataset_id"] + if remote.get("url") is not None: + result._url = remote["url"] + if remote.get("nodes_file_id") is not None: + result._nodes_file_id = remote["nodes_file_id"] + if remote.get("edges_file_id") is not None: + result._edges_file_id = remote["edges_file_id"] + if remote.get("privacy") is not None: + result._privacy = remote["privacy"] + else: + report.remote_state_skipped = True + warnings.warn( + "Bundle contains remote server state (dataset_id, url, etc.) " + "which was not restored. Pass restore_remote=True to restore it.", + UserWarning, + stacklevel=2, + ) + + return (result, report) + + finally: + if tmp_dir is not None: + import shutil + shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/graphistry/pygraphistry.py b/graphistry/pygraphistry.py index 6a8ae4aaa9..0efcc6db91 100644 --- a/graphistry/pygraphistry.py +++ b/graphistry/pygraphistry.py @@ -2533,6 +2533,27 @@ def _handle_api_response(self, response): switch_org = PyGraphistry.switch_org +def from_file(path, restore_remote=False): + """Load a Plottable graph from a bundle on disk. + + Requires pydantic >= 2.0: ``pip install 'graphistry[serialization]'`` + + :param path: Path to bundle directory or .zip file + :type path: str + :param restore_remote: If True, restore remote server state (dataset_id, url, etc.) + :type restore_remote: bool + :returns: Tuple of (Plottable, BundleReadReport) + + **Example** + :: + + g, report = graphistry.from_file('/tmp/my_graph') + g, report = graphistry.from_file('/tmp/my_graph.zip', restore_remote=True) + """ + from graphistry.io.plottable_bundle import from_file as _from_file + return _from_file(path, restore_remote=restore_remote) + + class NumpyJSONEncoder(json.JSONEncoder): def default(self, o): diff --git a/graphistry/tests/fixtures/generate_v1_bundle.py b/graphistry/tests/fixtures/generate_v1_bundle.py new file mode 100644 index 0000000000..43c65ff438 --- /dev/null +++ b/graphistry/tests/fixtures/generate_v1_bundle.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +"""Generate the v1_bundle golden fixture. + +Run once to create/refresh the fixture files: + python graphistry/tests/fixtures/generate_v1_bundle.py +""" +import hashlib +import json +import os + +import pandas as pd + + +def sha256_file(path): + h = hashlib.sha256() + with open(path, 'rb') as f: + for chunk in iter(lambda: f.read(65536), b''): + h.update(chunk) + return h.hexdigest() + + +def main(): + here = os.path.dirname(os.path.abspath(__file__)) + fixture_dir = os.path.join(here, 'v1_bundle') + data_dir = os.path.join(fixture_dir, 'data') + os.makedirs(data_dir, exist_ok=True) + + # Create edges + edges_df = pd.DataFrame({ + 's': ['a', 'b', 'c'], + 'd': ['b', 'c', 'a'], + 'w': [1.0, 2.0, 3.0], + }) + edges_path = os.path.join(data_dir, '_edges.parquet') + edges_df.to_parquet(edges_path) + + # Create nodes + nodes_df = pd.DataFrame({ + 'id': ['a', 'b', 'c'], + 'label': ['Node A', 'Node B', 'Node C'], + }) + nodes_path = os.path.join(data_dir, '_nodes.parquet') + nodes_df.to_parquet(nodes_path) + + # Create xy (tier 2) + xy_df = pd.DataFrame({ + 'x': [0.1, 0.2, 0.3], + 'y': [0.4, 0.5, 0.6], + }) + xy_path = os.path.join(data_dir, '_xy.parquet') + xy_df.to_parquet(xy_path) + + edges_sha = sha256_file(edges_path) + nodes_sha = sha256_file(nodes_path) + xy_sha = sha256_file(xy_path) + + manifest = { + 'schema_version': '1.0', + 'created_at': '2025-01-01T00:00:00+00:00', + 'python_version': '3.10.0', + 'graphistry_version': '0.35.0', + 'plottable_metadata': { + 'bindings': { + 'source': 's', + 'destination': 'd', + 'node': 'id', + 'edge_weight': 'w', + }, + 'encodings': { + 'edge_weight': 'w', + }, + 'metadata': { + 'name': 'Golden Test Graph', + 'description': 'A test graph for v1 bundle compatibility', + }, + }, + 'settings': { + 'height': 600, + 'render': 'g', + 'url_params': {'info': 'true', 'play': '2000'}, + }, + 'remote': { + 'dataset_id': 'golden_dataset_123', + 'url': 'https://hub.graphistry.com/graph/golden_dataset_123', + 'nodes_file_id': None, + 'edges_file_id': None, + 'privacy': None, + }, + 'algorithm_config': { + '_n_components': 2, + '_metric': 'euclidean', + }, + 'kg_config': {}, + 'layout': {}, + 'graph_indices': {}, + 'artifacts': { + '_edges': { + 'kind': 'parquet', + 'path': 'data/_edges.parquet', + 'sha256': edges_sha, + }, + '_nodes': { + 'kind': 'parquet', + 'path': 'data/_nodes.parquet', + 'sha256': nodes_sha, + }, + '_xy': { + 'kind': 'parquet', + 'path': 'data/_xy.parquet', + 'sha256': xy_sha, + }, + }, + 'files': { + 'data/_edges.parquet': edges_sha, + 'data/_nodes.parquet': nodes_sha, + 'data/_xy.parquet': xy_sha, + }, + } + + manifest_path = os.path.join(fixture_dir, 'manifest.json') + with open(manifest_path, 'w') as f: + json.dump(manifest, f, indent=2) + + print(f'Golden fixture generated at {fixture_dir}') + print(f' edges sha: {edges_sha}') + print(f' nodes sha: {nodes_sha}') + print(f' xy sha: {xy_sha}') + + +if __name__ == '__main__': + main() diff --git a/graphistry/tests/fixtures/v1_bundle/data/_edges.parquet b/graphistry/tests/fixtures/v1_bundle/data/_edges.parquet new file mode 100644 index 0000000000..8f8b1d7d1f Binary files /dev/null and b/graphistry/tests/fixtures/v1_bundle/data/_edges.parquet differ diff --git a/graphistry/tests/fixtures/v1_bundle/data/_nodes.parquet b/graphistry/tests/fixtures/v1_bundle/data/_nodes.parquet new file mode 100644 index 0000000000..c01b55142b Binary files /dev/null and b/graphistry/tests/fixtures/v1_bundle/data/_nodes.parquet differ diff --git a/graphistry/tests/fixtures/v1_bundle/data/_xy.parquet b/graphistry/tests/fixtures/v1_bundle/data/_xy.parquet new file mode 100644 index 0000000000..02e6cff253 Binary files /dev/null and b/graphistry/tests/fixtures/v1_bundle/data/_xy.parquet differ diff --git a/graphistry/tests/fixtures/v1_bundle/manifest.json b/graphistry/tests/fixtures/v1_bundle/manifest.json new file mode 100644 index 0000000000..65127744fc --- /dev/null +++ b/graphistry/tests/fixtures/v1_bundle/manifest.json @@ -0,0 +1,65 @@ +{ + "schema_version": "1.0", + "created_at": "2025-01-01T00:00:00+00:00", + "python_version": "3.10.0", + "graphistry_version": "0.35.0", + "plottable_metadata": { + "bindings": { + "source": "s", + "destination": "d", + "node": "id", + "edge_weight": "w" + }, + "encodings": { + "edge_weight": "w" + }, + "metadata": { + "name": "Golden Test Graph", + "description": "A test graph for v1 bundle compatibility" + } + }, + "settings": { + "height": 600, + "render": "g", + "url_params": { + "info": "true", + "play": "2000" + } + }, + "remote": { + "dataset_id": "golden_dataset_123", + "url": "https://hub.graphistry.com/graph/golden_dataset_123", + "nodes_file_id": null, + "edges_file_id": null, + "privacy": null + }, + "algorithm_config": { + "_n_components": 2, + "_metric": "euclidean" + }, + "kg_config": {}, + "layout": {}, + "graph_indices": {}, + "artifacts": { + "_edges": { + "kind": "parquet", + "path": "data/_edges.parquet", + "sha256": "9cc1ca161bd8c82d2144c2d25fa1fc878794c9b4c448e9245f4eb40fc975ce00" + }, + "_nodes": { + "kind": "parquet", + "path": "data/_nodes.parquet", + "sha256": "63412ad383eb6e74bcecc390bcf3c86e72cf18224d00b54b223be1c133219280" + }, + "_xy": { + "kind": "parquet", + "path": "data/_xy.parquet", + "sha256": "503a6aa07ff3efd3be70e7ef400a44f97aa43fef0c60146d679848a68268fa48" + } + }, + "files": { + "data/_edges.parquet": "9cc1ca161bd8c82d2144c2d25fa1fc878794c9b4c448e9245f4eb40fc975ce00", + "data/_nodes.parquet": "63412ad383eb6e74bcecc390bcf3c86e72cf18224d00b54b223be1c133219280", + "data/_xy.parquet": "503a6aa07ff3efd3be70e7ef400a44f97aa43fef0c60146d679848a68268fa48" + } +} \ No newline at end of file diff --git a/graphistry/tests/test_io_bundle.py b/graphistry/tests/test_io_bundle.py new file mode 100644 index 0000000000..3d3dc71a92 --- /dev/null +++ b/graphistry/tests/test_io_bundle.py @@ -0,0 +1,209 @@ +"""Tests for graphistry.io.bundle — generic bundle engine.""" +import hashlib +import json +import os +import tempfile +import unittest +import zipfile + +import pandas as pd + +from graphistry.io.bundle import ( + BundleReadReport, + BundleWriteReport, + _require_pydantic, + detect_format, + dir_to_zip, + read_df_parquet, + read_manifest, + sha256_bytes, + sha256_file, + write_df_parquet, + write_manifest, + zip_to_dir, +) + + +class TestSHA256(unittest.TestCase): + def test_sha256_bytes_consistency(self): + data = b"hello world" + expected = hashlib.sha256(data).hexdigest() + self.assertEqual(sha256_bytes(data), expected) + + def test_sha256_file_consistency(self): + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(b"test data for hashing") + f.flush() + path = f.name + try: + expected = hashlib.sha256(b"test data for hashing").hexdigest() + self.assertEqual(sha256_file(path), expected) + finally: + os.unlink(path) + + def test_sha256_empty(self): + result = sha256_bytes(b"") + expected = hashlib.sha256(b"").hexdigest() + self.assertEqual(result, expected) + + +class TestParquetRoundtrip(unittest.TestCase): + def test_write_read_roundtrip(self): + df = pd.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + with tempfile.TemporaryDirectory() as td: + report = BundleWriteReport() + art = write_df_parquet(df, "test_df", td, report) + self.assertIsNotNone(art) + self.assertEqual(art["kind"], "parquet") + self.assertIn("test_df", report.artifacts_written) + + read_report = BundleReadReport() + result = read_df_parquet( + art["path"], td, art["sha256"], read_report + ) + self.assertIsNotNone(result) + pd.testing.assert_frame_equal(result, df) + self.assertTrue(read_report.integrity_ok) + + def test_write_none_returns_none(self): + with tempfile.TemporaryDirectory() as td: + report = BundleWriteReport() + art = write_df_parquet(None, "missing", td, report) + self.assertIsNone(art) + + def test_sha_mismatch_detected(self): + df = pd.DataFrame({"x": [1]}) + with tempfile.TemporaryDirectory() as td: + report = BundleWriteReport() + art = write_df_parquet(df, "test", td, report) + + read_report = BundleReadReport() + result = read_df_parquet( + art["path"], td, "bad_sha256_value", read_report + ) + # Should still return data but flag integrity issue + self.assertIsNotNone(result) + self.assertFalse(read_report.integrity_ok) + self.assertTrue(any("mismatch" in w.lower() for w in read_report.warnings)) + + def test_non_dataframe_skipped(self): + with tempfile.TemporaryDirectory() as td: + report = BundleWriteReport() + art = write_df_parquet("not a dataframe", "bad", td, report) + self.assertIsNone(art) + self.assertIn("bad", report.artifacts_skipped) + + +class TestManifest(unittest.TestCase): + def test_write_read_roundtrip(self): + manifest = { + "schema_version": "1.0", + "artifacts": {"_edges": {"kind": "parquet", "path": "data/_edges.parquet"}}, + "nested": {"key": [1, 2, 3]}, + } + with tempfile.TemporaryDirectory() as td: + write_manifest(manifest, td) + result = read_manifest(td) + self.assertEqual(result, manifest) + + +class TestZipRoundtrip(unittest.TestCase): + def test_dir_to_zip_to_dir(self): + with tempfile.TemporaryDirectory() as src: + # Create some files + os.makedirs(os.path.join(src, "data")) + with open(os.path.join(src, "manifest.json"), "w") as f: + json.dump({"test": True}, f) + with open(os.path.join(src, "data", "file.txt"), "w") as f: + f.write("hello") + + with tempfile.TemporaryDirectory() as tmp: + zip_path = os.path.join(tmp, "test.zip") + dir_to_zip(src, zip_path) + self.assertTrue(os.path.exists(zip_path)) + + dest = os.path.join(tmp, "extracted") + os.makedirs(dest) + zip_to_dir(zip_path, dest) + + # Verify contents + with open(os.path.join(dest, "manifest.json")) as f: + self.assertEqual(json.load(f), {"test": True}) + with open(os.path.join(dest, "data", "file.txt")) as f: + self.assertEqual(f.read(), "hello") + + def test_zip_slip_protection(self): + with tempfile.TemporaryDirectory() as tmp: + zip_path = os.path.join(tmp, "evil.zip") + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("../../../etc/evil.txt", "malicious") + + dest = os.path.join(tmp, "safe") + os.makedirs(dest) + with self.assertRaises(ValueError) as ctx: + zip_to_dir(zip_path, dest) + self.assertIn("Zip-slip", str(ctx.exception)) + + +class TestDetectFormat(unittest.TestCase): + def test_detect_dir(self): + with tempfile.TemporaryDirectory() as td: + self.assertEqual(detect_format(td), "dir") + + def test_detect_zip(self): + with tempfile.TemporaryDirectory() as tmp: + zip_path = os.path.join(tmp, "test.zip") + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("test.txt", "hello") + self.assertEqual(detect_format(zip_path), "zip") + + def test_nonexistent_raises(self): + with self.assertRaises(FileNotFoundError): + detect_format("/nonexistent/path/xyz") + + def test_non_zip_file_raises(self): + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f: + f.write(b"not a zip") + path = f.name + try: + with self.assertRaises(ValueError): + detect_format(path) + finally: + os.unlink(path) + + +class TestRequirePydantic(unittest.TestCase): + def test_pydantic_importable(self): + """If pydantic is installed, _require_pydantic should succeed.""" + try: + pydantic = _require_pydantic() + self.assertTrue(hasattr(pydantic, 'BaseModel')) + except ImportError: + # If pydantic not installed, verify error message + try: + _require_pydantic() + except ImportError as e: + self.assertIn("pydantic", str(e)) + self.assertIn("serialization", str(e)) + + +class TestReports(unittest.TestCase): + def test_write_report_repr(self): + r = BundleWriteReport() + r.artifacts_written.append("a") + r.artifacts_skipped.append("b") + r.warnings.append("w") + s = repr(r) + self.assertIn("written=1", s) + self.assertIn("skipped=1", s) + + def test_read_report_repr(self): + r = BundleReadReport() + r.artifacts_loaded.append("a") + s = repr(r) + self.assertIn("loaded=1", s) + self.assertIn("integrity_ok=True", s) + + +if __name__ == "__main__": + unittest.main() diff --git a/graphistry/tests/test_io_plottable_bundle.py b/graphistry/tests/test_io_plottable_bundle.py new file mode 100644 index 0000000000..e90cf1122d --- /dev/null +++ b/graphistry/tests/test_io_plottable_bundle.py @@ -0,0 +1,362 @@ +"""Tests for graphistry.io.plottable_bundle — Plottable serialization/hydration.""" +import copy +import inspect +import os +import tempfile +import typing +import unittest +import warnings + +import pandas as pd + +import graphistry +from graphistry.io.plottable_bundle import ( + ALL_KNOWN_FIELDS, + NEVER_FIELDS, + SCHEMA_VERSION, + TIER1_BINDING_FIELDS, + TIER1_DF_FIELDS, + TIER1_DISPLAY_FIELDS, + TIER1_REMOTE_FIELDS, + TIER2_DF_FIELDS, + TIER2_JSON_ALGO_FIELDS, + TIER2_JSON_INDEX_FIELDS, + TIER2_JSON_KG_FIELDS, + TIER2_JSON_LAYOUT_FIELDS, + TIER3_FIELDS, + from_file, + to_file, +) +from graphistry.Plottable import Plottable +from graphistry.PlotterBase import PlotterBase + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_graph(): + """Build a basic graph for testing.""" + edges = pd.DataFrame({'s': ['a', 'b', 'c'], 'd': ['b', 'c', 'a'], 'w': [1, 2, 3]}) + nodes = pd.DataFrame({'id': ['a', 'b', 'c'], 'label': ['A', 'B', 'C']}) + g = graphistry.edges(edges, 's', 'd').nodes(nodes, 'id') + g = g.bind(edge_weight='w').name('Test Graph').description('A test') + return g + + +def _make_graph_with_tier2(): + """Build a graph with Tier 2 fields populated.""" + g = _make_graph() + result = copy.copy(g) + result._xy = pd.DataFrame({'x': [0.1, 0.2, 0.3], 'y': [0.4, 0.5, 0.6]}) + result._node_embedding = pd.DataFrame({'e0': [1.0, 2.0, 3.0], 'e1': [4.0, 5.0, 6.0]}) + result._n_components = 5 + result._metric = 'cosine' + return result + + +# --------------------------------------------------------------------------- +# Roundtrip tests +# --------------------------------------------------------------------------- + +class TestRoundtripDir(unittest.TestCase): + """Test save/load roundtrip with directory format.""" + + def test_basic_roundtrip(self): + g = _make_graph() + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'bundle') + g2, wr = to_file(g, path) + self.assertIs(g2, g) + self.assertIn('_edges', wr.artifacts_written) + self.assertIn('_nodes', wr.artifacts_written) + + g3, rr = from_file(path) + self.assertTrue(rr.integrity_ok) + + # Verify Tier 1 bindings + self.assertEqual(g3._source, 's') + self.assertEqual(g3._destination, 'd') + self.assertEqual(g3._node, 'id') + self.assertEqual(g3._edge_weight, 'w') + self.assertEqual(g3._name, 'Test Graph') + self.assertEqual(g3._description, 'A test') + + # Verify edges data + pd.testing.assert_frame_equal(g3._edges, g._edges) + pd.testing.assert_frame_equal(g3._nodes, g._nodes) + + def test_settings_roundtrip(self): + g = _make_graph() + g = g.settings(height=800, url_params={'play': '0', 'info': 'true'}) + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'bundle') + to_file(g, path) + g3, _ = from_file(path) + self.assertEqual(g3._height, 800) + self.assertEqual(g3._url_params, {'play': '0', 'info': 'true'}) + + +class TestRoundtripZip(unittest.TestCase): + """Test save/load roundtrip with zip format.""" + + def test_zip_roundtrip(self): + g = _make_graph() + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'bundle.zip') + g2, wr = to_file(g, path, format='zip') + self.assertIn('_edges', wr.artifacts_written) + + g3, rr = from_file(path) + self.assertTrue(rr.integrity_ok) + self.assertEqual(g3._source, 's') + self.assertEqual(g3._name, 'Test Graph') + pd.testing.assert_frame_equal(g3._edges, g._edges) + + +class TestTier2Roundtrip(unittest.TestCase): + """Test that Tier 2 DataFrames and JSON fields round-trip.""" + + def test_tier2_df_roundtrip(self): + g = _make_graph_with_tier2() + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'bundle') + _, wr = to_file(g, path) + self.assertIn('_xy', wr.artifacts_written) + self.assertIn('_node_embedding', wr.artifacts_written) + + g3, rr = from_file(path) + self.assertIn('_xy', rr.artifacts_loaded) + self.assertIn('_node_embedding', rr.artifacts_loaded) + pd.testing.assert_frame_equal(g3._xy, g._xy) + pd.testing.assert_frame_equal(g3._node_embedding, g._node_embedding) + + def test_tier2_json_roundtrip(self): + g = _make_graph_with_tier2() + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'bundle') + to_file(g, path) + g3, _ = from_file(path) + self.assertEqual(g3._n_components, 5) + self.assertEqual(g3._metric, 'cosine') + + +# --------------------------------------------------------------------------- +# Remote state tests +# --------------------------------------------------------------------------- + +class TestRemoteState(unittest.TestCase): + + def test_remote_dropped_by_default(self): + g = _make_graph() + result = copy.copy(g) + result._dataset_id = 'test_ds' + result._url = 'https://example.com/graph' + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'bundle') + with warnings.catch_warnings(record=True): + to_file(result, path) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + g3, rr = from_file(path) + self.assertTrue(rr.remote_state_skipped) + self.assertIsNone(g3._dataset_id) + self.assertIsNone(g3._url) + # Check warning was issued + remote_warnings = [w for w in caught if 'remote server state' in str(w.message).lower()] + self.assertTrue(len(remote_warnings) > 0) + + def test_remote_restored_when_requested(self): + g = _make_graph() + result = copy.copy(g) + result._dataset_id = 'test_ds' + result._url = 'https://example.com/graph' + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'bundle') + to_file(result, path) + g3, rr = from_file(path, restore_remote=True) + self.assertFalse(rr.remote_state_skipped) + self.assertEqual(g3._dataset_id, 'test_ds') + self.assertEqual(g3._url, 'https://example.com/graph') + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + +class TestEdgeCases(unittest.TestCase): + + def test_edges_required(self): + g = graphistry.bind() + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'bundle') + with self.assertRaises(RuntimeError) as ctx: + to_file(g, path) + self.assertIn("edges", str(ctx.exception).lower()) + + def test_partial_failure_tier2(self): + """Non-DF in Tier 2 slot → skip + warn, rest succeed.""" + g = _make_graph() + result = copy.copy(g) + result._node_embedding = "not a dataframe" + result._xy = pd.DataFrame({'x': [1.0], 'y': [2.0]}) + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'bundle') + _, wr = to_file(result, path) + self.assertIn('_node_embedding', wr.artifacts_skipped) + self.assertIn('_xy', wr.artifacts_written) + + def test_method_api(self): + """Test that g.to_file() works as a method.""" + g = _make_graph() + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'bundle') + g2, wr = g.to_file(path) + self.assertIs(g2, g) + g3, rr = graphistry.from_file(path) + self.assertEqual(g3._source, 's') + + +# --------------------------------------------------------------------------- +# Tripwire tests +# --------------------------------------------------------------------------- + +class TestTripwire(unittest.TestCase): + """Detect new fields added to Plottable/PlotterBase that aren't + accounted for in the serialization field groups.""" + + def _get_protocol_fields(self): + """Extract field names from Plottable Protocol annotations.""" + hints = typing.get_type_hints(Plottable) + return set(hints.keys()) + + def _get_init_fields(self): + """Extract field names set in PlotterBase.__init__.""" + source = inspect.getsource(PlotterBase.__init__) + fields = set() + for line in source.split('\n'): + stripped = line.strip() + if stripped.startswith('self.') and '=' in stripped: + field = stripped.split('=')[0].strip().replace('self.', '') + # Strip type annotations + if ':' in field: + field = field.split(':')[0].strip() + fields.add(field) + return fields + + def test_no_unknown_protocol_fields(self): + """Fail if Plottable Protocol has fields not in ALL_KNOWN_FIELDS.""" + protocol_fields = self._get_protocol_fields() + known = set(ALL_KNOWN_FIELDS) + # Exclude methods (only looking at data fields) + unknown = protocol_fields - known + # Filter out methods by checking if they're callable on the protocol + data_unknown = set() + for f in unknown: + hint = typing.get_type_hints(Plottable).get(f) + # If it's a callable type, skip it + origin = getattr(hint, '__origin__', None) + if origin is not None and origin is type: + continue + data_unknown.add(f) + self.assertEqual( + data_unknown, set(), + f"New fields in Plottable Protocol not in ALL_KNOWN_FIELDS: {data_unknown}. " + "Add them to the appropriate tier in plottable_bundle.py." + ) + + def test_no_unknown_init_fields(self): + """Fail if PlotterBase.__init__ sets fields not in ALL_KNOWN_FIELDS.""" + init_fields = self._get_init_fields() + known = set(ALL_KNOWN_FIELDS) + unknown = init_fields - known + self.assertEqual( + unknown, set(), + f"New fields in PlotterBase.__init__ not in ALL_KNOWN_FIELDS: {unknown}. " + "Add them to the appropriate tier in plottable_bundle.py." + ) + + def test_no_phantom_fields(self): + """Fail if ALL_KNOWN_FIELDS has fields not in Protocol or PlotterBase.__init__.""" + protocol_fields = self._get_protocol_fields() + init_fields = self._get_init_fields() + all_real = protocol_fields | init_fields + known = set(ALL_KNOWN_FIELDS) + phantom = known - all_real + self.assertEqual( + phantom, set(), + f"Phantom fields in ALL_KNOWN_FIELDS not in Protocol or __init__: {phantom}. " + "Remove them from plottable_bundle.py." + ) + + +# --------------------------------------------------------------------------- +# Golden fixture test +# --------------------------------------------------------------------------- + +class TestGoldenFixture(unittest.TestCase): + """Load the committed v1_bundle fixture and verify known values.""" + + @classmethod + def setUpClass(cls): + cls.fixture_dir = os.path.join( + os.path.dirname(__file__), 'fixtures', 'v1_bundle' + ) + if not os.path.exists(os.path.join(cls.fixture_dir, 'manifest.json')): + # Auto-generate if missing (first run) + from graphistry.tests.fixtures.generate_v1_bundle import main + main() + + def test_load_golden_fixture(self): + g, rr = from_file(self.fixture_dir) + self.assertTrue(rr.integrity_ok) + + # Check bindings + self.assertEqual(g._source, 's') + self.assertEqual(g._destination, 'd') + self.assertEqual(g._node, 'id') + self.assertEqual(g._edge_weight, 'w') + + # Check metadata + self.assertEqual(g._name, 'Golden Test Graph') + self.assertEqual(g._description, 'A test graph for v1 bundle compatibility') + + # Check settings + self.assertEqual(g._height, 600) + self.assertEqual(g._url_params, {'info': 'true', 'play': '2000'}) + + # Check edges data + self.assertEqual(len(g._edges), 3) + self.assertListEqual(list(g._edges['s']), ['a', 'b', 'c']) + + # Check nodes data + self.assertEqual(len(g._nodes), 3) + self.assertListEqual(list(g._nodes['id']), ['a', 'b', 'c']) + + # Check xy (tier 2) + self.assertIsNotNone(g._xy) + self.assertEqual(len(g._xy), 3) + + # Check algorithm config + self.assertEqual(g._n_components, 2) + self.assertEqual(g._metric, 'euclidean') + + def test_golden_remote_skipped_by_default(self): + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + g, rr = from_file(self.fixture_dir) + self.assertTrue(rr.remote_state_skipped) + self.assertIsNone(g._dataset_id) + + def test_golden_remote_restored(self): + g, rr = from_file(self.fixture_dir, restore_remote=True) + self.assertEqual(g._dataset_id, 'golden_dataset_123') + + def test_schema_version(self): + from graphistry.io.bundle import read_manifest + manifest = read_manifest(self.fixture_dir) + self.assertEqual(manifest['schema_version'], '1.0') + + +if __name__ == '__main__': + unittest.main() diff --git a/setup.py b/setup.py index e8d4a74089..15c8c002ad 100755 --- a/setup.py +++ b/setup.py @@ -55,7 +55,8 @@ def unique_flatten_dict(d): 'nodexl': ['openpyxl>=3.1.5', 'xlrd'], 'jupyter': ['ipython'], 'spanner': ['google-cloud-spanner'], - 'kusto': ['azure-kusto-data', 'azure-identity'] + 'kusto': ['azure-kusto-data', 'azure-identity'], + 'serialization': ['pydantic>=2.0'], } base_extras_heavy = {