-
Hey, I have a question. The
Does it mean, that better to mark import typing as tp
import functools as ft
import chex
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import jax
from jax import numpy as jnp
from flax import nnx
import optax
@ft.partial(jax.jit, static_argnums=0) # <========================= here
def step(graphdef: nnx.GraphDef, # compile-time constant
state: nnx.GraphState,
features: chex.ArrayBatched,
labels: chex.ArrayBatched,
train: bool
) -> nnx.GraphState:
def loss_fn(
model: nnx.Module,
features: chex.ArrayBatched,
labels: chex.ArrayBatched
) -> tuple[chex.ArrayBatched]:
"""Computes CE-Loss with optax. Returns loss and logits"""
logits = model(features)
loss = jnp.mean(
optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=labels
)
)
return loss, logits
def train_step(
args: tuple[nnx.Module, chex.ArrayBatched, nnx.Optimizer]
) -> tuple[chex.ArrayBatched]:
model, features, labels, optimizer = args
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, features, labels)
optimizer.update(grads)
return loss, logits
def eval_step(
args: tuple[nnx.Module, chex.ArrayBatched, nnx.Optimizer]
) -> tuple[chex.ArrayBatched]:
model, features, labels, _ = args
return loss_fn(model, features, labels)
model, optimizer, metrics = nnx.merge(graphdef, state)
loss, logits = nnx.cond(
train,
train_step,
eval_step,
(model, features, labels, optimizer)
)
metrics.update(loss=loss, logits=logits, labels=labels)
_, state = nnx.split((model, optimizer, metrics))
return state
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X,
y,
test_size=0.2,
shuffle=True,
random_state=42)
def batch(X: chex.Array,
y: chex.Array,
batch_size: int,
key: chex.PRNGKey,
train: bool = False) -> tp.Generator[tuple[chex.ArrayBatched], None, None]:
assert X.shape[0] == y.shape[0], \
f"Features and labels must have same dimensions, but got {X.shape[0]} and {y.shape[0]} respectively"
if X.shape[0] % batch_size:
warnings.warn(f"Can't split X with {X.shape[0]} examples in equal batches with {batch_size=}. Last batch will be droped", category=UserWarning)
if train:
permutation_order = jax.random.permutation(key=key, x=X.shape[0])
for batch_ in range(X.shape[0] // batch_size):
yield (
X[permutation_order[batch_*batch_size:(batch_+1)*batch_size]],
y[permutation_order[batch_*batch_size:(batch_+1)*batch_size]]
)
else:
for batch_ in range(X.shape[0] // batch_size):
yield (
X[batch_*batch_size:(batch_+1)*batch_size],
y[batch_*batch_size:(batch_+1)*batch_size]
)
def run(X_train: chex.ArrayNumpy,
y_train: chex.ArrayNumpy,
X_test: chex.ArrayNumpy,
y_test: chex.ArrayNumpy,
batch_size: int = 8,
key: tp.Optional[chex.PRNGKey] = None,
num_epochs: int = 10) -> None:
"""Inits model and it's components and run train and test"""
if key is None:
key = jax.random.key(42)
batch_key, model_key = jax.random.split(key)
model_key = nnx.Rngs(model_key)
model = nnx.Linear(4, 3, rngs=model_key)
optimizer = nnx.Optimizer(model, optax.adamw(5e-3, 0.9))
metrics = nnx.MultiMetric(
accuracy=nnx.metrics.Accuracy(),
loss=nnx.metrics.Average('loss'),
)
metrics_history = {
'train_loss': [],
'train_accuracy': [0],
'test_loss': [],
'test_accuracy': [0],
}
for i in range(num_epochs):
# train
model.train()
graphdef, state = nnx.split((model, optimizer, metrics))
# new order for shuffle
batch_key = jax.random.split(batch_key)[0]
for X_batched, y_batched in batch(X=X_train,
y=y_train,
batch_size=batch_size,
key=batch_key,
train=True):
state = step(graphdef=graphdef,
state=state,
features=X_batched,
labels=y_batched,
train=True)
nnx.update((model, optimizer, metrics), state)
# store train metrics
for metric, value in metrics.compute().items():
metrics_history[f'train_{metric}'].append(value)
metrics.reset()
# eval
model.eval()
graphdef, state = nnx.split((model, optimizer, metrics))
for X_batched, y_batched in batch(X=X_test,
y=y_test,
batch_size=batch_size,
key=batch_key,
train=False):
state = step(graphdef=graphdef,
state=state,
features=X_batched,
labels=y_batched,
train=False)
nnx.update((model, optimizer, metrics), state)
# store eval metrics
for metric, value in metrics.compute().items():
metrics_history[f'test_{metric}'].append(value)
metrics.reset()
print(
f"[train] eposh: {i}, "
f"loss: {metrics_history['train_loss'][-1]}, "
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
)
print(
f"[test] epoch: {i}, "
f"loss: {metrics_history['test_loss'][-1]}, "
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
)
run(X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
batch_size=10,
key=jax.random.key(42),
num_epochs=20) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi @Tomas542, there two things going on here:
|
Beta Was this translation helpful? Give feedback.
Hi @Tomas542, there two things going on here:
The
graphdef
contains all the structural information: topology of the object graph, all metadata fields like strings, integers, etc, that you would pass as astatic_argnum
if you passed them individually.You can mark
graphdef
as astatic_argnum
if you want but there is no need because we use jax.tree_util.register_static overNodeDef
so JAX will understand that its a static structure with no Arrays.