diff --git a/benchmarks/pso.py b/benchmarks/pso.py index 43e6dbfde..9acb77d8e 100644 --- a/benchmarks/pso.py +++ b/benchmarks/pso.py @@ -1,11 +1,13 @@ """Benchmark the performance of PSO algorithm in EvoX.""" import time + import torch -from torch.profiler import profile, ProfilerActivity -from evox.core import vmap, Problem, use_state, jit -from evox.workflows import StdWorkflow +from torch.profiler import ProfilerActivity, profile + from evox.algorithms import PSO +from evox.core import Problem, jit, use_state, vmap +from evox.workflows import StdWorkflow def run_pso(): diff --git a/benchmarks/utils.py b/benchmarks/utils.py index aef82cb7e..e92c6fe81 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -1,7 +1,8 @@ """Benchmark the performance of utils functions in EvoX""" import torch -from torch.profiler import profile, ProfilerActivity +from torch.profiler import ProfilerActivity, profile + from evox.core import jit from evox.utils import switch diff --git a/docs/fix_notebook_translation.py b/docs/fix_notebook_translation.py index bc210d93c..7baa70e05 100644 --- a/docs/fix_notebook_translation.py +++ b/docs/fix_notebook_translation.py @@ -1,10 +1,11 @@ """ -This script is used to wrap the translated strings in jupyter notebooks in the docs.po file +This script is used to wrap the translated strings in jupyter notebooks in the docs.po """ -import polib -import json import copy +import json + +import polib po = polib.pofile("docs/source/locale/zh/LC_MESSAGES/docs.po") diff --git a/unit_test/algorithms/test_pso.py b/unit_test/algorithms/test_pso.py index 440d0b18f..723806273 100644 --- a/unit_test/algorithms/test_pso.py +++ b/unit_test/algorithms/test_pso.py @@ -1,8 +1,10 @@ from unittest import TestCase + import torch -from evox.core import vmap, Problem, use_state, jit -from evox.workflows import StdWorkflow + from evox.algorithms import PSO +from evox.core import Problem, jit, use_state, vmap +from evox.workflows import StdWorkflow class TestPSO(TestCase): diff --git a/unit_test/core/test_jit_util.py b/unit_test/core/test_jit_util.py index 1d2f0c90d..67217df36 100644 --- a/unit_test/core/test_jit_util.py +++ b/unit_test/core/test_jit_util.py @@ -1,7 +1,9 @@ import unittest from functools import partial + import torch -from evox.core import vmap, jit + +from evox.core import jit, vmap @partial(vmap, example_ndim=2) diff --git a/unit_test/core/test_module.py b/unit_test/core/test_module.py index f6966a87f..96c5f584b 100644 --- a/unit_test/core/test_module.py +++ b/unit_test/core/test_module.py @@ -1,14 +1,14 @@ import unittest +from typing import Dict, List + import torch import torch.nn as nn -from typing import Dict, List -from evox.core import jit_class, ModuleBase, trace_impl, use_state +from evox.core import ModuleBase, jit_class, trace_impl, use_state @jit_class class DummyModule(ModuleBase): - def __init__(self, threshold=0.5): super().__init__() self.threshold = threshold @@ -36,7 +36,6 @@ def g(self, p: torch.Tensor) -> torch.Tensor: class TestModule(unittest.TestCase): - def setUp(self): self.test_instance = DummyModule() diff --git a/unit_test/problems/test_hpo_wrapper.py b/unit_test/problems/test_hpo_wrapper.py index a520d9712..707c533ef 100644 --- a/unit_test/problems/test_hpo_wrapper.py +++ b/unit_test/problems/test_hpo_wrapper.py @@ -1,9 +1,11 @@ import unittest + import torch from torch import nn -from evox.core import jit_class, Problem, Algorithm, trace_impl, Parameter + +from evox.core import Algorithm, Parameter, Problem, jit_class, trace_impl +from evox.problems.hpo_wrapper import HPOFitnessMonitor, HPOProblemWrapper from evox.workflows import StdWorkflow -from evox.problems.hpo_wrapper import HPOProblemWrapper, HPOFitnessMonitor @jit_class diff --git a/unit_test/utils/test_control_flow.py b/unit_test/utils/test_control_flow.py index e172a38e7..a0d144f6c 100644 --- a/unit_test/utils/test_control_flow.py +++ b/unit_test/utils/test_control_flow.py @@ -1,4 +1,5 @@ import unittest + import torch from evox.core import ModuleBase, jit, jit_class, trace_impl, use_state, vmap diff --git a/unit_test/utils/test_jit_fix.py b/unit_test/utils/test_jit_fix.py index 56a05b7ad..4171942dc 100644 --- a/unit_test/utils/test_jit_fix.py +++ b/unit_test/utils/test_jit_fix.py @@ -1,5 +1,7 @@ import unittest + import torch + from evox.core import jit from evox.utils import switch diff --git a/unit_test/workflows/test_std_workflow.py b/unit_test/workflows/test_std_workflow.py index 20b9c5009..ed8670910 100644 --- a/unit_test/workflows/test_std_workflow.py +++ b/unit_test/workflows/test_std_workflow.py @@ -1,21 +1,22 @@ import unittest + import torch import torch.nn as nn + from evox.core import ( - vmap, - trace_impl, - use_state, - jit, - jit_class, Algorithm, Problem, + jit, + jit_class, + trace_impl, + use_state, + vmap, ) -from evox.workflows import StdWorkflow, EvalMonitor +from evox.workflows import EvalMonitor, StdWorkflow @jit_class class BasicProblem(Problem): - def __init__(self): super().__init__() self._eval_fn = vmap(BasicProblem._single_eval, trace=False) @@ -34,7 +35,6 @@ def trace_evaluate(self, pop: torch.Tensor): @jit_class class BasicAlgorithm(Algorithm): - def __init__(self, pop_size: int, lb: torch.Tensor, ub: torch.Tensor): super().__init__() assert ( @@ -73,7 +73,6 @@ def trace_step(self): class TestStdWorkflow(unittest.TestCase): - def setUp(self): torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu") self.algo = BasicAlgorithm(10, -10 * torch.ones(2), 10 * torch.ones(2))