diff --git a/mermaid/__init__.py b/mermaid/__init__.py index 065e446..a6b3938 100644 --- a/mermaid/__init__.py +++ b/mermaid/__init__.py @@ -12,7 +12,7 @@ """ from enum import Enum -from ._main import Mermaid +from ._main import Mermaid, Position from ._utils import load, text_to_snake_case from .configuration import Config from .graph import Graph @@ -30,4 +30,7 @@ class Direction(Enum): BOTTOM_TO_TOP = 'BT' -__all__ = ['Mermaid', 'load', 'Direction', 'Graph', 'Style', 'Config', 'Icon'] +__all__ = [ + 'Mermaid', 'load', 'Direction', 'Graph', 'Style', 'Config', 'Icon', + 'Position' +] diff --git a/mermaid/_main.py b/mermaid/_main.py index ca30ccc..f920c80 100644 --- a/mermaid/_main.py +++ b/mermaid/_main.py @@ -1,4 +1,5 @@ import base64 +from enum import Enum from pathlib import Path from typing import Union @@ -8,6 +9,16 @@ from .graph import Graph +class Position(Enum): + """ + This class represents the position of the node in a Mermaid diagram. + """ + LEFT = 'left' + RIGHT = 'right' + CENTER = 'center' + NONE = 'none' + + class Mermaid: """ This class represents a Mermaid diagram. @@ -17,16 +28,30 @@ class Mermaid: svg_response (Response): The response from the GET request to the Mermaid SVG API. img_response (Response): The response from the GET request to the Mermaid IMG API. """ - def __init__(self, graph: Graph): + def __init__(self, + graph: Graph, + position: Union[Position, str] = Position.NONE): """ The constructor for the Mermaid class. Parameters: graph (Graph): The Graph object containing the Mermaid diagram script. """ + self.__position: str = position if isinstance(position, + str) else position.value self._diagram = self._process_diagram(graph.script) self._make_request_to_mermaid() + def set_position(self, position: Union[Position, str]) -> None: + """ + Set the position of the node in the Mermaid diagram. + + Parameters: + position (Union[Position, str]): The position of the node. + """ + self.__position = position if isinstance(position, + str) else position.value + @staticmethod def _process_diagram(diagram: str) -> str: """ @@ -50,7 +75,9 @@ def _repr_html_(self) -> str: Returns: str: The text of the SVG response. """ - return self.svg_response.text + if self.__position == Position.NONE.value: + return self.svg_response.text + return f'