Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/vector/backends/awkward_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy

from vector._methods import _repr_momentum_to_generic


def _recname(is_momentum: bool, dimension: int) -> str:
name = "Momentum" if is_momentum else "Vector"
Expand Down Expand Up @@ -199,6 +201,20 @@ def _check_names(
if dimension == 0:
raise TypeError(complaint1 if is_momentum else complaint2)

# Check if any remaining fieldnames would conflict with already-processed coordinates
# or with each other when mapped to generic names (e.g., "x" and "px" both map to "x")
if fieldnames:
# Check leftovers against already-processed coordinates
for fname in fieldnames:
generic = _repr_momentum_to_generic.get(fname, fname)
if generic in names:
raise TypeError(complaint1 if is_momentum else complaint2)

# Check leftovers against each other for duplicates
leftover_generics = [_repr_momentum_to_generic.get(x, x) for x in fieldnames]
if len(leftover_generics) != len(set(leftover_generics)):
raise TypeError(complaint1 if is_momentum else complaint2)

for name in fieldnames:
names.append(name)
columns.append(projectable[name])
Expand Down
165 changes: 165 additions & 0 deletions src/vector/backends/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,168 @@ def __setitem__(self, where: typing.Any, what: typing.Any) -> None:
return _setitem(self, where, what, True)


def _validate_numpy_coordinates(fieldnames: tuple[str, ...]) -> None:
"""
Validate coordinate field names using dimension-guard pattern.

This follows the same logic as _check_names in awkward_constructors to ensure
consistent validation across backends.

Raises TypeError if duplicate or conflicting coordinates are detected.
"""
complaint1 = "duplicate coordinates (through momentum-aliases): " + ", ".join(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a picky comment but it's a bit sub-optimal that we duplicate these "complaint strings" in several submodules. For better maintainability it's probably best to move them to a trival submodule and import the strings in the several places, such as here but also in awkward_constructions.py for example.

repr(x) for x in fieldnames
)
complaint2 = (
"unrecognized combination of coordinates, allowed combinations are:\n\n"
" (2D) x= y=\n"
" (2D) rho= phi=\n"
" (3D) x= y= z=\n"
" (3D) x= y= theta=\n"
" (3D) x= y= eta=\n"
" (3D) rho= phi= z=\n"
" (3D) rho= phi= theta=\n"
" (3D) rho= phi= eta=\n"
" (4D) x= y= z= t=\n"
" (4D) x= y= z= tau=\n"
" (4D) x= y= theta= t=\n"
" (4D) x= y= theta= tau=\n"
" (4D) x= y= eta= t=\n"
" (4D) x= y= eta= tau=\n"
" (4D) rho= phi= z= t=\n"
" (4D) rho= phi= z= tau=\n"
" (4D) rho= phi= theta= t=\n"
" (4D) rho= phi= theta= tau=\n"
" (4D) rho= phi= eta= t=\n"
" (4D) rho= phi= eta= tau="
)

is_momentum = False
dimension = 0
fieldnames_copy = list(fieldnames)

# 2D azimuthal coordinates
if "x" in fieldnames_copy and "y" in fieldnames_copy:
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("x")
fieldnames_copy.remove("y")
if "rho" in fieldnames_copy and "phi" in fieldnames_copy:
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("rho")
fieldnames_copy.remove("phi")
if "x" in fieldnames_copy and "py" in fieldnames_copy:
is_momentum = True
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("x")
fieldnames_copy.remove("py")
if "px" in fieldnames_copy and "y" in fieldnames_copy:
is_momentum = True
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("px")
fieldnames_copy.remove("y")
if "px" in fieldnames_copy and "py" in fieldnames_copy:
is_momentum = True
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("px")
fieldnames_copy.remove("py")
if "pt" in fieldnames_copy and "phi" in fieldnames_copy:
is_momentum = True
if dimension != 0:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 2
fieldnames_copy.remove("pt")
fieldnames_copy.remove("phi")

# 3D longitudinal coordinates
if "z" in fieldnames_copy:
if dimension != 2:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 3
fieldnames_copy.remove("z")
if "theta" in fieldnames_copy:
if dimension != 2:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 3
fieldnames_copy.remove("theta")
if "eta" in fieldnames_copy:
if dimension != 2:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 3
fieldnames_copy.remove("eta")
if "pz" in fieldnames_copy:
is_momentum = True
if dimension != 2:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 3
fieldnames_copy.remove("pz")

# 4D temporal coordinates
if "t" in fieldnames_copy:
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("t")
if "tau" in fieldnames_copy:
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("tau")
if "E" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("E")
if "e" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("e")
if "energy" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("energy")
if "M" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("M")
if "m" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("m")
if "mass" in fieldnames_copy:
is_momentum = True
if dimension != 3:
raise TypeError(complaint1 if is_momentum else complaint2)
dimension = 4
fieldnames_copy.remove("mass")

# Check if any remaining fieldnames would conflict with already-processed coordinates
# when mapped to generic names (e.g., pt was processed, rho shouldn't remain)
if fieldnames_copy:
# Map all original fieldnames to generic names to detect conflicts
generic_names = [_repr_momentum_to_generic.get(x, x) for x in fieldnames]
if len(generic_names) != len(set(generic_names)):
raise TypeError(complaint1 if is_momentum else complaint2)


def array(*args: typing.Any, **kwargs: typing.Any) -> VectorNumpy:
"""
Constructs a NumPy array of vectors, whose type is determined by the dtype
Expand Down Expand Up @@ -2138,6 +2300,9 @@ def array(*args: typing.Any, **kwargs: typing.Any) -> VectorNumpy:

is_momentum = any(x in _repr_momentum_to_generic for x in names)

# Validate coordinates using dimension-guard pattern (same as awkward _check_names)
_validate_numpy_coordinates(names)
Comment on lines +2303 to +2304
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector.array is just a wrapper around individual constructors of Vector/MomentumNumpy*D, which can be used to construct vectors (unlike the Awkward backend). Hence, it would be better if we move this check to the __array_finalize__ method of each class:

def __array_finalize__(self, obj: typing.Any) -> None:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vaguely remember that I had some issue with __array_finalize__ regarding when is it being ran when I was making these edits but I will take a look again.


if any(x in ("t", "E", "e", "energy", "tau", "M", "m", "mass") for x in names):
cls = MomentumNumpy4D if is_momentum else VectorNumpy4D
elif any(x in ("z", "pz", "theta", "eta") for x in names):
Expand Down
4 changes: 2 additions & 2 deletions src/vector/backends/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -3211,7 +3211,7 @@ def obj(**coordinates: float) -> VectorObject:
if "E" in coordinates:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also have the and "t" not in generic_coordinates condition?

is_momentum = True
generic_coordinates["t"] = coordinates.pop("E")
if "e" in coordinates:
if "e" in coordinates and "t" not in generic_coordinates:
is_momentum = True
generic_coordinates["t"] = coordinates.pop("e")
if "energy" in coordinates and "t" not in generic_coordinates:
Expand All @@ -3220,7 +3220,7 @@ def obj(**coordinates: float) -> VectorObject:
if "M" in coordinates:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, for tau

is_momentum = True
generic_coordinates["tau"] = coordinates.pop("M")
if "m" in coordinates:
if "m" in coordinates and "tau" not in generic_coordinates:
is_momentum = True
generic_coordinates["tau"] = coordinates.pop("m")
if "mass" in coordinates and "tau" not in generic_coordinates:
Expand Down
Loading