Skip to content

Commit 11b0fc7

Browse files
authored
Merge pull request #375 from apax-hub/py12
Py 3.12 compatibility, rm chex and jaxtyping
2 parents 82d5d6b + f742cef commit 11b0fc7

File tree

7 files changed

+591
-800
lines changed

7 files changed

+591
-800
lines changed

apax/md/md_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import logging
2+
import typing as t
23
from pathlib import Path
34

45
from flax.training import checkpoints
5-
from jaxtyping import PyTree
66

77
log = logging.getLogger(__name__)
88

99

10-
def load_md_state(state: PyTree, ckpt_dir: Path) -> tuple[PyTree, int]:
10+
def load_md_state(state: t.Any, ckpt_dir: Path) -> tuple[t.Any, int]:
1111
try:
1212
log.info(f"loading MD state from {ckpt_dir}")
1313
target = {"state": state, "step": 0}

apax/nodes/optimizer/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

apax/nodes/optimizer/get_optimizer.py

Lines changed: 0 additions & 157 deletions
This file was deleted.

apax/nodes/optimizer/optimizers.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

apax/optimizer/optimizers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import NamedTuple
22

3-
import chex
3+
import jax
44
import jax.numpy as jnp
55
import optax
66
from jax import tree_util as jtu
@@ -10,8 +10,8 @@
1010

1111

1212
class ScaleByAdemamixState(NamedTuple):
13-
count: chex.Array
14-
count_m2: chex.Array
13+
count: jax.Array
14+
count_m2: jax.Array
1515
m1: base.Updates
1616
m2: base.Updates
1717
nu: base.Updates

0 commit comments

Comments
 (0)