From 1906dca651fe7b9463c5de556e547a12b7c8fd51 Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Wed, 21 Jul 2021 14:36:18 +0200 Subject: [PATCH 1/7] refactor: store rendered_graph separately --- src/qrules/io/_dot.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index 79bbd92a..fbf69860 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -105,11 +105,19 @@ def __graph_to_dot_content( # pylint: disable=too-many-locals,too-many-branches ) -> str: dot = "" if isinstance(graph, (StateTransition, StateTransitionGraph)): + rendered_graph: Union[ + StateTransition, + StateTransitionGraph, + Topology, + ] = graph topology = graph.topology elif isinstance(graph, Topology): + rendered_graph = graph topology = graph else: - raise NotImplementedError + raise NotImplementedError( + f"Cannot render {graph.__class__.__name__} as dot" + ) top = topology.incoming_edge_ids outs = topology.outgoing_edge_ids for edge_id in top | outs: @@ -117,7 +125,7 @@ def __graph_to_dot_content( # pylint: disable=too-many-locals,too-many-branches render = render_initial_state_id else: render = render_final_state_id - edge_label = __get_edge_label(graph, edge_id, render) + edge_label = __get_edge_label(rendered_graph, edge_id, render) dot += _DOT_DEFAULT_NODE.format( prefix + __node_name(edge_id), edge_label, @@ -134,7 +142,7 @@ def __graph_to_dot_content( # pylint: disable=too-many-locals,too-many-branches dot += _DOT_LABEL_EDGE.format( prefix + __node_name(i, k), prefix + __node_name(i, j), - __get_edge_label(graph, i, render_resonance_id), + __get_edge_label(rendered_graph, i, render_resonance_id), ) if isinstance(graph, (StateTransition, StateTransitionGraph)): if isinstance(graph, StateTransition): From 32ab3941e9370eed9b1c82f0e787ae9eda2813b9 Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Wed, 21 Jul 2021 14:45:33 +0200 Subject: [PATCH 2/7] refactor: extract ___render_edge_with_id --- src/qrules/io/_dot.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index fbf69860..f37dada9 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -191,22 +191,32 @@ def __get_edge_label( graph = graph.to_graph() if isinstance(graph, StateTransitionGraph): edge_prop = graph.get_edge_props(edge_id) - if not edge_prop: - return str(edge_id) - edge_label = __edge_label(edge_prop) - if not render_edge_id: - return edge_label - if "\n" in edge_label: - return f"{edge_id}:\n{edge_label}" - return f"{edge_id}: {edge_label}" + return ___render_edge_with_id(edge_id, edge_prop, render_edge_id) if isinstance(graph, Topology): if render_edge_id: return str(edge_id) return "" - raise NotImplementedError + raise NotImplementedError( + f"Cannot render {graph.__class__.__name__} as dot" + ) + + +def ___render_edge_with_id( + edge_id: int, + edge_prop: Optional[Union[ParticleCollection, Particle, ParticleWithSpin]], + render_edge_id: bool, +) -> str: + if edge_prop is None or not edge_prop: + return str(edge_id) + edge_label = __render_edge_property(edge_prop) + if not render_edge_id: + return edge_label + if "\n" in edge_label: + return f"{edge_id}:\n{edge_label}" + return f"{edge_id}: {edge_label}" -def __edge_label( +def __render_edge_property( edge_prop: Union[ParticleCollection, Particle, ParticleWithSpin] ) -> str: if isinstance(edge_prop, Particle): From 7cf0f2e4392ce8c934a5781a6d8da6569c642d71 Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Wed, 21 Jul 2021 14:49:54 +0200 Subject: [PATCH 3/7] feat: render ProblemSet as DOT --- src/qrules/io/__init__.py | 3 +- src/qrules/io/_dot.py | 80 ++++++++++++++++++++++++++++++++++++--- tests/unit/io/test_dot.py | 21 ++++++++++ 3 files changed, 97 insertions(+), 7 deletions(-) diff --git a/src/qrules/io/__init__.py b/src/qrules/io/__init__.py index 3a93f599..c4e2e99b 100644 --- a/src/qrules/io/__init__.py +++ b/src/qrules/io/__init__.py @@ -16,6 +16,7 @@ from qrules.particle import Particle, ParticleCollection from qrules.topology import StateTransitionGraph, Topology from qrules.transition import ( + ProblemSet, ReactionInfo, State, StateTransition, @@ -126,7 +127,7 @@ def asdot( """ if isinstance(instance, StateTransition): instance = instance.to_graph() - if isinstance(instance, (StateTransitionGraph, Topology)): + if isinstance(instance, (ProblemSet, StateTransitionGraph, Topology)): return _dot.graph_to_dot( instance, render_node=render_node, diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index f37dada9..4933f9a3 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -3,13 +3,15 @@ See :doc:`/usage/visualize` for more info. """ +import re from collections import abc from typing import Callable, Iterable, List, Mapping, Optional, Union from qrules.particle import Particle, ParticleCollection, ParticleWithSpin from qrules.quantum_numbers import InteractionProperties, _to_fraction +from qrules.solving import EdgeSettings, NodeSettings from qrules.topology import StateTransitionGraph, Topology -from qrules.transition import StateTransition +from qrules.transition import ProblemSet, StateTransition _DOT_HEAD = """digraph { rankdir=LR; @@ -95,7 +97,12 @@ def graph_to_dot( def __graph_to_dot_content( # pylint: disable=too-many-locals,too-many-branches - graph: Union[StateTransition, StateTransitionGraph, Topology], + graph: Union[ + ProblemSet, + StateTransition, + StateTransitionGraph, + Topology, + ], prefix: str = "", *, render_node: bool, @@ -106,11 +113,15 @@ def __graph_to_dot_content( # pylint: disable=too-many-locals,too-many-branches dot = "" if isinstance(graph, (StateTransition, StateTransitionGraph)): rendered_graph: Union[ + ProblemSet, StateTransition, StateTransitionGraph, Topology, ] = graph topology = graph.topology + elif isinstance(graph, ProblemSet): + rendered_graph = graph + topology = graph.topology elif isinstance(graph, Topology): rendered_graph = graph topology = graph @@ -144,6 +155,15 @@ def __graph_to_dot_content( # pylint: disable=too-many-locals,too-many-branches prefix + __node_name(i, j), __get_edge_label(rendered_graph, i, render_resonance_id), ) + if isinstance(graph, ProblemSet): + node_props = graph.solving_settings.node_settings + for node_id, settings in node_props.items(): + node_label = "" + if render_node: + node_label = __node_label(settings) + dot += _DOT_DEFAULT_NODE.format( + f"{prefix}node{node_id}", node_label + ) if isinstance(graph, (StateTransition, StateTransitionGraph)): if isinstance(graph, StateTransition): interactions: Mapping[ @@ -183,10 +203,24 @@ def __rank_string(node_edge_ids: Iterable[int], prefix: str = "") -> str: def __get_edge_label( - graph: Union[StateTransition, StateTransitionGraph, Topology], + graph: Union[ + ProblemSet, + StateTransition, + StateTransitionGraph, + Topology, + ], edge_id: int, render_edge_id: bool, ) -> str: + if isinstance(graph, ProblemSet): + edge_setting = graph.solving_settings.edge_settings.get(edge_id) + initial_fact = graph.initial_facts.edge_props.get(edge_id) + edge_property: Optional[Union[EdgeSettings, ParticleWithSpin]] = None + if edge_setting: + edge_property = edge_setting + if initial_fact: + edge_property = initial_fact + return ___render_edge_with_id(edge_id, edge_property, render_edge_id) if isinstance(graph, StateTransition): graph = graph.to_graph() if isinstance(graph, StateTransitionGraph): @@ -203,7 +237,9 @@ def __get_edge_label( def ___render_edge_with_id( edge_id: int, - edge_prop: Optional[Union[ParticleCollection, Particle, ParticleWithSpin]], + edge_prop: Optional[ + Union[EdgeSettings, ParticleCollection, Particle, ParticleWithSpin] + ], render_edge_id: bool, ) -> str: if edge_prop is None or not edge_prop: @@ -217,8 +253,12 @@ def ___render_edge_with_id( def __render_edge_property( - edge_prop: Union[ParticleCollection, Particle, ParticleWithSpin] + edge_prop: Optional[ + Union[EdgeSettings, ParticleCollection, Particle, ParticleWithSpin] + ] ) -> str: + if isinstance(edge_prop, EdgeSettings): + return __render_settings(edge_prop) if isinstance(edge_prop, Particle): return edge_prop.name if isinstance(edge_prop, tuple): @@ -230,7 +270,9 @@ def __render_edge_property( raise NotImplementedError -def __node_label(node_prop: Union[InteractionProperties]) -> str: +def __node_label(node_prop: Union[InteractionProperties, NodeSettings]) -> str: + if isinstance(node_prop, NodeSettings): + return __render_settings(node_prop) if isinstance(node_prop, InteractionProperties): output = "" if node_prop.l_magnitude is not None: @@ -256,6 +298,32 @@ def __node_label(node_prop: Union[InteractionProperties]) -> str: raise NotImplementedError +def __render_settings(settings: Union[EdgeSettings, NodeSettings]) -> str: + output = "" + if settings.rule_priorities: + output += "RULE PRIORITIES\n" + rule_names = map( + lambda item: f"{item[0].__name__} - {item[1]}", # type: ignore + settings.rule_priorities.items(), + ) + sorted_names = sorted( + rule_names, + key=lambda s: int(re.match(r".* \- ([0-9]+)$", s)[1]), # type: ignore + reverse=True, + ) + output += "\n".join(sorted_names) + if settings.qn_domains: + if output: + output += "\n" + domains = map( + lambda item: f"{item[0].__name__} ∊ {item[1]}", # type: ignore + settings.qn_domains.items(), + ) + output += "DOMAINS\n" + output += "\n".join(sorted(domains)) + return output + + def _get_particle_graphs( graphs: Iterable[StateTransitionGraph[ParticleWithSpin]], ) -> List[StateTransitionGraph[Particle]]: diff --git a/tests/unit/io/test_dot.py b/tests/unit/io/test_dot.py index 291f1acc..996019c1 100644 --- a/tests/unit/io/test_dot.py +++ b/tests/unit/io/test_dot.py @@ -1,6 +1,8 @@ # pylint: disable=no-self-use import pydot +import pytest +import qrules from qrules import io from qrules.io._dot import _collapse_graphs, _get_particle_graphs from qrules.particle import ParticleCollection @@ -26,6 +28,25 @@ def test_asdot(reaction: ReactionInfo): assert pydot.graph_from_dot_data(dot_data) is not None +@pytest.mark.parametrize( + "formalism", + ["canonical", "canonical-helicity", "helicity"], +) +def test_asdot_problemset(formalism: str): + stm = qrules.StateTransitionManager( + initial_state=[("J/psi(1S)", [+1])], + final_state=["gamma", "pi0", "pi0"], + formalism=formalism, + ) + problem_sets = stm.create_problem_sets() + for problem_set_list in problem_sets.values(): + for problem_set in problem_set_list: + dot_data = io.asdot(problem_set) + assert pydot.graph_from_dot_data(dot_data) is not None + dot_data = io.asdot(problem_set_list) + assert pydot.graph_from_dot_data(dot_data) is not None + + def test_asdot_topology(): dot_data = io.asdot(create_n_body_topology(3, 4)) assert pydot.graph_from_dot_data(dot_data) is not None From 84f5d5692f702f83414093bd418c84646d91e570 Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Wed, 21 Jul 2021 15:27:59 +0200 Subject: [PATCH 4/7] docs: render ProblemSet in documentation --- docs/usage/reaction.ipynb | 45 +++++++++++++++++---- docs/usage/visualize.ipynb | 80 ++++++++++++++++++++++++++++++++++---- 2 files changed, 110 insertions(+), 15 deletions(-) diff --git a/docs/usage/reaction.ipynb b/docs/usage/reaction.ipynb index 938601ce..87c11256 100644 --- a/docs/usage/reaction.ipynb +++ b/docs/usage/reaction.ipynb @@ -183,7 +183,42 @@ "metadata": {}, "outputs": [], "source": [ - "problem_sets = stm.create_problem_sets()" + "problem_sets = stm.create_problem_sets()\n", + "sorted(problem_sets, reverse=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To get an idea of what these {class}`.ProblemSet`s represent, you can use {func}`.asdot` and {doc}`graphviz:index` to visualize one of them (see {doc}`usage/visualize`):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import graphviz\n", + "\n", + "from qrules import io\n", + "\n", + "some_problem_set = problem_sets[60.0][0]\n", + "dot = io.asdot(some_problem_set, render_node=True)\n", + "graphviz.Source(dot)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Each {class}`.ProblemSet` provides a mapping of {attr}`~.ProblemSet.initial_facts` that represent the initial and final states with spin projections. The nodes and edges in between these {attr}`~.ProblemSet.initial_facts` are still to be generated. This will be done from the provided {attr}`~.ProblemSet.solving_settings` ({class}`~.GraphSettings`). There are two mechanisms there:\n", + "\n", + "1. One the one hand, the {attr}`.EdgeSettings.qn_domains` and {attr}`.NodeSettings.qn_domains` contained in the {class}`~.GraphSettings` define the **domain** over which quantum number sets can be generated.\n", + "2. On the other, the {attr}`.EdgeSettings.rule_priorities` and {attr}`.NodeSettings.rule_priorities` in {class}`~.GraphSettings` define which **{mod}`.conservation_rules`** are used to determine which of the sets of generated quantum numbers are valid.\n", + "\n", + "Together, these two constraints allow the {class}`.StateTransitionManager` to generate a number of {class}`.StateTransitionGraph`s that comply with the selected {mod}`.conservation_rules`." ] }, { @@ -362,10 +397,6 @@ }, "outputs": [], "source": [ - "import graphviz\n", - "\n", - "from qrules import io\n", - "\n", "dot = io.asdot(reaction, collapse_graphs=True, render_node=False)\n", "graphviz.Source(dot)" ] @@ -401,8 +432,6 @@ "metadata": {}, "outputs": [], "source": [ - "from qrules import io\n", - "\n", "io.asdict(reaction.transition_groups[0].topology)" ] }, @@ -448,7 +477,7 @@ "metadata": { "celltoolbar": "Raw Cell Format", "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/docs/usage/visualize.ipynb b/docs/usage/visualize.ipynb index 5b57ad18..9af219f2 100644 --- a/docs/usage/visualize.ipynb +++ b/docs/usage/visualize.ipynb @@ -51,7 +51,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The {mod}`~qrules.io` module allows you to convert {class}`.StateTransitionGraph` and {class}`.Topology` instances to [DOT language](https://graphviz.org/doc/info/lang.html) with {func}`.asdot`. You can visualize its output with third-party libraries, such as [Graphviz](https://graphviz.org). This is particularly useful after running {meth}`~.StateTransitionManager.find_solutions`, which produces a {class}`.ReactionInfo` object with a {class}`.list` of {class}`.StateTransitionGraph` instances (see {doc}`/usage/reaction`)." + "The {mod}`~qrules.io` module allows you to convert {class}`.StateTransitionGraph`, {class}`.Topology` instances, and {class}`.ProblemSet`s to [DOT language](https://graphviz.org/doc/info/lang.html) with {func}`.asdot`. You can visualize its output with third-party libraries, such as [Graphviz](https://graphviz.org). This is particularly useful after running {meth}`~.StateTransitionManager.find_solutions`, which produces a {class}`.ReactionInfo` object with a {class}`.list` of {class}`.StateTransitionGraph` instances (see {doc}`/usage/reaction`)." ] }, { @@ -80,7 +80,7 @@ "source": [ "import graphviz\n", "\n", - "from qrules import io\n", + "import qrules\n", "from qrules.topology import create_isobar_topologies, create_n_body_topology" ] }, @@ -91,7 +91,7 @@ "outputs": [], "source": [ "topology = create_n_body_topology(2, 4)\n", - "graphviz.Source(io.asdot(topology, render_initial_state_id=True))" + "graphviz.Source(qrules.io.asdot(topology, render_initial_state_id=True))" ] }, { @@ -108,7 +108,7 @@ "outputs": [], "source": [ "topologies = create_isobar_topologies(4)\n", - "graphviz.Source(io.asdot(topologies))" + "graphviz.Source(qrules.io.asdot(topologies))" ] }, { @@ -125,7 +125,7 @@ "outputs": [], "source": [ "topologies = create_isobar_topologies(3)\n", - "graphviz.Source(io.asdot(topologies, render_node=False))" + "graphviz.Source(qrules.io.asdot(topologies, render_node=False))" ] }, { @@ -142,7 +142,7 @@ "outputs": [], "source": [ "topologies = create_isobar_topologies(5)\n", - "dot = io.asdot(\n", + "dot = qrules.io.asdot(\n", " topologies[0],\n", " render_final_state_id=False,\n", " render_resonance_id=True,\n", @@ -154,6 +154,72 @@ { "cell_type": "markdown", "metadata": {}, + "source": [ + "## {class}`.ProblemSet`s" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As noted in {doc}`usage/reaction`, the {class}`.StateTransitionManager` provides more control than the façade function {func}`.generate_transitions`. One advantages, is that the {class}`.StateTransitionManager` first generates a set of {class}`.ProblemSet`s with {meth}`.create_problem_sets` that you can further configure if you wish." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "stm = qrules.StateTransitionManager(\n", + " initial_state=[\"J/psi(1S)\"],\n", + " final_state=[\"K0\", \"Sigma+\", \"p~\"],\n", + " formalism=\"canonical-helicity\",\n", + ")\n", + "problem_sets = stm.create_problem_sets()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the output of {meth}`.create_problem_sets` is a {obj}`dict` with {obj}`float` values as keys (representing the interaction strength) and {obj}`list`s of {obj}`.ProblemSet`s as values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sorted(problem_sets, reverse=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(problem_sets[60.0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "problem_set = problem_sets[60.0][0]\n", + "dot = qrules.io.asdot(problem_set, render_node=True)\n", + "graphviz.Source(dot)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, "source": [ "## {class}`.StateTransition`s" ] @@ -297,7 +363,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, From 0ec7d8b4be37fde27cc2e7399f4665f6fdb7ebfa Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Wed, 21 Jul 2021 15:36:09 +0200 Subject: [PATCH 5/7] feat: render tuple with InitialFacts/GraphSettings --- src/qrules/io/_dot.py | 28 ++++++++++++++++++++++------ tests/unit/io/test_dot.py | 7 +++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index 4933f9a3..0d6c129c 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -5,11 +5,12 @@ import re from collections import abc -from typing import Callable, Iterable, List, Mapping, Optional, Union +from typing import Callable, Iterable, List, Mapping, Optional, Tuple, Union +from qrules.combinatorics import InitialFacts from qrules.particle import Particle, ParticleCollection, ParticleWithSpin from qrules.quantum_numbers import InteractionProperties, _to_fraction -from qrules.solving import EdgeSettings, NodeSettings +from qrules.solving import EdgeSettings, GraphSettings, NodeSettings from qrules.topology import StateTransitionGraph, Topology from qrules.transition import ProblemSet, StateTransition @@ -96,12 +97,14 @@ def graph_to_dot( ) -def __graph_to_dot_content( # pylint: disable=too-many-locals,too-many-branches +def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals,too-many-statements graph: Union[ ProblemSet, StateTransition, StateTransitionGraph, Topology, + Tuple[Topology, InitialFacts], + Tuple[Topology, GraphSettings], ], prefix: str = "", *, @@ -111,17 +114,22 @@ def __graph_to_dot_content( # pylint: disable=too-many-locals,too-many-branches render_initial_state_id: bool, ) -> str: dot = "" - if isinstance(graph, (StateTransition, StateTransitionGraph)): + if isinstance(graph, tuple) and len(graph) == 2: + topology: Topology = graph[0] rendered_graph: Union[ + GraphSettings, + InitialFacts, ProblemSet, StateTransition, StateTransitionGraph, Topology, - ] = graph - topology = graph.topology + ] = graph[1] elif isinstance(graph, ProblemSet): rendered_graph = graph topology = graph.topology + elif isinstance(graph, (StateTransition, StateTransitionGraph)): + rendered_graph = graph + topology = graph.topology elif isinstance(graph, Topology): rendered_graph = graph topology = graph @@ -204,6 +212,8 @@ def __rank_string(node_edge_ids: Iterable[int], prefix: str = "") -> str: def __get_edge_label( graph: Union[ + GraphSettings, + InitialFacts, ProblemSet, StateTransition, StateTransitionGraph, @@ -212,6 +222,12 @@ def __get_edge_label( edge_id: int, render_edge_id: bool, ) -> str: + if isinstance(graph, GraphSettings): + edge_setting = graph.edge_settings.get(edge_id) + return ___render_edge_with_id(edge_id, edge_setting, render_edge_id) + if isinstance(graph, InitialFacts): + initial_fact = graph.edge_props.get(edge_id) + return ___render_edge_with_id(edge_id, initial_fact, render_edge_id) if isinstance(graph, ProblemSet): edge_setting = graph.solving_settings.edge_settings.get(edge_id) initial_fact = graph.initial_facts.edge_props.get(edge_id) diff --git a/tests/unit/io/test_dot.py b/tests/unit/io/test_dot.py index 996019c1..3af54e6a 100644 --- a/tests/unit/io/test_dot.py +++ b/tests/unit/io/test_dot.py @@ -43,6 +43,13 @@ def test_asdot_problemset(formalism: str): for problem_set in problem_set_list: dot_data = io.asdot(problem_set) assert pydot.graph_from_dot_data(dot_data) is not None + topology = problem_set.topology + initial_facts = problem_set.initial_facts + settings = problem_set.solving_settings + dot_data = io.asdot([(topology, initial_facts)]) + assert pydot.graph_from_dot_data(dot_data) is not None + dot_data = io.asdot([(topology, settings)]) + assert pydot.graph_from_dot_data(dot_data) is not None dot_data = io.asdot(problem_set_list) assert pydot.graph_from_dot_data(dot_data) is not None From 85d86f753543e7b914894e2b6d286a3d20fbc690 Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Wed, 21 Jul 2021 15:41:31 +0200 Subject: [PATCH 6/7] fix: retain class type with decorator --- src/qrules/_implementers.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/qrules/_implementers.py b/src/qrules/_implementers.py index 64fd1bd6..9f7881d1 100644 --- a/src/qrules/_implementers.py +++ b/src/qrules/_implementers.py @@ -1,6 +1,6 @@ """A collection of implementation tools that can be used accross all modules.""" -from typing import Any, Callable +from typing import Any, Callable, Type, TypeVar import attr @@ -10,10 +10,17 @@ PrettyPrinter = Any -def implement_pretty_repr() -> Callable[[type], type]: +_DecoratedClass = TypeVar("_DecoratedClass") + + +def implement_pretty_repr() -> Callable[ + [Type[_DecoratedClass]], Type[_DecoratedClass] +]: """Implement a pretty :code:`repr` in a `attr` decorated class.""" - def decorator(decorated_class: type) -> type: + def decorator( + decorated_class: Type[_DecoratedClass], + ) -> Type[_DecoratedClass]: if not attr.has(decorated_class): raise TypeError( "Can only implement a pretty repr for a class created with attrs" From 858eee61a40092a1626e689fe74f935b7909c209 Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Wed, 21 Jul 2021 15:47:17 +0200 Subject: [PATCH 7/7] fix: cspell ignore ipykernel --- .cspell.json | 1 + 1 file changed, 1 insertion(+) diff --git a/.cspell.json b/.cspell.json index 151a393a..e47eebeb 100644 --- a/.cspell.json +++ b/.cspell.json @@ -143,6 +143,7 @@ "heli", "heurisch", "imag", + "ipykernel", "isfunction", "isinstance", "jpsi",