88import enum
99from typing import Annotated
1010
11- from pydantic import BaseModel , Field , PlainSerializer , model_validator
11+ from pydantic import AfterValidator , BaseModel , Field , PlainSerializer , model_validator
1212from typing_extensions import Self
1313
1414from .rendering import RenderedObject
@@ -59,6 +59,21 @@ def serialise_grid_dist(value):
5959 )
6060
6161
62+ def all_gt (iterable , m ):
63+ if all (correct := [x > m for x in iterable ]):
64+ return iterable
65+ else :
66+ message = f"{ iterable = } contains values <= { m = } while all should be greater than m. Valid are the following: { correct = } ."
67+ raise ValueError (message )
68+
69+
70+ def grid_dist_validate (grid_dist ):
71+ if grid_dist is None :
72+ return None
73+ if all_gt (sum (grid_dist , []), 0 ):
74+ return grid_dist
75+
76+
6277class Grid3D (BaseModel , RenderedObject ):
6378 """
6479 PIConGPU 3 dimensional (cartesian) grid
@@ -68,10 +83,10 @@ class Grid3D(BaseModel, RenderedObject):
6883 The bounding box is implicitly given as TODO.
6984 """
7085
71- cell_size : Vec3_float = Field (alias = "cell_size_si" )
86+ cell_size : Annotated [ Vec3_float , AfterValidator ( lambda x : all_gt ( x , 0 ))] = Field (alias = "cell_size_si" )
7287 """Width of individual cell in each direction"""
7388
74- cell_cnt : Vec3_int
89+ cell_cnt : Annotated [ Vec3_int , AfterValidator ( lambda x : all_gt ( x , 0 ))]
7590 """total number of cells in each direction"""
7691
7792 boundary_condition : Annotated [
@@ -80,10 +95,14 @@ class Grid3D(BaseModel, RenderedObject):
8095 ]
8196 """behavior towards particles crossing each boundary"""
8297
83- gpu_cnt : Vec3_int = Field ((1 , 1 , 1 ), alias = "n_gpus" )
98+ gpu_cnt : Annotated [ Vec3_int , AfterValidator ( lambda x : all_gt ( x , 0 ))] = Field ((1 , 1 , 1 ), alias = "n_gpus" )
8499 """number of GPUs in x y and z direction as 3-integer tuple"""
85100
86- grid_dist : Annotated [tuple [list [int ], list [int ], list [int ]] | None , PlainSerializer (serialise_grid_dist )] = None
101+ grid_dist : Annotated [
102+ tuple [list [int ], list [int ], list [int ]] | None ,
103+ PlainSerializer (serialise_grid_dist ),
104+ AfterValidator (grid_dist_validate ),
105+ ] = None
87106 """distribution of grid cells to GPUs for each axis"""
88107
89108 super_cell_size : Vec3_int
@@ -92,8 +111,6 @@ class Grid3D(BaseModel, RenderedObject):
92111 @model_validator (mode = "after" )
93112 def check (self ) -> Self :
94113 """serialized representation provided for RenderedObject"""
95- assert all (x > 0 for x in self .cell_cnt ), "cell_cnt must be greater than 0"
96- assert all (x > 0 for x in self .gpu_cnt ), "all n_gpus entries must be greater than 0"
97114 if self .grid_dist is not None :
98115 assert sum (self .grid_dist [0 ]) == self .cell_cnt [0 ], "sum of grid_dists in x must be equal to number_of_cells"
99116 assert sum (self .grid_dist [1 ]) == self .cell_cnt [1 ], "sum of grid_dists in y must be equal to number_of_cells"
0 commit comments