diff --git a/README.md b/README.md index 6b04fcb..7c098b9 100644 --- a/README.md +++ b/README.md @@ -844,7 +844,7 @@ Overlay arbitrary 3D vectors as arrows on the rendered image via a JSON file. Us ```bash -xyzrender ethanol.xyz --vectors ethanol_dip.json -o ethanol_dip.svg +xyzrender ethanol.xyz --vector ethanol_dip.json -o ethanol_dip.svg ``` Each entry in the JSON array defines one arrow: @@ -1153,7 +1153,7 @@ Available rotation axes: `x`, `y`, `z`, `xy`, `xz`, `yz`, `yx`, `zx`, `zy`. Pref | `--label-size PT` | Label font size (overrides preset) | | `--cmap FILE` | Per-atom property colormap (Viridis, 1-indexed) | | `--cmap-range VMIN VMAX` | Explicit colormap range (default: auto from file) | -| `--vectors FILE` | JSON file of vector arrows to overlay (see Vector arrows section) | +| `--vector FILE` | JSON file of vector arrows to overlay (see Vector arrows section) | | `--vector-scale FACTOR` | Global length scale for all vector arrows (default: 1.0) | | **Crystal** | | | `--crystal [{vasp,qe}]` | Load as crystal via phonopy; format auto-detected or specify explicitly | diff --git a/docs/source/cli_reference.md b/docs/source/cli_reference.md index 12e2345..05e42a1 100644 --- a/docs/source/cli_reference.md +++ b/docs/source/cli_reference.md @@ -102,7 +102,7 @@ Full flag reference for `xyzrender`. See also `xyzrender --help`. | Flag | Description | |------|-------------| -| `--vectors FILE` | Path to a JSON file defining 3D vector arrows for overlay | +| `--vector FILE` | Path to a JSON file defining 3D vector arrows for overlay | | `--vector-scale` | Global length multiplier for all vector arrows | ## GIF animations diff --git a/docs/source/examples/annotations.md b/docs/source/examples/annotations.md index f030497..c4edbb0 100644 --- a/docs/source/examples/annotations.md +++ b/docs/source/examples/annotations.md @@ -92,7 +92,7 @@ Overlay arbitrary 3D vectors as arrows on the rendered image via a JSON file. Us ```bash -xyzrender ethanol.xyz --vectors ethanol_dip.json -o ethanol_dip.svg +xyzrender ethanol.xyz --vector ethanol_dip.json -o ethanol_dip.svg ``` Each entry in the JSON array defines one arrow: diff --git a/examples/examples.ipynb b/examples/examples.ipynb index e1f73c7..0864cc9 100644 --- a/examples/examples.ipynb +++ b/examples/examples.ipynb @@ -1360,122 +1360,19 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "c2cb9dfe", "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " μ\n", - " \n", - " \n", - " \n", - " \n", - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from xyzrender import load, render\n", - "\n", - "etoh = load(\"structures/ethanol.xyz\")\n", - "vec = {\n", - " \"anchor\": \"center\",\n", - " \"vectors\": [\n", - " {\n", - " \"origin\": \"com\",\n", - " \"vector\": [1.0320170291976951, -0.042708195030485986, -1.332397645862797],\n", - " \"color\": \"firebrick\",\n", - " \"label\": \"μ\",\n", - " }\n", - " ],\n", - "}\n", - "render(etoh, vectors=vec, vector_scale=1.5)" - ] + "outputs": [], + "source": "from xyzrender import load, render\n\netoh = load(\"structures/ethanol.xyz\")\nvec = {\n \"anchor\": \"center\",\n \"vectors\": [\n {\n \"origin\": \"com\",\n \"vector\": [1.0320170291976951, -0.042708195030485986, -1.332397645862797],\n \"color\": \"firebrick\",\n \"label\": \"μ\",\n }\n ],\n}\nrender(etoh, vector=vec, vector_scale=1.5)" }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "4fa885ba", "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " E\n", - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "render(etoh, hy=True, vectors=\"structures/ethanol_forces_efield.json\", vector_scale=2)" - ] + "outputs": [], + "source": "render(etoh, hy=True, vector=\"structures/ethanol_forces_efield.json\", vector_scale=2)" }, { "cell_type": "markdown", @@ -3681,4 +3578,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/examples/generate.sh b/examples/generate.sh index 6f398a5..ea4e014 100644 --- a/examples/generate.sh +++ b/examples/generate.sh @@ -75,8 +75,8 @@ xyzrender "$DIR/bimp.out" -o "$IMG/bimp_ts_nci.svg" --ts --gif-trj --vdw 84-169 xyzrender "$DIR/bimp.out" -o "$IMG/bimp_ts_nci.svg" --gif-ts --gif-rot --vdw 84-169 --nci -go "$IMG/bimp_nci_ts.gif" echo "=== Vector arrows ===" -xyzrender "$DIR/ethanol.xyz" --vectors "$DIR/ethanol_dip.json" -o "$IMG/ethanol_dip.svg" --gif-rot -go "$IMG/ethanol_dip.gif" # dipole at center of mass, with rotation -xyzrender "$DIR/ethanol.xyz" --hy --vectors "$DIR/ethanol_forces_efield.json" --vector-scale 1.5 -o "$IMG/ethanol_forces_efield.svg" -go "$IMG/ethanol_forces_efield.gif" --gif-rot # per-atom forces, with rotation +xyzrender "$DIR/ethanol.xyz" --vector "$DIR/ethanol_dip.json" -o "$IMG/ethanol_dip.svg" --gif-rot -go "$IMG/ethanol_dip.gif" # dipole at center of mass, with rotation +xyzrender "$DIR/ethanol.xyz" --hy --vector "$DIR/ethanol_forces_efield.json" --vector-scale 1.5 -o "$IMG/ethanol_forces_efield.svg" -go "$IMG/ethanol_forces_efield.gif" --gif-rot # per-atom forces, with rotation echo "=== Crystal / unit cell ===" xyzrender "$DIR/caffeine_cell.xyz" --cell -o "$IMG/caffeine_cell.svg" --no-orient --gif-rot -go "$IMG/caffeine_cell.gif" diff --git a/src/xyzrender/api.py b/src/xyzrender/api.py index 643aee8..092d8ce 100644 --- a/src/xyzrender/api.py +++ b/src/xyzrender/api.py @@ -419,7 +419,7 @@ def render( labels: list[str] | None = None, label_file: str | None = None, # --- Vector arrows --- - vectors: str | Path | dict | list[VectorArrow] | None = None, + vector: str | Path | dict | list[VectorArrow] | None = None, vector_scale: float | None = None, vector_color: str | None = None, # --- Surface opacity --- @@ -548,6 +548,7 @@ def render( if not isinstance(config, str): # Pre-built RenderConfig — shallow copy so we don't mutate the caller's object cfg = copy.copy(config) + cfg.vectors = list(cfg.vectors) if _orient is not None: cfg.auto_orient = _orient elif mol.oriented: @@ -665,13 +666,23 @@ def render( oriented=mol.oriented, ) + # --- Vectors (user-supplied + crystal axes) --- + _combine_vector_sources( + cfg, + rmol.graph, + vector=vector, + vector_scale=vector_scale, + vector_color=vector_color, + cell_data=rmol.cell_data, + axes=axes, + ) + # --- Cell / crystal config --- if rmol.cell_data is not None: _apply_cell_config( rmol, cfg, no_cell=no_cell, - axes=axes, axis=axis, ghosts=ghosts, cell_color=cell_color, @@ -691,22 +702,6 @@ def render( inline = [s.split() for s in labels] if labels else None cfg.annotations = parse_annotations(inline_specs=inline, file_path=label_file, graph=rmol.graph) - # --- Vector arrows --- - if vector_scale is not None: - cfg.vector_scale = vector_scale - if vector_color is not None: - from xyzrender.types import resolve_color - - cfg.vector_color = resolve_color(vector_color) - if vectors is not None: - if isinstance(vectors, list): - cfg.vectors = vectors - else: - from xyzrender.annotations import load_vectors - - _vec_src = vectors if isinstance(vectors, dict) else Path(vectors) - cfg.vectors = load_vectors(_vec_src, rmol.graph, default_color=cfg.vector_color) - # --- Early overlay validation (before ghost atoms are added to g1) --- if overlay is not None and mol.cell_data is not None: msg = "overlay= is mutually exclusive with crystal/cell display" @@ -898,7 +893,7 @@ def render_gif( # --- NCI detection (gif_ts / gif_trj / gif_rot) --- detect_nci: bool = False, # --- Vector arrows (gif_rot only) --- - vectors: str | Path | dict | list[VectorArrow] | None = None, + vector: str | Path | dict | list[VectorArrow] | None = None, vector_scale: float | None = None, vector_color: str | None = None, # --- Surfaces (gif_rot only) --- @@ -1013,6 +1008,7 @@ def render_gif( # Resolve config if not isinstance(config, str): cfg = copy.copy(config) + cfg.vectors = list(cfg.vectors) # Apply hull overrides to pre-built config if hull is not None: if hull == "rings": @@ -1193,21 +1189,17 @@ def render_gif( aligned2 = align(ref_graph, g2) ref_graph = merge_graphs(ref_graph, g2, aligned2, overlay_color=cfg.overlay_color) - # --- Vector arrows (gif_rot only; needs ref_graph for COM) --- - if vector_scale is not None: - cfg.vector_scale = vector_scale - if vector_color is not None: - from xyzrender.types import resolve_color - - cfg.vector_color = resolve_color(vector_color) - if vectors is not None: - if isinstance(vectors, list): - cfg.vectors = vectors - else: - from xyzrender.annotations import load_vectors - - _vec_src = vectors if isinstance(vectors, dict) else Path(vectors) - cfg.vectors = load_vectors(_vec_src, ref_graph, default_color=cfg.vector_color) + # --- Vectors (user-supplied + crystal axes; gif_rot only) --- + _cell_data_for_vecs = molecule.cell_data if isinstance(molecule, Molecule) else None + _combine_vector_sources( + cfg, + ref_graph, + vector=vector, + vector_scale=vector_scale, + vector_color=vector_color, + cell_data=_cell_data_for_vecs, + axes=axes, + ) cube_data = molecule.cube_data if isinstance(molecule, Molecule) else None @@ -1223,7 +1215,6 @@ def render_gif( _gif_mol, cfg, no_cell=no_cell, - axes=axes, axis=axis, ghosts=ghosts, cell_color=cell_color, @@ -1322,12 +1313,65 @@ def _resolve_cmap( return load_cmap(str(cmap), graph) +def _combine_vector_sources( + cfg: "RenderConfig", + graph: "nx.Graph", + *, + vector=None, + vector_scale: "float | None" = None, + vector_color: "str | None" = None, + cell_data: "CellData | None" = None, + axes: bool = True, +) -> None: + """Populate ``cfg.vectors`` from user-supplied vectors and crystal axis arrows. + + Must be called *before* :func:`_apply_cell_config` so that all vectors are + already in ``cfg.vectors`` when :func:`orient_hkl_to_view` applies the HKL + rotation to the whole list in one pass. + """ + if vector_scale is not None: + cfg.vector_scale = vector_scale + if vector_color is not None: + from xyzrender.types import resolve_color + + cfg.vector_color = resolve_color(vector_color) + if vector is not None: + if not isinstance(vector, list): + from xyzrender.annotations import load_vectors + + _vec_src = vector if isinstance(vector, dict) else Path(vector) + vector = load_vectors(_vec_src, graph, default_color=cfg.vector_color) + cfg.vectors.extend(vector) + if cell_data is not None and axes: + from xyzrender.types import VectorArrow + + lat = cell_data.lattice + orig3d = cell_data.cell_origin + for vec, color, label in zip(lat, cfg.axis_colors, ("a", "b", "c"), strict=True): + length = float(np.linalg.norm(vec)) + if length < 1e-6: + continue + frac = min(0.25, 2.0 / length) + cfg.vectors.append( + VectorArrow( + vector=vec * frac, + origin=orig3d, + color=color, + label=label, + scale=1.0, + draw_on_top=True, + is_axis=True, + font_size=cfg.label_font_size * 1.8, + width=cfg.bond_width * 1.1, + ) + ) + + def _apply_cell_config( mol: Molecule, cfg: RenderConfig, *, no_cell: bool, - axes: bool, axis: str | None, ghosts: bool | None, cell_color: str | None, @@ -1357,7 +1401,7 @@ def _apply_cell_config( if axis is not None: from xyzrender.viewer import orient_hkl_to_view - orient_hkl_to_view(mol.graph, cell_data, axis) + orient_hkl_to_view(mol.graph, cell_data, axis, cfg) cfg.auto_orient = False # Ghost (periodic image) atoms — default: on when cell_data is present @@ -1367,31 +1411,6 @@ def _apply_cell_config( add_crystal_images(mol.graph, cell_data) - # Crystal axes a/b/c as annotation vectors at the cell origin - if axes: - from xyzrender.types import VectorArrow - - lat = cell_data.lattice - orig3d = cell_data.cell_origin - for vec, color, label in zip(lat, cfg.axis_colors, ("a", "b", "c"), strict=True): - length = float(np.linalg.norm(vec)) - if length < 1e-6: - continue - # Arrow spans 25% of the cell edge (max 2 Å) from the origin corner - frac = min(0.25, 2.0 / length) - cfg.vectors.append( - VectorArrow( - vector=vec * frac, - origin=orig3d, - color=color, - label=label, - scale=1.0, - draw_on_top=True, - font_size=cfg.label_font_size * 1.8, - width=cfg.bond_width * 1.1, - ) - ) - # Default no-bo for periodic structures (bond orders are not PBC-aware) if bo_explicit is None: cfg.bond_orders = False diff --git a/src/xyzrender/cli.py b/src/xyzrender/cli.py index 72e63c7..9a43548 100644 --- a/src/xyzrender/cli.py +++ b/src/xyzrender/cli.py @@ -295,7 +295,7 @@ def main() -> None: help="Explicit colormap range (default: auto from file values)", ) annot_g.add_argument( - "--vectors", + "--vector", default=None, metavar="FILE", help=( @@ -601,7 +601,7 @@ def main() -> None: nci_coloring=args.nci_coloring, overlay=args.overlay, overlay_color=args.overlay_color, - vectors=args.vectors, + vector=args.vector, vector_scale=args.vector_scale, output=args.output, ) @@ -656,7 +656,7 @@ def main() -> None: cell_color=args.cell_color, cell_width=args.cell_width, ghost_opacity=args.ghost_opacity, - vectors=args.vectors, + vector=args.vector, vector_scale=args.vector_scale, ) except ValueError as e: diff --git a/src/xyzrender/gif.py b/src/xyzrender/gif.py index 3ca57bc..2a899f8 100644 --- a/src/xyzrender/gif.py +++ b/src/xyzrender/gif.py @@ -353,7 +353,7 @@ def render_rotation_gif( for i, nid in enumerate(nodes): graph.nodes[nid]["position"] = tuple(_tilted[i].tolist()) - # Set up the cell origin (0,0,0) so _apply_rot_to_lattice rotates it around + # Set up the cell origin (0,0,0) so it rotates it around # the molecular centre of mass, not around the origin. # _orient_graph already does this for the PCA path; this covers --no-orient. if "lattice" in graph.graph and "lattice_origin" not in graph.graph: diff --git a/src/xyzrender/renderer.py b/src/xyzrender/renderer.py index bde109e..9323453 100644 --- a/src/xyzrender/renderer.py +++ b/src/xyzrender/renderer.py @@ -129,7 +129,8 @@ def render_svg(graph, config: RenderConfig | None = None, *, _log: bool = True) _ref_px_per_ang = (_REF_CANVAS - 2 * cfg.padding) / _REF_SPAN _vec_tips = [] for vi, va in enumerate(cfg.vectors): - scaled_vec = _vec_dirs[vi] * va.scale * cfg.vector_scale + _vec_scale = 1.0 if va.is_axis else cfg.vector_scale + scaled_vec = _vec_dirs[vi] * va.scale * _vec_scale tail3d = _vec_origins[vi] - scaled_vec / 2 if va.anchor == "center" else _vec_origins[vi] tip3d = tail3d + scaled_vec _vec_tips.append(tip3d) @@ -463,7 +464,8 @@ def render_svg(graph, config: RenderConfig | None = None, *, _log: bool = True) if cfg.vectors: for vi in range(len(cfg.vectors)): va = cfg.vectors[vi] - scaled_vec = _vec_dirs[vi] * va.scale * cfg.vector_scale + _global = 1.0 if va.is_axis else cfg.vector_scale + scaled_vec = _vec_dirs[vi] * va.scale * _global if va.anchor == "center": tail3d = _vec_origins[vi] - scaled_vec / 2 else: diff --git a/src/xyzrender/types.py b/src/xyzrender/types.py index 2b48d49..7215c8b 100644 --- a/src/xyzrender/types.py +++ b/src/xyzrender/types.py @@ -226,6 +226,7 @@ class VectorArrow: anchor: str = "tail" # "tail" (origin = arrow tail) or "center" (origin = arrow midpoint) host_atom: int | None = None # 0-based atom index, or None for com/explicit origins draw_on_top: bool = False + is_axis: bool = False # True for crystallographic axis arrows (not affected by vector_scale) font_size: float | None = None width: float | None = None @@ -437,7 +438,7 @@ class RenderConfig: "royalblue", ) # firebrick, forestgreen, royalblue axis_width_scale: float = 3.0 # multiplier on cell_line_width for axis stroke width - # Arbitrary vector arrows (--vectors) + # Arbitrary vector arrows (--vector) vectors: list[VectorArrow] = field(default_factory=list) vector_scale: float = 1.0 # global length multiplier applied to all vectors vector_color: str = "firebrick" # default arrow color (firebrick) when not specified per-arrow diff --git a/src/xyzrender/utils.py b/src/xyzrender/utils.py index afe241f..ce3b21a 100644 --- a/src/xyzrender/utils.py +++ b/src/xyzrender/utils.py @@ -174,9 +174,9 @@ def resolve_orientation( # Co-rotate crystal lattice and cell origin by the same matrix if cfg.cell_data is not None: - lat = cfg.cell_data.lattice - cfg.cell_data.lattice = (rot @ lat.T).T - cfg.cell_data.cell_origin = rot @ (cfg.cell_data.cell_origin - centroid_before) + curr_centroid + cfg.cell_data.lattice, cfg.cell_data.cell_origin = _apply_rot_to_vecs( + rot, cfg.cell_data.lattice, cfg.cell_data.cell_origin, centroid_before + ) cfg.auto_orient = False # already applied; renderer must not re-apply @@ -200,6 +200,20 @@ def resolve_orientation( return rot, atom_centroid, curr_centroid +def _apply_rot_to_vecs( + rot: np.ndarray, + directions: np.ndarray, + origins: np.ndarray, + centroid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Rotate direction vectors and translate origins around *centroid* by *rot*. + + Works for shape ``(3,)`` (single vector) or ``(N, 3)`` (row-vectors). + Returns ``(rotated_directions, rotated_origins)``. + """ + return (rot @ directions.T).T, (rot @ (origins - centroid).T).T + centroid + + def apply_axis_angle_rotation(graph: nx.Graph, axis: np.ndarray, angle: float) -> None: """Rotate all atom positions in-place around an arbitrary axis (degrees). @@ -215,8 +229,6 @@ def apply_axis_angle_rotation(graph: nx.Graph, axis: np.ndarray, angle: float) - angle: Rotation angle in degrees. """ - from xyzrender.viewer import _apply_rot_to_lattice - nodes = list(graph.nodes()) theta = np.radians(angle) k = axis / np.linalg.norm(axis) @@ -229,7 +241,11 @@ def apply_axis_angle_rotation(graph: nx.Graph, axis: np.ndarray, angle: float) - rotated = (rot @ (positions - centroid).T).T + centroid for i, nid in enumerate(nodes): graph.nodes[nid]["position"] = tuple(rotated[i].tolist()) - _apply_rot_to_lattice(graph, rot, centroid) + if "lattice" in graph.graph: + origin = np.asarray(graph.graph.get("lattice_origin", np.zeros(3)), dtype=float) + graph.graph["lattice"], graph.graph["lattice_origin"] = _apply_rot_to_vecs( + rot, graph.graph["lattice"], origin, centroid + ) def kabsch_rotation(original: np.ndarray, target: np.ndarray) -> np.ndarray: diff --git a/src/xyzrender/viewer.py b/src/xyzrender/viewer.py index 4f5592a..be6ba8c 100644 --- a/src/xyzrender/viewer.py +++ b/src/xyzrender/viewer.py @@ -15,6 +15,7 @@ import networkx as nx from vmol import Vmol + from xyzrender.config import RenderConfig from xyzrender.types import CellData _Atoms: TypeAlias = list[tuple[str, tuple[float, float, float]]] @@ -90,40 +91,7 @@ def rotate_with_viewer( return rot, c1, c2 -def apply_rotation(graph: nx.Graph, rx: float, ry: float, rz: float) -> None: - """Rotate all atom positions in-place by Euler angles (degrees). - - Rotation is around the molecular centroid so the molecule stays centered. - - Parameters - ---------- - graph: - Molecular graph whose node positions are updated in-place. - rx, ry, rz: - Rotation angles around x, y, z axes in degrees. - """ - nodes = list(graph.nodes()) - rx, ry, rz = np.radians(rx), np.radians(ry), np.radians(rz) - cx, sx = np.cos(rx), np.sin(rx) - cy, sy = np.cos(ry), np.sin(ry) - cz, sz = np.cos(rz), np.sin(rz) - # Rz @ Ry @ Rx - rot = np.array( - [ - [cy * cz, sx * sy * cz - cx * sz, cx * sy * cz + sx * sz], - [cy * sz, sx * sy * sz + cx * cz, cx * sy * sz - sx * cz], - [-sy, sx * cy, cx * cy], - ] - ) - positions = np.array([graph.nodes[n]["position"] for n in nodes]) - centroid = positions.mean(axis=0) - rotated = (rot @ (positions - centroid).T).T + centroid - for i, nid in enumerate(nodes): - graph.nodes[nid]["position"] = tuple(rotated[i].tolist()) - _apply_rot_to_lattice(graph, rot, centroid) - - -def orient_hkl_to_view(graph: nx.Graph, cell_data: "CellData", axis_str: str) -> None: +def orient_hkl_to_view(graph: nx.Graph, cell_data: "CellData", axis_str: str, cfg: "RenderConfig") -> None: """Rotate *graph* and *cell_data* so that the [hkl] direction points along +z. Parameters @@ -134,6 +102,8 @@ def orient_hkl_to_view(graph: nx.Graph, cell_data: "CellData", axis_str: str) -> Crystal cell data whose lattice and origin are updated in-place. axis_str: 3-digit Miller index string, optionally prefixed with ``-`` (e.g. ``'111'``, ``'-110'``). + cfg: + Render configuration object. Raises ------ @@ -169,32 +139,14 @@ def orient_hkl_to_view(graph: nx.Graph, cell_data: "CellData", axis_str: str) -> pos_rot = (rot_view @ (pos - centroid).T).T + centroid for idx, nid in enumerate(node_ids): graph.nodes[nid]["position"] = tuple(pos_rot[idx].tolist()) - cell_data.lattice = (rot_view @ cell_data.lattice.T).T - cell_data.cell_origin = rot_view @ (cell_data.cell_origin - centroid) + centroid - - -def _apply_rot_to_lattice(graph: nx.Graph, rot: np.ndarray, centroid: np.ndarray) -> None: - """Rotate the lattice vectors and cell origin stored on *graph* by *rot*. + from xyzrender.utils import _apply_rot_to_vecs - Both the lattice vectors and the cell origin are always updated so that - the cell box stays aligned with the atoms after any rotation. The origin - defaults to (0, 0, 0) when not explicitly present in the graph. - - Parameters - ---------- - graph: - Molecular graph (lattice stored in ``graph.graph``). - rot: - 3x3 rotation matrix. - centroid: - Centroid position to rotate around. - """ - if "lattice" not in graph.graph: - return - lat = np.array(graph.graph["lattice"], dtype=float) - graph.graph["lattice"] = (rot @ lat.T).T - origin = np.array(graph.graph.get("lattice_origin", np.zeros(3)), dtype=float) - graph.graph["lattice_origin"] = rot @ (origin - centroid) + centroid + cell_data.lattice, cell_data.cell_origin = _apply_rot_to_vecs( + rot_view, cell_data.lattice, cell_data.cell_origin, centroid + ) + if hasattr(cfg, "vectors"): + for vec in cfg.vectors: + vec.vector, vec.origin = _apply_rot_to_vecs(rot_view, vec.vector, vec.origin, centroid) def _run_viewer(viewer: Vmol, mol: dict, extra_args: list[str] | None = None) -> str: diff --git a/tests/test_crystal.py b/tests/test_crystal.py index fa19d0b..ed21421 100644 --- a/tests/test_crystal.py +++ b/tests/test_crystal.py @@ -201,133 +201,97 @@ def test_extxyz_cell_box_renders(extxyz_graph): assert len(cell_lines) == 12 -def test_cell_corotates_with_atoms(extxyz_graph): +def test_orient_hkl_cell_corotates_with_atoms(vasp_crystal): + """orient_hkl_to_view keeps lattice and atom positions mutually consistent.""" import copy - from xyzrender.types import CellData, RenderConfig - from xyzrender.viewer import apply_rotation + from xyzrender.renderer import render_svg + from xyzrender.types import RenderConfig + from xyzrender.viewer import orient_hkl_to_view - graph = copy.deepcopy(extxyz_graph) - lat_before = np.array(graph.graph["lattice"], dtype=float) - cell_data = CellData( - lattice=lat_before.copy(), cell_origin=np.array(graph.graph.get("lattice_origin", [0.0, 0.0, 0.0]), dtype=float) - ) + graph, cell_data = copy.deepcopy(vasp_crystal) + lat_before = cell_data.lattice.copy() cfg = RenderConfig(cell_data=cell_data, show_cell=True) - assert cfg.cell_data is not None - apply_rotation(graph, rx=30.0, ry=45.0, rz=15.0) - lat_graph = np.array(graph.graph["lattice"], dtype=float) - assert not np.allclose(lat_graph, lat_before, atol=1e-6) - assert np.allclose(cfg.cell_data.lattice, lat_before, atol=1e-6) - cfg.cell_data.lattice = lat_graph.copy() - cfg.cell_data.cell_origin = np.array(graph.graph.get("lattice_origin", [0.0, 0.0, 0.0]), dtype=float) - np.testing.assert_allclose(cfg.cell_data.lattice, lat_graph, atol=1e-9) - node_ids = list(graph.nodes()) - positions = np.array([graph.nodes[i]["position"] for i in node_ids], dtype=float) - centroid = positions.mean(axis=0) - cell_norm = np.linalg.norm(cfg.cell_data.lattice, axis=1).max() - origin_dist = np.linalg.norm(cfg.cell_data.cell_origin - centroid) - assert origin_dist < 2 * cell_norm - from xyzrender.renderer import render_svg + orient_hkl_to_view(graph, cell_data, "100", cfg) + # Lattice must have been rotated + assert not np.allclose(cell_data.lattice, lat_before, atol=1e-6) + # Render must still produce exactly 12 cell edges svg = render_svg(graph, cfg) - q = chr(34) - cell_lines = [ln for ln in svg.splitlines() if "class=" + q + "cell-edge" + q in ln] + cell_lines = [ln for ln in svg.splitlines() if 'class="cell-edge"' in ln] assert len(cell_lines) == 12 -def test_apply_rotation_sets_lattice_origin_when_absent(extxyz_graph): - """apply_rotation must write graph.graph['lattice_origin'] even when the - file had no explicit origin — the origin is (0,0,0) implicitly but must - be updated to the rotated value so the cell box stays aligned.""" +def test_orient_hkl_cell_origin_updated_from_zero(vasp_crystal): + """orient_hkl_to_view updates cell_origin from its zero default value.""" import copy - from xyzrender.viewer import apply_rotation + from xyzrender.types import RenderConfig + from xyzrender.viewer import orient_hkl_to_view - graph = copy.deepcopy(extxyz_graph) - # Confirm the fixture has no explicit origin (caffeine_cell.xyz has none) - assert "lattice_origin" not in graph.graph + graph, cell_data = copy.deepcopy(vasp_crystal) + assert np.allclose(cell_data.cell_origin, np.zeros(3)) - apply_rotation(graph, rx=45.0, ry=0.0, rz=0.0) + cfg = RenderConfig() + orient_hkl_to_view(graph, cell_data, "100", cfg) - # After rotation the key must exist and must NOT be zeros - assert "lattice_origin" in graph.graph, "apply_rotation must write lattice_origin" - origin = np.array(graph.graph["lattice_origin"], dtype=float) - assert not np.allclose(origin, np.zeros(3), atol=1e-6), ( - f"lattice_origin should be non-zero after a 45° rotation (got {origin})" + # cell_origin is rotated around the atom centroid; since the centroid is + # not at the origin the rotated zero-origin is non-zero + assert not np.allclose(cell_data.cell_origin, np.zeros(3), atol=1e-6), ( + f"cell_origin should be non-zero after HKL rotation (got {cell_data.cell_origin})" ) -def test_fractional_coords_preserved_after_rotation(extxyz_graph): - """Fractional coordinates of all atoms must be unchanged after a consistent - rotation of both atoms and cell (lattice + origin).""" +def test_orient_hkl_fractional_coords_preserved(vasp_crystal): + """orient_hkl_to_view preserves the fractional coordinates of all atoms.""" import copy - from xyzrender.viewer import apply_rotation + from xyzrender.types import RenderConfig + from xyzrender.viewer import orient_hkl_to_view - graph = copy.deepcopy(extxyz_graph) - lat0 = np.array(graph.graph["lattice"], dtype=float) - orig0 = np.array(graph.graph.get("lattice_origin", np.zeros(3)), dtype=float) + graph, cell_data = copy.deepcopy(vasp_crystal) node_ids = list(graph.nodes()) pos0 = np.array([graph.nodes[i]["position"] for i in node_ids], dtype=float) - # Fractional coords before: solve lat.T @ f = (pos - origin) for each atom - frac_before = np.linalg.solve(lat0.T, (pos0 - orig0).T).T # shape (n, 3) + frac_before = np.linalg.solve(cell_data.lattice.T, (pos0 - cell_data.cell_origin).T).T - apply_rotation(graph, rx=30.0, ry=20.0, rz=50.0) + cfg = RenderConfig() + orient_hkl_to_view(graph, cell_data, "111", cfg) - lat1 = np.array(graph.graph["lattice"], dtype=float) - orig1 = np.array(graph.graph["lattice_origin"], dtype=float) pos1 = np.array([graph.nodes[i]["position"] for i in node_ids], dtype=float) - frac_after = np.linalg.solve(lat1.T, (pos1 - orig1).T).T + frac_after = np.linalg.solve(cell_data.lattice.T, (pos1 - cell_data.cell_origin).T).T np.testing.assert_allclose( frac_after, frac_before, atol=1e-9, - err_msg="Fractional coordinates must be preserved after rotation", + err_msg="Fractional coordinates must be preserved after orient_hkl_to_view", ) -def test_cell_corotates_with_ghost_atoms(extxyz_graph): - """Full --cell -I pipeline: add ghost atoms, rotate, re-sync cell_data, - render — cell box must still have exactly 12 edges and be consistent.""" +def test_orient_hkl_cell_corotates_with_ghost_atoms(vasp_crystal): + """orient_hkl_to_view rotates all ghost atoms and keeps 12 cell edges.""" import copy from xyzrender.crystal import add_crystal_images from xyzrender.renderer import render_svg - from xyzrender.types import CellData, RenderConfig - from xyzrender.viewer import apply_rotation + from xyzrender.types import RenderConfig + from xyzrender.viewer import orient_hkl_to_view - graph = copy.deepcopy(extxyz_graph) - lat0 = np.array(graph.graph["lattice"], dtype=float) - cell_data = CellData(lattice=lat0.copy()) + graph, cell_data = copy.deepcopy(vasp_crystal) + n_real = graph.number_of_nodes() add_crystal_images(graph, cell_data) - n_after_ghosts = graph.number_of_nodes() - assert n_after_ghosts > extxyz_graph.number_of_nodes(), "Ghost atoms must have been added" - - # Simulate what rotate_with_viewer / --cell -I does: rotate (all nodes - # including ghosts), then re-sync cell_data from updated graph.graph. - apply_rotation(graph, rx=0.0, ry=90.0, rz=0.0) + assert graph.number_of_nodes() > n_real, "Ghost atoms must have been added" - cell_data.lattice = np.array(graph.graph["lattice"], dtype=float) - cell_data.cell_origin = np.array(graph.graph["lattice_origin"], dtype=float) + cfg = RenderConfig(cell_data=cell_data, show_cell=True, auto_orient=False) + orient_hkl_to_view(graph, cell_data, "110", cfg) - cfg = RenderConfig( - cell_data=cell_data, - show_cell=True, - auto_orient=False, - ) svg = render_svg(graph, cfg) cell_lines = [ln for ln in svg.splitlines() if 'class="cell-edge"' in ln] assert len(cell_lines) == 12 - # COM of real atoms must be within a reasonable distance of the cell box - real_ids = [i for i in graph.nodes() if not graph.nodes[i].get("image", False)] - real_pos = np.array([graph.nodes[i]["position"] for i in real_ids], dtype=float) + # COM of real atoms must be inside the cell (fractional coords in [0, 1]) + real_pos = np.array([graph.nodes[i]["position"] for i in range(n_real)], dtype=float) com = real_pos.mean(axis=0) - lat = cell_data.lattice - orig = cell_data.cell_origin - frac = np.linalg.solve(lat.T, com - orig) - # All fractional coordinates of the COM should be roughly in [0, 1] - # (within one cell width, allowing for atoms at the boundary) + frac = np.linalg.solve(cell_data.lattice.T, com - cell_data.cell_origin) assert np.all(frac > -0.5), f"COM fractional coords {frac} are far outside the cell after rotation" assert np.all(frac < 1.5), f"COM fractional coords {frac} are far outside the cell after rotation" diff --git a/tests/test_vectors.py b/tests/test_vectors.py index 8ac027d..88657e3 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -171,7 +171,6 @@ def test_render_with_vector_arrows(): svg = render_svg(graph, cfg) assert "#cc0000" in svg assert "μ" in svg - assert "]*fill="#ab1234"', svg), "arrowhead polygon with user color must appear in SVG" + + +def test_render_user_vector_appears_with_crystal_axes(): + """render() with cell_data + axes=True must include BOTH axis colors AND the + user-supplied vector color in the SVG output. This is the main regression + guard for the bug where cfg.vectors was overwritten instead of extended.""" + from xyzrender.api import load as api_load + from xyzrender.api import render + + mol = api_load(EXAMPLES / "caffeine_cell.xyz", cell=True) + # Use a color that differs from all three axis colors (firebrick/forestgreen/royalblue) + user_color = "#cd5c5c" # indianred — resolves to a distinct hex from axis colors + jf = _write_json([{"origin": "com", "vector": [2.0, 0.0, 0.0], "color": user_color, "label": "dipole"}]) + svg = str(render(mol, vector=str(jf), axes=True)) + # Axis colors must be present (firebrick → #b22222, forestgreen → #228b22, royalblue → #4169e1) + from xyzrender.types import resolve_color + + assert resolve_color("firebrick") in svg, "a-axis (firebrick) must appear" + assert resolve_color("forestgreen") in svg, "b-axis (forestgreen) must appear" + assert resolve_color("royalblue") in svg, "c-axis (royalblue) must appear" + # User vector must also be present + assert user_color in svg, "user vector color must appear in SVG alongside axis arrows" + assert "dipole" in svg, "user vector label must appear in SVG" + + +def test_render_user_vector_appears_with_crystal_axes_no_double_loading(): + """Calling render() twice with the same pre-built config must NOT accumulate + vectors across calls (shallow-copy list aliasing regression).""" + from xyzrender.api import load as api_load + from xyzrender.api import render + from xyzrender.config import build_config + + mol = api_load(EXAMPLES / "caffeine_cell.xyz", cell=True) + user_color = "#cd5c5c" + jf = _write_json([{"origin": "com", "vector": [2.0, 0.0, 0.0], "color": user_color, "label": "dipole"}]) + + cfg = build_config("default") + svg1 = str(render(mol, config=cfg, vector=str(jf), axes=True)) + svg2 = str(render(mol, config=cfg, vector=str(jf), axes=True)) + + # Count polygon arrowheads with user color in each SVG — should be exactly 1 each + count1 = len(re.findall(rf']*fill="{re.escape(user_color)}"', svg1)) + count2 = len(re.findall(rf']*fill="{re.escape(user_color)}"', svg2)) + assert count1 == 1, f"first render: expected 1 user-vector arrowhead, got {count1}" + assert count2 == 1, f"second render: expected 1 user-vector arrowhead, got {count2}" + + +# --------------------------------------------------------------------------- +# orient_hkl_to_view — user vectors co-rotate with crystal +# --------------------------------------------------------------------------- + + +def _cubic_graph_and_cell(a: float = 4.0): + """Return a minimal 2-atom graph and CellData for a simple cubic cell.""" + import networkx as nx + + from xyzrender.types import CellData + + g = nx.Graph() + g.add_node(0, symbol="C", position=(0.0, 0.0, 0.0)) + g.add_node(1, symbol="C", position=(a, a, a)) + lattice = np.diag([a, a, a]).astype(float) + return g, CellData(lattice=lattice, cell_origin=np.zeros(3)) + + +def test_orient_hkl_vectors_001_no_rotation(): + """Axis [001] on a cubic cell is already +z — vector direction and origin are unchanged.""" + from xyzrender.types import RenderConfig + from xyzrender.viewer import orient_hkl_to_view + + g, cell_data = _cubic_graph_and_cell() + va = VectorArrow(vector=np.array([1.0, 2.0, 3.0]), origin=np.array([1.0, 1.0, 1.0])) + cfg = RenderConfig(vectors=[va]) + + orient_hkl_to_view(g, cell_data, "001", cfg) + + np.testing.assert_allclose(np.array(cfg.vectors[0].vector), [1.0, 2.0, 3.0], atol=1e-9) + np.testing.assert_allclose(np.array(cfg.vectors[0].origin), [1.0, 1.0, 1.0], atol=1e-9) + + +def test_orient_hkl_vectors_100_direction(): + """Axis [100] on a cubic cell: a vector along lattice-a must point along +z after rotation.""" + from xyzrender.types import RenderConfig + from xyzrender.viewer import orient_hkl_to_view + + g, cell_data = _cubic_graph_and_cell() + # The a-axis is [1,0,0]; viewing along [100] must align it with +z. + va = VectorArrow(vector=np.array([1.0, 0.0, 0.0]), origin=np.array([2.0, 2.0, 2.0])) + cfg = RenderConfig(vectors=[va]) + + orient_hkl_to_view(g, cell_data, "100", cfg) + + np.testing.assert_allclose(np.array(cfg.vectors[0].vector), [0.0, 0.0, 1.0], atol=1e-9) + + +def test_orient_hkl_vectors_100_origin(): + """After [100] rotation the vector origin is rotated around the atom centroid.""" + from xyzrender.types import RenderConfig + from xyzrender.viewer import orient_hkl_to_view + + a = 4.0 + g, cell_data = _cubic_graph_and_cell(a) + pos = np.array([g.nodes[i]["position"] for i in g.nodes()]) + centroid = pos.mean(axis=0) # (2, 2, 2) + + # Place the origin displaced by a along the a-axis from the centroid. + origin = centroid + np.array([a, 0.0, 0.0]) + va = VectorArrow(vector=np.array([1.0, 0.0, 0.0]), origin=origin.copy()) + cfg = RenderConfig(vectors=[va]) + + orient_hkl_to_view(g, cell_data, "100", cfg) + + # The rotation maps displacement [a, 0, 0] → [0, 0, a]. + expected_origin = centroid + np.array([0.0, 0.0, a]) + np.testing.assert_allclose(np.array(cfg.vectors[0].origin), expected_origin, atol=1e-9) + + +def test_orient_hkl_vectors_110_direction(): + """Axis [110] on a cubic cell: a vector along the [110] diagonal must point along +z.""" + from xyzrender.types import RenderConfig + from xyzrender.viewer import orient_hkl_to_view + + g, cell_data = _cubic_graph_and_cell() + # The [110] lattice direction is [1, 1, 0] / sqrt(2). + va = VectorArrow( + vector=np.array([1.0, 1.0, 0.0]) / np.sqrt(2), + origin=np.array([2.0, 2.0, 2.0]), # at centroid → origin is unchanged + ) + cfg = RenderConfig(vectors=[va]) + + orient_hkl_to_view(g, cell_data, "110", cfg) + + np.testing.assert_allclose(np.array(cfg.vectors[0].vector), [0.0, 0.0, 1.0], atol=1e-9) + + +# --------------------------------------------------------------------------- +# Integration — user vectors via render() with axis= on a real crystal file +# --------------------------------------------------------------------------- + + +def test_render_crystal_axis_vector_projects_as_dot_when_parallel_to_view(): + """render() with axis='100' on an orthorhombic cell: a user vector aligned + with the a-axis must project as a dot () in the SVG because the + arrow is now pointing directly along the viewing axis (+z).""" + from xyzrender.api import Molecule, render + from xyzrender.types import CellData + + graph, _ = load_molecule(EXAMPLES / "caffeine_cell.xyz") + cell_data = CellData( + lattice=np.array(graph.graph["lattice"], dtype=float), + cell_origin=np.zeros(3), + ) + mol = Molecule(graph=graph, cell_data=cell_data) + + # Unit vector along the a-axis (first lattice row) + a_hat = cell_data.lattice[0] / np.linalg.norm(cell_data.lattice[0]) + positions = np.array([graph.nodes[i]["position"] for i in graph.nodes()]) + centroid = positions.mean(axis=0) + + # Distinct color so we can identify this arrow's elements in the SVG. + user_color = "#7f1234" + va = VectorArrow(vector=a_hat.copy(), origin=centroid.copy(), color=user_color) + + # View along [100]: the user vector is now parallel to the viewing axis → + # projected length ≈ 0 → _draw_arrow_svg renders it as a filled circle, not a polygon. + svg = str(render(mol, vector=[va], axes=False, ghosts=False, axis="100")) + + assert f'fill="{user_color}"' in svg, "user vector must appear in the rotated SVG" + # In the short-projection code path the arrow renders as a circle, not a polygon arrowhead. + assert re.search(rf']*fill="{re.escape(user_color)}"', svg), ( + "user vector aligned with view axis should render as a dot (circle), not a line+polygon" + ) + assert not re.search(rf']*fill="{re.escape(user_color)}"', svg), ( + "user vector aligned with view axis must NOT render as a polygon arrowhead" + )