Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit ef6aab6

Browse files
committed
statistics
1 parent 0205451 commit ef6aab6

File tree

1 file changed

+52
-67
lines changed

1 file changed

+52
-67
lines changed

ecml_tools/create/zarr.py

Lines changed: 52 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,47 @@
1414
LOG = logging.getLogger(__name__)
1515

1616

17+
class StatisticsRegistry:
18+
name = 'statistics'
19+
#names = [ "mean", "stdev", "minimum", "maximum", "sums", "squares", "count", ]
20+
#build_names = [ "minimum", "maximum", "sums", "squares", "count", ]
21+
22+
def __init__(self, dirname, zarr_registry=False, overwrite=False):
23+
self.dirname = dirname
24+
self.data_dirname = os.path.join(self.dirname, self.name)
25+
self.overwrite = overwrite
26+
27+
def create(self):
28+
assert not os.path.exists(self.data_dirname), self.data_dirname
29+
os.makedirs(self.data_dirname, exist_ok=True)
30+
if self.zarr_registry:
31+
self.zarr_registry.add_to_history(f"{self.name}_registry_initialised", **{f"{self.name}_version"}=2)
32+
33+
def delete(self):
34+
import shutil
35+
shutil.rmtree(self.data_dirname)
36+
37+
def __setitem__(self, key, data):
38+
path = self.dirname + "/" + key + ".npz"
39+
if self.overwrite is False:
40+
assert not os.path.exist(path), f"{path} already exists"
41+
LOG.info(f"Writing {self.name} for {key}")
42+
with open(path, 'wb') as f:
43+
pickle.dump((key, data), f)
44+
LOG.info(f"Written {self.name} data for {key} in {path}")
45+
46+
def read_all(self, expected_lenghts=None):
47+
# use glob to read all pickles
48+
files = glob.glob(self.data_dirname + "/*.npz")
49+
LOG.info(f"Reading {self.name} data, found {len(files)} for {self.name} in {self.dirname}")
50+
dic = {}
51+
for f in files:
52+
with open(f, 'rb') as f:
53+
key, data = pickle.load(f)
54+
assert key not in dic, f"Duplicate key {key}"
55+
yield key, data
56+
57+
1758
def add_zarr_dataset(
1859
*,
1960
name,
@@ -47,21 +88,23 @@ def add_zarr_dataset(
4788
return a
4889

4990

50-
class ZarrRegistry:
51-
synchronizer_name = None # to be defined in subclasses
52-
53-
def __init__(self, path):
54-
assert self.synchronizer_name is not None, self.synchronizer_name
91+
class ZarrBuiltRegistry:
92+
name_lengths = "lengths"
93+
name_flags = "flags"
94+
lengths = None
95+
flags = None
96+
z = None
5597

98+
def __init__(self, path, synchronizer_path=None):
5699
import zarr
57100

58101
assert isinstance(path, str), path
59102
self.zarr_path = path
60-
self.synchronizer = zarr.ProcessSynchronizer(self._synchronizer_path)
61103

62-
@property
63-
def _synchronizer_path(self):
64-
return self.zarr_path + "-" + self.synchronizer_name + ".sync"
104+
if synchronizer_path is None:
105+
synchronizer_path = self.zarr_path + ".sync"
106+
self.synchronizer_path = synchronizer_path
107+
self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path)
65108

66109
def _open_write(self):
67110
import zarr
@@ -94,64 +137,6 @@ def add_to_history(self, action, **kwargs):
94137
z.attrs["history"] = history
95138

96139

97-
class ZarrStatisticsRegistry(ZarrRegistry):
98-
names = [
99-
"mean",
100-
"stdev",
101-
"minimum",
102-
"maximum",
103-
"sums",
104-
"squares",
105-
"count",
106-
]
107-
build_names = [
108-
"minimum",
109-
"maximum",
110-
"sums",
111-
"squares",
112-
"count",
113-
]
114-
synchronizer_name = "statistics"
115-
116-
def __init__(self, path):
117-
super().__init__(path)
118-
119-
def create(self):
120-
z = self._open_read()
121-
shape = z["data"].shape
122-
shape = (shape[0], shape[1])
123-
124-
for name in self.build_names:
125-
if name == "count":
126-
self.new_dataset(name=name, shape=shape, fill_value=0, dtype=np.int64)
127-
else:
128-
self.new_dataset(
129-
name=name, shape=shape, fill_value=np.nan, dtype=np.float64
130-
)
131-
self.add_to_history("statistics_initialised")
132-
133-
def __setitem__(self, key, stats):
134-
z = self._open_write()
135-
136-
LOG.info(f"Writting stats for {key}")
137-
for name in self.build_names:
138-
LOG.info(f"Writting stats for {key} {name} {stats[name].shape}")
139-
z["_build"][name][key] = stats[name]
140-
LOG.info(f"Written stats for {key}")
141-
142-
def get_by_name(self, name):
143-
z = self._open_read()
144-
return z["_build"][name]
145-
146-
147-
class ZarrBuiltRegistry(ZarrRegistry):
148-
name_lengths = "lengths"
149-
name_flags = "flags"
150-
lengths = None
151-
flags = None
152-
z = None
153-
synchronizer_name = "build"
154-
155140
def get_slice_for(self, i):
156141
lengths = self.get_lengths()
157142
assert i >= 0 and i < len(lengths)

0 commit comments

Comments
 (0)