Skip to content

Commit

Permalink
fix: don't raise error when missing time series on token collection
Browse files Browse the repository at this point in the history
When collecting the tokens from time series collection, we don't want to
raise an error since that would limit the information we get from
validation. Instead we want to collect the tokens we can, then validate
the model based on that context. The errors will still give information
about a missing resource, but also additional information, and file
context.
  • Loading branch information
jsolaas committed Dec 10, 2024
1 parent 2bbb989 commit d4d40a5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 30 deletions.
66 changes: 38 additions & 28 deletions src/libecalc/presentation/yaml/domain/time_series_collections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Self

from libecalc.presentation.yaml.domain.time_series import TimeSeries
from libecalc.presentation.yaml.domain.time_series_collection import TimeSeriesCollection
Expand All @@ -17,7 +18,41 @@ class TimeSeriesCollections(TimeSeriesProvider):
steps in all collections.
"""

def __init__(self, time_series: list[YamlTimeSeriesCollection], resources: dict[str, Resource]):
def __init__(self, time_series_collections: dict[str, TimeSeriesCollection]):
self._time_series_collections = time_series_collections

def get_time_series_references(self) -> list[str]:
time_series_references = []
for collection in self._time_series_collections.values():
for time_series_reference in collection.get_time_series_references():
time_series_references.append(f"{collection.name};{time_series_reference}")
return time_series_references

def get_time_series(self, time_series_id: str) -> TimeSeries:
reference_id_parts = time_series_id.split(";")
if len(reference_id_parts) != 2:
raise TimeSeriesNotFound(time_series_id)
[collection_id, time_series_id] = reference_id_parts

if collection_id not in self._time_series_collections:
raise TimeSeriesNotFound(time_series_id)

return self._time_series_collections[collection_id].get_time_series(time_series_id)

def get_time_vector(self) -> set[datetime]:
time_vector: set[datetime] = set()
for time_series_collection in self._time_series_collections.values():
if time_series_collection.should_influence_time_vector():
time_vector = time_vector.union(time_series_collection.get_time_vector())
return time_vector

@classmethod
def create(
cls,
time_series: list[YamlTimeSeriesCollection],
resources: dict[str, Resource],
raise_on_error: bool,
) -> tuple[Self, list[ModelValidationError]]:
time_series_collections: dict[str, TimeSeriesCollection] = {}
errors: list[ModelValidationError] = []
for time_series_collection in time_series:
Expand Down Expand Up @@ -52,32 +87,7 @@ def __init__(self, time_series: list[YamlTimeSeriesCollection], resources: dict[
for error in e.errors()
]
)
if len(errors) != 0:
if raise_on_error and len(errors) != 0:
raise ModelValidationException(errors=errors)

self._time_series_collections = time_series_collections

def get_time_series_references(self) -> list[str]:
time_series_references = []
for collection in self._time_series_collections.values():
for time_series_reference in collection.get_time_series_references():
time_series_references.append(f"{collection.name};{time_series_reference}")
return time_series_references

def get_time_series(self, time_series_id: str) -> TimeSeries:
reference_id_parts = time_series_id.split(";")
if len(reference_id_parts) != 2:
raise TimeSeriesNotFound(time_series_id)
[collection_id, time_series_id] = reference_id_parts

if collection_id not in self._time_series_collections:
raise TimeSeriesNotFound(time_series_id)

return self._time_series_collections[collection_id].get_time_series(time_series_id)

def get_time_vector(self) -> set[datetime]:
time_vector: set[datetime] = set()
for time_series_collection in self._time_series_collections.values():
if time_series_collection.should_influence_time_vector():
time_vector = time_vector.union(time_series_collection.get_time_vector())
return time_vector
return cls(time_series_collections), errors
14 changes: 12 additions & 2 deletions src/libecalc/presentation/yaml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,13 @@ def end(self) -> Optional[datetime]:
return self._configuration.end

def _get_time_series_collections(self) -> TimeSeriesCollections:
return TimeSeriesCollections(time_series=self._configuration.time_series, resources=self.resources)
time_series_collections, err = TimeSeriesCollections.create(
time_series=self._configuration.time_series,
resources=self.resources,
raise_on_error=True,
)
assert len(err) == 0
return time_series_collections

def _get_time_vector(self):
return get_global_time_vector(
Expand Down Expand Up @@ -158,7 +164,11 @@ def get_graph(self) -> ComponentGraph:
return self._graph

def _get_token_references(self, yaml_model: YamlValidator) -> list[str]:
token_references = self._get_time_series_collections().get_time_series_references()
# Only get references for valid time series collections
time_series_collections, _ = TimeSeriesCollections.create(
time_series=self._configuration.time_series, resources=self.resources, raise_on_error=False
)
token_references = time_series_collections.get_time_series_references()

for reference in yaml_model.variables:
token_references.append(f"$var.{reference}")
Expand Down

0 comments on commit d4d40a5

Please sign in to comment.