diff --git a/README.md b/README.md
index 6b04fcb..7c098b9 100644
--- a/README.md
+++ b/README.md
@@ -844,7 +844,7 @@ Overlay arbitrary 3D vectors as arrows on the rendered image via a JSON file. Us
```bash
-xyzrender ethanol.xyz --vectors ethanol_dip.json -o ethanol_dip.svg
+xyzrender ethanol.xyz --vector ethanol_dip.json -o ethanol_dip.svg
```
Each entry in the JSON array defines one arrow:
@@ -1153,7 +1153,7 @@ Available rotation axes: `x`, `y`, `z`, `xy`, `xz`, `yz`, `yx`, `zx`, `zy`. Pref
| `--label-size PT` | Label font size (overrides preset) |
| `--cmap FILE` | Per-atom property colormap (Viridis, 1-indexed) |
| `--cmap-range VMIN VMAX` | Explicit colormap range (default: auto from file) |
-| `--vectors FILE` | JSON file of vector arrows to overlay (see Vector arrows section) |
+| `--vector FILE` | JSON file of vector arrows to overlay (see Vector arrows section) |
| `--vector-scale FACTOR` | Global length scale for all vector arrows (default: 1.0) |
| **Crystal** | |
| `--crystal [{vasp,qe}]` | Load as crystal via phonopy; format auto-detected or specify explicitly |
diff --git a/docs/source/cli_reference.md b/docs/source/cli_reference.md
index 12e2345..05e42a1 100644
--- a/docs/source/cli_reference.md
+++ b/docs/source/cli_reference.md
@@ -102,7 +102,7 @@ Full flag reference for `xyzrender`. See also `xyzrender --help`.
| Flag | Description |
|------|-------------|
-| `--vectors FILE` | Path to a JSON file defining 3D vector arrows for overlay |
+| `--vector FILE` | Path to a JSON file defining 3D vector arrows for overlay |
| `--vector-scale` | Global length multiplier for all vector arrows |
## GIF animations
diff --git a/docs/source/examples/annotations.md b/docs/source/examples/annotations.md
index f030497..c4edbb0 100644
--- a/docs/source/examples/annotations.md
+++ b/docs/source/examples/annotations.md
@@ -92,7 +92,7 @@ Overlay arbitrary 3D vectors as arrows on the rendered image via a JSON file. Us
```bash
-xyzrender ethanol.xyz --vectors ethanol_dip.json -o ethanol_dip.svg
+xyzrender ethanol.xyz --vector ethanol_dip.json -o ethanol_dip.svg
```
Each entry in the JSON array defines one arrow:
diff --git a/examples/examples.ipynb b/examples/examples.ipynb
index e1f73c7..0864cc9 100644
--- a/examples/examples.ipynb
+++ b/examples/examples.ipynb
@@ -1360,122 +1360,19 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"id": "c2cb9dfe",
"metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- ""
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from xyzrender import load, render\n",
- "\n",
- "etoh = load(\"structures/ethanol.xyz\")\n",
- "vec = {\n",
- " \"anchor\": \"center\",\n",
- " \"vectors\": [\n",
- " {\n",
- " \"origin\": \"com\",\n",
- " \"vector\": [1.0320170291976951, -0.042708195030485986, -1.332397645862797],\n",
- " \"color\": \"firebrick\",\n",
- " \"label\": \"μ\",\n",
- " }\n",
- " ],\n",
- "}\n",
- "render(etoh, vectors=vec, vector_scale=1.5)"
- ]
+ "outputs": [],
+ "source": "from xyzrender import load, render\n\netoh = load(\"structures/ethanol.xyz\")\nvec = {\n \"anchor\": \"center\",\n \"vectors\": [\n {\n \"origin\": \"com\",\n \"vector\": [1.0320170291976951, -0.042708195030485986, -1.332397645862797],\n \"color\": \"firebrick\",\n \"label\": \"μ\",\n }\n ],\n}\nrender(etoh, vector=vec, vector_scale=1.5)"
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"id": "4fa885ba",
"metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- ""
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "render(etoh, hy=True, vectors=\"structures/ethanol_forces_efield.json\", vector_scale=2)"
- ]
+ "outputs": [],
+ "source": "render(etoh, hy=True, vector=\"structures/ethanol_forces_efield.json\", vector_scale=2)"
},
{
"cell_type": "markdown",
@@ -3681,4 +3578,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
-}
+}
\ No newline at end of file
diff --git a/examples/generate.sh b/examples/generate.sh
index 6f398a5..ea4e014 100644
--- a/examples/generate.sh
+++ b/examples/generate.sh
@@ -75,8 +75,8 @@ xyzrender "$DIR/bimp.out" -o "$IMG/bimp_ts_nci.svg" --ts --gif-trj --vdw 84-169
xyzrender "$DIR/bimp.out" -o "$IMG/bimp_ts_nci.svg" --gif-ts --gif-rot --vdw 84-169 --nci -go "$IMG/bimp_nci_ts.gif"
echo "=== Vector arrows ==="
-xyzrender "$DIR/ethanol.xyz" --vectors "$DIR/ethanol_dip.json" -o "$IMG/ethanol_dip.svg" --gif-rot -go "$IMG/ethanol_dip.gif" # dipole at center of mass, with rotation
-xyzrender "$DIR/ethanol.xyz" --hy --vectors "$DIR/ethanol_forces_efield.json" --vector-scale 1.5 -o "$IMG/ethanol_forces_efield.svg" -go "$IMG/ethanol_forces_efield.gif" --gif-rot # per-atom forces, with rotation
+xyzrender "$DIR/ethanol.xyz" --vector "$DIR/ethanol_dip.json" -o "$IMG/ethanol_dip.svg" --gif-rot -go "$IMG/ethanol_dip.gif" # dipole at center of mass, with rotation
+xyzrender "$DIR/ethanol.xyz" --hy --vector "$DIR/ethanol_forces_efield.json" --vector-scale 1.5 -o "$IMG/ethanol_forces_efield.svg" -go "$IMG/ethanol_forces_efield.gif" --gif-rot # per-atom forces, with rotation
echo "=== Crystal / unit cell ==="
xyzrender "$DIR/caffeine_cell.xyz" --cell -o "$IMG/caffeine_cell.svg" --no-orient --gif-rot -go "$IMG/caffeine_cell.gif"
diff --git a/src/xyzrender/api.py b/src/xyzrender/api.py
index 643aee8..092d8ce 100644
--- a/src/xyzrender/api.py
+++ b/src/xyzrender/api.py
@@ -419,7 +419,7 @@ def render(
labels: list[str] | None = None,
label_file: str | None = None,
# --- Vector arrows ---
- vectors: str | Path | dict | list[VectorArrow] | None = None,
+ vector: str | Path | dict | list[VectorArrow] | None = None,
vector_scale: float | None = None,
vector_color: str | None = None,
# --- Surface opacity ---
@@ -548,6 +548,7 @@ def render(
if not isinstance(config, str):
# Pre-built RenderConfig — shallow copy so we don't mutate the caller's object
cfg = copy.copy(config)
+ cfg.vectors = list(cfg.vectors)
if _orient is not None:
cfg.auto_orient = _orient
elif mol.oriented:
@@ -665,13 +666,23 @@ def render(
oriented=mol.oriented,
)
+ # --- Vectors (user-supplied + crystal axes) ---
+ _combine_vector_sources(
+ cfg,
+ rmol.graph,
+ vector=vector,
+ vector_scale=vector_scale,
+ vector_color=vector_color,
+ cell_data=rmol.cell_data,
+ axes=axes,
+ )
+
# --- Cell / crystal config ---
if rmol.cell_data is not None:
_apply_cell_config(
rmol,
cfg,
no_cell=no_cell,
- axes=axes,
axis=axis,
ghosts=ghosts,
cell_color=cell_color,
@@ -691,22 +702,6 @@ def render(
inline = [s.split() for s in labels] if labels else None
cfg.annotations = parse_annotations(inline_specs=inline, file_path=label_file, graph=rmol.graph)
- # --- Vector arrows ---
- if vector_scale is not None:
- cfg.vector_scale = vector_scale
- if vector_color is not None:
- from xyzrender.types import resolve_color
-
- cfg.vector_color = resolve_color(vector_color)
- if vectors is not None:
- if isinstance(vectors, list):
- cfg.vectors = vectors
- else:
- from xyzrender.annotations import load_vectors
-
- _vec_src = vectors if isinstance(vectors, dict) else Path(vectors)
- cfg.vectors = load_vectors(_vec_src, rmol.graph, default_color=cfg.vector_color)
-
# --- Early overlay validation (before ghost atoms are added to g1) ---
if overlay is not None and mol.cell_data is not None:
msg = "overlay= is mutually exclusive with crystal/cell display"
@@ -898,7 +893,7 @@ def render_gif(
# --- NCI detection (gif_ts / gif_trj / gif_rot) ---
detect_nci: bool = False,
# --- Vector arrows (gif_rot only) ---
- vectors: str | Path | dict | list[VectorArrow] | None = None,
+ vector: str | Path | dict | list[VectorArrow] | None = None,
vector_scale: float | None = None,
vector_color: str | None = None,
# --- Surfaces (gif_rot only) ---
@@ -1013,6 +1008,7 @@ def render_gif(
# Resolve config
if not isinstance(config, str):
cfg = copy.copy(config)
+ cfg.vectors = list(cfg.vectors)
# Apply hull overrides to pre-built config
if hull is not None:
if hull == "rings":
@@ -1193,21 +1189,17 @@ def render_gif(
aligned2 = align(ref_graph, g2)
ref_graph = merge_graphs(ref_graph, g2, aligned2, overlay_color=cfg.overlay_color)
- # --- Vector arrows (gif_rot only; needs ref_graph for COM) ---
- if vector_scale is not None:
- cfg.vector_scale = vector_scale
- if vector_color is not None:
- from xyzrender.types import resolve_color
-
- cfg.vector_color = resolve_color(vector_color)
- if vectors is not None:
- if isinstance(vectors, list):
- cfg.vectors = vectors
- else:
- from xyzrender.annotations import load_vectors
-
- _vec_src = vectors if isinstance(vectors, dict) else Path(vectors)
- cfg.vectors = load_vectors(_vec_src, ref_graph, default_color=cfg.vector_color)
+ # --- Vectors (user-supplied + crystal axes; gif_rot only) ---
+ _cell_data_for_vecs = molecule.cell_data if isinstance(molecule, Molecule) else None
+ _combine_vector_sources(
+ cfg,
+ ref_graph,
+ vector=vector,
+ vector_scale=vector_scale,
+ vector_color=vector_color,
+ cell_data=_cell_data_for_vecs,
+ axes=axes,
+ )
cube_data = molecule.cube_data if isinstance(molecule, Molecule) else None
@@ -1223,7 +1215,6 @@ def render_gif(
_gif_mol,
cfg,
no_cell=no_cell,
- axes=axes,
axis=axis,
ghosts=ghosts,
cell_color=cell_color,
@@ -1322,12 +1313,65 @@ def _resolve_cmap(
return load_cmap(str(cmap), graph)
+def _combine_vector_sources(
+ cfg: "RenderConfig",
+ graph: "nx.Graph",
+ *,
+ vector=None,
+ vector_scale: "float | None" = None,
+ vector_color: "str | None" = None,
+ cell_data: "CellData | None" = None,
+ axes: bool = True,
+) -> None:
+ """Populate ``cfg.vectors`` from user-supplied vectors and crystal axis arrows.
+
+ Must be called *before* :func:`_apply_cell_config` so that all vectors are
+ already in ``cfg.vectors`` when :func:`orient_hkl_to_view` applies the HKL
+ rotation to the whole list in one pass.
+ """
+ if vector_scale is not None:
+ cfg.vector_scale = vector_scale
+ if vector_color is not None:
+ from xyzrender.types import resolve_color
+
+ cfg.vector_color = resolve_color(vector_color)
+ if vector is not None:
+ if not isinstance(vector, list):
+ from xyzrender.annotations import load_vectors
+
+ _vec_src = vector if isinstance(vector, dict) else Path(vector)
+ vector = load_vectors(_vec_src, graph, default_color=cfg.vector_color)
+ cfg.vectors.extend(vector)
+ if cell_data is not None and axes:
+ from xyzrender.types import VectorArrow
+
+ lat = cell_data.lattice
+ orig3d = cell_data.cell_origin
+ for vec, color, label in zip(lat, cfg.axis_colors, ("a", "b", "c"), strict=True):
+ length = float(np.linalg.norm(vec))
+ if length < 1e-6:
+ continue
+ frac = min(0.25, 2.0 / length)
+ cfg.vectors.append(
+ VectorArrow(
+ vector=vec * frac,
+ origin=orig3d,
+ color=color,
+ label=label,
+ scale=1.0,
+ draw_on_top=True,
+ is_axis=True,
+ font_size=cfg.label_font_size * 1.8,
+ width=cfg.bond_width * 1.1,
+ )
+ )
+
+
def _apply_cell_config(
mol: Molecule,
cfg: RenderConfig,
*,
no_cell: bool,
- axes: bool,
axis: str | None,
ghosts: bool | None,
cell_color: str | None,
@@ -1357,7 +1401,7 @@ def _apply_cell_config(
if axis is not None:
from xyzrender.viewer import orient_hkl_to_view
- orient_hkl_to_view(mol.graph, cell_data, axis)
+ orient_hkl_to_view(mol.graph, cell_data, axis, cfg)
cfg.auto_orient = False
# Ghost (periodic image) atoms — default: on when cell_data is present
@@ -1367,31 +1411,6 @@ def _apply_cell_config(
add_crystal_images(mol.graph, cell_data)
- # Crystal axes a/b/c as annotation vectors at the cell origin
- if axes:
- from xyzrender.types import VectorArrow
-
- lat = cell_data.lattice
- orig3d = cell_data.cell_origin
- for vec, color, label in zip(lat, cfg.axis_colors, ("a", "b", "c"), strict=True):
- length = float(np.linalg.norm(vec))
- if length < 1e-6:
- continue
- # Arrow spans 25% of the cell edge (max 2 Å) from the origin corner
- frac = min(0.25, 2.0 / length)
- cfg.vectors.append(
- VectorArrow(
- vector=vec * frac,
- origin=orig3d,
- color=color,
- label=label,
- scale=1.0,
- draw_on_top=True,
- font_size=cfg.label_font_size * 1.8,
- width=cfg.bond_width * 1.1,
- )
- )
-
# Default no-bo for periodic structures (bond orders are not PBC-aware)
if bo_explicit is None:
cfg.bond_orders = False
diff --git a/src/xyzrender/cli.py b/src/xyzrender/cli.py
index 72e63c7..9a43548 100644
--- a/src/xyzrender/cli.py
+++ b/src/xyzrender/cli.py
@@ -295,7 +295,7 @@ def main() -> None:
help="Explicit colormap range (default: auto from file values)",
)
annot_g.add_argument(
- "--vectors",
+ "--vector",
default=None,
metavar="FILE",
help=(
@@ -601,7 +601,7 @@ def main() -> None:
nci_coloring=args.nci_coloring,
overlay=args.overlay,
overlay_color=args.overlay_color,
- vectors=args.vectors,
+ vector=args.vector,
vector_scale=args.vector_scale,
output=args.output,
)
@@ -656,7 +656,7 @@ def main() -> None:
cell_color=args.cell_color,
cell_width=args.cell_width,
ghost_opacity=args.ghost_opacity,
- vectors=args.vectors,
+ vector=args.vector,
vector_scale=args.vector_scale,
)
except ValueError as e:
diff --git a/src/xyzrender/gif.py b/src/xyzrender/gif.py
index 3ca57bc..2a899f8 100644
--- a/src/xyzrender/gif.py
+++ b/src/xyzrender/gif.py
@@ -353,7 +353,7 @@ def render_rotation_gif(
for i, nid in enumerate(nodes):
graph.nodes[nid]["position"] = tuple(_tilted[i].tolist())
- # Set up the cell origin (0,0,0) so _apply_rot_to_lattice rotates it around
+ # Set up the cell origin (0,0,0) so it rotates it around
# the molecular centre of mass, not around the origin.
# _orient_graph already does this for the PCA path; this covers --no-orient.
if "lattice" in graph.graph and "lattice_origin" not in graph.graph:
diff --git a/src/xyzrender/renderer.py b/src/xyzrender/renderer.py
index bde109e..9323453 100644
--- a/src/xyzrender/renderer.py
+++ b/src/xyzrender/renderer.py
@@ -129,7 +129,8 @@ def render_svg(graph, config: RenderConfig | None = None, *, _log: bool = True)
_ref_px_per_ang = (_REF_CANVAS - 2 * cfg.padding) / _REF_SPAN
_vec_tips = []
for vi, va in enumerate(cfg.vectors):
- scaled_vec = _vec_dirs[vi] * va.scale * cfg.vector_scale
+ _vec_scale = 1.0 if va.is_axis else cfg.vector_scale
+ scaled_vec = _vec_dirs[vi] * va.scale * _vec_scale
tail3d = _vec_origins[vi] - scaled_vec / 2 if va.anchor == "center" else _vec_origins[vi]
tip3d = tail3d + scaled_vec
_vec_tips.append(tip3d)
@@ -463,7 +464,8 @@ def render_svg(graph, config: RenderConfig | None = None, *, _log: bool = True)
if cfg.vectors:
for vi in range(len(cfg.vectors)):
va = cfg.vectors[vi]
- scaled_vec = _vec_dirs[vi] * va.scale * cfg.vector_scale
+ _global = 1.0 if va.is_axis else cfg.vector_scale
+ scaled_vec = _vec_dirs[vi] * va.scale * _global
if va.anchor == "center":
tail3d = _vec_origins[vi] - scaled_vec / 2
else:
diff --git a/src/xyzrender/types.py b/src/xyzrender/types.py
index 2b48d49..7215c8b 100644
--- a/src/xyzrender/types.py
+++ b/src/xyzrender/types.py
@@ -226,6 +226,7 @@ class VectorArrow:
anchor: str = "tail" # "tail" (origin = arrow tail) or "center" (origin = arrow midpoint)
host_atom: int | None = None # 0-based atom index, or None for com/explicit origins
draw_on_top: bool = False
+ is_axis: bool = False # True for crystallographic axis arrows (not affected by vector_scale)
font_size: float | None = None
width: float | None = None
@@ -437,7 +438,7 @@ class RenderConfig:
"royalblue",
) # firebrick, forestgreen, royalblue
axis_width_scale: float = 3.0 # multiplier on cell_line_width for axis stroke width
- # Arbitrary vector arrows (--vectors)
+ # Arbitrary vector arrows (--vector)
vectors: list[VectorArrow] = field(default_factory=list)
vector_scale: float = 1.0 # global length multiplier applied to all vectors
vector_color: str = "firebrick" # default arrow color (firebrick) when not specified per-arrow
diff --git a/src/xyzrender/utils.py b/src/xyzrender/utils.py
index afe241f..ce3b21a 100644
--- a/src/xyzrender/utils.py
+++ b/src/xyzrender/utils.py
@@ -174,9 +174,9 @@ def resolve_orientation(
# Co-rotate crystal lattice and cell origin by the same matrix
if cfg.cell_data is not None:
- lat = cfg.cell_data.lattice
- cfg.cell_data.lattice = (rot @ lat.T).T
- cfg.cell_data.cell_origin = rot @ (cfg.cell_data.cell_origin - centroid_before) + curr_centroid
+ cfg.cell_data.lattice, cfg.cell_data.cell_origin = _apply_rot_to_vecs(
+ rot, cfg.cell_data.lattice, cfg.cell_data.cell_origin, centroid_before
+ )
cfg.auto_orient = False # already applied; renderer must not re-apply
@@ -200,6 +200,20 @@ def resolve_orientation(
return rot, atom_centroid, curr_centroid
+def _apply_rot_to_vecs(
+ rot: np.ndarray,
+ directions: np.ndarray,
+ origins: np.ndarray,
+ centroid: np.ndarray,
+) -> tuple[np.ndarray, np.ndarray]:
+ """Rotate direction vectors and translate origins around *centroid* by *rot*.
+
+ Works for shape ``(3,)`` (single vector) or ``(N, 3)`` (row-vectors).
+ Returns ``(rotated_directions, rotated_origins)``.
+ """
+ return (rot @ directions.T).T, (rot @ (origins - centroid).T).T + centroid
+
+
def apply_axis_angle_rotation(graph: nx.Graph, axis: np.ndarray, angle: float) -> None:
"""Rotate all atom positions in-place around an arbitrary axis (degrees).
@@ -215,8 +229,6 @@ def apply_axis_angle_rotation(graph: nx.Graph, axis: np.ndarray, angle: float) -
angle:
Rotation angle in degrees.
"""
- from xyzrender.viewer import _apply_rot_to_lattice
-
nodes = list(graph.nodes())
theta = np.radians(angle)
k = axis / np.linalg.norm(axis)
@@ -229,7 +241,11 @@ def apply_axis_angle_rotation(graph: nx.Graph, axis: np.ndarray, angle: float) -
rotated = (rot @ (positions - centroid).T).T + centroid
for i, nid in enumerate(nodes):
graph.nodes[nid]["position"] = tuple(rotated[i].tolist())
- _apply_rot_to_lattice(graph, rot, centroid)
+ if "lattice" in graph.graph:
+ origin = np.asarray(graph.graph.get("lattice_origin", np.zeros(3)), dtype=float)
+ graph.graph["lattice"], graph.graph["lattice_origin"] = _apply_rot_to_vecs(
+ rot, graph.graph["lattice"], origin, centroid
+ )
def kabsch_rotation(original: np.ndarray, target: np.ndarray) -> np.ndarray:
diff --git a/src/xyzrender/viewer.py b/src/xyzrender/viewer.py
index 4f5592a..be6ba8c 100644
--- a/src/xyzrender/viewer.py
+++ b/src/xyzrender/viewer.py
@@ -15,6 +15,7 @@
import networkx as nx
from vmol import Vmol
+ from xyzrender.config import RenderConfig
from xyzrender.types import CellData
_Atoms: TypeAlias = list[tuple[str, tuple[float, float, float]]]
@@ -90,40 +91,7 @@ def rotate_with_viewer(
return rot, c1, c2
-def apply_rotation(graph: nx.Graph, rx: float, ry: float, rz: float) -> None:
- """Rotate all atom positions in-place by Euler angles (degrees).
-
- Rotation is around the molecular centroid so the molecule stays centered.
-
- Parameters
- ----------
- graph:
- Molecular graph whose node positions are updated in-place.
- rx, ry, rz:
- Rotation angles around x, y, z axes in degrees.
- """
- nodes = list(graph.nodes())
- rx, ry, rz = np.radians(rx), np.radians(ry), np.radians(rz)
- cx, sx = np.cos(rx), np.sin(rx)
- cy, sy = np.cos(ry), np.sin(ry)
- cz, sz = np.cos(rz), np.sin(rz)
- # Rz @ Ry @ Rx
- rot = np.array(
- [
- [cy * cz, sx * sy * cz - cx * sz, cx * sy * cz + sx * sz],
- [cy * sz, sx * sy * sz + cx * cz, cx * sy * sz - sx * cz],
- [-sy, sx * cy, cx * cy],
- ]
- )
- positions = np.array([graph.nodes[n]["position"] for n in nodes])
- centroid = positions.mean(axis=0)
- rotated = (rot @ (positions - centroid).T).T + centroid
- for i, nid in enumerate(nodes):
- graph.nodes[nid]["position"] = tuple(rotated[i].tolist())
- _apply_rot_to_lattice(graph, rot, centroid)
-
-
-def orient_hkl_to_view(graph: nx.Graph, cell_data: "CellData", axis_str: str) -> None:
+def orient_hkl_to_view(graph: nx.Graph, cell_data: "CellData", axis_str: str, cfg: "RenderConfig") -> None:
"""Rotate *graph* and *cell_data* so that the [hkl] direction points along +z.
Parameters
@@ -134,6 +102,8 @@ def orient_hkl_to_view(graph: nx.Graph, cell_data: "CellData", axis_str: str) ->
Crystal cell data whose lattice and origin are updated in-place.
axis_str:
3-digit Miller index string, optionally prefixed with ``-`` (e.g. ``'111'``, ``'-110'``).
+ cfg:
+ Render configuration object.
Raises
------
@@ -169,32 +139,14 @@ def orient_hkl_to_view(graph: nx.Graph, cell_data: "CellData", axis_str: str) ->
pos_rot = (rot_view @ (pos - centroid).T).T + centroid
for idx, nid in enumerate(node_ids):
graph.nodes[nid]["position"] = tuple(pos_rot[idx].tolist())
- cell_data.lattice = (rot_view @ cell_data.lattice.T).T
- cell_data.cell_origin = rot_view @ (cell_data.cell_origin - centroid) + centroid
-
-
-def _apply_rot_to_lattice(graph: nx.Graph, rot: np.ndarray, centroid: np.ndarray) -> None:
- """Rotate the lattice vectors and cell origin stored on *graph* by *rot*.
+ from xyzrender.utils import _apply_rot_to_vecs
- Both the lattice vectors and the cell origin are always updated so that
- the cell box stays aligned with the atoms after any rotation. The origin
- defaults to (0, 0, 0) when not explicitly present in the graph.
-
- Parameters
- ----------
- graph:
- Molecular graph (lattice stored in ``graph.graph``).
- rot:
- 3x3 rotation matrix.
- centroid:
- Centroid position to rotate around.
- """
- if "lattice" not in graph.graph:
- return
- lat = np.array(graph.graph["lattice"], dtype=float)
- graph.graph["lattice"] = (rot @ lat.T).T
- origin = np.array(graph.graph.get("lattice_origin", np.zeros(3)), dtype=float)
- graph.graph["lattice_origin"] = rot @ (origin - centroid) + centroid
+ cell_data.lattice, cell_data.cell_origin = _apply_rot_to_vecs(
+ rot_view, cell_data.lattice, cell_data.cell_origin, centroid
+ )
+ if hasattr(cfg, "vectors"):
+ for vec in cfg.vectors:
+ vec.vector, vec.origin = _apply_rot_to_vecs(rot_view, vec.vector, vec.origin, centroid)
def _run_viewer(viewer: Vmol, mol: dict, extra_args: list[str] | None = None) -> str:
diff --git a/tests/test_crystal.py b/tests/test_crystal.py
index fa19d0b..ed21421 100644
--- a/tests/test_crystal.py
+++ b/tests/test_crystal.py
@@ -201,133 +201,97 @@ def test_extxyz_cell_box_renders(extxyz_graph):
assert len(cell_lines) == 12
-def test_cell_corotates_with_atoms(extxyz_graph):
+def test_orient_hkl_cell_corotates_with_atoms(vasp_crystal):
+ """orient_hkl_to_view keeps lattice and atom positions mutually consistent."""
import copy
- from xyzrender.types import CellData, RenderConfig
- from xyzrender.viewer import apply_rotation
+ from xyzrender.renderer import render_svg
+ from xyzrender.types import RenderConfig
+ from xyzrender.viewer import orient_hkl_to_view
- graph = copy.deepcopy(extxyz_graph)
- lat_before = np.array(graph.graph["lattice"], dtype=float)
- cell_data = CellData(
- lattice=lat_before.copy(), cell_origin=np.array(graph.graph.get("lattice_origin", [0.0, 0.0, 0.0]), dtype=float)
- )
+ graph, cell_data = copy.deepcopy(vasp_crystal)
+ lat_before = cell_data.lattice.copy()
cfg = RenderConfig(cell_data=cell_data, show_cell=True)
- assert cfg.cell_data is not None
- apply_rotation(graph, rx=30.0, ry=45.0, rz=15.0)
- lat_graph = np.array(graph.graph["lattice"], dtype=float)
- assert not np.allclose(lat_graph, lat_before, atol=1e-6)
- assert np.allclose(cfg.cell_data.lattice, lat_before, atol=1e-6)
- cfg.cell_data.lattice = lat_graph.copy()
- cfg.cell_data.cell_origin = np.array(graph.graph.get("lattice_origin", [0.0, 0.0, 0.0]), dtype=float)
- np.testing.assert_allclose(cfg.cell_data.lattice, lat_graph, atol=1e-9)
- node_ids = list(graph.nodes())
- positions = np.array([graph.nodes[i]["position"] for i in node_ids], dtype=float)
- centroid = positions.mean(axis=0)
- cell_norm = np.linalg.norm(cfg.cell_data.lattice, axis=1).max()
- origin_dist = np.linalg.norm(cfg.cell_data.cell_origin - centroid)
- assert origin_dist < 2 * cell_norm
- from xyzrender.renderer import render_svg
+ orient_hkl_to_view(graph, cell_data, "100", cfg)
+ # Lattice must have been rotated
+ assert not np.allclose(cell_data.lattice, lat_before, atol=1e-6)
+ # Render must still produce exactly 12 cell edges
svg = render_svg(graph, cfg)
- q = chr(34)
- cell_lines = [ln for ln in svg.splitlines() if "class=" + q + "cell-edge" + q in ln]
+ cell_lines = [ln for ln in svg.splitlines() if 'class="cell-edge"' in ln]
assert len(cell_lines) == 12
-def test_apply_rotation_sets_lattice_origin_when_absent(extxyz_graph):
- """apply_rotation must write graph.graph['lattice_origin'] even when the
- file had no explicit origin — the origin is (0,0,0) implicitly but must
- be updated to the rotated value so the cell box stays aligned."""
+def test_orient_hkl_cell_origin_updated_from_zero(vasp_crystal):
+ """orient_hkl_to_view updates cell_origin from its zero default value."""
import copy
- from xyzrender.viewer import apply_rotation
+ from xyzrender.types import RenderConfig
+ from xyzrender.viewer import orient_hkl_to_view
- graph = copy.deepcopy(extxyz_graph)
- # Confirm the fixture has no explicit origin (caffeine_cell.xyz has none)
- assert "lattice_origin" not in graph.graph
+ graph, cell_data = copy.deepcopy(vasp_crystal)
+ assert np.allclose(cell_data.cell_origin, np.zeros(3))
- apply_rotation(graph, rx=45.0, ry=0.0, rz=0.0)
+ cfg = RenderConfig()
+ orient_hkl_to_view(graph, cell_data, "100", cfg)
- # After rotation the key must exist and must NOT be zeros
- assert "lattice_origin" in graph.graph, "apply_rotation must write lattice_origin"
- origin = np.array(graph.graph["lattice_origin"], dtype=float)
- assert not np.allclose(origin, np.zeros(3), atol=1e-6), (
- f"lattice_origin should be non-zero after a 45° rotation (got {origin})"
+ # cell_origin is rotated around the atom centroid; since the centroid is
+ # not at the origin the rotated zero-origin is non-zero
+ assert not np.allclose(cell_data.cell_origin, np.zeros(3), atol=1e-6), (
+ f"cell_origin should be non-zero after HKL rotation (got {cell_data.cell_origin})"
)
-def test_fractional_coords_preserved_after_rotation(extxyz_graph):
- """Fractional coordinates of all atoms must be unchanged after a consistent
- rotation of both atoms and cell (lattice + origin)."""
+def test_orient_hkl_fractional_coords_preserved(vasp_crystal):
+ """orient_hkl_to_view preserves the fractional coordinates of all atoms."""
import copy
- from xyzrender.viewer import apply_rotation
+ from xyzrender.types import RenderConfig
+ from xyzrender.viewer import orient_hkl_to_view
- graph = copy.deepcopy(extxyz_graph)
- lat0 = np.array(graph.graph["lattice"], dtype=float)
- orig0 = np.array(graph.graph.get("lattice_origin", np.zeros(3)), dtype=float)
+ graph, cell_data = copy.deepcopy(vasp_crystal)
node_ids = list(graph.nodes())
pos0 = np.array([graph.nodes[i]["position"] for i in node_ids], dtype=float)
- # Fractional coords before: solve lat.T @ f = (pos - origin) for each atom
- frac_before = np.linalg.solve(lat0.T, (pos0 - orig0).T).T # shape (n, 3)
+ frac_before = np.linalg.solve(cell_data.lattice.T, (pos0 - cell_data.cell_origin).T).T
- apply_rotation(graph, rx=30.0, ry=20.0, rz=50.0)
+ cfg = RenderConfig()
+ orient_hkl_to_view(graph, cell_data, "111", cfg)
- lat1 = np.array(graph.graph["lattice"], dtype=float)
- orig1 = np.array(graph.graph["lattice_origin"], dtype=float)
pos1 = np.array([graph.nodes[i]["position"] for i in node_ids], dtype=float)
- frac_after = np.linalg.solve(lat1.T, (pos1 - orig1).T).T
+ frac_after = np.linalg.solve(cell_data.lattice.T, (pos1 - cell_data.cell_origin).T).T
np.testing.assert_allclose(
frac_after,
frac_before,
atol=1e-9,
- err_msg="Fractional coordinates must be preserved after rotation",
+ err_msg="Fractional coordinates must be preserved after orient_hkl_to_view",
)
-def test_cell_corotates_with_ghost_atoms(extxyz_graph):
- """Full --cell -I pipeline: add ghost atoms, rotate, re-sync cell_data,
- render — cell box must still have exactly 12 edges and be consistent."""
+def test_orient_hkl_cell_corotates_with_ghost_atoms(vasp_crystal):
+ """orient_hkl_to_view rotates all ghost atoms and keeps 12 cell edges."""
import copy
from xyzrender.crystal import add_crystal_images
from xyzrender.renderer import render_svg
- from xyzrender.types import CellData, RenderConfig
- from xyzrender.viewer import apply_rotation
+ from xyzrender.types import RenderConfig
+ from xyzrender.viewer import orient_hkl_to_view
- graph = copy.deepcopy(extxyz_graph)
- lat0 = np.array(graph.graph["lattice"], dtype=float)
- cell_data = CellData(lattice=lat0.copy())
+ graph, cell_data = copy.deepcopy(vasp_crystal)
+ n_real = graph.number_of_nodes()
add_crystal_images(graph, cell_data)
- n_after_ghosts = graph.number_of_nodes()
- assert n_after_ghosts > extxyz_graph.number_of_nodes(), "Ghost atoms must have been added"
-
- # Simulate what rotate_with_viewer / --cell -I does: rotate (all nodes
- # including ghosts), then re-sync cell_data from updated graph.graph.
- apply_rotation(graph, rx=0.0, ry=90.0, rz=0.0)
+ assert graph.number_of_nodes() > n_real, "Ghost atoms must have been added"
- cell_data.lattice = np.array(graph.graph["lattice"], dtype=float)
- cell_data.cell_origin = np.array(graph.graph["lattice_origin"], dtype=float)
+ cfg = RenderConfig(cell_data=cell_data, show_cell=True, auto_orient=False)
+ orient_hkl_to_view(graph, cell_data, "110", cfg)
- cfg = RenderConfig(
- cell_data=cell_data,
- show_cell=True,
- auto_orient=False,
- )
svg = render_svg(graph, cfg)
cell_lines = [ln for ln in svg.splitlines() if 'class="cell-edge"' in ln]
assert len(cell_lines) == 12
- # COM of real atoms must be within a reasonable distance of the cell box
- real_ids = [i for i in graph.nodes() if not graph.nodes[i].get("image", False)]
- real_pos = np.array([graph.nodes[i]["position"] for i in real_ids], dtype=float)
+ # COM of real atoms must be inside the cell (fractional coords in [0, 1])
+ real_pos = np.array([graph.nodes[i]["position"] for i in range(n_real)], dtype=float)
com = real_pos.mean(axis=0)
- lat = cell_data.lattice
- orig = cell_data.cell_origin
- frac = np.linalg.solve(lat.T, com - orig)
- # All fractional coordinates of the COM should be roughly in [0, 1]
- # (within one cell width, allowing for atoms at the boundary)
+ frac = np.linalg.solve(cell_data.lattice.T, com - cell_data.cell_origin)
assert np.all(frac > -0.5), f"COM fractional coords {frac} are far outside the cell after rotation"
assert np.all(frac < 1.5), f"COM fractional coords {frac} are far outside the cell after rotation"
diff --git a/tests/test_vectors.py b/tests/test_vectors.py
index 8ac027d..88657e3 100644
--- a/tests/test_vectors.py
+++ b/tests/test_vectors.py
@@ -171,7 +171,6 @@ def test_render_with_vector_arrows():
svg = render_svg(graph, cfg)
assert "#cc0000" in svg
assert "μ" in svg
- assert "]*fill="#ab1234"', svg), "arrowhead polygon with user color must appear in SVG"
+
+
+def test_render_user_vector_appears_with_crystal_axes():
+ """render() with cell_data + axes=True must include BOTH axis colors AND the
+ user-supplied vector color in the SVG output. This is the main regression
+ guard for the bug where cfg.vectors was overwritten instead of extended."""
+ from xyzrender.api import load as api_load
+ from xyzrender.api import render
+
+ mol = api_load(EXAMPLES / "caffeine_cell.xyz", cell=True)
+ # Use a color that differs from all three axis colors (firebrick/forestgreen/royalblue)
+ user_color = "#cd5c5c" # indianred — resolves to a distinct hex from axis colors
+ jf = _write_json([{"origin": "com", "vector": [2.0, 0.0, 0.0], "color": user_color, "label": "dipole"}])
+ svg = str(render(mol, vector=str(jf), axes=True))
+ # Axis colors must be present (firebrick → #b22222, forestgreen → #228b22, royalblue → #4169e1)
+ from xyzrender.types import resolve_color
+
+ assert resolve_color("firebrick") in svg, "a-axis (firebrick) must appear"
+ assert resolve_color("forestgreen") in svg, "b-axis (forestgreen) must appear"
+ assert resolve_color("royalblue") in svg, "c-axis (royalblue) must appear"
+ # User vector must also be present
+ assert user_color in svg, "user vector color must appear in SVG alongside axis arrows"
+ assert "dipole" in svg, "user vector label must appear in SVG"
+
+
+def test_render_user_vector_appears_with_crystal_axes_no_double_loading():
+ """Calling render() twice with the same pre-built config must NOT accumulate
+ vectors across calls (shallow-copy list aliasing regression)."""
+ from xyzrender.api import load as api_load
+ from xyzrender.api import render
+ from xyzrender.config import build_config
+
+ mol = api_load(EXAMPLES / "caffeine_cell.xyz", cell=True)
+ user_color = "#cd5c5c"
+ jf = _write_json([{"origin": "com", "vector": [2.0, 0.0, 0.0], "color": user_color, "label": "dipole"}])
+
+ cfg = build_config("default")
+ svg1 = str(render(mol, config=cfg, vector=str(jf), axes=True))
+ svg2 = str(render(mol, config=cfg, vector=str(jf), axes=True))
+
+ # Count polygon arrowheads with user color in each SVG — should be exactly 1 each
+ count1 = len(re.findall(rf']*fill="{re.escape(user_color)}"', svg1))
+ count2 = len(re.findall(rf']*fill="{re.escape(user_color)}"', svg2))
+ assert count1 == 1, f"first render: expected 1 user-vector arrowhead, got {count1}"
+ assert count2 == 1, f"second render: expected 1 user-vector arrowhead, got {count2}"
+
+
+# ---------------------------------------------------------------------------
+# orient_hkl_to_view — user vectors co-rotate with crystal
+# ---------------------------------------------------------------------------
+
+
+def _cubic_graph_and_cell(a: float = 4.0):
+ """Return a minimal 2-atom graph and CellData for a simple cubic cell."""
+ import networkx as nx
+
+ from xyzrender.types import CellData
+
+ g = nx.Graph()
+ g.add_node(0, symbol="C", position=(0.0, 0.0, 0.0))
+ g.add_node(1, symbol="C", position=(a, a, a))
+ lattice = np.diag([a, a, a]).astype(float)
+ return g, CellData(lattice=lattice, cell_origin=np.zeros(3))
+
+
+def test_orient_hkl_vectors_001_no_rotation():
+ """Axis [001] on a cubic cell is already +z — vector direction and origin are unchanged."""
+ from xyzrender.types import RenderConfig
+ from xyzrender.viewer import orient_hkl_to_view
+
+ g, cell_data = _cubic_graph_and_cell()
+ va = VectorArrow(vector=np.array([1.0, 2.0, 3.0]), origin=np.array([1.0, 1.0, 1.0]))
+ cfg = RenderConfig(vectors=[va])
+
+ orient_hkl_to_view(g, cell_data, "001", cfg)
+
+ np.testing.assert_allclose(np.array(cfg.vectors[0].vector), [1.0, 2.0, 3.0], atol=1e-9)
+ np.testing.assert_allclose(np.array(cfg.vectors[0].origin), [1.0, 1.0, 1.0], atol=1e-9)
+
+
+def test_orient_hkl_vectors_100_direction():
+ """Axis [100] on a cubic cell: a vector along lattice-a must point along +z after rotation."""
+ from xyzrender.types import RenderConfig
+ from xyzrender.viewer import orient_hkl_to_view
+
+ g, cell_data = _cubic_graph_and_cell()
+ # The a-axis is [1,0,0]; viewing along [100] must align it with +z.
+ va = VectorArrow(vector=np.array([1.0, 0.0, 0.0]), origin=np.array([2.0, 2.0, 2.0]))
+ cfg = RenderConfig(vectors=[va])
+
+ orient_hkl_to_view(g, cell_data, "100", cfg)
+
+ np.testing.assert_allclose(np.array(cfg.vectors[0].vector), [0.0, 0.0, 1.0], atol=1e-9)
+
+
+def test_orient_hkl_vectors_100_origin():
+ """After [100] rotation the vector origin is rotated around the atom centroid."""
+ from xyzrender.types import RenderConfig
+ from xyzrender.viewer import orient_hkl_to_view
+
+ a = 4.0
+ g, cell_data = _cubic_graph_and_cell(a)
+ pos = np.array([g.nodes[i]["position"] for i in g.nodes()])
+ centroid = pos.mean(axis=0) # (2, 2, 2)
+
+ # Place the origin displaced by a along the a-axis from the centroid.
+ origin = centroid + np.array([a, 0.0, 0.0])
+ va = VectorArrow(vector=np.array([1.0, 0.0, 0.0]), origin=origin.copy())
+ cfg = RenderConfig(vectors=[va])
+
+ orient_hkl_to_view(g, cell_data, "100", cfg)
+
+ # The rotation maps displacement [a, 0, 0] → [0, 0, a].
+ expected_origin = centroid + np.array([0.0, 0.0, a])
+ np.testing.assert_allclose(np.array(cfg.vectors[0].origin), expected_origin, atol=1e-9)
+
+
+def test_orient_hkl_vectors_110_direction():
+ """Axis [110] on a cubic cell: a vector along the [110] diagonal must point along +z."""
+ from xyzrender.types import RenderConfig
+ from xyzrender.viewer import orient_hkl_to_view
+
+ g, cell_data = _cubic_graph_and_cell()
+ # The [110] lattice direction is [1, 1, 0] / sqrt(2).
+ va = VectorArrow(
+ vector=np.array([1.0, 1.0, 0.0]) / np.sqrt(2),
+ origin=np.array([2.0, 2.0, 2.0]), # at centroid → origin is unchanged
+ )
+ cfg = RenderConfig(vectors=[va])
+
+ orient_hkl_to_view(g, cell_data, "110", cfg)
+
+ np.testing.assert_allclose(np.array(cfg.vectors[0].vector), [0.0, 0.0, 1.0], atol=1e-9)
+
+
+# ---------------------------------------------------------------------------
+# Integration — user vectors via render() with axis= on a real crystal file
+# ---------------------------------------------------------------------------
+
+
+def test_render_crystal_axis_vector_projects_as_dot_when_parallel_to_view():
+ """render() with axis='100' on an orthorhombic cell: a user vector aligned
+ with the a-axis must project as a dot () in the SVG because the
+ arrow is now pointing directly along the viewing axis (+z)."""
+ from xyzrender.api import Molecule, render
+ from xyzrender.types import CellData
+
+ graph, _ = load_molecule(EXAMPLES / "caffeine_cell.xyz")
+ cell_data = CellData(
+ lattice=np.array(graph.graph["lattice"], dtype=float),
+ cell_origin=np.zeros(3),
+ )
+ mol = Molecule(graph=graph, cell_data=cell_data)
+
+ # Unit vector along the a-axis (first lattice row)
+ a_hat = cell_data.lattice[0] / np.linalg.norm(cell_data.lattice[0])
+ positions = np.array([graph.nodes[i]["position"] for i in graph.nodes()])
+ centroid = positions.mean(axis=0)
+
+ # Distinct color so we can identify this arrow's elements in the SVG.
+ user_color = "#7f1234"
+ va = VectorArrow(vector=a_hat.copy(), origin=centroid.copy(), color=user_color)
+
+ # View along [100]: the user vector is now parallel to the viewing axis →
+ # projected length ≈ 0 → _draw_arrow_svg renders it as a filled circle, not a polygon.
+ svg = str(render(mol, vector=[va], axes=False, ghosts=False, axis="100"))
+
+ assert f'fill="{user_color}"' in svg, "user vector must appear in the rotated SVG"
+ # In the short-projection code path the arrow renders as a circle, not a polygon arrowhead.
+ assert re.search(rf']*fill="{re.escape(user_color)}"', svg), (
+ "user vector aligned with view axis should render as a dot (circle), not a line+polygon"
+ )
+ assert not re.search(rf']*fill="{re.escape(user_color)}"', svg), (
+ "user vector aligned with view axis must NOT render as a polygon arrowhead"
+ )