-
Notifications
You must be signed in to change notification settings - Fork 35
feat: improve errors for invalid combinations of arguments in vector constructor methods #659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -2071,6 +2071,168 @@ def __setitem__(self, where: typing.Any, what: typing.Any) -> None: | |||
| return _setitem(self, where, what, True) | ||||
|
|
||||
|
|
||||
| def _validate_numpy_coordinates(fieldnames: tuple[str, ...]) -> None: | ||||
| """ | ||||
| Validate coordinate field names using dimension-guard pattern. | ||||
|
|
||||
| This follows the same logic as _check_names in awkward_constructors to ensure | ||||
| consistent validation across backends. | ||||
|
|
||||
| Raises TypeError if duplicate or conflicting coordinates are detected. | ||||
| """ | ||||
| complaint1 = "duplicate coordinates (through momentum-aliases): " + ", ".join( | ||||
| repr(x) for x in fieldnames | ||||
| ) | ||||
| complaint2 = ( | ||||
| "unrecognized combination of coordinates, allowed combinations are:\n\n" | ||||
| " (2D) x= y=\n" | ||||
| " (2D) rho= phi=\n" | ||||
| " (3D) x= y= z=\n" | ||||
| " (3D) x= y= theta=\n" | ||||
| " (3D) x= y= eta=\n" | ||||
| " (3D) rho= phi= z=\n" | ||||
| " (3D) rho= phi= theta=\n" | ||||
| " (3D) rho= phi= eta=\n" | ||||
| " (4D) x= y= z= t=\n" | ||||
| " (4D) x= y= z= tau=\n" | ||||
| " (4D) x= y= theta= t=\n" | ||||
| " (4D) x= y= theta= tau=\n" | ||||
| " (4D) x= y= eta= t=\n" | ||||
| " (4D) x= y= eta= tau=\n" | ||||
| " (4D) rho= phi= z= t=\n" | ||||
| " (4D) rho= phi= z= tau=\n" | ||||
| " (4D) rho= phi= theta= t=\n" | ||||
| " (4D) rho= phi= theta= tau=\n" | ||||
| " (4D) rho= phi= eta= t=\n" | ||||
| " (4D) rho= phi= eta= tau=" | ||||
| ) | ||||
|
|
||||
| is_momentum = False | ||||
| dimension = 0 | ||||
| fieldnames_copy = list(fieldnames) | ||||
|
|
||||
| # 2D azimuthal coordinates | ||||
| if "x" in fieldnames_copy and "y" in fieldnames_copy: | ||||
| if dimension != 0: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 2 | ||||
| fieldnames_copy.remove("x") | ||||
| fieldnames_copy.remove("y") | ||||
| if "rho" in fieldnames_copy and "phi" in fieldnames_copy: | ||||
| if dimension != 0: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 2 | ||||
| fieldnames_copy.remove("rho") | ||||
| fieldnames_copy.remove("phi") | ||||
| if "x" in fieldnames_copy and "py" in fieldnames_copy: | ||||
| is_momentum = True | ||||
| if dimension != 0: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 2 | ||||
| fieldnames_copy.remove("x") | ||||
| fieldnames_copy.remove("py") | ||||
| if "px" in fieldnames_copy and "y" in fieldnames_copy: | ||||
| is_momentum = True | ||||
| if dimension != 0: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 2 | ||||
| fieldnames_copy.remove("px") | ||||
| fieldnames_copy.remove("y") | ||||
| if "px" in fieldnames_copy and "py" in fieldnames_copy: | ||||
| is_momentum = True | ||||
| if dimension != 0: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 2 | ||||
| fieldnames_copy.remove("px") | ||||
| fieldnames_copy.remove("py") | ||||
| if "pt" in fieldnames_copy and "phi" in fieldnames_copy: | ||||
| is_momentum = True | ||||
| if dimension != 0: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 2 | ||||
| fieldnames_copy.remove("pt") | ||||
| fieldnames_copy.remove("phi") | ||||
|
|
||||
| # 3D longitudinal coordinates | ||||
| if "z" in fieldnames_copy: | ||||
| if dimension != 2: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 3 | ||||
| fieldnames_copy.remove("z") | ||||
| if "theta" in fieldnames_copy: | ||||
| if dimension != 2: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 3 | ||||
| fieldnames_copy.remove("theta") | ||||
| if "eta" in fieldnames_copy: | ||||
| if dimension != 2: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 3 | ||||
| fieldnames_copy.remove("eta") | ||||
| if "pz" in fieldnames_copy: | ||||
| is_momentum = True | ||||
| if dimension != 2: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 3 | ||||
| fieldnames_copy.remove("pz") | ||||
|
|
||||
| # 4D temporal coordinates | ||||
| if "t" in fieldnames_copy: | ||||
| if dimension != 3: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 4 | ||||
| fieldnames_copy.remove("t") | ||||
| if "tau" in fieldnames_copy: | ||||
| if dimension != 3: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 4 | ||||
| fieldnames_copy.remove("tau") | ||||
| if "E" in fieldnames_copy: | ||||
| is_momentum = True | ||||
| if dimension != 3: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 4 | ||||
| fieldnames_copy.remove("E") | ||||
| if "e" in fieldnames_copy: | ||||
| is_momentum = True | ||||
| if dimension != 3: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 4 | ||||
| fieldnames_copy.remove("e") | ||||
| if "energy" in fieldnames_copy: | ||||
| is_momentum = True | ||||
| if dimension != 3: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 4 | ||||
| fieldnames_copy.remove("energy") | ||||
| if "M" in fieldnames_copy: | ||||
| is_momentum = True | ||||
| if dimension != 3: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 4 | ||||
| fieldnames_copy.remove("M") | ||||
| if "m" in fieldnames_copy: | ||||
| is_momentum = True | ||||
| if dimension != 3: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 4 | ||||
| fieldnames_copy.remove("m") | ||||
| if "mass" in fieldnames_copy: | ||||
| is_momentum = True | ||||
| if dimension != 3: | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
| dimension = 4 | ||||
| fieldnames_copy.remove("mass") | ||||
|
|
||||
| # Check if any remaining fieldnames would conflict with already-processed coordinates | ||||
| # when mapped to generic names (e.g., pt was processed, rho shouldn't remain) | ||||
| if fieldnames_copy: | ||||
| # Map all original fieldnames to generic names to detect conflicts | ||||
| generic_names = [_repr_momentum_to_generic.get(x, x) for x in fieldnames] | ||||
| if len(generic_names) != len(set(generic_names)): | ||||
| raise TypeError(complaint1 if is_momentum else complaint2) | ||||
|
|
||||
|
|
||||
| def array(*args: typing.Any, **kwargs: typing.Any) -> VectorNumpy: | ||||
| """ | ||||
| Constructs a NumPy array of vectors, whose type is determined by the dtype | ||||
|
|
@@ -2138,6 +2300,9 @@ def array(*args: typing.Any, **kwargs: typing.Any) -> VectorNumpy: | |||
|
|
||||
| is_momentum = any(x in _repr_momentum_to_generic for x in names) | ||||
|
|
||||
| # Validate coordinates using dimension-guard pattern (same as awkward _check_names) | ||||
| _validate_numpy_coordinates(names) | ||||
|
Comment on lines
+2303
to
+2304
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
vector/src/vector/backends/numpy.py Line 1168 in e374915
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I vaguely remember that I had some issue with |
||||
|
|
||||
| if any(x in ("t", "E", "e", "energy", "tau", "M", "m", "mass") for x in names): | ||||
| cls = MomentumNumpy4D if is_momentum else VectorNumpy4D | ||||
| elif any(x in ("z", "pz", "theta", "eta") for x in names): | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3211,7 +3211,7 @@ def obj(**coordinates: float) -> VectorObject: | |
| if "E" in coordinates: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this also have the |
||
| is_momentum = True | ||
| generic_coordinates["t"] = coordinates.pop("E") | ||
| if "e" in coordinates: | ||
| if "e" in coordinates and "t" not in generic_coordinates: | ||
| is_momentum = True | ||
| generic_coordinates["t"] = coordinates.pop("e") | ||
| if "energy" in coordinates and "t" not in generic_coordinates: | ||
|
|
@@ -3220,7 +3220,7 @@ def obj(**coordinates: float) -> VectorObject: | |
| if "M" in coordinates: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto, for |
||
| is_momentum = True | ||
| generic_coordinates["tau"] = coordinates.pop("M") | ||
| if "m" in coordinates: | ||
| if "m" in coordinates and "tau" not in generic_coordinates: | ||
| is_momentum = True | ||
| generic_coordinates["tau"] = coordinates.pop("m") | ||
| if "mass" in coordinates and "tau" not in generic_coordinates: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a picky comment but it's a bit sub-optimal that we duplicate these "complaint strings" in several submodules. For better maintainability it's probably best to move them to a trival submodule and import the strings in the several places, such as here but also in awkward_constructions.py for example.