From e5e4618c4ed8f565c245d5643abf488c89ed0f32 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 4 Mar 2024 21:57:30 +0000 Subject: [PATCH] make aggregated stats faster --- ecml_tools/create/statistics.py | 70 ++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 28 deletions(-) diff --git a/ecml_tools/create/statistics.py b/ecml_tools/create/statistics.py index 9b0a58b..ec6e1a8 100644 --- a/ecml_tools/create/statistics.py +++ b/ecml_tools/create/statistics.py @@ -184,7 +184,6 @@ def __init__(self, dates, variables_names, owner): self.sums = np.full(self.shape, np.nan, dtype=np.float64) self.squares = np.full(self.shape, np.nan, dtype=np.float64) self.count = np.full(self.shape, -1, dtype=np.int64) - self.flags = np.full(self.shape, False, dtype=np.bool_) self._read() @@ -195,44 +194,59 @@ def _date_to_index(self, date): return np.where(self.dates == date)[0][0] def _read(self): - available_dates = [] - for _, dates, stats in self.owner._gather_data(): + def check_type(a, b): + a = list(a) + b = list(b) + a = a[0] if a else None + b = b[0] if b else None + assert type(a) is type(b), (type(a), type(b)) + + found = set() + offset = 0 + for _, _dates, stats in self.owner._gather_data(): assert isinstance(stats, dict), stats + assert stats["minimum"].shape[0] == len(_dates), (stats["minimum"].shape, len(_dates)) + assert stats["minimum"].shape[1] == len(self.variables_names), ( + stats["minimum"].shape, + len(self.variables_names), + ) for n in self.NAMES: assert n in stats, (n, list(stats.keys())) - dates = to_datetimes(dates) - - indexes = [] - stats_indexes = [] - for i, d in enumerate(dates): - if d not in self.dates: - continue - stats_indexes.append(i) - indexes.append(self._date_to_index(d)) - available_dates.append(d) - - if not indexes: + _dates = to_datetimes(_dates) + check_type(_dates, self.dates) + if found: + check_type(found, self.dates) + assert found.isdisjoint(_dates), "Duplicate dates found in precomputed statistics" + + # filter dates + dates = set(_dates) & set(self.dates) + + if not dates: + # dates have been completely filtered for this chunk continue - self.flags[indexes] = True + # filter data + bitmap = np.isin(_dates, self.dates) + for k in self.NAMES: + stats[k] = stats[k][bitmap] + + assert stats["minimum"].shape[0] == len(dates), (stats["minimum"].shape, len(dates)) + + # store data in self + found |= set(dates) for name in self.NAMES: array = getattr(self, name) - data = stats[name] - data = data[stats_indexes] - array[indexes] = data + assert stats[name].shape[0] == len(dates), (stats[name].shape, len(dates)) + array[offset : offset + len(dates)] = stats[name] + offset += len(dates) - assert type(available_dates[0]) is type(self.dates[0]), (available_dates[0], self.dates[0]) - assert len(available_dates) == len(set(available_dates)), "Duplicate dates found in statistics" for d in self.dates: - assert d in available_dates, f"Statistics for date {d} not precomputed." - assert len(available_dates) == len(self.dates) - print(f"Statistics for {len(available_dates)} dates found.") + assert d in found, f"Statistics for date {d} not precomputed." + assert len(self.dates) == len(found), "Not all dates found in precomputed statistics" + assert len(self.dates) == offset, "Not all dates found in precomputed statistics." + print(f"Statistics for {len(found)} dates found.") def aggregate(self): - if not np.all(self.flags): - not_found = np.where(self.flags == False) # noqa: E712 - raise Exception(f"Statistics not precomputed for {not_found}", not_found) - for name in self.NAMES: if name == "count": continue