Skip to content

Commit

Permalink
Format code
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Jan 6, 2025
1 parent 3d07d1e commit f5aa061
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 27 deletions.
8 changes: 5 additions & 3 deletions benchmarks/pso.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Benchmark the performance of PSO algorithm in EvoX."""

import time

import torch
from torch.profiler import profile, ProfilerActivity
from evox.core import vmap, Problem, use_state, jit
from evox.workflows import StdWorkflow
from torch.profiler import ProfilerActivity, profile

from evox.algorithms import PSO
from evox.core import Problem, jit, use_state, vmap
from evox.workflows import StdWorkflow


def run_pso():

Check failure on line 13 in benchmarks/pso.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

benchmarks/pso.py:3:1: I001 Import block is un-sorted or un-formatted
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Benchmark the performance of utils functions in EvoX"""

import torch
from torch.profiler import profile, ProfilerActivity
from torch.profiler import ProfilerActivity, profile

from evox.core import jit
from evox.utils import switch

Expand Down
7 changes: 4 additions & 3 deletions docs/fix_notebook_translation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""
This script is used to wrap the translated strings in jupyter notebooks in the docs.po file
This script is used to wrap the translated strings in jupyter notebooks in the docs.po
"""

import polib
import json
import copy
import json

import polib

po = polib.pofile("docs/source/locale/zh/LC_MESSAGES/docs.po")

Expand Down
6 changes: 4 additions & 2 deletions unit_test/algorithms/test_pso.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from unittest import TestCase

import torch
from evox.core import vmap, Problem, use_state, jit
from evox.workflows import StdWorkflow

from evox.algorithms import PSO
from evox.core import Problem, jit, use_state, vmap
from evox.workflows import StdWorkflow


class TestPSO(TestCase):

Check failure on line 10 in unit_test/algorithms/test_pso.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

unit_test/algorithms/test_pso.py:1:1: I001 Import block is un-sorted or un-formatted
Expand Down
4 changes: 3 additions & 1 deletion unit_test/core/test_jit_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest
from functools import partial

import torch
from evox.core import vmap, jit

from evox.core import jit, vmap


@partial(vmap, example_ndim=2)

Check failure on line 9 in unit_test/core/test_jit_util.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

unit_test/core/test_jit_util.py:1:1: I001 Import block is un-sorted or un-formatted
Expand Down
7 changes: 3 additions & 4 deletions unit_test/core/test_module.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import unittest
from typing import Dict, List

import torch
import torch.nn as nn
from typing import Dict, List

from evox.core import jit_class, ModuleBase, trace_impl, use_state
from evox.core import ModuleBase, jit_class, trace_impl, use_state


@jit_class

Check failure on line 10 in unit_test/core/test_module.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

unit_test/core/test_module.py:1:1: I001 Import block is un-sorted or un-formatted
class DummyModule(ModuleBase):

def __init__(self, threshold=0.5):
super().__init__()
self.threshold = threshold
Expand Down Expand Up @@ -36,7 +36,6 @@ def g(self, p: torch.Tensor) -> torch.Tensor:


class TestModule(unittest.TestCase):

def setUp(self):
self.test_instance = DummyModule()

Expand Down
6 changes: 4 additions & 2 deletions unit_test/problems/test_hpo_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import unittest

import torch
from torch import nn
from evox.core import jit_class, Problem, Algorithm, trace_impl, Parameter

from evox.core import Algorithm, Parameter, Problem, jit_class, trace_impl
from evox.problems.hpo_wrapper import HPOFitnessMonitor, HPOProblemWrapper
from evox.workflows import StdWorkflow
from evox.problems.hpo_wrapper import HPOProblemWrapper, HPOFitnessMonitor


@jit_class

Check failure on line 11 in unit_test/problems/test_hpo_wrapper.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

unit_test/problems/test_hpo_wrapper.py:1:1: I001 Import block is un-sorted or un-formatted
Expand Down
6 changes: 4 additions & 2 deletions unit_test/utils/test_control_flow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest

import torch
from evox.core import use_state, jit, vmap
from evox.utils import TracingWhile, TracingCond

from evox.core import jit, use_state, vmap
from evox.utils import TracingCond, TracingWhile


class TestControlFlow(unittest.TestCase):

Check failure on line 9 in unit_test/utils/test_control_flow.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

unit_test/utils/test_control_flow.py:1:1: I001 Import block is un-sorted or un-formatted
Expand Down
2 changes: 2 additions & 0 deletions unit_test/utils/test_jit_fix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest

import torch

from evox.core import jit
from evox.utils import switch

Expand Down
17 changes: 8 additions & 9 deletions unit_test/workflows/test_std_workflow.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import unittest

import torch
import torch.nn as nn

from evox.core import (
vmap,
trace_impl,
use_state,
jit,
jit_class,
Algorithm,
Problem,
jit,
jit_class,
trace_impl,
use_state,
vmap,
)
from evox.workflows import StdWorkflow, EvalMonitor
from evox.workflows import EvalMonitor, StdWorkflow


@jit_class

Check failure on line 18 in unit_test/workflows/test_std_workflow.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

unit_test/workflows/test_std_workflow.py:1:1: I001 Import block is un-sorted or un-formatted
class BasicProblem(Problem):

def __init__(self):
super().__init__()
self._eval_fn = vmap(BasicProblem._single_eval, trace=False)
Expand All @@ -34,7 +35,6 @@ def trace_evaluate(self, pop: torch.Tensor):

@jit_class
class BasicAlgorithm(Algorithm):

def __init__(self, pop_size: int, lb: torch.Tensor, ub: torch.Tensor):
super().__init__()
assert (
Expand Down Expand Up @@ -73,7 +73,6 @@ def trace_step(self):


class TestStdWorkflow(unittest.TestCase):

def setUp(self):
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
self.algo = BasicAlgorithm(10, -10 * torch.ones(2), 10 * torch.ones(2))
Expand Down

0 comments on commit f5aa061

Please sign in to comment.