-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathLMR_gridded.py
1975 lines (1595 loc) · 69.8 KB
/
LMR_gridded.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
A module containing classes and methods for gridded data
Author: Andre
Adapted from load_gridded_data, LMR_prior, LMR_calibrate
"""
from abc import abstractmethod, ABCMeta
from netCDF4 import Dataset, num2date
from datetime import datetime, timedelta
from collections import OrderedDict
from os.path import join
from numcodecs import Blosc, Pickle
import numpy as np
import warnings
import os
import random
import zarr
import pylim.DataTools as DT
from pylim.Stats import detrend_data
import LMR_config
from LMR_utils import (regrid_sphere_gridded_object, var_to_hdf5_carray,
empty_hdf5_carray, regrid_esmpy_grid_object,
get_chunk_shape, ReqDataFractionMismatchError)
from LMR_utils import fix_lon, regular_cov_infl
# import pylim.DataTools as DT
# Constant definitions
_LAT = 'lat'
_LON = 'lon'
_LEV = 'lev'
_TIME = 'time'
_DEFAULT_DIM_ORDER = [_TIME, _LEV, _LAT, _LON]
_ALT_DIMENSION_DEFS = {'latitude': _LAT,
'longitude': _LON,
'plev': _LEV}
_BYPASS_DIMENSION_DEFS = {'j': _LAT,
'i': _LON}
_ftypes = LMR_config.Constants.data['file_types']
def _cnvt_to_float64(num):
if num is None:
return None
else:
num_as_array = np.array(num)
return num_as_array.astype(np.float64)
class GriddedVariable(object):
"""
Object for holding and manipulating gridded data of a single variable.
"""
PRE_PROCESSED_FILETAG = '.pre_{}.zarr'
PRE_PROCESSED_FILEDIR = 'pre_proc_files'
PRE_PROCESSED_OBJ_NODENAME = 'grid_object'
PRE_PROCESSED_DATA_NODENAME = 'grid_object_data'
def __init__(self, name, dims_ordered, data, time=None,
lev=None, lat=None, lon=None, fill_val=None,
sampled=None, avg_interval=None, regrid_method=None,
regrid_grid=None, esmpy_interp=None, lat_grid=None,
lon_grid=None, climo=None, rotated_pole=False,
cell_area=None, ref_period=None, data_req_frac=None):
"""
Parameters
----------
name: str
Name of gridded variable.
dims_ordered: list of str
Ordered list of the gridded variable dimensions. Dimension names
should match the constant definitions at the beginning of this
module.
data: ndarray
Gridded variable data
time: ndarray, optional
Array of time dimension values
lev: ndarray, optional
Array of level dimension values
lat: ndarray, optional
Array of latitude dimension values
lon: ndarray, optional
Array of longitude dimension values
fill_val: float, optional
Fill value indicating missing data
sampled: array of int, optional
List of indices indicating the sample from the original data
used to create this object.
avg_interval: str, optional
Key indicating the averaging interval of the data in this
object. Should match distinction in constants.yml.
regrid_method: str, optional
Key indicating which regridding method was used to create
data in this object.
regrid_grid: str, optional
Key indicating which grid the data is on
esmpy_interp: str, optional
Key indicating the interpolation method used when ESMPy was used to
regrid the data
lat_grid: ndarray, optional
Latitude values for the grid. Necessary for non-regular grids for
certain operations.
lon_grid: ndarray, optional
Longitude values for the grid. Necessary for non-regular grids for
certain operations.
climo: ndarray, optional
Climatology used to center the data
ref_period: tuple
Start and end of time interval used to calculate the climatology
and center the data
rotated_pole: bool, optional
Indication of whether or not the data is on a rotated pole grid
cell_area: ndarray, optional
Grid array describing the area of each grid cell.
data_req_frac: float, optional
The fraction of required data over the average interval to qualify
as a valid data point. Used as a reference for pre-averaged data
loading to check equivalince with what is requested.
Attributes
----------
ndim: int
Number of data dimensions
nsamples: int
Number of samples in the data
space_shp: tuple of int
Shape of spatial dimensions of the data
type: str
Variable definition based on dimensions. E.g., timeseries,
2D:horizontal, 2D:vertical/meridional
"""
self.name = name
self.dim_order = dims_ordered
self.ndim = len(dims_ordered)
self.data = data
self.climo = climo
self.time = time
self.lev = _cnvt_to_float64(lev)
self.lat = _cnvt_to_float64(lat)
lon_adjusted = fix_lon(lon)
self.lon = _cnvt_to_float64(lon_adjusted)
self.avg_interval = avg_interval
self.regrid_method = regrid_method
self.regrid_grid = regrid_grid
self.esmpy_interp = esmpy_interp
self.lat_grid = _cnvt_to_float64(lat_grid)
self.lon_grid = _cnvt_to_float64(lon_grid)
self.rotated_pole = rotated_pole
self.cell_area = cell_area
self.ref_period = ref_period
self._fill_val = fill_val
self._idx_used_for_sample = sampled
self._data_req_frac = data_req_frac
self._dim_coord_map = {_TIME: self.time,
_LEV: self.lev,
_LAT: self.lat,
_LON: self.lon}
# TODO: Robert's code flips latitudes so it monotonically increases
# Is it necessary here?
# Make sure ndimensions specified match data
if self.ndim != len(self.data.shape):
raise ValueError('Number of dimensions given do not match data'
' dimensions.')
# Make sure each dimension has consistent number of values as data
for i, dim in enumerate(self.dim_order):
if self._dim_coord_map[dim] is None:
raise ValueError('Dimension specified but no values provided '
'for initialization')
if data.shape[i] != len(self._dim_coord_map[dim]):
raise ValueError('Dimension values provided do not match in '
'length with dimension axis of data')
# Determine sampling dimension size if any
if time is not None:
self.nsamples = len(self.time)
else:
self.nsamples = 1
self.dim_order = list(self.dim_order)
self.dim_order.insert(0, _TIME)
self.data = self.data.reshape(1, *self.data.shape)
self._space_dims = [dim for dim in dims_ordered if dim != _TIME]
# Spatial shape is left as a list for easy shape combining w/ sampling
self.space_shp = [len(self._dim_coord_map[dim])
for dim in self._space_dims]
if len(self._space_dims) > 2:
raise NotImplementedError('Class cannot handle >2D data yet!'
' spatial shape = '
'{}'.format(self.space_shp))
# Determine the type of field for this gridded variable
if not self._space_dims:
self.type = '0D:time_series'
self.space_shp = [1]
self.data = self.data.reshape(self.nsamples, 1)
elif len(self.space_shp) == 1 and _LAT in self._space_dims:
self.type = '1D:meridional'
elif _LAT in self._space_dims and _LON in self._space_dims:
self.type = '2D:horizontal'
elif _LAT in self._space_dims and _LEV in self._space_dims:
self.type = '2D:meridional_vertical'
else:
raise NotImplementedError('Unrecognized dimension combination. '
'This type of variable has not been '
'implemented yet.')
def save(self, filename):
"""
Save gridded data object to file. Creats a PyTables HDF5 file
for the data and the gridded variable object.
Parameters
----------
filename: str
Absolute path to the saving file.
Returns
-------
None
Notes
-----
If the file exists, it is opened in append mode. This can probably
result in very large files if resaving the same variable multiple
times.
"""
avg_interval = self.avg_interval
regrid_method = self.regrid_method
regrid_grid = self.regrid_grid
esmpy_interp = self.esmpy_interp
# create the path to save data within the HDF5 file
path_pieces = [avg_interval, regrid_method, regrid_grid, esmpy_interp]
path_pieces = [str(piece) for piece in path_pieces if piece is not None]
data_path = join(*path_pieces)
print(('Pre-processed data filename: {}'.format(filename)))
print(('Zarr output group storage path: {}'.format(data_path)))
obj_node_name = self.PRE_PROCESSED_OBJ_NODENAME
data_node_name = self.PRE_PROCESSED_DATA_NODENAME
compressor = Blosc(cname='zstd', clevel=4, shuffle=Blosc.BITSHUFFLE)
root = zarr.open(filename, mode='a')
# Create save file, overwrites the data group at the same path
group_node = root.create_group(data_path, overwrite=True)
obj_node = group_node.empty(obj_node_name, shape=1, dtype=object,
object_codec=Pickle())
data_chunks = get_chunk_shape(self.data.shape, self.data.dtype, 5)
data_node = group_node.empty_like(data_node_name, self.data,
chunks=data_chunks,
compressor=compressor,
dtype=self.data.dtype)
data_node[:] = self.data
print(data_node.info)
tmp_dat = self.data
del self.data
obj_node[0] = self
self.data = tmp_dat
def print_data_stats(self):
"""
Print stats of the data contained in the object.
**Don't call on large data! You may run out of memory**
Returns
-------
None
"""
print(('{}: Global: mean={:1.3e}, '
'std-dev:={:1.3e}'.format(self.name, np.nanmean(self.data),
np.nanstd(self.data))))
def regrid(self, regrid_method, regrid_grid=None, grid_def=None,
interp_method=None):
"""
Regrid data in gridded object. Only works for 2D:horizontal data
Parameters
----------
regrid_method: str
Key indicating regridding package to use. Allowed: 'simple',
'sperical_harmonics', and 'esmpy'.
regrid_grid: str, optional
Key indicating the destination grid for spherical harmonics.
grid_def: dict, optional
Grid definition dictionary from grid_def.yml for ESMPy regridding
interp_method: str, optional
Interpolation method to use in ESMPy. Allowed: bilinear, patch
Returns
-------
GriddedVariable
New gridded variable object with regridded data.
"""
assert self.type == '2D:horizontal'
class_obj = type(self)
if regrid_method == 'simple':
raise NotImplemented('Have not fixed simple regridding yet -AP')
elif regrid_method == 'spherical_harmonics':
[regrid_data,
new_lat,
new_lon,
climo] = regrid_sphere_gridded_object(self, regrid_grid)
elif regrid_method == 'esmpy':
target_nlat = grid_def['target_nlat']
target_nlon = grid_def['target_nlon']
[regrid_data,
new_lat,
new_lon,
climo] = regrid_esmpy_grid_object(target_nlat, target_nlon,
self,
interp_method=interp_method)
else:
raise ValueError('Unrecognized regridding method: {}'.format(regrid_method))
# Rotated pole omitted for regridded data
# TODO: Figure out how to transfer cell area
return class_obj(self.name, self.dim_order, regrid_data,
time=self.time,
lev=self.lev,
lat=new_lat[:, 0],
lon=new_lon[0],
fill_val=self._fill_val,
sampled=self._idx_used_for_sample,
avg_interval=self.avg_interval,
regrid_method=regrid_method,
regrid_grid=regrid_grid,
esmpy_interp=interp_method,
lat_grid=new_lat,
lon_grid=new_lon,
climo=climo,
data_req_frac=self._data_req_frac,
ref_period=self.ref_period)
def fill_val_to_nan(self):
"""
Convert fill value to NaN
Returns
-------
None
"""
convert_to_masked_array = False
# Steps through the data in chunks to handle instances where data is
# very large. Slower, but doesn't go into swap ;)
step = 10
for i in np.arange(0, len(self.data), step=step):
tmp_data = self.data[i:i+step]
mask = self._check_fill_val_mach_eps(tmp_data, self._fill_val)
# Determine if invalid data and set flag to convert to
# np.ma.MaskedArray
if np.any(mask):
if not convert_to_masked_array:
convert_to_masked_array = True
tmp_data[mask] = np.nan
self.data[i:i+step] = tmp_data
if convert_to_masked_array:
self.data = np.ma.masked_invalid(self.data)
@staticmethod
def _check_fill_val_mach_eps(data, fill_val):
eps = np.finfo(data.dtype).eps
delta = abs(fill_val*eps)
fill_upper = fill_val + delta
fill_lower = fill_val - delta
match = (data <= fill_upper) & (data >= fill_lower)
return match
def nan_to_fill_val(self):
"""
Convert NaN to fill value.
Returns
-------
None
"""
if np.ma.is_masked(self.data):
self.data = self.data.filled(fill_value=self._fill_val)
else:
step = 10
for i in np.arange(0, len(self.data), step=step):
tmp_dat = self.data[i:i+step]
tmp_dat[np.isnan(tmp_dat)] = self._fill_val
self.data[i:i+step] = tmp_dat
# self.data[~da.isfinite(self.data)] = self._fill_val
def flattened_spatial(self):
"""
Get a flattened spatial field representation of the data. Preserves
sampling dimension.
Returns
-------
flat_data: ndarray
Flattened view of the data array
flat_coords: dict{str: ndarray}
Flattened full coordinate grids for each spatial dimension.
Shape will match flat_data shape.
"""
flat_data = self.data.reshape(len(self.time),
np.product(self.space_shp))
# Get dimensions of data
coords = [self._dim_coord_map[key] for key in self._space_dims]
grids = np.meshgrid(*coords, indexing='ij')
flat_coords = {dim: grid.flatten()
for dim, grid in zip(self._space_dims, grids)}
return flat_data, flat_coords
def random_sample(self, nens, seed=None, sample_omit_edge=False):
"""
Take a random sample along the sampling dimension of the data.
Parameters
----------
nens: int
Size of sample
seed: int, optional
Seed for the random number generator
sample_omit_edge: bool, optional
Remove the first and last element from the pool of possible samples
Returns
-------
GriddedVariable
New gridded variable object with the sampled data.
"""
if sample_omit_edge:
sample_range = list(range(1, self.data.shape[0] - 1))
else:
sample_range = list(range(self.data.shape[0]))
random.seed(seed)
sample = random.sample(sample_range, nens)
return self.sample_from_idx(sample)
def _get_yr_indices(self, sample_years):
years = [time.year for time in self.time]
if not isinstance(sample_years[0], int):
sample_years = [time.year for time in sample_years]
sample_idxs = []
for sample_yr in sample_years:
curr_idx = years.index(sample_yr)
if years.count(sample_yr) > 1:
raise ValueError('Sampling by year does not work on '
'sub-annual data.')
sample_idxs.append(curr_idx)
return sample_idxs
def sample_from_yr(self, sample_years):
yr_idxs = self._get_yr_indices(sample_years)
return self.sample_from_idx(yr_idxs)
def sample_from_idx(self, sample_idxs):
"""
Take a specified sample along the sampling dimension of the data.
Parameters
----------
sample_idxs: list[int]
A list of indices to take along the sampling dimension of the data
Returns
-------
GriddedVariable
New gridded variable object with the sampled data
"""
cls = type(self)
nsamples = len(sample_idxs)
if nsamples == self.data.shape[0]:
print ('Size of sample and total number of available members are '
'equivalent. No resampling performed...')
return self
print(('Random selection of {} ensemble members'.format(nsamples)))
time_sample = self.time[sample_idxs]
data_sample = np.zeros([nsamples] + list(self.data.shape[1:]))
for k, idx in enumerate(sample_idxs):
data_sample[k] = self.data[idx]
# Account for timeseries trailing singleton dimension
data_sample = np.squeeze(data_sample)
return cls(self.name, self.dim_order, data_sample,
time=time_sample,
lev=self.lev,
lat=self.lat,
lon=self.lon,
fill_val=self._fill_val,
avg_interval=self.avg_interval,
rotated_pole=self.rotated_pole,
lat_grid=self.lat_grid,
lon_grid=self.lon_grid,
regrid_method=self.regrid_method,
regrid_grid=self.regrid_grid,
esmpy_interp=self.esmpy_interp,
climo=self.climo,
sampled=sample_idxs)
def is_sampled(self):
"""
Return whether data in the current object is from a sampling
operation.
Returns
-------
bool
"""
if self._idx_used_for_sample is None:
return False
else:
return True
def get_sample_years(self):
return self.time
@staticmethod
def _get_time_range_idx(time, start, end):
try:
# Handle datetime
range_mask = [start <= dt.year <= end for dt in time]
except AttributeError:
# Assume array of year integers
time = np.array(time)
range_mask = (time >= start) & (time <= end)
for i, mask_val in enumerate(range_mask):
if mask_val:
begin_idx = i
break
else:
raise ValueError('No values are within the specified time range:'
'{}-{}'.format(start, end))
for i in range(len(range_mask))[::-1]:
if range_mask[i]:
end_idx = i
break
return begin_idx, end_idx
def reduce_to_time_period(self, time_range):
"""
Reduces available data to be within the specified time range.
Parameters
----------
time_range: tuple of int
Time range as a tuple of length 2 with the start and end year
"""
time = self.time
range_idx = self._get_time_range_idx(time, *time_range)
r_start, r_end = range_idx
self.time = self.time[r_start:r_end]
self.data = self.data[r_start:r_end]
self.nsamples = len(self.time)
def convert_to_anomaly(self, climo=None, ref_period=None):
"""
Center data by removing climatological mean.
Parameters
----------
climo: ndarray, optional
Climatological reference to center data to. If not provided,
the climatology is determined across the entire sampling
dimension.
ref_period: tuple of int, optional
Time range as a tuple of length 2 with the start and end year
Returns
-------
None
"""
if ref_period is not None:
print(f'Calculating anomaly relative to years: {ref_period}')
center_idx = self._get_time_range_idx(self.time, *ref_period)
c_start, c_end = center_idx
self.climo = np.nanmean(self.data[c_start:c_end], axis=0,
keepdims=True)
self.ref_period = ref_period
elif climo is None:
print('Calculating gridpoint anomalies over entire dataset.')
self.climo = self.data[:].mean(axis=0, keepdims=True)
else:
self.climo = climo
self.data = self.data - self.climo
def convert_to_standard(self):
"""
Add back climatology to centered data.
Returns
-------
None
"""
print('Adding temporal mean to every gridpoint...')
if self.climo is None:
raise ValueError('Cannot convert to standard state data is not an '
'anomaly to start.')
self.data = self.data + self.climo
self.climo = None
def forecast_var_to_pylim_dataobj(self):
"""
Create a pyLIM data object for use in LIM forecasting.
Returns
-------
pylim.DataTools.BaseDataObject
Data object for a LIM that has the same dimensions.
"""
print(('Converting ForecastVariable to pylim.DataObject: '
'{}'.format(self.name)))
BDO = DT.BaseDataObject
key_map = {_TIME: BDO.TIME,
_LEV: BDO.LEVEL,
_LAT: BDO.LAT,
_LON: BDO.LON}
dim_coords = {key_map[dim]: (i, getattr(self, dim)[:])
for i, dim in enumerate(self.dim_order)}
coord_grids = {}
if self.lat_grid is not None:
coord_grids[BDO.LAT] = self.lat_grid
if self.lon_grid is not None:
coord_grids[BDO.LON] = self.lon_grid
if not coord_grids:
coord_grids = None
new_dobj = DT.BaseDataObject(self.data,
dim_coords=dim_coords,
coord_grids=coord_grids,
force_flat=True,
fill_value=self._fill_val,
cell_area=self.cell_area)
return new_dobj
@classmethod
def load(cls, gridded_config, varname=None,
anomaly=False, sample=None, sample_omit_edge=False, **kwargs):
"""
Load a single variable as a GriddedVariable
Parameters
----------
gridded_config: LMR_Config.prior
Configuration definition object for prior variable
varname: str, optional
The name of the variable to load. If None and there are multiple
variables an error is raised.
anomaly: bool, Optional
Whether to convert data to an anomaly format.
sample: list of times, Optional
List of times to take as a sample. Only works for annual or
longer time deltas due to reliance on year.
sample_omit_edge: bool, Optional
If taking a random sample of the data, omit the first time which
is safe for multiple averaging intervals removing years
nens: int, Optional
The number of ensemble members to randomly sample from the data
seed: int, Optional
Seed for the random number generator. Only used when nens is
specified.
detrend: bool, Optional
Flag specifying whether or not to detrend the data along the
sampling dimension.
Returns
-------
GriddedVariable
"""
file_dir = gridded_config.datadir
file_name = gridded_config.datafile
file_type = gridded_config.dataformat
datainfo = gridded_config.datainfo
avg_interval = gridded_config.avg_interval
avg_interval_kwargs = gridded_config.avg_interval_kwargs
regrid_config = gridded_config.regrid_cfg
save = regrid_config.save_pre_avg_file
ignore_pre_avg = regrid_config.ignore_pre_avg_file
regrid_method = regrid_config.regrid_method
regrid_grid = regrid_config.regrid_grid
if varname is None:
if datainfo['multiple_vars']:
raise ValueError('Selected dataset has multiple available '
'variables and none were specified to load. '
'Please input a specific variable name in the '
'load function.')
else:
varname = datainfo['available_vars'][0]
if isinstance(regrid_config.esmpy_interp_method, dict):
interp_method = regrid_config.esmpy_interp_method[varname]
else:
interp_method = regrid_config.esmpy_interp_method
esmpy_kwargs = {'grid_def': regrid_config.esmpy_grid_def,
'interp_method': interp_method}
unique_cfg_kwargs = cls._load_unique_cfg_kwargs(gridded_config)
for key, arg in kwargs.items():
if key in unique_cfg_kwargs:
unique_cfg_kwargs[key] = arg
else:
raise KeyError('Unrecognized keyword argument provided '
'to load function: {}'.format(key))
if 'rotated_pole' in list(datainfo.keys()):
rotated_pole = varname in datainfo['rotated_pole']
else:
rotated_pole = False
if datainfo['cell_area'] is not None:
for realm_key, realm_val in datainfo['var_realm_def'].items():
if realm_key in varname:
realm = realm_val
break
else:
raise ValueError('Realm specification in datasets.yml could '
'not be found in variable name.')
cella_template = datainfo['cell_area_template']
cella_realm_def = datainfo['cell_area_realmvar_def'][realm]
cell_area_file = datainfo['cell_area']
cell_area_file = cell_area_file.replace(cella_template,
cella_realm_def)
else:
cell_area_file = None
if datainfo['template'] is not None:
file_name = file_name.replace(datainfo['template'], varname)
varname = varname.split('_')[0]
return cls._main_load_helper(file_dir, file_name, varname, file_type,
sample=sample,
sample_omit_edge=sample_omit_edge,
save=save,
ignore_pre_avg=ignore_pre_avg,
avg_interval=avg_interval,
avg_interval_kwargs=avg_interval_kwargs,
regrid_method=regrid_method,
regrid_grid=regrid_grid,
esmpy_kwargs=esmpy_kwargs,
rotated_pole=rotated_pole,
anomaly=anomaly,
cell_area_file=cell_area_file,
**unique_cfg_kwargs)
@staticmethod
def _load_unique_cfg_kwargs(config):
"""
Grab configuration keyword arguments that are specific to the gridded
class.
Parameters
----------
config: LMR_config.prior
Configuration object for the prior class.
Returns
-------
cfg_kwargs:
Special keyword arguments for the current gridded class.
"""
return {}
@classmethod
def _main_load_helper(cls, file_dir, file_name, varname, file_type,
nens=None, seed=None, sample=None,
sample_omit_edge=False,
avg_interval=None, avg_interval_kwargs=None,
regrid_method=None, regrid_grid=None,
esmpy_kwargs=None,
data_req_frac=1.0, save=True,
ignore_pre_avg=False, rotated_pole=False,
anomaly=True, detrend=False,
cell_area_file=None,
anom_reference_period=None, calib_period=None):
"""
Main helper for deciding which loading function to use based on the
data. Resampling and regridding operations are decided in this
method.
"""
# Get correct loader class for specified filetype.
try:
ftype_loader = cls.get_loader_for_filetype(file_type)
except KeyError:
raise TypeError('Specified file type not supported yet.')
# Try to load pre-averaged data if it exists. Otherwise, use the
# specific loader for the filetype
try:
if ignore_pre_avg:
raise IOError('Ignore pre_averaged files is set to True.')
interp_method = esmpy_kwargs['interp_method']
var_obj = cls._load_pre_avg_obj(file_dir, file_name, varname,
avg_interval=avg_interval,
regrid_method=regrid_method,
regrid_grid=regrid_grid,
anomaly=anomaly,
nens=nens,
sample=sample,
sample_omit_edge=sample_omit_edge,
seed=seed,
interp_method=interp_method,
anom_ref=anom_reference_period,
calib_period=calib_period,
data_req_frac=data_req_frac,
detrend=detrend)
except (IOError, KeyError, ReqDataFractionMismatchError) as e:
print(e)
print(('No equivalent pre-averaged file found ({}) or '
'ignore specified ... '.format(varname)))
var_obj = ftype_loader(file_dir, file_name, varname, save=save,
data_req_frac=data_req_frac,
avg_interval=avg_interval,
avg_interval_kwargs=avg_interval_kwargs,
rotated_pole=rotated_pole,
anomaly=anomaly,
detrend=detrend,
cell_area_file=cell_area_file,
anom_ref=anom_reference_period,
calib_period=calib_period)
print('Loaded from file: {}/{}'.format(file_dir, file_name))
# # TODO: This may be unnecessary due to usage in loaders
# var_obj.fill_val_to_nan()
# Do regridding and save if specified
if regrid_method is not None and var_obj.regrid_method is None:
var_obj = var_obj.regrid(regrid_method=regrid_method,
regrid_grid=regrid_grid,
**esmpy_kwargs)
var_obj.print_data_stats()
if save:
pre_tag = cls.PRE_PROCESSED_FILETAG.format(varname)
pre_dir = cls.PRE_PROCESSED_FILEDIR
path = join(file_dir, pre_dir, file_name + pre_tag)
var_obj.save(path)
# Sample the data
if not var_obj.is_sampled() and (nens is not None or sample is not None):
if sample is not None:
var_obj = var_obj.sample_from_yr(sample)
else:
var_obj = var_obj.random_sample(nens, seed=seed,
sample_omit_edge=sample_omit_edge)
return var_obj
@classmethod
def get_loader_for_filetype(cls, file_type):
"""
Retrieve the correct function for loading specific filetypes
Parameters
----------
file_type: str
Key for the file type to get loader for.
Returns
-------
Method that will load data for the given file type
"""
ftype_map = {_ftypes['netcdf']: cls._load_from_netcdf}
return ftype_map[file_type]
@classmethod
def _load_pre_avg_obj(cls, dir_name, filename, varname, avg_interval=None,
regrid_method=None, regrid_grid=None,
anomaly=False, nens=None, sample=None,
sample_omit_edge=False, detrend=False,
seed=None, interp_method=None, anom_ref=None,
calib_period=None, data_req_frac=None):
"""
General structure for load pre-averaged:
1. Load data
a. If regrid is desired it searches for pre_avg regridded data
but if not found, then uses loaded data and regrids
2. Sample if desired
3. Return a gridded variable object.
"""
# Check if pre-processed averages file exists
pre_proc_tag = cls.PRE_PROCESSED_FILETAG.format(varname)
pre_filedir = cls.PRE_PROCESSED_FILEDIR
path = join(dir_name, pre_filedir, filename + pre_proc_tag)
# Look for pre_averaged_file
if not os.path.exists(path):
raise IOError('No pre-averaged file found for given specifications')
root = zarr.open(path, mode='r')
obj_node_name = cls.PRE_PROCESSED_OBJ_NODENAME
data_node_name = cls.PRE_PROCESSED_DATA_NODENAME
avg_int_group = root[avg_interval]
obj = avg_int_group[obj_node_name][0]
obj_data = avg_int_group[data_node_name]
print(('Found node for avg_interval path: {}'.format(avg_interval)))
if data_req_frac is not None and obj._data_req_frac != data_req_frac:
raise ReqDataFractionMismatchError(
'Requested minimum data fraction is not equivalent to '
'the pre-averaged object. obj{:1.2f} != req{:1.2f}'
''.format(obj._data_req_frac, data_req_frac)
)
do_sample = True
if regrid_method is not None:
regrid_path = [regrid_method, regrid_grid, interp_method]
regrid_path = [str(path_piece) for path_piece in regrid_path
if path_piece is not None]
regrid_obj_dir = join(avg_interval, *regrid_path)
try:
regrid_obj_grp = root[regrid_obj_dir]
regrid_obj = regrid_obj_grp[obj_node_name]
regrid_obj_data = regrid_obj_grp[data_node_name]
print(('Found node for regridded data under path: '
'{}'.format(regrid_obj_dir)))
obj = regrid_obj[0]
obj_data = regrid_obj_data
except KeyError:
# Do not sample, since regrid specified and might save
do_sample = False
obj_data = obj_data[:]
print(('Regridded pre-processed grid object not found for '
'regridding: {}.'.format(regrid_obj_dir)))
obj.data = obj_data
if (anomaly and obj.climo is None) or anom_ref is not None:
if not obj.ref_period == anom_ref:
obj.convert_to_anomaly(ref_period=anom_ref)
if calib_period is not None: