diff --git a/src/emevo/env.py b/src/emevo/env.py index baa2e34d..eefb624a 100644 --- a/src/emevo/env.py +++ b/src/emevo/env.py @@ -93,8 +93,7 @@ def init_uniqueid(n: int, max_n: int) -> UniqueID: class ObsProtocol(Protocol): """Abstraction for agent's observation""" - def as_array(self) -> jax.Array: - ... + def as_array(self) -> jax.Array: ... OBS = TypeVar("OBS", bound="ObsProtocol") @@ -109,8 +108,7 @@ class StateProtocol(Protocol): status: Status n_born_agents: jax.Array - def is_extinct(self) -> bool: - ... + def is_extinct(self) -> bool: ... STATE = TypeVar("STATE", bound="StateProtocol") diff --git a/src/emevo/environments/env_utils.py b/src/emevo/environments/env_utils.py index 971c5291..12b9922f 100644 --- a/src/emevo/environments/env_utils.py +++ b/src/emevo/environments/env_utils.py @@ -48,8 +48,7 @@ def get_slice(self, index: int) -> Self: class ReprNumFn(Protocol): initial: int - def __call__(self, n_steps: int, state: FoodNumState) -> FoodNumState: - ... + def __call__(self, n_steps: int, state: FoodNumState) -> FoodNumState: ... @dataclasses.dataclass(frozen=True) @@ -176,16 +175,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> tuple[ReprNumFn, FoodNumState]: class Coordinate(Protocol): - def bbox(self) -> tuple[tuple[float, float], tuple[float, float]]: - ... + def bbox(self) -> tuple[tuple[float, float], tuple[float, float]]: ... def contains_circle( self, center: jax.Array, radius: jax.Array | float - ) -> jax.Array: - ... + ) -> jax.Array: ... - def uniform(self, key: chex.PRNGKey) -> jax.Array: - ... + def uniform(self, key: chex.PRNGKey) -> jax.Array: ... @dataclasses.dataclass diff --git a/src/emevo/environments/phyjax2d.py b/src/emevo/environments/phyjax2d.py index 547fb6b0..cde2f396 100644 --- a/src/emevo/environments/phyjax2d.py +++ b/src/emevo/environments/phyjax2d.py @@ -99,8 +99,7 @@ class _PositionLike(Protocol): angle: jax.Array # Angular velocity (N,) xy: jax.Array # (N, 2) - def __init__(self, angle: jax.Array, xy: jax.Array) -> None: - ... + def __init__(self, angle: jax.Array, xy: jax.Array) -> None: ... def batch_size(self) -> int: return self.angle.shape[0] diff --git a/src/emevo/visualizer.py b/src/emevo/visualizer.py index a1c863d8..3a9313df 100644 --- a/src/emevo/visualizer.py +++ b/src/emevo/visualizer.py @@ -13,8 +13,7 @@ def close(self) -> None: """Close this visualizer""" ... - def get_image(self) -> NDArray: - ... + def get_image(self) -> NDArray: ... def render(self, state: STATE, **kwargs) -> None: """Render image"""