-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: merge functionality for results
Make the merge functionality easier to reuse and extend.
- Loading branch information
Showing
3 changed files
with
154 additions
and
205 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import itertools | ||
from datetime import datetime | ||
from typing import List, Protocol, TypeVar | ||
|
||
from libecalc.common.utils.rates import TimeSeries | ||
from typing_extensions import Self | ||
|
||
|
||
class TabularTimeSeries(Protocol): | ||
timesteps: List[datetime] | ||
|
||
def copy(self) -> Self: | ||
... | ||
|
||
|
||
ObjectWithTimeSeries = TypeVar("ObjectWithTimeSeries", bound=TabularTimeSeries) | ||
|
||
|
||
class TabularTimeSeriesUtils: | ||
""" | ||
Utility functions for objects containing TimeSeries | ||
""" | ||
|
||
@classmethod | ||
def _merge_helper(cls, *objects_with_timeseries: ObjectWithTimeSeries) -> ObjectWithTimeSeries: | ||
first, *others = objects_with_timeseries | ||
merged_object = first.copy() | ||
|
||
for key, value in first.__dict__.items(): | ||
for other in others: | ||
accumulated_value = merged_object.__getattribute__(key) | ||
other_value = other.__getattribute__(key) | ||
if key == "timesteps": | ||
merged_timesteps = sorted(itertools.chain(accumulated_value, other_value)) | ||
merged_object.__setattr__(key, merged_timesteps) | ||
elif isinstance(value, TimeSeries): | ||
merged_object.__setattr__(key, accumulated_value.merge(other_value)) | ||
|
||
return merged_object | ||
|
||
@classmethod | ||
def merge(cls, *tabular_time_series_list: TabularTimeSeries): | ||
""" | ||
Merge objects containing TimeSeries. Other attributes will be copied from the first object. | ||
Args: | ||
*tabular_time_series_list: list of objects to merge | ||
Returns: a merged object of the same type | ||
""" | ||
# Verify that we are merging the same types | ||
if len({type(tabular_time_series) for tabular_time_series in tabular_time_series_list}) != 1: | ||
raise ValueError("Can not merge objects of differing types.") | ||
|
||
return cls._merge_helper(*tabular_time_series_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.