Skip to content

Commit 84c9d35

Browse files
committed
Move constraints from Grid3D schema to validation logic
1 parent c9c2f97 commit 84c9d35

File tree

2 files changed

+28
-11
lines changed
  • lib/python/picongpu/pypicongpu
  • test/python/picongpu/quick/pypicongpu

2 files changed

+28
-11
lines changed

lib/python/picongpu/pypicongpu/grid.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import enum
99
from typing import Annotated
1010

11-
from pydantic import BaseModel, Field, PlainSerializer, model_validator
11+
from pydantic import AfterValidator, BaseModel, Field, PlainSerializer, model_validator
1212
from typing_extensions import Self
1313

1414
from .rendering import RenderedObject
@@ -59,6 +59,21 @@ def serialise_grid_dist(value):
5959
)
6060

6161

62+
def all_gt(iterable, m):
63+
if all(correct := [x > m for x in iterable]):
64+
return iterable
65+
else:
66+
message = f"{iterable=} contains values <= {m=} while all should be greater than m. Valid are the following: {correct=}."
67+
raise ValueError(message)
68+
69+
70+
def grid_dist_validate(grid_dist):
71+
if grid_dist is None:
72+
return None
73+
if all_gt(sum(grid_dist, []), 0):
74+
return grid_dist
75+
76+
6277
class Grid3D(BaseModel, RenderedObject):
6378
"""
6479
PIConGPU 3 dimensional (cartesian) grid
@@ -68,10 +83,10 @@ class Grid3D(BaseModel, RenderedObject):
6883
The bounding box is implicitly given as TODO.
6984
"""
7085

71-
cell_size: Vec3_float = Field(alias="cell_size_si")
86+
cell_size: Annotated[Vec3_float, AfterValidator(lambda x: all_gt(x, 0))] = Field(alias="cell_size_si")
7287
"""Width of individual cell in each direction"""
7388

74-
cell_cnt: Vec3_int
89+
cell_cnt: Annotated[Vec3_int, AfterValidator(lambda x: all_gt(x, 0))]
7590
"""total number of cells in each direction"""
7691

7792
boundary_condition: Annotated[
@@ -80,10 +95,14 @@ class Grid3D(BaseModel, RenderedObject):
8095
]
8196
"""behavior towards particles crossing each boundary"""
8297

83-
gpu_cnt: Vec3_int = Field((1, 1, 1), alias="n_gpus")
98+
gpu_cnt: Annotated[Vec3_int, AfterValidator(lambda x: all_gt(x, 0))] = Field((1, 1, 1), alias="n_gpus")
8499
"""number of GPUs in x y and z direction as 3-integer tuple"""
85100

86-
grid_dist: Annotated[tuple[list[int], list[int], list[int]] | None, PlainSerializer(serialise_grid_dist)] = None
101+
grid_dist: Annotated[
102+
tuple[list[int], list[int], list[int]] | None,
103+
PlainSerializer(serialise_grid_dist),
104+
AfterValidator(grid_dist_validate),
105+
] = None
87106
"""distribution of grid cells to GPUs for each axis"""
88107

89108
super_cell_size: Vec3_int
@@ -92,8 +111,6 @@ class Grid3D(BaseModel, RenderedObject):
92111
@model_validator(mode="after")
93112
def check(self) -> Self:
94113
"""serialized representation provided for RenderedObject"""
95-
assert all(x > 0 for x in self.cell_cnt), "cell_cnt must be greater than 0"
96-
assert all(x > 0 for x in self.gpu_cnt), "all n_gpus entries must be greater than 0"
97114
if self.grid_dist is not None:
98115
assert sum(self.grid_dist[0]) == self.cell_cnt[0], "sum of grid_dists in x must be equal to number_of_cells"
99116
assert sum(self.grid_dist[1]) == self.cell_cnt[1], "sum of grid_dists in y must be equal to number_of_cells"

test/python/picongpu/quick/pypicongpu/grid.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,17 @@ def test_types(self):
5858

5959
def test_gpu_and_cell_cnt_positive(self):
6060
"""test if n_gpus and cell number s are >0"""
61-
with self.assertRaisesRegex(ValidationError, ".*cell_cnt.*greater than 0.*"):
61+
with self.assertRaisesRegex(ValidationError, ".*contains values <= m=0.*"):
6262
Grid3D(**self.kwargs | dict(cell_cnt=(-1, 7, 8)))
6363

64-
with self.assertRaisesRegex(ValidationError, ".*cell_cnt.*greater than 0.*"):
64+
with self.assertRaisesRegex(ValidationError, ".*contains values <= m=0.*"):
6565
Grid3D(**self.kwargs | dict(cell_cnt=(6, -2, 8)))
6666

67-
with self.assertRaisesRegex(ValidationError, ".*cell_cnt.*greater than 0.*"):
67+
with self.assertRaisesRegex(ValidationError, ".*contains values <= m=0.*"):
6868
Grid3D(**self.kwargs | dict(cell_cnt=(6, 7, 0)))
6969

7070
for wrong_n_gpus in [tuple([-1, 1, 1]), tuple([1, 1, 0])]:
71-
with self.assertRaisesRegex(ValidationError, ".*greater than 0.*"):
71+
with self.assertRaisesRegex(ValidationError, ".*contains values <= m=0.*"):
7272
Grid3D(**self.kwargs | dict(n_gpus=wrong_n_gpus))
7373

7474
def test_mandatory(self):

0 commit comments

Comments
 (0)