diff --git a/benchmarks/pso.py b/benchmarks/pso.py index 43e6dbfd..9acb77d8 100644 --- a/benchmarks/pso.py +++ b/benchmarks/pso.py @@ -1,11 +1,13 @@ """Benchmark the performance of PSO algorithm in EvoX.""" import time + import torch -from torch.profiler import profile, ProfilerActivity -from evox.core import vmap, Problem, use_state, jit -from evox.workflows import StdWorkflow +from torch.profiler import ProfilerActivity, profile + from evox.algorithms import PSO +from evox.core import Problem, jit, use_state, vmap +from evox.workflows import StdWorkflow def run_pso(): diff --git a/benchmarks/utils.py b/benchmarks/utils.py index aef82cb7..e92c6fe8 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -1,7 +1,8 @@ """Benchmark the performance of utils functions in EvoX""" import torch -from torch.profiler import profile, ProfilerActivity +from torch.profiler import ProfilerActivity, profile + from evox.core import jit from evox.utils import switch diff --git a/docs/fix_notebook_translation.py b/docs/fix_notebook_translation.py index bc210d93..7baa70e0 100644 --- a/docs/fix_notebook_translation.py +++ b/docs/fix_notebook_translation.py @@ -1,10 +1,11 @@ """ -This script is used to wrap the translated strings in jupyter notebooks in the docs.po file +This script is used to wrap the translated strings in jupyter notebooks in the docs.po """ -import polib -import json import copy +import json + +import polib po = polib.pofile("docs/source/locale/zh/LC_MESSAGES/docs.po") diff --git a/unit_test/algorithms/test_pso.py b/unit_test/algorithms/test_pso.py index 440d0b18..72380627 100644 --- a/unit_test/algorithms/test_pso.py +++ b/unit_test/algorithms/test_pso.py @@ -1,8 +1,10 @@ from unittest import TestCase + import torch -from evox.core import vmap, Problem, use_state, jit -from evox.workflows import StdWorkflow + from evox.algorithms import PSO +from evox.core import Problem, jit, use_state, vmap +from evox.workflows import StdWorkflow class TestPSO(TestCase): diff --git a/unit_test/core/test_jit_util.py b/unit_test/core/test_jit_util.py index 6ce5bc14..aa41e716 100644 --- a/unit_test/core/test_jit_util.py +++ b/unit_test/core/test_jit_util.py @@ -1,7 +1,9 @@ import unittest from functools import partial + import torch -from evox.core import vmap, jit + +from evox.core import jit, vmap @partial(vmap, example_ndim=2) diff --git a/unit_test/core/test_module.py b/unit_test/core/test_module.py index f6966a87..96c5f584 100644 --- a/unit_test/core/test_module.py +++ b/unit_test/core/test_module.py @@ -1,14 +1,14 @@ import unittest +from typing import Dict, List + import torch import torch.nn as nn -from typing import Dict, List -from evox.core import jit_class, ModuleBase, trace_impl, use_state +from evox.core import ModuleBase, jit_class, trace_impl, use_state @jit_class class DummyModule(ModuleBase): - def __init__(self, threshold=0.5): super().__init__() self.threshold = threshold @@ -36,7 +36,6 @@ def g(self, p: torch.Tensor) -> torch.Tensor: class TestModule(unittest.TestCase): - def setUp(self): self.test_instance = DummyModule() diff --git a/unit_test/problems/test_hpo_wrapper.py b/unit_test/problems/test_hpo_wrapper.py index 48a625bf..122a4008 100644 --- a/unit_test/problems/test_hpo_wrapper.py +++ b/unit_test/problems/test_hpo_wrapper.py @@ -1,9 +1,11 @@ import unittest + import torch from torch import nn -from evox.core import jit_class, Problem, Algorithm, trace_impl, Parameter + +from evox.core import Algorithm, Parameter, Problem, jit_class, trace_impl +from evox.problems.hpo_wrapper import HPOFitnessMonitor, HPOProblemWrapper from evox.workflows import StdWorkflow -from evox.problems.hpo_wrapper import HPOProblemWrapper, HPOFitnessMonitor @jit_class diff --git a/unit_test/utils/test_control_flow.py b/unit_test/utils/test_control_flow.py index f0d418df..6c40e9a7 100644 --- a/unit_test/utils/test_control_flow.py +++ b/unit_test/utils/test_control_flow.py @@ -1,7 +1,9 @@ import unittest + import torch -from evox.core import use_state, jit, vmap -from evox.utils import TracingWhile, TracingCond + +from evox.core import jit, use_state, vmap +from evox.utils import TracingCond, TracingWhile class TestControlFlow(unittest.TestCase): diff --git a/unit_test/utils/test_jit_fix.py b/unit_test/utils/test_jit_fix.py index ad5dd5fc..d09b69d5 100644 --- a/unit_test/utils/test_jit_fix.py +++ b/unit_test/utils/test_jit_fix.py @@ -1,5 +1,7 @@ import unittest + import torch + from evox.core import jit from evox.utils import switch diff --git a/unit_test/workflows/test_std_workflow.py b/unit_test/workflows/test_std_workflow.py index 20b9c500..ed867091 100644 --- a/unit_test/workflows/test_std_workflow.py +++ b/unit_test/workflows/test_std_workflow.py @@ -1,21 +1,22 @@ import unittest + import torch import torch.nn as nn + from evox.core import ( - vmap, - trace_impl, - use_state, - jit, - jit_class, Algorithm, Problem, + jit, + jit_class, + trace_impl, + use_state, + vmap, ) -from evox.workflows import StdWorkflow, EvalMonitor +from evox.workflows import EvalMonitor, StdWorkflow @jit_class class BasicProblem(Problem): - def __init__(self): super().__init__() self._eval_fn = vmap(BasicProblem._single_eval, trace=False) @@ -34,7 +35,6 @@ def trace_evaluate(self, pop: torch.Tensor): @jit_class class BasicAlgorithm(Algorithm): - def __init__(self, pop_size: int, lb: torch.Tensor, ub: torch.Tensor): super().__init__() assert ( @@ -73,7 +73,6 @@ def trace_step(self): class TestStdWorkflow(unittest.TestCase): - def setUp(self): torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu") self.algo = BasicAlgorithm(10, -10 * torch.ones(2), 10 * torch.ones(2))