Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow gradient transform parameters to be dynamic #516

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 55 additions & 55 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class TraceState(NamedTuple):


def trace(
decay: float,
decay: Union[float, jax.Array],
nesterov: bool = False,
accumulator_dtype: Optional[Any] = None,
) -> base.GradientTransformation:
Expand Down Expand Up @@ -134,7 +134,7 @@ class EmaState(NamedTuple):


def ema(
decay: float,
decay: Union[float, jax.Array],
debias: bool = True,
accumulator_dtype: Optional[Any] = None
) -> base.GradientTransformation:
Expand Down Expand Up @@ -180,8 +180,8 @@ class ScaleByRssState(NamedTuple):


def scale_by_rss(
initial_accumulator_value: float = 0.1,
eps: float = 1e-7
initial_accumulator_value: Union[float, jax.Array] = 0.1,
eps: Union[float, jax.Array] = 1e-7
) -> base.GradientTransformation:
"""Rescale updates by the root of the sum of all squared gradients to date.

Expand Down Expand Up @@ -221,9 +221,9 @@ class ScaleByRmsState(NamedTuple):


def scale_by_rms(
decay: float = 0.9,
eps: float = 1e-8,
initial_scale: float = 0.
decay: Union[float, jax.Array] = 0.9,
eps: Union[float, jax.Array] = 1e-8,
initial_scale: Union[float, jax.Array] = 0.
) -> base.GradientTransformation:
"""Rescale updates by the root of the exp. moving avg of the square.

Expand Down Expand Up @@ -261,9 +261,9 @@ class ScaleByRStdDevState(NamedTuple):


def scale_by_stddev(
decay: float = 0.9,
eps: float = 1e-8,
initial_scale: float = 0.
decay: Union[float, jax.Array] = 0.9,
eps: Union[float, jax.Array] = 1e-8,
initial_scale: Union[float, jax.Array] = 0.
) -> base.GradientTransformation:
"""Rescale updates by the root of the centered exp. moving average of squares.

Expand Down Expand Up @@ -305,10 +305,10 @@ class ScaleByAdamState(NamedTuple):


def scale_by_adam(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
b1: Union[float, jax.Array] = 0.9,
b2: Union[float, jax.Array] = 0.999,
eps: Union[float, jax.Array] = 1e-8,
eps_root: Union[float, jax.Array] = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None,
) -> base.GradientTransformation:
"""Rescale updates according to the Adam algorithm.
Expand Down Expand Up @@ -361,10 +361,10 @@ class ScaleByAmsgradState(NamedTuple):


def scale_by_amsgrad(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
b1: Union[float, jax.Array] = 0.9,
b2: Union[float, jax.Array] = 0.999,
eps: Union[float, jax.Array] = 1e-8,
eps_root: Union[float, jax.Array] = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None,
) -> base.GradientTransformation:
"""Rescale updates according to the AMSGrad algorithm.
Expand Down Expand Up @@ -413,9 +413,9 @@ def update_fn(updates, state, params=None):


def scale_by_adamax(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8
b1: Union[float, jax.Array] = 0.9,
b2: Union[float, jax.Array] = 0.999,
eps: Union[float, jax.Array] = 1e-8
) -> base.GradientTransformation:
"""Rescale updates according to the Adamax algorithm.

Expand Down Expand Up @@ -456,8 +456,8 @@ class ScaleByLionState(NamedTuple):


def scale_by_lion(
b1: float = 0.9,
b2: float = 0.99,
b1: Union[float, jax.Array] = 0.9,
b2: Union[float, jax.Array] = 0.99,
mu_dtype: Optional[chex.ArrayDType] = None,
) -> base.GradientTransformation:
"""Rescale updates according to the Lion algorithm.
Expand Down Expand Up @@ -498,7 +498,7 @@ def update_fn(updates, state, params=None):


def scale(
step_size: float
step_size: Union[float, jax.Array]
) -> base.GradientTransformation:
"""Scale updates by some fixed scalar `step_size`.

Expand All @@ -522,7 +522,7 @@ def update_fn(updates, state, params=None):


def scale_by_param_block_norm(
min_scale: float = 1e-3
min_scale: Union[float, jax.Array] = 1e-3
) -> base.GradientTransformation:
"""Scale updates for each param block by the norm of that block's parameters.

Expand Down Expand Up @@ -552,7 +552,7 @@ def update_fn(updates, state, params):


def scale_by_param_block_rms(
min_scale: float = 1e-3
min_scale: Union[float, jax.Array] = 1e-3
) -> base.GradientTransformation:
"""Scale updates by rms of the gradient for each param vector or matrix.

Expand Down Expand Up @@ -589,10 +589,10 @@ class ScaleByBeliefState(NamedTuple):


def scale_by_belief(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-16,
eps_root: float = 1e-16
b1: Union[float, jax.Array] = 0.9,
b2: Union[float, jax.Array] = 0.999,
eps: Union[float, jax.Array] = 1e-16,
eps_root: Union[float, jax.Array] = 1e-16
) -> base.GradientTransformation:
"""Rescale updates according to the AdaBelief algorithm.

Expand Down Expand Up @@ -634,11 +634,11 @@ def update_fn(updates, state, params=None):


def scale_by_yogi(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-3,
eps_root: float = 0.0,
initial_accumulator_value: float = 1e-6
b1: Union[float, jax.Array] = 0.9,
b2: Union[float, jax.Array] = 0.999,
eps: Union[float, jax.Array] = 1e-3,
eps_root: Union[float, jax.Array] = 0.0,
initial_accumulator_value: Union[float, jax.Array] = 1e-6
) -> base.GradientTransformation:
"""Rescale updates according to the Yogi algorithm.

Expand Down Expand Up @@ -684,11 +684,11 @@ def update_fn(updates, state, params=None):


def scale_by_radam(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
threshold: float = 5.0
b1: Union[float, jax.Array] = 0.9,
b2: Union[float, jax.Array] = 0.999,
eps: Union[float, jax.Array] = 1e-8,
eps_root: Union[float, jax.Array] = 0.0,
threshold: Union[float, jax.Array] = 5.0
) -> base.GradientTransformation:
"""Rescale updates according to the Rectified Adam algorithm.

Expand Down Expand Up @@ -816,9 +816,9 @@ class ScaleByTrustRatioState(NamedTuple):


def scale_by_trust_ratio(
min_norm: float = 0.0,
trust_coefficient: float = 1.,
eps: float = 0.,
min_norm: Union[float, jax.Array] = 0.0,
trust_coefficient: Union[float, jax.Array] = 1.,
eps: Union[float, jax.Array] = 0.,
) -> base.GradientTransformation:
"""Scale updates by `trust ratio`.

Expand Down Expand Up @@ -870,8 +870,8 @@ class AddNoiseState(NamedTuple):


def add_noise(
eta: float,
gamma: float,
eta: Union[float, jax.Array],
gamma: Union[float, jax.Array],
seed: int
) -> base.GradientTransformation:
"""Add gradient noise.
Expand Down Expand Up @@ -993,9 +993,9 @@ class ScaleBySM3State(NamedTuple):


def scale_by_sm3(
b1: float = 0.9,
b2: float = 1.0,
eps: float = 1e-8
b1: Union[float, jax.Array] = 0.9,
b2: Union[float, jax.Array] = 1.0,
eps: Union[float, jax.Array] = 1e-8
) -> base.GradientTransformation:
"""Scale updates by `sm3`.

Expand Down Expand Up @@ -1069,11 +1069,11 @@ class ScaleByNovogradState(NamedTuple):


def scale_by_novograd(
b1: float = 0.9,
b2: float = 0.25,
eps: float = 1e-8,
eps_root: float = 0.0,
weight_decay: float = 0.0,
b1: Union[float, jax.Array] = 0.9,
b2: Union[float, jax.Array] = 0.25,
eps: Union[float, jax.Array] = 1e-8,
eps_root: Union[float, jax.Array] = 0.0,
weight_decay: Union[float, jax.Array] = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None,
) -> base.GradientTransformation:
"""Computes NovoGrad updates.
Expand Down Expand Up @@ -1141,8 +1141,8 @@ def update_fn(updates, state, params):
return base.GradientTransformation(init_fn, update_fn)


def scale_by_optimistic_gradient(alpha: float = 1.0,
beta: float = 1.0
def scale_by_optimistic_gradient(alpha: Union[float, jax.Array] = 1.0,
beta: Union[float, jax.Array] = 1.0
) -> base.GradientTransformation:
"""Compute generalized optimistic gradients.

Expand Down