Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 90 additions & 19 deletions docs/examples/quickstart_jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 17.708301544189453 ms\n"
"Mean time: 18.840689659118652 ms\n"
]
}
],
Expand All @@ -256,8 +256,8 @@
" variables=params,\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs={\"dropout\": dropout_key},\n",
")"
]
},
Expand Down Expand Up @@ -422,7 +422,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 16.505107879638672 ms\n"
"Mean time: 17.292580604553223 ms\n"
]
}
],
Expand All @@ -441,8 +441,8 @@
" variables=te_params, # Ensure the correct `params` is passed\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs={\"dropout\": dropout_key},\n",
")\n"
]
},
Expand All @@ -460,11 +460,35 @@
"id": "5146cd99",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/jberchtold/lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:603: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/mnt/jberchtold/lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:711: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
"Fall back to the unfused attention.\n",
"Please try to update the cuDNN and TE to the latest version.\n",
"self.dtype=<class 'jax.numpy.float32'>\n",
"qkv_layout=<QKVLayout.BSHD_BSHD_BSHD: <NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: 9>>\n",
"attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
"attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
"self.attention_dropout=0.1\n",
"self.num_attention_heads=32\n",
"self.num_gqa_groups=32\n",
"seqlen_q=2048\n",
"seqlen_kv=2048\n",
"head_dim_qk=128\n",
"head_dim_v=128\n",
"\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 12.80329704284668 ms\n"
"Mean time: 13.253312110900879 ms\n"
]
}
],
Expand All @@ -482,8 +506,8 @@
" variables=te_params, # Ensure the correct `params` is passed\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs={\"dropout\": dropout_key},\n",
")"
]
},
Expand Down Expand Up @@ -535,7 +559,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 9.615030288696289 ms\n"
"Mean time: 9.897074699401855 ms\n"
]
}
],
Expand All @@ -551,9 +575,9 @@
" variables=te_unfused_params, # Ensure the correct `params` is passed\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe}\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe},\n",
" rngs={\"dropout\": dropout_key},\n",
")"
]
},
Expand Down Expand Up @@ -659,15 +683,39 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "6b0c705e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/jberchtold/lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:603: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/mnt/jberchtold/lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:711: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
"Fall back to the unfused attention.\n",
"Please try to update the cuDNN and TE to the latest version.\n",
"self.dtype=<class 'jax.numpy.float32'>\n",
"qkv_layout=<QKVLayout.BSHD_BSHD_BSHD: <NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: 9>>\n",
"attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
"attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
"self.attention_dropout=0.1\n",
"self.num_attention_heads=32\n",
"self.num_gqa_groups=32\n",
"seqlen_q=2048\n",
"seqlen_kv=2048\n",
"head_dim_qk=128\n",
"head_dim_v=128\n",
"\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 9.331779479980469 ms\n"
"Mean time: 9.692559242248535 ms\n"
]
}
],
Expand All @@ -688,9 +736,9 @@
" variables=te_fused_params,\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe}\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe},\n",
" rngs={\"dropout\": dropout_key},\n",
")"
]
},
Expand All @@ -704,10 +752,33 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"id": "b2aaa8ef",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/jberchtold/lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:711: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
"Fall back to the unfused attention.\n",
"Please try to update the cuDNN and TE to the latest version.\n",
"self.dtype=<class 'jax.numpy.float32'>\n",
"qkv_layout=<QKVLayout.BS3HD: <NVTE_QKV_Layout.NVTE_BS3HD: 5>>\n",
"attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
"attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
"self.attention_dropout=0.1\n",
"self.num_attention_heads=32\n",
"self.num_gqa_groups=32\n",
"seqlen_q=2048\n",
"seqlen_kv=2048\n",
"head_dim_qk=128\n",
"head_dim_v=128\n",
"\n",
" warnings.warn(\n"
]
}
],
"source": [
"\n",
"te_transformer = te_flax.TransformerLayer(\n",
Expand All @@ -731,15 +802,15 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 19,
"id": "b9cdbf22",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 9.23741340637207 ms\n"
"Mean time: 9.5003080368042 ms\n"
]
}
],
Expand All @@ -750,9 +821,9 @@
" variables=te_transformer_params,\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe }\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
" rngs={\"dropout\": dropout_key},\n",
")"
]
}
Expand Down
35 changes: 27 additions & 8 deletions docs/examples/quickstart_jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def speedometer(
variables: Any,
input: jnp.ndarray,
output_grad: jnp.ndarray,
dropout_key: jax.random.PRNGKey,
model_init_fn: Callable = None,
forward_kwargs: dict = {},
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50,
warmup_iters: int = 50,
rngs: Dict[str, jax.random.PRNGKey] = None,
) -> None:
"""Measure average runtime for a JAX module
Perform forward and backward passes .
Expand All @@ -33,19 +33,21 @@ def speedometer(
autocast_kwargs = {"enabled": False}
model_init_fn = None

if rngs is None:
rngs = {}

train_step_fn = create_train_step_fn(model_apply_fn, autocast_kwargs, forward_kwargs)

# Warm up runs
key = dropout_key
for _ in range(warmup_iters):
key, step_key = jax.random.split(key)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
rngs, step_rngs = _split_step_rngs(rngs)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_rngs)

# Timing runs
start = time.time()
for _ in range(timing_iters):
key, step_key = jax.random.split(key)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
rngs, step_rngs = _split_step_rngs(rngs)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_rngs)
end = time.time()

print(f"Mean time: {(end - start) * 1000 / timing_iters} ms")
Expand All @@ -63,8 +65,12 @@ def create_train_step_fn(
if forward_kwargs is None:
forward_kwargs = {}

def loss_fn(variables: Any, inp: jnp.ndarray, grad_target: jnp.ndarray, dropout_key):
rngs = {"dropout": dropout_key}
def loss_fn(
variables: Any,
inp: jnp.ndarray,
grad_target: jnp.ndarray,
rngs: Dict[str, jax.random.PRNGKey],
):
with te.autocast(**autocast_kwargs):
# Forward Pass: Apply the model using current parameters and variables
call_kwargs = {**forward_kwargs, "rngs": rngs}
Expand All @@ -84,3 +90,16 @@ def fwd_bwd_fn(*args, **kwargs):

# JIT-compile the fwd_bwd_fn
return jax.jit(fwd_bwd_fn)


def _split_step_rngs(
rngs: Dict[str, jax.random.PRNGKey],
) -> Tuple[Dict[str, jax.random.PRNGKey], Dict[str, jax.random.PRNGKey]]:
"""Splits each RNG in the rngs dictionary for a new step."""
step_rngs = {}
new_rngs = {}
for name, key in rngs.items():
new_key, step_key = jax.random.split(key)
new_rngs[name] = new_key
step_rngs[name] = step_key
return new_rngs, step_rngs
Loading
Loading