diff --git a/scripts/plot_bd_models.py b/scripts/plot_bd_models.py index 6e67a1a3..6d4fb1f8 100644 --- a/scripts/plot_bd_models.py +++ b/scripts/plot_bd_models.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path import matplotlib as mpl @@ -33,11 +34,9 @@ def plot_bd_models( simpletitle: bool = typer.Option(False, help="Make title simple"), birth2d: bool = typer.Option(False, help="Make 2D plot for birth rate"), ) -> None: - try: - import PySide6 - + if importlib.util.find_spec("PySide6") is not None: mpl.use("QtAgg") - except ImportError: + else: mpl.use("TkAgg") with config.open("r") as f: @@ -85,7 +84,7 @@ def plot_bd_models( else: ax = fig.add_subplot(111, projection="3d") if simpletitle: - ax.set_title(f"Birth function") # type: ignore + ax.set_title("Birth function") # type: ignore else: ax.set_title(f"{type(birth_model).__name__} Birth function") # type: ignore if birth2d: @@ -110,9 +109,8 @@ def plot_bd_models( if yes or typer.confirm("Plot survivor ship curve?"): fig = plt.figure(figsize=(5, 10)) ax = fig.add_subplot(111) - ax.set_title( - f"{type(birth_model).__name__} Survivor ship when energy={survivorship_energy}" - ) + tname = type(birth_model).__name__ + ax.set_title(f"{tname} Survivor ship when energy={survivorship_energy}") vis_survivorship(ax=ax, hazard_fn=hazard_model, age_max=age_max, initial=True) plt.show() diff --git a/src/emevo/analysis/qt_widget.py b/src/emevo/analysis/qt_widget.py index 55603fe0..64358a61 100644 --- a/src/emevo/analysis/qt_widget.py +++ b/src/emevo/analysis/qt_widget.py @@ -2,7 +2,6 @@ """ from __future__ import annotations -import dataclasses import sys from collections import deque from collections.abc import Iterable diff --git a/src/emevo/environments/moderngl_vis.py b/src/emevo/environments/moderngl_vis.py index 8d78f864..d2787c45 100644 --- a/src/emevo/environments/moderngl_vis.py +++ b/src/emevo/environments/moderngl_vis.py @@ -4,7 +4,7 @@ """ from __future__ import annotations -from typing import Any, Callable, ClassVar +from typing import Callable, ClassVar import jax.numpy as jnp import moderngl as mgl diff --git a/src/emevo/exp_utils.py b/src/emevo/exp_utils.py index e2f4a3f9..43e99367 100644 --- a/src/emevo/exp_utils.py +++ b/src/emevo/exp_utils.py @@ -333,7 +333,7 @@ def finalize(self) -> None: ] pq.write_table( pa.Table.from_pylist(profile_and_rewards), - self.logdir.joinpath(f"profile_and_rewards.parquet"), + self.logdir.joinpath("profile_and_rewards.parquet"), ) if self.mode in [LogMode.FULL, LogMode.REWARD_AND_LOG]: diff --git a/src/emevo/visualizer.py b/src/emevo/visualizer.py index 4508d906..2ec6945c 100644 --- a/src/emevo/visualizer.py +++ b/src/emevo/visualizer.py @@ -1,7 +1,7 @@ from __future__ import annotations from os import PathLike -from typing import Any, Protocol, TypeVar +from typing import Protocol, TypeVar from numpy.typing import NDArray