Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix # #40

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions lompe/utils/save_load_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" Saving functionality """
import numpy as np
from scipy.interpolate import interp2d
from lompe.model.data import Data
class ArgumentError(Exception):
pass
# dictionary for finding the functions associated with each save string
Expand Down Expand Up @@ -245,8 +246,8 @@ def interp(grid, conductance, lon, lat):
return griddata((grid.xi.flatten(), grid.eta.flatten()), conductance.flatten(), grid.projection.geo2cube(lon, lat))


class DummyData(object):
def __init__(self, coordinates, data_type):
class DummyData(Data):
def __init__(self, coordinates, datatype, label=False):
"""
Generating a dummy data object that can function similar to the lompe Data object using data location values
read from file. It has limited functionality and will raise an attribute
Expand All @@ -271,11 +272,14 @@ def __init__(self, coordinates, data_type):

"""
self.coords = coordinates
self.data_type= data_type

self.datatype= datatype
if not label:
label= datatype
self.label= label
# Altering the getattr function to allow message to be added
self.__getattributeoriginal__= self.__getattribute__
self.__getattribute__= self.__getattributeCheck__
self.values = ['No data loaded from saved model']
def __getattr__(self, attribute):
"""
Altering the getattr function
Expand Down Expand Up @@ -314,6 +318,8 @@ def __getattributeCheck__(self, attribute):
return self.__getattributeoriginal__(attribute)
except AttributeError:
raise AttributeError(f'Attribute does not exist likely because data has been loaded from file so only coordinates are provided. Attribute used: {attribute}')
def subset(self, *args):
raise AttributeError(f'Attribute does not exist likely because data has been loaded from file so only coordinates are provided. Attribute used: {attribute}')



Expand Down Expand Up @@ -343,15 +349,19 @@ def data_locs_to_dict(model):
coords=[]
dtypes= []
labels= []
length= []
for ds in model.data[dtype]:
coords.append([ds.coords[key] for key in ['lon', 'lat']])
dtypes.append(dtype)
labels.append(ds.label)
length.append(len(ds.coords['lon']))
if len(coords):
coords= np.array(coords)
# coords= np.array(coords)
# np.array([np.concatenate([c for c in coords[:,0]]), np.concatenate([c for c in coords[:,1]])])
coords = np.concatenate(coords, axis=1)
data_vars.update({dtype+'_input_locations': (['time'],
[np.array([np.concatenate([c for c in coords[:,0]]), np.concatenate([c for c in coords[:,1]])]).tobytes()],
{'labels': '\t'.join([f'{label} length: {len(c)}' for label, c in zip(labels, coords[:,0])])})})
[coords.tobytes()],
{'labels': '\t'.join([f'{label} length: {l}' for label, l in zip(labels, length)])})})


return data_vars
Expand Down Expand Up @@ -405,8 +415,16 @@ def load_model(file, time='first'):
if json.loads(ds.attrs['Data_locs']):
for dtype in ['efield', 'convection', 'ground_mag', 'space_mag_full', 'space_mag_fac', 'fac']:
if dtype+'_input_locations' in ds:
lon, lat= np.frombuffer(ds[dtype+'_input_locations'].values).reshape(2, -1)
model.data[dtype].append(DummyData({'lon':lon, 'lat':lat}, dtype))

length1= None
for label_info in ds[dtype+'_input_locations'].attrs['labels'].split('\t'):
label, length2= label_info.split(' length: ')
if not length1 is None:
length2= int(length2)+length1
else:
length2= int(length2)
lon, lat= np.frombuffer(ds[dtype+'_input_locations'].values).reshape(2, -1)[:, length1:length2]
model.data[dtype].append(DummyData({'lon':lon, 'lat':lat}, dtype, label=label))

return model
def load_grid(file):
Expand Down Expand Up @@ -437,7 +455,7 @@ def load_grid(file):
from functools import partial
import json

if isinstance(file, (str, np.str_)):
if isinstance(file, (str, np.str_)):
ds= xr.load_dataset(file)
elif isinstance(file, (xr.Dataset)):
ds= file
Expand Down
Loading