|
14 | 14 | LOG = logging.getLogger(__name__)
|
15 | 15 |
|
16 | 16 |
|
| 17 | +class StatisticsRegistry: |
| 18 | + name = 'statistics' |
| 19 | + #names = [ "mean", "stdev", "minimum", "maximum", "sums", "squares", "count", ] |
| 20 | + #build_names = [ "minimum", "maximum", "sums", "squares", "count", ] |
| 21 | + |
| 22 | + def __init__(self, dirname, zarr_registry=False, overwrite=False): |
| 23 | + self.dirname = dirname |
| 24 | + self.data_dirname = os.path.join(self.dirname, self.name) |
| 25 | + self.overwrite = overwrite |
| 26 | + |
| 27 | + def create(self): |
| 28 | + assert not os.path.exists(self.data_dirname), self.data_dirname |
| 29 | + os.makedirs(self.data_dirname, exist_ok=True) |
| 30 | + if self.zarr_registry: |
| 31 | + self.zarr_registry.add_to_history(f"{self.name}_registry_initialised", **{f"{self.name}_version"}=2) |
| 32 | + |
| 33 | + def delete(self): |
| 34 | + import shutil |
| 35 | + shutil.rmtree(self.data_dirname) |
| 36 | + |
| 37 | + def __setitem__(self, key, data): |
| 38 | + path = self.dirname + "/" + key + ".npz" |
| 39 | + if self.overwrite is False: |
| 40 | + assert not os.path.exist(path), f"{path} already exists" |
| 41 | + LOG.info(f"Writing {self.name} for {key}") |
| 42 | + with open(path, 'wb') as f: |
| 43 | + pickle.dump((key, data), f) |
| 44 | + LOG.info(f"Written {self.name} data for {key} in {path}") |
| 45 | + |
| 46 | + def read_all(self, expected_lenghts=None): |
| 47 | + # use glob to read all pickles |
| 48 | + files = glob.glob(self.data_dirname + "/*.npz") |
| 49 | + LOG.info(f"Reading {self.name} data, found {len(files)} for {self.name} in {self.dirname}") |
| 50 | + dic = {} |
| 51 | + for f in files: |
| 52 | + with open(f, 'rb') as f: |
| 53 | + key, data = pickle.load(f) |
| 54 | + assert key not in dic, f"Duplicate key {key}" |
| 55 | + yield key, data |
| 56 | + |
| 57 | + |
17 | 58 | def add_zarr_dataset(
|
18 | 59 | *,
|
19 | 60 | name,
|
@@ -47,21 +88,23 @@ def add_zarr_dataset(
|
47 | 88 | return a
|
48 | 89 |
|
49 | 90 |
|
50 |
| -class ZarrRegistry: |
51 |
| - synchronizer_name = None # to be defined in subclasses |
52 |
| - |
53 |
| - def __init__(self, path): |
54 |
| - assert self.synchronizer_name is not None, self.synchronizer_name |
| 91 | +class ZarrBuiltRegistry: |
| 92 | + name_lengths = "lengths" |
| 93 | + name_flags = "flags" |
| 94 | + lengths = None |
| 95 | + flags = None |
| 96 | + z = None |
55 | 97 |
|
| 98 | + def __init__(self, path, synchronizer_path=None): |
56 | 99 | import zarr
|
57 | 100 |
|
58 | 101 | assert isinstance(path, str), path
|
59 | 102 | self.zarr_path = path
|
60 |
| - self.synchronizer = zarr.ProcessSynchronizer(self._synchronizer_path) |
61 | 103 |
|
62 |
| - @property |
63 |
| - def _synchronizer_path(self): |
64 |
| - return self.zarr_path + "-" + self.synchronizer_name + ".sync" |
| 104 | + if synchronizer_path is None: |
| 105 | + synchronizer_path = self.zarr_path + ".sync" |
| 106 | + self.synchronizer_path = synchronizer_path |
| 107 | + self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path) |
65 | 108 |
|
66 | 109 | def _open_write(self):
|
67 | 110 | import zarr
|
@@ -94,64 +137,6 @@ def add_to_history(self, action, **kwargs):
|
94 | 137 | z.attrs["history"] = history
|
95 | 138 |
|
96 | 139 |
|
97 |
| -class ZarrStatisticsRegistry(ZarrRegistry): |
98 |
| - names = [ |
99 |
| - "mean", |
100 |
| - "stdev", |
101 |
| - "minimum", |
102 |
| - "maximum", |
103 |
| - "sums", |
104 |
| - "squares", |
105 |
| - "count", |
106 |
| - ] |
107 |
| - build_names = [ |
108 |
| - "minimum", |
109 |
| - "maximum", |
110 |
| - "sums", |
111 |
| - "squares", |
112 |
| - "count", |
113 |
| - ] |
114 |
| - synchronizer_name = "statistics" |
115 |
| - |
116 |
| - def __init__(self, path): |
117 |
| - super().__init__(path) |
118 |
| - |
119 |
| - def create(self): |
120 |
| - z = self._open_read() |
121 |
| - shape = z["data"].shape |
122 |
| - shape = (shape[0], shape[1]) |
123 |
| - |
124 |
| - for name in self.build_names: |
125 |
| - if name == "count": |
126 |
| - self.new_dataset(name=name, shape=shape, fill_value=0, dtype=np.int64) |
127 |
| - else: |
128 |
| - self.new_dataset( |
129 |
| - name=name, shape=shape, fill_value=np.nan, dtype=np.float64 |
130 |
| - ) |
131 |
| - self.add_to_history("statistics_initialised") |
132 |
| - |
133 |
| - def __setitem__(self, key, stats): |
134 |
| - z = self._open_write() |
135 |
| - |
136 |
| - LOG.info(f"Writting stats for {key}") |
137 |
| - for name in self.build_names: |
138 |
| - LOG.info(f"Writting stats for {key} {name} {stats[name].shape}") |
139 |
| - z["_build"][name][key] = stats[name] |
140 |
| - LOG.info(f"Written stats for {key}") |
141 |
| - |
142 |
| - def get_by_name(self, name): |
143 |
| - z = self._open_read() |
144 |
| - return z["_build"][name] |
145 |
| - |
146 |
| - |
147 |
| -class ZarrBuiltRegistry(ZarrRegistry): |
148 |
| - name_lengths = "lengths" |
149 |
| - name_flags = "flags" |
150 |
| - lengths = None |
151 |
| - flags = None |
152 |
| - z = None |
153 |
| - synchronizer_name = "build" |
154 |
| - |
155 | 140 | def get_slice_for(self, i):
|
156 | 141 | lengths = self.get_lengths()
|
157 | 142 | assert i >= 0 and i < len(lengths)
|
|
0 commit comments