@@ -89,21 +89,21 @@ def grid(self) -> Array:
89
89
the corner.
90
90
"""
91
91
# We must use meshgrid instead of mgrid here in order to be jittable
92
- N_x , N_y = self .spatial_shape
92
+ N_y , N_x = self .spatial_shape
93
93
grid = jnp .meshgrid (
94
- jnp .linspace (- N_x // 2 , N_x // 2 - 1 , num = N_x ) + 0.5 ,
95
94
jnp .linspace (- N_y // 2 , N_y // 2 - 1 , num = N_y ) + 0.5 ,
95
+ jnp .linspace (- N_x // 2 , N_x // 2 - 1 , num = N_x ) + 0.5 ,
96
96
indexing = "ij" ,
97
97
)
98
98
grid = rearrange (grid , "d h w -> d " + ("1 " * (self .ndim - 4 )) + "h w 1 1" )
99
99
return self .dx * grid
100
100
101
101
@property
102
102
def k_grid (self ) -> Array :
103
- N_x , N_y = self .spatial_shape
103
+ N_y , N_x = self .spatial_shape
104
104
grid = jnp .meshgrid (
105
- jnp .linspace (- N_x // 2 , N_x // 2 - 1 , num = N_x ) + 0.5 ,
106
105
jnp .linspace (- N_y // 2 , N_y // 2 - 1 , num = N_y ) + 0.5 ,
106
+ jnp .linspace (- N_x // 2 , N_x // 2 - 1 , num = N_x ) + 0.5 ,
107
107
indexing = "ij" ,
108
108
)
109
109
grid = rearrange (grid , "d h w -> d " + ("1 " * (self .ndim - 4 )) + "h w 1 1" )
0 commit comments