Skip to content

Commit

Permalink
Corrections to make JAX omistaging work, sparse decoding speed and me…
Browse files Browse the repository at this point in the history
…mory efficient training.

* setting RNGs manually after they have been traced by JAX (we're now fully omni-staged!)
* storing SparseFF weights in decoding-friendly format (transposing in train), makes decoding timing test faster
* corrections to ReversibleSerialTrainer slots setter bug and a test (without this, restoring from checkpoint was failing)
* lowering warmup_steps in gin config as it was affecting learning rate

PiperOrigin-RevId: 332582223
  • Loading branch information
Lukasz Kaiser authored and copybara-github committed Sep 19, 2020
1 parent d24503a commit a0b31ef
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 6 deletions.
2 changes: 2 additions & 0 deletions trax/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,9 @@ def state(self, state):

def weights_and_state_signature(self, input_signature):
"""Return a pair containing the signatures of weights and state."""
rng = self.rng
abstract_init = fastmath.abstract_eval(self.init)
self.rng = rng
return abstract_init(input_signature)

@property
Expand Down
10 changes: 8 additions & 2 deletions trax/layers/research/efficient_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,9 @@ def forward(self, x):
Tensor of same shape and dtype as the input.
"""
m1, m2, mb, w1, w2, b2 = self.weights
if self._mode != 'predict':
w1 = np.reshape(w1.T, (-1, self._d_ff))
w2 = np.reshape(w2, (self._d_ff, -1))
x_shape = x.shape
x = np.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.

Expand Down Expand Up @@ -1684,8 +1687,8 @@ def forward(self, x):
relu = np.where(mid <= 0, np.zeros_like(mid), mid)
res = np.dot(relu, w2) + b2
elif self._mode == 'predict':
w1 = np.reshape(w1.T, (self._d1, self._d2, -1))
w2 = np.reshape(w2, (self._d1, self._d2, -1))
# w1 = np.reshape(w1.T, (self._d1, self._d2, -1))
# w2 = np.reshape(w2, (self._d1, self._d2, -1))
# This implementation mimicks inference. It's not efficient for large
# size of joint_batch, but at inference that will be 1 most of the time.
# Shapes:
Expand Down Expand Up @@ -1733,4 +1736,7 @@ def init_weights_and_state(self, input_signature):
w1 = self._kernel_initializer(shape_w1, rng_w1)
w2 = self._kernel_initializer(shape_w2, rng_w2)
b2 = self._bias_initializer(shape_b2, rng_b2)

w1 = np.reshape(w1.T, (self._d1, self._d2, -1))
w2 = np.reshape(w2, (self._d1, self._d2, -1))
self.weights = (m1, m2, mb, w1, w2, b2)
4 changes: 2 additions & 2 deletions trax/optimizers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def slots(self):
@slots.setter
def slots(self, slots):
"""Sets the slots of all optimizers."""
for ((s_opts, r_opts), (s_slots, r_slots)) in zip(self._optimizers, slots):
for (opt, slot) in zip(s_opts + r_opts, s_slots + r_slots):
for ((s_opt, r_opts), (s_slots, r_slots)) in zip(self._optimizers, slots):
for (opt, slot) in zip([s_opt] + r_opts, [s_slots] + r_slots):
opt.slots = slot

def _pjit(self, f):
Expand Down
13 changes: 13 additions & 0 deletions trax/optimizers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def test_run_simple_task_tfnp(self):
rng = fastmath.random.get_prng(0)
trainer.one_step(labeled_batch, rng)

def test_run_reversible_slots(self):
"""Tests that slots can be read and assigned in reversible trainer."""
layers = [tl.Dense(4), tl.Dup()]
rev_layers = [tl.ReversibleHalfResidual(tl.Dense(4)),
tl.ReversibleSwap()]
loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(4),
tl.LogSoftmax(), tl.CrossEntropyLoss())
trainer = optimizers.ReversibleSerialTrainer(
[(layers, rev_layers)], loss_layer, optimizers.Adam)
slots = trainer.slots
trainer.slots = slots
self.assertEqual(slots, trainer.slots)

def test_run_reversible_same_as_default_basic(self):
"""Runs the reversible trainer, check results are the same as default."""
inputs_batch = np.arange(8).reshape((2, 4))
Expand Down
2 changes: 2 additions & 0 deletions trax/rl/serialization_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def test_extract_inner_model(self):
action_serializer=act_serializer,
significance_decay=0.9,
)
rng = inner_model.rng

obs_sig = shapes.ShapeDtype((1, 2))
act_sig = shapes.ShapeDtype((1, 1))
Expand All @@ -239,6 +240,7 @@ def test_extract_inner_model(self):
(inner_weights, inner_state) = map(
serialization_utils.extract_inner_model, (weights, state)
)
inner_model.rng = rng
inner_model(jnp.array([[0]]), weights=inner_weights, state=inner_state)

@parameterized.named_parameters(('raw', None), ('serialized', 32))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ data_streams.bare_preprocess_fn=@trax.data.tf_inputs.c4_bare_preprocess_fn
# ==============================================================================
multifactor.constant = 1.0
multifactor.factors = 'constant * rsqrt_decay'
# NOTE: T5's batch is 18x bigger, so we just 18x-ed their warmup steps (10k).
multifactor.warmup_steps = 180000
multifactor.warmup_steps = 10000

# Parameters for Adafactor:
# ==============================================================================
Expand Down

0 comments on commit a0b31ef

Please sign in to comment.