Skip to content

Commit

Permalink
Added pre-commit-hook for formatting code (#599)
Browse files Browse the repository at this point in the history
* Added `pre-commit-hook` for formatting code
* Removed unnecessary comments
* Mentioned pre-commits in documentation
  • Loading branch information
david-zwicker authored Aug 17, 2024
1 parent 20bd997 commit f6d0b6c
Show file tree
Hide file tree
Showing 36 changed files with 83 additions and 72 deletions.
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.1
hooks:
- id: ruff
args: [--fix, --show-fixes]
- id: ruff-format
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
project = "py-pde"
module_name = "pde"
author = "Zwicker Group"
copyright = f"{date.today().year}, {author}" # @ReservedAssignment # noqa: A001
copyright = f"{date.today().year}, {author}" # noqa: A001
html_logo = "_images/logo_small.png"

# Determine the version from the actual package
Expand Down
1 change: 1 addition & 0 deletions docs/source/manual/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ This folder also contain a script :file:`tests_types.sh`, which uses :mod:`mypy`
to check the consistency of the python type annotations.
We use these type annotations for additional documentation and they have also
already been useful for finding some bugs.
Finally, we have pre-commit hooks, which you should install using `pre-commit install`.

We also have some conventions that should make the package more consistent and
thus easier to use. For instance, we try to use ``properties`` instead of getter
Expand Down
14 changes: 7 additions & 7 deletions pde/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
import contextlib

# import all other modules that should occupy the main name space
from .fields import * # @UnusedWildImport
from .grids import * # @UnusedWildImport
from .pdes import * # @UnusedWildImport
from .solvers import * # @UnusedWildImport
from .storage import * # @UnusedWildImport
from .fields import *
from .grids import *
from .pdes import *
from .solvers import *
from .storage import *
from .tools.parameters import Parameter
from .trackers import * # @UnusedWildImport
from .visualization import * # @UnusedWildImport
from .trackers import *
from .visualization import *

with contextlib.suppress(ImportError):
from .tools.modelrunner import *
Expand Down
8 changes: 4 additions & 4 deletions pde/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.label = label
self._logger = logging.getLogger(self.__class__.__name__)

def __init_subclass__(cls, **kwargs): # @NoSelf
def __init_subclass__(cls, **kwargs):
"""Register all subclassess to reconstruct them later."""
super().__init_subclass__(**kwargs)

Expand Down Expand Up @@ -359,7 +359,7 @@ def assert_field_compatible(
Determines whether it is acceptable that `other` is an instance of
:class:`~pde.fields.ScalarField`.
"""
from .scalar import ScalarField # @Reimport
from .scalar import ScalarField

# check whether they are the same class
is_scalar = accept_scalar and isinstance(other, ScalarField)
Expand Down Expand Up @@ -489,7 +489,7 @@ def _binary_operation(

if isinstance(other, FieldBase):
# right operator is a field
from .scalar import ScalarField # @Reimport
from .scalar import ScalarField

# determine the dtype of the result of the operation
dtype = np.result_type(self.data, other.data)
Expand Down Expand Up @@ -539,7 +539,7 @@ def _binary_operation_inplace(
"""
if isinstance(other, FieldBase):
# right operator is a field
from .scalar import ScalarField # @Reimport
from .scalar import ScalarField

if scalar_second:
# right operator must be a scalar
Expand Down
4 changes: 1 addition & 3 deletions pde/fields/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,9 +1041,7 @@ def plot(
kind = [kind] * num_panels
reference = [
field.plot(kind=knd, ax=ax, action="none", **kwargs, **sp_args)
for field, knd, ax, sp_args in zip( # @UnusedVariable
self.fields, kind, axs, subplot_args
)
for field, knd, ax, sp_args in zip(self.fields, kind, axs, subplot_args)
]

# return the references for all subplots
Expand Down
4 changes: 2 additions & 2 deletions pde/fields/vectorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def dot(
:class:`~pde.fields.scalar.ScalarField` or
:class:`~pde.fields.vectorial.VectorField`: result of applying the operator
"""
from .tensorial import Tensor2Field # @Reimport
from .tensorial import Tensor2Field

# check input
self.grid.assert_grid_compatible(other.grid)
Expand Down Expand Up @@ -253,7 +253,7 @@ def outer_product(
Returns:
:class:`~pde.fields.tensorial.Tensor2Field`: result of the operation
"""
from .tensorial import Tensor2Field # @Reimport
from .tensorial import Tensor2Field

self.assert_field_compatible(other)

Expand Down
6 changes: 3 additions & 3 deletions pde/grids/_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def broadcast(self, data: TData) -> TData:
Returns:
The same data, but on all nodes
"""
from mpi4py.MPI import COMM_WORLD # @UnresolvedImport
from mpi4py.MPI import COMM_WORLD

return COMM_WORLD.bcast(data, root=0) # type: ignore

Expand All @@ -747,7 +747,7 @@ def gather(self, data: TData) -> list[TData] | None:
None on all nodes, except the main node, which receives an ordered list with
the data from all nodes.
"""
from mpi4py.MPI import COMM_WORLD # @UnresolvedImport
from mpi4py.MPI import COMM_WORLD

return COMM_WORLD.gather(data, root=0)

Expand All @@ -761,7 +761,7 @@ def allgather(self, data: TData) -> list[TData]:
Returns:
list: data from all nodes.
"""
from mpi4py.MPI import COMM_WORLD # @UnresolvedImport
from mpi4py.MPI import COMM_WORLD

return COMM_WORLD.allgather(data)

Expand Down
8 changes: 4 additions & 4 deletions pde/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self) -> None:
self.axes = [self.c.axes[i] for i in self._axes_described]
self.axes_symmetric = [self.c.axes[i] for i in self.axes_symmetric] # type: ignore

def __init_subclass__(cls, **kwargs) -> None: # @NoSelf
def __init_subclass__(cls, **kwargs) -> None:
"""Register all subclassess to reconstruct them later."""
super().__init_subclass__(**kwargs)
if cls is not GridBase:
Expand Down Expand Up @@ -1066,7 +1066,7 @@ def get_boundary_conditions(
PeriodicityError:
If the boundaries are not compatible with the periodic axes of the grid.
"""
from .boundaries import Boundaries # @Reimport
from .boundaries import Boundaries

if self._mesh is None:
# get boundary conditions for a simple grid that is not part of a mesh
Expand Down Expand Up @@ -1226,7 +1226,7 @@ def register_operator(factor_func_arg: OperatorFactory):

@hybridmethod # type: ignore
@property
def operators(cls) -> set[str]: # @NoSelf
def operators(cls) -> set[str]:
"""set: all operators defined for this class"""
result = set()
# add all customly defined operators
Expand Down Expand Up @@ -1574,7 +1574,7 @@ def integrate(

else:
# we are in a parallel run, so we need to gather the sub-integrals from all
from mpi4py.MPI import COMM_WORLD # @UnresolvedImport
from mpi4py.MPI import COMM_WORLD

integral_full = np.empty_like(integral)
COMM_WORLD.Allreduce(integral, integral_full)
Expand Down
2 changes: 1 addition & 1 deletion pde/grids/boundaries/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def __init__(self, grid: GridBase, axis: int, upper: bool, *, rank: int = 0):

self._logger = logging.getLogger(self.__class__.__name__)

def __init_subclass__(cls, **kwargs): # @NoSelf
def __init_subclass__(cls, **kwargs):
"""Register all subclasses to reconstruct them later."""
super().__init_subclass__(**kwargs)

Expand Down
2 changes: 1 addition & 1 deletion pde/grids/cylindrical.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def slice(self, indices: Sequence[int]) -> CartesianGrid | PolarSymGrid:

if indices[0] == 0:
# return a radial grid
from .spherical import PolarSymGrid # @Reimport
from .spherical import PolarSymGrid

return PolarSymGrid(self.radius, self.shape[0])

Expand Down
4 changes: 2 additions & 2 deletions pde/grids/operators/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from ...tools.typing import OperatorType
from ..boundaries import Boundaries
from ..cartesian import CartesianGrid
from .common import make_derivative as _make_derivative # @UnusedImport
from .common import make_derivative2 as _make_derivative2 # @UnusedImport
from .common import make_derivative as _make_derivative
from .common import make_derivative2 as _make_derivative2
from .common import make_general_poisson_solver, uniform_discretization

# The `make_derivative?` methods are imported for backward compatibility. Their usage is
Expand Down
2 changes: 1 addition & 1 deletion pde/pdes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def solve(
the current node is not the main MPI node.
"""
from ..solvers import Controller
from ..solvers.base import SolverBase # @Reimport
from ..solvers.base import SolverBase

# create solver instance
if callable(solver):
Expand Down
2 changes: 1 addition & 1 deletion pde/pdes/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..fields import ScalarField
from ..grids.base import GridBase
from ..grids.boundaries.axes import BoundariesData # @UnusedImport
from ..grids.boundaries.axes import BoundariesData
from ..tools.docstrings import fill_in_docstring


Expand Down
6 changes: 3 additions & 3 deletions pde/pdes/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import numba as nb
import numpy as np
from numba.typed import Dict as NumbaDict # @UnresolvedImport
from numba.typed import Dict as NumbaDict
from sympy import Symbol
from sympy.core.function import UndefinedFunction

Expand Down Expand Up @@ -329,13 +329,13 @@ def _compile_rhs_single(
# extend the signature
signature += tuple(state.grid.axes)
# inject the spatial coordinates into the expression for the rhs
extra_args = tuple( # @UnusedVariable
extra_args = tuple(
state.grid.cell_coords[..., i] for i in range(state.grid.num_axes)
)

else:
# expression only depends on the actual variables
extra_args = () # @UnusedVariable
extra_args = ()

# check whether all variables are accounted for
extra_vars = set(expr.vars) - set(signature)
Expand Down
4 changes: 2 additions & 2 deletions pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, pde: PDEBase, *, backend: BackendType = "auto"):
self.info["pde_class"] = self.pde.__class__.__name__
self._logger = logging.getLogger(self.__class__.__name__)

def __init_subclass__(cls, **kwargs): # @NoSelf
def __init_subclass__(cls, **kwargs):
"""Register all subclassess to reconstruct them later."""
super().__init_subclass__(**kwargs)
if not isabstract(cls):
Expand Down Expand Up @@ -113,7 +113,7 @@ def from_name(cls, name: str, pde: PDEBase, **kwargs) -> SolverBase:
return solver_class(pde, **kwargs)

@classproperty
def registered_solvers(cls) -> list[str]: # @NoSelf
def registered_solvers(cls) -> list[str]:
"""list of str: the names of the registered solvers"""
return sorted(cls._subclasses.keys())

Expand Down
8 changes: 4 additions & 4 deletions pde/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def extract_field(
:class:`MemoryStorage`: a storage instance that contains the data for the
single field
"""
from .memory import MemoryStorage # @Reimport
from .memory import MemoryStorage

if self._field is None:
self._init_field()
Expand Down Expand Up @@ -435,7 +435,7 @@ def extract_time_range(
Returns:
:class:`MemoryStorage`: a storage instance that contains the extracted data.
"""
from .memory import MemoryStorage # @Reimport
from .memory import MemoryStorage

# get the time bracket
try:
Expand Down Expand Up @@ -502,7 +502,7 @@ def apply(
raise TypeError("The user function must return a field")

if out is None:
from .memory import MemoryStorage # @Reimport
from .memory import MemoryStorage

out = MemoryStorage(field_obj=transformed)

Expand All @@ -517,7 +517,7 @@ def apply(

# make sure that a storage is returned, even when no fields are present
if out is None:
from .memory import MemoryStorage # @Reimport
from .memory import MemoryStorage

out = MemoryStorage()

Expand Down
2 changes: 1 addition & 1 deletion pde/tools/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def make_unserializer(method: SerializerMethod) -> Callable:
return yaml.full_load

if method == "yaml_unsafe":
import yaml # @Reimport
import yaml

return yaml.unsafe_load

Expand Down
8 changes: 3 additions & 5 deletions pde/tools/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def compile_func(func):
# partial function instead of replacing the constants in the sympy expression
# directly since sympy does not work well with numpy arrays.
if constants:
const_values = tuple(self.consts[c] for c in constants) # @UnusedVariable
const_values = tuple(self.consts[c] for c in constants)

if prepare_compilation:
func = jit(func)
Expand Down Expand Up @@ -1099,13 +1099,11 @@ def evaluate(
# extend the signature
signature += tuple(grid.axes)
# inject the spatial coordinates into the expression for the rhs
extra_args = tuple( # @UnusedVariable
grid.cell_coords[..., i] for i in range(grid.num_axes)
)
extra_args = tuple(grid.cell_coords[..., i] for i in range(grid.num_axes))

else:
# expression only depends on the actual variables
extra_args = () # @UnusedVariable
extra_args = ()

# check whether all variables are accounted for
extra_vars = set(expr.vars) - set(signature)
Expand Down
12 changes: 6 additions & 6 deletions pde/tools/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
from numba.core.types import npytypes, scalars
from numba.extending import overload, register_jitable
from numba.typed import Dict as NumbaDict # @UnresolvedImport
from numba.typed import Dict as NumbaDict

from .. import config
from ..tools.misc import decorator_arguments
Expand Down Expand Up @@ -128,12 +128,12 @@ def f():
"multithreading_threshold": config["numba.multithreading_threshold"],
"fastmath": config["numba.fastmath"],
"debug": config["numba.debug"],
"using_svml": nb.config.USING_SVML, # @UndefinedVariable
"using_svml": nb.config.USING_SVML,
"threading_layer": threading_layer,
"omp_num_threads": os.environ.get("OMP_NUM_THREADS"),
"mkl_num_threads": os.environ.get("MKL_NUM_THREADS"),
"num_threads": nb.config.NUMBA_NUM_THREADS, # @UndefinedVariable
"num_threads_default": nb.config.NUMBA_DEFAULT_NUM_THREADS, # @UndefinedVariable
"num_threads": nb.config.NUMBA_NUM_THREADS,
"num_threads_default": nb.config.NUMBA_DEFAULT_NUM_THREADS,
"cuda_available": cuda_available,
"roc_available": roc_available,
}
Expand Down Expand Up @@ -203,7 +203,7 @@ def jit(function: TFunc, signature=None, parallel: bool = False, **kwargs) -> TF
return nb.jit(signature, **kwargs)(function) # type: ignore


if nb.config.DISABLE_JIT: # @UndefinedVariable
if nb.config.DISABLE_JIT:
# dummy function that creates a ctypes pointer
def address_as_void_pointer(addr):
"""Returns a void pointer from a given memory address.
Expand Down Expand Up @@ -321,7 +321,7 @@ def random_seed(seed: int = 0) -> None:
seed (int): Sets random seed
"""
np.random.seed(seed)
if not nb.config.DISABLE_JIT: # @UndefinedVariable
if not nb.config.DISABLE_JIT:
_random_seed_compiled(seed)


Expand Down
Loading

0 comments on commit f6d0b6c

Please sign in to comment.