Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 07c5892

Browse files
committed
Merge branch 'develop' of https://github.com/ecmwf-lab/ecml-tools into develop
2 parents 5b7a770 + 0184cd0 commit 07c5892

File tree

5 files changed

+83
-31
lines changed

5 files changed

+83
-31
lines changed

ecml_tools/commands/inspect/zarr.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ def frequency(self):
152152
def resolution(self):
153153
return self.metadata["resolution"]
154154

155+
@property
156+
def field_shape(self):
157+
return self.metadata.get("field_shape")
158+
155159
@property
156160
def shape(self):
157161
if self.data and hasattr(self.data, "shape"):
@@ -170,15 +174,16 @@ def uncompressed_data_size(self):
170174

171175
def info(self, detailed, size):
172176
print()
173-
print(f'📅 Start : {self.first_date.strftime("%Y-%m-%d %H:%M")}')
174-
print(f'📅 End : {self.last_date.strftime("%Y-%m-%d %H:%M")}')
175-
print(f"⏰ Frequency : {self.frequency}h")
177+
print(f'📅 Start : {self.first_date.strftime("%Y-%m-%d %H:%M")}')
178+
print(f'📅 End : {self.last_date.strftime("%Y-%m-%d %H:%M")}')
179+
print(f"⏰ Frequency : {self.frequency}h")
176180
if self.n_missing_dates is not None:
177-
print(f"🚫 Missing : {self.n_missing_dates:,}")
178-
print(f"🌎 Resolution: {self.resolution}")
181+
print(f"🚫 Missing : {self.n_missing_dates:,}")
182+
print(f"🌎 Resolution : {self.resolution}")
183+
print(f"🌎 Field shape: {self.field_shape}")
179184

180185
print()
181-
shape_str = "📐 Shape : "
186+
shape_str = "📐 Shape : "
182187
if self.shape:
183188
shape_str += " × ".join(["{:,}".format(s) for s in self.shape])
184189
if self.uncompressed_data_size:
@@ -237,9 +242,9 @@ def print_sizes(self, size):
237242
total_size, n = compute_directory_size(self.path)
238243

239244
if total_size is not None:
240-
print(f"💽 Size : {bytes(total_size)} ({number(total_size)})")
245+
print(f"💽 Size : {bytes(total_size)} ({number(total_size)})")
241246
if n is not None:
242-
print(f"📁 Files : {number(n)}")
247+
print(f"📁 Files : {number(n)}")
243248

244249
@property
245250
def statistics(self):

ecml_tools/create/input.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def _build_coords(self):
190190
self._grid_points = grid_points
191191
self._resolution = first_field.resolution
192192
self._grid_values = grid_values
193+
self._field_shape = first_field.shape
193194

194195
@cached_property
195196
def variables(self):
@@ -216,6 +217,11 @@ def grid_points(self):
216217
self._build_coords
217218
return self._grid_points
218219

220+
@cached_property
221+
def field_shape(self):
222+
self._build_coords
223+
return self._field_shape
224+
219225

220226
class HasCoordsMixin:
221227
@cached_property
@@ -238,6 +244,10 @@ def grid_values(self):
238244
def grid_points(self):
239245
return self._coords.grid_points
240246

247+
@cached_property
248+
def field_shape(self):
249+
return self._coords.field_shape
250+
241251
@cached_property
242252
def shape(self):
243253
return [

ecml_tools/create/loaders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def initialise(self, check_name=True):
271271
metadata["variables"] = variables
272272
metadata["variables_with_nans"] = variables_with_nans
273273
metadata["resolution"] = resolution
274+
metadata["field_shape"] = self.minimal_input.field_shape
274275

275276
metadata["licence"] = self.main_config["licence"]
276277
metadata["copyright"] = self.main_config["copyright"]

ecml_tools/data.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,10 @@ def statistics(self):
486486
def resolution(self):
487487
return self.z.attrs["resolution"]
488488

489+
@property
490+
def field_shape(self):
491+
return tuple(self.z.attrs["field_shape"])
492+
489493
@property
490494
def frequency(self):
491495
try:
@@ -610,6 +614,10 @@ def dates(self):
610614
def resolution(self):
611615
return self.forward.resolution
612616

617+
@property
618+
def field_shape(self):
619+
return self.forward.field_shape
620+
613621
@property
614622
def frequency(self):
615623
return self.forward.frequency
@@ -912,12 +920,10 @@ def __init__(self, forward, thinning, method):
912920
self.thinning = thinning
913921
self.method = method
914922

915-
assert method is None, f"Thinning method not supported: {method}"
916-
latitudes = sorted(set(forward.latitudes))
917-
longitudes = sorted(set(forward.longitudes))
918-
919-
latitudes = set(latitudes[::thinning])
920-
longitudes = set(longitudes[::thinning])
923+
latitudes = forward.latitudes.reshape(forward.field_shape)
924+
longitudes = forward.longitudes.reshape(forward.field_shape)
925+
latitudes = latitudes[::thinning, ::thinning].flatten()
926+
longitudes = longitudes[::thinning, ::thinning].flatten()
921927

922928
mask = [lat in latitudes and lon in longitudes for lat, lon in zip(forward.latitudes, forward.longitudes)]
923929
mask = np.array(mask, dtype=bool)

ecml_tools/grids.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,35 @@
1414
def plot_mask(path, mask, lats, lons, global_lats, global_lons):
1515
import matplotlib.pyplot as plt
1616

17+
middle = (np.amin(lons) + np.amax(lons)) / 2
18+
print("middle", middle)
19+
s = 1
20+
21+
# gmiddle = (np.amin(global_lons)+ np.amax(global_lons))/2
22+
23+
# print('gmiddle', gmiddle)
24+
# global_lons = global_lons-gmiddle+middle
25+
global_lons[global_lons >= 180] -= 360
26+
1727
plt.figure(figsize=(10, 5))
18-
plt.scatter(global_lons, global_lats, s=0.01, marker="o", c="r")
28+
plt.scatter(global_lons, global_lats, s=s, marker="o", c="r")
1929
plt.savefig(path + "-global.png")
2030

2131
plt.figure(figsize=(10, 5))
22-
plt.scatter(global_lons[mask], global_lats[mask], s=0.1, c="k")
32+
plt.scatter(global_lons[mask], global_lats[mask], s=s, c="k")
2333
plt.savefig(path + "-cutout.png")
2434

2535
plt.figure(figsize=(10, 5))
26-
plt.scatter(lons, lats, s=0.01)
36+
plt.scatter(lons, lats, s=s)
2737
plt.savefig(path + "-lam.png")
2838
# plt.scatter(lons, lats, s=0.01)
2939

40+
plt.figure(figsize=(10, 5))
41+
plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r")
42+
plt.scatter(lons, lats, s=s)
43+
plt.savefig(path + "-both.png")
44+
# plt.scatter(lons, lats, s=0.01)
45+
3046

3147
def latlon_to_xyz(lat, lon, radius=1.0):
3248
# https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates
@@ -64,27 +80,27 @@ def intersect(self, ray_origin, ray_direction):
6480
a = np.dot(self.v1 - self.v0, h)
6581

6682
if -epsilon < a < epsilon:
67-
return None
83+
return False
6884

6985
f = 1.0 / a
7086
s = ray_origin - self.v0
7187
u = f * np.dot(s, h)
7288

7389
if u < 0.0 or u > 1.0:
74-
return None
90+
return False
7591

7692
q = np.cross(s, self.v1 - self.v0)
7793
v = f * np.dot(ray_direction, q)
7894

7995
if v < 0.0 or u + v > 1.0:
80-
return None
96+
return False
8197

8298
t = f * np.dot(self.v2 - self.v0, q)
8399

84100
if t > epsilon:
85-
return t
101+
return True
86102

87-
return None
103+
return False
88104

89105

90106
def cropping_mask(lats, lons, north, west, south, east):
@@ -106,7 +122,7 @@ def cutout_mask(
106122
global_lats,
107123
global_lons,
108124
cropping_distance=2.0,
109-
min_distance=0.0,
125+
min_distance_km=0.0,
110126
plot=None,
111127
):
112128
"""
@@ -115,6 +131,8 @@ def cutout_mask(
115131

116132
# TODO: transform min_distance from lat/lon to xyz
117133

134+
min_distance = min_distance_km / 6371.0
135+
118136
assert global_lats.ndim == 1
119137
assert global_lons.ndim == 1
120138
assert lats.ndim == 1
@@ -140,31 +158,43 @@ def cutout_mask(
140158
)
141159

142160
# return mask
161+
# mask = np.array([True] * len(global_lats), dtype=bool)
143162
global_lats_masked = global_lats[mask]
144163
global_lons_masked = global_lons[mask]
145164

146165
global_xyx = latlon_to_xyz(global_lats_masked, global_lons_masked)
147166
global_points = np.array(global_xyx).transpose()
148167

149168
xyx = latlon_to_xyz(lats, lons)
150-
points = np.array(xyx).transpose()
169+
lam_points = np.array(xyx).transpose()
151170

152171
# Use a KDTree to find the nearest points
153-
kdtree = KDTree(points)
172+
kdtree = KDTree(lam_points)
154173
distances, indices = kdtree.query(global_points, k=3)
155174

156175
zero = np.array([0.0, 0.0, 0.0])
157176
ok = []
158177
for i, (global_point, distance, index) in enumerate(zip(global_points, distances, indices)):
159-
t = Triangle3D(points[index[0]], points[index[1]], points[index[2]])
160-
distance = np.min(distance)
178+
t = Triangle3D(lam_points[index[0]], lam_points[index[1]], lam_points[index[2]])
179+
# distance = np.min(distance)
161180
# The point is inside the triangle if the intersection with the ray
162181
# from the point to the center of the Earth is not None
163182
# (the direction of the ray is not important)
164-
ok.append(
165-
(t.intersect(zero, global_point) or t.intersect(global_point, zero))
166-
# and (distance >= min_distance)
167-
)
183+
184+
intersect = t.intersect(zero, global_point) or t.intersect(global_point, zero)
185+
close = np.min(distance) <= min_distance
186+
187+
if not intersect and False:
188+
189+
if 0 <= global_lons_masked[i] <= 30:
190+
if 55 <= global_lats_masked[i] <= 70:
191+
print(global_lats_masked[i], global_lons_masked[i], distance, intersect, close)
192+
print(lats[index[0]], lons[index[0]])
193+
print(lats[index[1]], lons[index[1]])
194+
print(lats[index[2]], lons[index[2]])
195+
assert False
196+
197+
ok.append(intersect and not close)
168198

169199
j = 0
170200
ok = np.array(ok)

0 commit comments

Comments
 (0)