Skip to content

Commit

Permalink
Format code
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Jan 6, 2025
1 parent f8280ab commit 8d415de
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 25 deletions.
8 changes: 5 additions & 3 deletions benchmarks/pso.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Benchmark the performance of PSO algorithm in EvoX."""

import time

import torch
from torch.profiler import profile, ProfilerActivity
from evox.core import vmap, Problem, use_state, jit
from evox.workflows import StdWorkflow
from torch.profiler import ProfilerActivity, profile

from evox.algorithms import PSO
from evox.core import Problem, jit, use_state, vmap
from evox.workflows import StdWorkflow


def run_pso():
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Benchmark the performance of utils functions in EvoX"""

import torch
from torch.profiler import profile, ProfilerActivity
from torch.profiler import ProfilerActivity, profile

from evox.core import jit
from evox.utils import switch

Expand Down
7 changes: 4 additions & 3 deletions docs/fix_notebook_translation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""
This script is used to wrap the translated strings in jupyter notebooks in the docs.po file
This script is used to wrap the translated strings in jupyter notebooks in the docs.po
"""

import polib
import json
import copy
import json

import polib

po = polib.pofile("docs/source/locale/zh/LC_MESSAGES/docs.po")

Expand Down
6 changes: 4 additions & 2 deletions unit_test/algorithms/test_pso.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from unittest import TestCase

import torch
from evox.core import vmap, Problem, use_state, jit
from evox.workflows import StdWorkflow

from evox.algorithms import PSO
from evox.core import Problem, jit, use_state, vmap
from evox.workflows import StdWorkflow


class TestPSO(TestCase):
Expand Down
4 changes: 3 additions & 1 deletion unit_test/core/test_jit_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest
from functools import partial

import torch
from evox.core import vmap, jit

from evox.core import jit, vmap


@partial(vmap, example_ndim=2)
Expand Down
7 changes: 3 additions & 4 deletions unit_test/core/test_module.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import unittest
from typing import Dict, List

import torch
import torch.nn as nn
from typing import Dict, List

from evox.core import jit_class, ModuleBase, trace_impl, use_state
from evox.core import ModuleBase, jit_class, trace_impl, use_state


@jit_class
class DummyModule(ModuleBase):

def __init__(self, threshold=0.5):
super().__init__()
self.threshold = threshold
Expand Down Expand Up @@ -36,7 +36,6 @@ def g(self, p: torch.Tensor) -> torch.Tensor:


class TestModule(unittest.TestCase):

def setUp(self):
self.test_instance = DummyModule()

Expand Down
6 changes: 4 additions & 2 deletions unit_test/problems/test_hpo_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import unittest

import torch
from torch import nn
from evox.core import jit_class, Problem, Algorithm, trace_impl, Parameter

from evox.core import Algorithm, Parameter, Problem, jit_class, trace_impl
from evox.problems.hpo_wrapper import HPOFitnessMonitor, HPOProblemWrapper
from evox.workflows import StdWorkflow
from evox.problems.hpo_wrapper import HPOProblemWrapper, HPOFitnessMonitor


@jit_class
Expand Down
1 change: 1 addition & 0 deletions unit_test/utils/test_control_flow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest

import torch

from evox.core import ModuleBase, jit, jit_class, trace_impl, use_state, vmap
Expand Down
2 changes: 2 additions & 0 deletions unit_test/utils/test_jit_fix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest

import torch

from evox.core import jit
from evox.utils import switch

Expand Down
17 changes: 8 additions & 9 deletions unit_test/workflows/test_std_workflow.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import unittest

import torch
import torch.nn as nn

from evox.core import (
vmap,
trace_impl,
use_state,
jit,
jit_class,
Algorithm,
Problem,
jit,
jit_class,
trace_impl,
use_state,
vmap,
)
from evox.workflows import StdWorkflow, EvalMonitor
from evox.workflows import EvalMonitor, StdWorkflow


@jit_class
class BasicProblem(Problem):

def __init__(self):
super().__init__()
self._eval_fn = vmap(BasicProblem._single_eval, trace=False)
Expand All @@ -34,7 +35,6 @@ def trace_evaluate(self, pop: torch.Tensor):

@jit_class
class BasicAlgorithm(Algorithm):

def __init__(self, pop_size: int, lb: torch.Tensor, ub: torch.Tensor):
super().__init__()
assert (
Expand Down Expand Up @@ -73,7 +73,6 @@ def trace_step(self):


class TestStdWorkflow(unittest.TestCase):

def setUp(self):
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
self.algo = BasicAlgorithm(10, -10 * torch.ones(2), 10 * torch.ones(2))
Expand Down

0 comments on commit 8d415de

Please sign in to comment.