Skip to content

Commit

Permalink
Merge pull request #3 from finsberg/types-upgrade
Browse files Browse the repository at this point in the history
Types upgrade
  • Loading branch information
finsberg authored Dec 10, 2023
2 parents 8b9d944 + 2b7281c commit 36fa951
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 36 deletions.
3 changes: 0 additions & 3 deletions setup.py

This file was deleted.

19 changes: 8 additions & 11 deletions src/gotranx/atoms.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from __future__ import annotations

from typing import Optional
from typing import Union

import attr
import lark
import pint
Expand All @@ -29,7 +26,7 @@ def _set_symbol(instance, name: str) -> None:
)


def unit_from_string(unit_str: Optional[str]) -> Optional[pint.Unit]:
def unit_from_string(unit_str: str | None) -> pint.Unit | None:
if unit_str is not None:
try:
unit = ureg.Unit(unit_str)
Expand All @@ -55,13 +52,13 @@ class Atom:
"""Base class for atoms"""

name: str = attr.ib()
value: Union[float, Expression] = attr.ib()
component: Optional[str] = attr.ib(None)
description: Optional[str] = attr.ib(None)
info: Optional[str] = attr.ib(None)
value: float | Expression = attr.ib()
component: str | None = attr.ib(None)
description: str | None = attr.ib(None)
info: str | None = attr.ib(None)
symbol: sp.Symbol = attr.ib(None)
unit_str: Optional[str] = attr.ib(None, repr=False)
unit: Optional[pint.Unit] = attr.ib(None)
unit_str: str | None = attr.ib(None, repr=False)
unit: pint.Unit | None = attr.ib(None)

def __attrs_post_init__(self):
if self.unit is None:
Expand Down Expand Up @@ -138,7 +135,7 @@ class Assignment(Atom):
"""Assignments are object of the form `name = value`."""

value: Expression = attr.ib()
expr: Optional[sp.Expr] = attr.ib(None)
expr: sp.Expr | None = attr.ib(None)

def resolve_expression(self, symbols: dict[str, sp.Symbol]) -> Assignment:
expr = self.value.resolve(symbols)
Expand Down
4 changes: 2 additions & 2 deletions src/gotranx/cli/gotran2c.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import typing
from __future__ import annotations
from pathlib import Path

from ..codegen import CCodeGenerator
from ..load import load_ode


def main(fname: Path, suffix: str = ".h", outname: typing.Optional[str] = None) -> None:
def main(fname: Path, suffix: str = ".h", outname: str | None = None) -> None:
ode = load_ode(fname)
codegen = CCodeGenerator(ode)
code = "\n".join(
Expand Down
2 changes: 1 addition & 1 deletion src/gotranx/codecomponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class CodeComponent:
ode: ODE = attr.ib()
function_name: str = attr.ib()
description: str = attr.ib()
params: typing.Optional[dict[str, typing.Any]] = attr.ib(default=None)
params: dict[str, typing.Any] | None = attr.ib(default=None)


def rhs_expressions(ode: ODE, function_name: str = "rhs", result_name: str = "dy", params=None):
Expand Down
4 changes: 1 addition & 3 deletions src/gotranx/expressions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import Optional

import lark
import sympy as sp

Expand All @@ -27,7 +25,7 @@ def binary_op(op: str, fst, snd):

def build_expression(
root: lark.Tree,
symbols: Optional[dict[str, sp.Symbol]] = None,
symbols: dict[str, sp.Symbol] | None = None,
) -> sp.Expr:
symbols_: dict[str, sp.Symbol] = symbols or {}

Expand Down
4 changes: 2 additions & 2 deletions src/gotranx/load.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import typing
from __future__ import annotations
from pathlib import Path

from structlog import get_logger
Expand All @@ -12,7 +12,7 @@
logger = get_logger()


def load_ode(path: typing.Union[str, Path]):
def load_ode(path: str | Path):
fname = Path(path)

logger.info(f"Load ode {path}")
Expand Down
7 changes: 3 additions & 4 deletions src/gotranx/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import cached_property
from graphlib import TopologicalSorter
from typing import Iterable
from typing import Optional
from typing import Sequence
from typing import TypeVar

Expand Down Expand Up @@ -134,7 +133,7 @@ def resolve_expressions(

def make_ode(
components: Sequence[Component],
comments: Optional[Sequence[atoms.Comment]] = None,
comments: Sequence[atoms.Comment] | None = None,
name: str = "ODE",
) -> ODE:
check_components(components=components)
Expand Down Expand Up @@ -172,9 +171,9 @@ class ODE:
def __init__(
self,
components: Sequence[Component],
t: Optional[sp.Symbol] = None,
t: sp.Symbol | None = None,
name: str = "ODE",
comments: Optional[Sequence[atoms.Comment]] = None,
comments: Sequence[atoms.Comment] | None = None,
):
check_components(components)
symbol_names, symbols, lookup = gather_atoms(components=components)
Expand Down
19 changes: 9 additions & 10 deletions src/gotranx/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from collections import defaultdict
from typing import NamedTuple
from typing import Optional
from typing import Type
from typing import TypeVar

Expand All @@ -24,21 +23,21 @@ def remove_quotes(s: str) -> str:
return s.replace("'", "").replace('"', "")


def find_assignment_component(s) -> Optional[str]:
def find_assignment_component(s) -> str | None:
component = None
if isinstance(s, lark.Token) and s.type == "COMPONENT_NAME":
component = remove_quotes(str(s))
return component


def find_assignment_info(s) -> Optional[str]:
def find_assignment_info(s) -> str | None:
info = None
if len(s) > 1 and isinstance(s[1], lark.Token) and s[1].type == "INFO":
info = remove_quotes(str(s[1]))
return info


def get_unit_from_assignment(s: lark.Tree) -> Optional[str]:
def get_unit_from_assignment(s: lark.Tree) -> str | None:
if len(s.children) >= 3:
unit = s.children[2]
try:
Expand All @@ -51,8 +50,8 @@ def get_unit_from_assignment(s: lark.Tree) -> Optional[str]:

def find_assignments(
s,
component: Optional[str] = None,
info: Optional[str] = None,
component: str | None = None,
info: str | None = None,
) -> list[atoms.Assignment]:
if isinstance(s, lark.Tree):
return [
Expand All @@ -70,9 +69,9 @@ def find_assignments(

def tree2parameter(
s: lark.Tree,
component: Optional[str],
component: str | None,
cls: Type[T],
info: Optional[str] = None,
info: str | None = None,
) -> T:
kwargs = {}
if info is not None:
Expand Down Expand Up @@ -177,7 +176,7 @@ def ode(self, s) -> LarkODE:
atoms.State: "states",
}

components: dict[Optional[str], dict[str, set[atoms.Atom]]] = defaultdict(
components: dict[str | None, dict[str, set[atoms.Atom]]] = defaultdict(
lambda: {atom: set() for atom in mapping.values()},
)
comments = []
Expand All @@ -190,7 +189,7 @@ def ode(self, s) -> LarkODE:
components[atom.component][mapping[type(atom)]].add(atom)

# Make sets frozen
frozen_components: dict[Optional[str], dict[str, frozenset[atoms.Atom]]] = {}
frozen_components: dict[str | None, dict[str, frozenset[atoms.Atom]]] = {}
for component_name, component_values in components.items():
frozen_components[component_name] = {}
for atom_name, atom_values in component_values.items():
Expand Down

0 comments on commit 36fa951

Please sign in to comment.