diff --git a/mermaid/__init__.py b/mermaid/__init__.py index 9bbf86b..50c01e3 100644 --- a/mermaid/__init__.py +++ b/mermaid/__init__.py @@ -1,6 +1,6 @@ """ a beter docs sttrings """ +from ._main import Mermaid from ._utils import load -from .mermaid import * __version__: str = '0.1.6' diff --git a/mermaid/_main.py b/mermaid/_main.py new file mode 100644 index 0000000..c71ae22 --- /dev/null +++ b/mermaid/_main.py @@ -0,0 +1,38 @@ +import base64 +from pathlib import Path +from typing import Union + +import requests +from requests import Response + +from .graph import Graph + + +class Mermaid: + def __init__(self, graph: Graph): + self._diagram = self._process_diagram(graph.script) + self._make_request_to_mermaid() + + @staticmethod + def _process_diagram(diagram: str) -> str: + graphbytes = diagram.encode('utf8') + base64_bytes = base64.b64encode(graphbytes) + diagram = base64_bytes.decode('ascii') + return diagram + + def _repr_html_(self) -> str: + return self.svg_response.text + + def _make_request_to_mermaid(self) -> None: + self.svg_response: Response = requests.get('https://mermaid.ink/svg/' + + self._diagram) + self.img_response: Response = requests.get('https://mermaid.ink/img/' + + self._diagram) + + def to_svg(self, path: Union[str, Path]) -> None: + with open(path, 'w', encoding='utf-8') as file: + file.write(self.svg_response.text) + + def to_png(self, path: Union[str, Path]) -> None: + with open(path, 'w', encoding='utf-8') as file: + file.write(self.img_response.text) diff --git a/mermaid/mermaid.py b/mermaid/mermaid.py deleted file mode 100644 index d17df5a..0000000 --- a/mermaid/mermaid.py +++ /dev/null @@ -1,26 +0,0 @@ -import base64 - -import requests -from requests import Response - -from .graph import Graph - - -class Mermaid: - def __init__(self, graph: Graph): - self._diagram = self._process_diagram(graph.script) - self._make_request_to_mermaid() - - @staticmethod - def _process_diagram(diagram: str) -> str: - graphbytes = diagram.encode('utf8') - base64_bytes = base64.b64encode(graphbytes) - diagram = base64_bytes.decode('ascii') - return diagram - - def _repr_html_(self) -> str: - return self.response.text - - def _make_request_to_mermaid(self) -> None: - self.response: Response = requests.get('https://mermaid.ink/svg/' + - self._diagram) diff --git a/mermaid/tests/test_mermaid.py b/mermaid/tests/test_mermaid.py index c82a906..cf5d9dd 100644 --- a/mermaid/tests/test_mermaid.py +++ b/mermaid/tests/test_mermaid.py @@ -1,4 +1,6 @@ +import os import unittest +from pathlib import Path from mermaid import Mermaid from mermaid.graph import Graph @@ -11,8 +13,53 @@ def setUp(self) -> None: A-->C; B-->D; C-->D;""" - graph: Graph = Graph('simple-graph', script) + self.name: str = 'simple-graph' + graph: Graph = Graph(self.name, script) self.mermaid_object = Mermaid(graph) - def test_make_request_to_mermaid_api(self): - self.assertTrue(self.mermaid_object.response.status_code == 200) + def test_make_request_to_mermaid_api_for_svg(self): + self.assertTrue(self.mermaid_object.svg_response.status_code == 200) + + def test_make_request_to_mermaid_api_for_png(self): + self.assertTrue(self.mermaid_object.img_response.status_code == 200) + + def test_to_svg_on_mermaid(self): + output_path: str = f'./{self.name}.svg' + self.mermaid_object.to_svg(output_path) + + self.assertTrue(os.path.exists(output_path)) + + def test_to_svg_on_mermaid_with_path(self): + output_path: Path = Path(f'./{self.name}.svg') + self.mermaid_object.to_svg(output_path) + + self.assertTrue(Path.exists(output_path)) + + def test_to_png_on_mermaid(self): + output_path: str = f'./{self.name}.png' + self.mermaid_object.to_png(output_path) + + self.assertTrue(os.path.exists(output_path)) + + def test_to_png_on_mermaid_with_path(self): + output_path: Path = Path(f'./{self.name}.png') + self.mermaid_object.to_png(output_path) + + self.assertTrue(Path.exists(output_path)) + + def tearDown(self) -> None: + output_svg: str = f'./{self.name}.svg' + output_png: str = f'./{self.name}.png' + output_svg_path: Path = Path(f'./{self.name}.svg') + output_png_path: Path = Path(f'./{self.name}.png') + + if os.path.exists(output_svg): + os.remove(output_svg) + if os.path.exists(output_png): + os.remove(output_png) + if os.path.exists(output_svg_path): + os.remove(output_svg_path) + if os.path.exists(output_png_path): + os.remove(output_png_path) + + return super().tearDown()