From ce65dbad024fc9afae74a2c533767152ce2efa20 Mon Sep 17 00:00:00 2001 From: Jostein Solaas <33114722+jsolaas@users.noreply.github.com> Date: Thu, 26 Oct 2023 14:04:04 +0200 Subject: [PATCH] refactor: use Graph object to build graph (#250) Proper graph implementation, isolate networkx and implementation details. --- src/libecalc/dto/components.py | 101 ++++++++++----------------------- src/libecalc/dto/graph.py | 20 ++++++- 2 files changed, 47 insertions(+), 74 deletions(-) diff --git a/src/libecalc/dto/components.py b/src/libecalc/dto/components.py index ea1ed496d6..cef589d647 100644 --- a/src/libecalc/dto/components.py +++ b/src/libecalc/dto/components.py @@ -2,7 +2,6 @@ from datetime import datetime from typing import Dict, List, Literal, Optional, TypeVar, Union -import networkx as nx from libecalc import dto from libecalc.common.string_utils import generate_id, get_duplicates from libecalc.common.temporal_model import TemporalExpression, TemporalModel @@ -214,17 +213,12 @@ class CompressorSystem(BaseConsumer): compressors: List[CompressorComponent] def get_graph(self) -> Graph: - component_dtos = {} - graph = nx.DiGraph() - graph.add_node(self.id) - component_dtos[self.id] = self + graph = Graph() + graph.add_node(self) for compressor in self.compressors: - component_dtos[compressor.id] = compressor + graph.add_node(compressor) graph.add_edge(self.id, compressor.id) - return Graph( - graph=graph, - components=component_dtos, - ) + return graph def evaluate_operational_settings( self, @@ -287,17 +281,12 @@ class PumpSystem(BaseConsumer): pumps: List[PumpComponent] def get_graph(self) -> Graph: - component_dtos = {} - graph = nx.DiGraph() - graph.add_node(self.id) - component_dtos[self.id] = self + graph = Graph() + graph.add_node(self) for pump in self.pumps: - component_dtos[pump.id] = pump + graph.add_node(pump) graph.add_edge(self.id, pump.id) - return Graph( - graph=graph, - components=component_dtos, - ) + return graph def evaluate_operational_settings( self, @@ -379,27 +368,17 @@ def check_mandatory_category_for_generator_set(cls, user_defined_category): return user_defined_category def get_graph(self) -> Graph: - component_dtos = {} - graph = nx.DiGraph() - graph.add_node(self.id) - component_dtos[self.id] = self - + graph = Graph() + graph.add_node(self) for electricity_consumer in self.consumers: - component_dtos[electricity_consumer.id] = electricity_consumer - if hasattr(electricity_consumer, "get_graph"): - consumer_sub_graph = electricity_consumer.get_graph() - component_dtos.update(consumer_sub_graph.components) - graph = nx.compose(graph, consumer_sub_graph.graph) - graph.add_edge(self.id, electricity_consumer.id) + graph.add_subgraph(electricity_consumer.get_graph()) else: - graph.add_node(electricity_consumer.id) - graph.add_edge(self.id, electricity_consumer.id) + graph.add_node(electricity_consumer) - return Graph( - graph=graph, - components=component_dtos, - ) + graph.add_edge(self.id, electricity_consumer.id) + + return graph class Installation(BaseComponent): @@ -435,30 +414,17 @@ def check_user_defined_category(cls, user_defined_category, values): return user_defined_category def get_graph(self) -> Graph: - component_dtos = {} - graph = nx.DiGraph() - graph.add_node(self.id) - component_dtos[self.id] = self - for fuel_consumer in self.fuel_consumers: - component_dtos[fuel_consumer.id] = fuel_consumer - - if hasattr(fuel_consumer, "get_graph"): - consumer_sub_graph = fuel_consumer.get_graph() - component_dtos.update(consumer_sub_graph.components) - graph = nx.compose(graph, consumer_sub_graph.graph) - graph.add_edge(self.id, fuel_consumer.id) + graph = Graph() + graph.add_node(self) + for component in [*self.fuel_consumers, *self.direct_emitters]: + if hasattr(component, "get_graph"): + graph.add_subgraph(component.get_graph()) else: - graph.add_node(fuel_consumer.id) - graph.add_edge(self.id, fuel_consumer.id) - - for direct_emitter in self.direct_emitters: - component_dtos[direct_emitter.id] = direct_emitter - graph.add_node(direct_emitter.id) - graph.add_edge(self.id, direct_emitter.id) - return Graph( - graph=graph, - components=component_dtos, - ) + graph.add_node(component) + + graph.add_edge(self.id, component.id) + + return graph class Asset(Component): @@ -546,20 +512,13 @@ def validate_unique_names(cls, values): return values def get_graph(self) -> Graph: - component_dtos = {} - graph = nx.DiGraph() - graph.add_node(self.id) - component_dtos[self.id] = self + graph = Graph() + graph.add_node(self) for installation in self.installations: - component_dtos[installation.id] = installation - installation_graph = installation.get_graph() - component_dtos.update(installation_graph.components) - graph = nx.compose(graph, installation_graph.graph) + graph.add_subgraph(installation.get_graph()) graph.add_edge(self.id, installation.id) - return Graph( - graph=graph, - components=component_dtos, - ) + + return graph ComponentDTO = Union[ diff --git a/src/libecalc/dto/graph.py b/src/libecalc/dto/graph.py index c5b863b233..206ea9e59b 100644 --- a/src/libecalc/dto/graph.py +++ b/src/libecalc/dto/graph.py @@ -13,9 +13,23 @@ class Graph: - def __init__(self, graph: nx.DiGraph, components: Dict[str, ComponentDTO]): - self.graph = graph - self.components = components + def __init__(self): + self.graph = nx.DiGraph() + self.components: Dict[str, ComponentDTO] = {} + + def add_node(self, component: ComponentDTO): + self.graph.add_node(component.id) + self.components[component.id] = component + + def add_edge(self, from_id: str, to_id: str): + if from_id not in self.components or to_id not in self.components: + raise ValueError("Add node before adding edges") + + self.graph.add_edge(from_id, to_id) + + def add_subgraph(self, subgraph: Graph): + self.components.update(subgraph.components) + self.graph = nx.compose(self.graph, subgraph.graph) def get_successors(self, component_id: str, recursively=False) -> List[str]: if recursively: