diff --git a/mermaid/__init__.py b/mermaid/__init__.py index 065e446..a6b3938 100644 --- a/mermaid/__init__.py +++ b/mermaid/__init__.py @@ -12,7 +12,7 @@ """ from enum import Enum -from ._main import Mermaid +from ._main import Mermaid, Position from ._utils import load, text_to_snake_case from .configuration import Config from .graph import Graph @@ -30,4 +30,7 @@ class Direction(Enum): BOTTOM_TO_TOP = 'BT' -__all__ = ['Mermaid', 'load', 'Direction', 'Graph', 'Style', 'Config', 'Icon'] +__all__ = [ + 'Mermaid', 'load', 'Direction', 'Graph', 'Style', 'Config', 'Icon', + 'Position' +] diff --git a/mermaid/_main.py b/mermaid/_main.py index ca30ccc..f920c80 100644 --- a/mermaid/_main.py +++ b/mermaid/_main.py @@ -1,4 +1,5 @@ import base64 +from enum import Enum from pathlib import Path from typing import Union @@ -8,6 +9,16 @@ from .graph import Graph +class Position(Enum): + """ + This class represents the position of the node in a Mermaid diagram. + """ + LEFT = 'left' + RIGHT = 'right' + CENTER = 'center' + NONE = 'none' + + class Mermaid: """ This class represents a Mermaid diagram. @@ -17,16 +28,30 @@ class Mermaid: svg_response (Response): The response from the GET request to the Mermaid SVG API. img_response (Response): The response from the GET request to the Mermaid IMG API. """ - def __init__(self, graph: Graph): + def __init__(self, + graph: Graph, + position: Union[Position, str] = Position.NONE): """ The constructor for the Mermaid class. Parameters: graph (Graph): The Graph object containing the Mermaid diagram script. """ + self.__position: str = position if isinstance(position, + str) else position.value self._diagram = self._process_diagram(graph.script) self._make_request_to_mermaid() + def set_position(self, position: Union[Position, str]) -> None: + """ + Set the position of the node in the Mermaid diagram. + + Parameters: + position (Union[Position, str]): The position of the node. + """ + self.__position = position if isinstance(position, + str) else position.value + @staticmethod def _process_diagram(diagram: str) -> str: """ @@ -50,7 +75,9 @@ def _repr_html_(self) -> str: Returns: str: The text of the SVG response. """ - return self.svg_response.text + if self.__position == Position.NONE.value: + return self.svg_response.text + return f'
{self.svg_response.text}
' def _make_request_to_mermaid(self) -> None: """ diff --git a/mermaid/tests/test_mermaid.py b/mermaid/tests/test_mermaid.py index cf5d9dd..c418c6f 100644 --- a/mermaid/tests/test_mermaid.py +++ b/mermaid/tests/test_mermaid.py @@ -2,7 +2,7 @@ import unittest from pathlib import Path -from mermaid import Mermaid +from mermaid import Mermaid, Position from mermaid.graph import Graph @@ -47,6 +47,24 @@ def test_to_png_on_mermaid_with_path(self): self.assertTrue(Path.exists(output_path)) + def test_repr_html_on_mermaid_with_default_position(self): + self.assertEqual(self.mermaid_object._repr_html_(), + self.mermaid_object.svg_response.text) + + def test_repr_html_on_mermaid_with_str_position(self): + position = 'center' + self.mermaid_object.set_position(position) + self.assertTrue(self.mermaid_object._repr_html_().startswith( + f'
')) + self.assertTrue(self.mermaid_object._repr_html_().endswith('
')) + + def test_repr_html_on_mermaid_with_enum_position(self): + position = Position.RIGHT + self.mermaid_object.set_position(position) + self.assertTrue(self.mermaid_object._repr_html_().startswith( + f'
')) + self.assertTrue(self.mermaid_object._repr_html_().endswith('
')) + def tearDown(self) -> None: output_svg: str = f'./{self.name}.svg' output_png: str = f'./{self.name}.png'