Skip to content

Commit 30e51bf

Browse files
committed
Remove weight sharing between different iterations of the transformerLayer
Signed-off-by: tdophung <tdophung@nvidia.com> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: tdophung <tdophung@nvidia.com>
1 parent 3db2c9b commit 30e51bf

File tree

2 files changed

+46
-330
lines changed

2 files changed

+46
-330
lines changed

docs/examples/quickstart_jax.ipynb

Lines changed: 43 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
},
5252
{
5353
"cell_type": "code",
54-
"execution_count": null,
54+
"execution_count": 25,
5555
"id": "881fd001",
5656
"metadata": {},
5757
"outputs": [],
@@ -65,7 +65,7 @@
6565
},
6666
{
6767
"cell_type": "code",
68-
"execution_count": 8,
68+
"execution_count": 26,
6969
"id": "d5284a38",
7070
"metadata": {},
7171
"outputs": [],
@@ -79,7 +79,7 @@
7979
},
8080
{
8181
"cell_type": "code",
82-
"execution_count": 9,
82+
"execution_count": 27,
8383
"id": "a4d1cfdc",
8484
"metadata": {},
8585
"outputs": [],
@@ -173,7 +173,7 @@
173173
},
174174
{
175175
"cell_type": "code",
176-
"execution_count": 10,
176+
"execution_count": 28,
177177
"id": "8b44649d",
178178
"metadata": {},
179179
"outputs": [],
@@ -194,7 +194,7 @@
194194
},
195195
{
196196
"cell_type": "code",
197-
"execution_count": 11,
197+
"execution_count": 29,
198198
"id": "e44ed26d",
199199
"metadata": {},
200200
"outputs": [
@@ -224,7 +224,7 @@
224224
},
225225
{
226226
"cell_type": "code",
227-
"execution_count": 12,
227+
"execution_count": 30,
228228
"id": "de91af7a",
229229
"metadata": {},
230230
"outputs": [
@@ -250,15 +250,15 @@
250250
},
251251
{
252252
"cell_type": "code",
253-
"execution_count": 13,
253+
"execution_count": 31,
254254
"id": "037bc8d9",
255255
"metadata": {},
256256
"outputs": [
257257
{
258258
"name": "stdout",
259259
"output_type": "stream",
260260
"text": [
261-
"Mean time: 27.269372940063477 ms\n"
261+
"Mean time: 28.229827880859375 ms\n"
262262
]
263263
}
264264
],
@@ -308,7 +308,7 @@
308308
},
309309
{
310310
"cell_type": "code",
311-
"execution_count": 14,
311+
"execution_count": 32,
312312
"id": "bed20d6b",
313313
"metadata": {},
314314
"outputs": [],
@@ -328,7 +328,7 @@
328328
},
329329
{
330330
"cell_type": "code",
331-
"execution_count": 15,
331+
"execution_count": 33,
332332
"id": "56105579",
333333
"metadata": {},
334334
"outputs": [],
@@ -424,16 +424,15 @@
424424
},
425425
{
426426
"cell_type": "code",
427-
"execution_count": 16,
427+
"execution_count": 34,
428428
"id": "5146cd99",
429429
"metadata": {},
430430
"outputs": [
431431
{
432432
"name": "stdout",
433433
"output_type": "stream",
434434
"text": [
435-
"Basic TE parameter shapes: {'BasicTEMLP_0': {'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(16384,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 16384), names=(), mesh=None, rules=None)}, 'DenseGeneral_1': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(16384, 4096), names=(), mesh=None, rules=None)}}, 'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(12288,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 12288), names=(), mesh=None, rules=None)}, 'DenseGeneral_1': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=(), mesh=None, rules=None)}, 'LayerNorm_0': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}, 'LayerNorm_1': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}}\n",
436-
"Mean time: 17.397570610046387 ms\n"
435+
"Mean time: 17.390952110290527 ms\n"
437436
]
438437
}
439438
],
@@ -449,21 +448,12 @@
449448
"\n",
450449
"te_params = basic_te_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
451450
"\n",
452-
"# Extract the 'params' pytrees\n",
453-
"basic_params = params['params']\n",
454-
"te_params_template = te_params['params']\n",
455-
"\n",
456-
"print(f\"Basic TE parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, te_params_template)}\")\n",
457-
"\n",
458-
"shared_te_params = {}\n",
459-
"shared_te_params['params'] = utils.share_parameters_with_basic_te_model(basic_params, te_params_template)\n",
460-
"\n",
461451
"# Test forward pass\n",
462452
"y = basic_te_transformer.apply(te_params, x, attention_mask=None, deterministic=True)\n",
463453
"\n",
464454
"utils.speedometer(\n",
465455
" model_apply_fn=basic_te_transformer.apply,\n",
466-
" variables=shared_te_params, # Ensure the correct `params` is passed\n",
456+
" variables=te_params, # Ensure the correct `params` is passed\n",
467457
" input=x,\n",
468458
" output_grad=dy,\n",
469459
" dropout_key=dropout_key,\n",
@@ -502,7 +492,7 @@
502492
},
503493
{
504494
"cell_type": "code",
505-
"execution_count": 17,
495+
"execution_count": 35,
506496
"id": "11203785",
507497
"metadata": {},
508498
"outputs": [],
@@ -570,58 +560,41 @@
570560
},
571561
{
572562
"cell_type": "code",
573-
"execution_count": 18,
563+
"execution_count": 36,
574564
"id": "114de14f",
575565
"metadata": {},
576-
"outputs": [
577-
{
578-
"name": "stdout",
579-
"output_type": "stream",
580-
"text": [
581-
"Fused TE parameter shapes: {'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=(), mesh=None, rules=None)}, 'LayerNormDenseGeneral_0': {'bias': LogicallyPartitioned(value=(12288,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 12288), names=(), mesh=None, rules=None), 'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}, 'LayerNormMLP_0': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'wi_bias': LogicallyPartitioned(value=(1, 16384), names=('act', 'mlp'), mesh=None, rules=None), 'wi_kernel': LogicallyPartitioned(value=(4096, 1, 16384), names=('embed', 'act', 'mlp'), mesh=None, rules=None), 'wo_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'wo_kernel': LogicallyPartitioned(value=(16384, 4096), names=('mlp', 'embed'), mesh=None, rules=None)}}\n"
582-
]
583-
}
584-
],
566+
"outputs": [],
585567
"source": [
586-
"import quickstart_jax_utils\n",
587-
"importlib.reload(quickstart_jax_utils)\n",
588-
"\n",
589568
"fused_te_transformer = FusedTETransformerLayer(\n",
590569
" hidden_size, \n",
591570
" ffn_hidden_size, \n",
592571
" num_attention_heads\n",
593572
")\n",
594573
"\n",
595-
"fused_te_params = fused_te_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
596-
"\n",
597-
"fused_te_params_template = fused_te_params['params']\n",
598-
"print(f\"Fused TE parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, fused_te_params_template)}\")"
574+
"fused_te_params = fused_te_transformer.init(key, x, attention_mask=None, deterministic=False)"
599575
]
600576
},
601577
{
602578
"cell_type": "code",
603-
"execution_count": 19,
579+
"execution_count": 37,
604580
"id": "6b0c705e",
605581
"metadata": {},
606582
"outputs": [
607583
{
608584
"name": "stdout",
609585
"output_type": "stream",
610586
"text": [
611-
"Mean time: 18.0991792678833 ms\n"
587+
"Mean time: 18.087706565856934 ms\n"
612588
]
613589
}
614590
],
615591
"source": [
616-
"shared_fused_te_params = {}\n",
617-
"shared_fused_te_params['params'] = utils.share_fused_parameters_with_basic_te_model(basic_params, fused_te_params_template)\n",
618-
"\n",
619592
"# Test forward pass\n",
620593
"y = fused_te_transformer.apply(fused_te_params, x, attention_mask=None, deterministic=True)\n",
621594
"\n",
622595
"utils.speedometer(\n",
623596
" model_apply_fn=fused_te_transformer.apply,\n",
624-
" variables=shared_fused_te_params, # Ensure the correct `params` is passed\n",
597+
" variables=fused_te_params, # Ensure the correct `params` is passed\n",
625598
" input=x,\n",
626599
" output_grad=dy,\n",
627600
" dropout_key=dropout_key,\n",
@@ -639,18 +612,10 @@
639612
},
640613
{
641614
"cell_type": "code",
642-
"execution_count": 20,
615+
"execution_count": 38,
643616
"id": "7496b159",
644617
"metadata": {},
645-
"outputs": [
646-
{
647-
"name": "stdout",
648-
"output_type": "stream",
649-
"text": [
650-
"TE TransformerLayer parameter shapes: {'attention': {'out': {'bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=('nvte_w_tp', 'nvte_w_fsdp'), mesh=None, rules=None)}, 'qkv': {'bias': LogicallyPartitioned(value=(3, 4096), names=('nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 3, 4096), names=('nvte_w_fsdp', 'nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'ln_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None)}}, 'mlp': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'wi_bias': LogicallyPartitioned(value=(1, 16384), names=('nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'wi_kernel': LogicallyPartitioned(value=(4096, 1, 16384), names=('nvte_w_fsdp', 'nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'wo_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'wo_kernel': LogicallyPartitioned(value=(16384, 4096), names=('nvte_w_tp', 'nvte_w_fsdp'), mesh=None, rules=None)}, 'relpos_bias': {'rel_embedding': LogicallyPartitioned(value=(32, 32), names=('heads', 'relpos_buckets'), mesh=None, rules=None)}}\n"
651-
]
652-
}
653-
],
618+
"outputs": [],
654619
"source": [
655620
"te_transformer = te_flax.TransformerLayer(\n",
656621
" hidden_size=hidden_size,\n",
@@ -662,39 +627,30 @@
662627
" use_bias=True\n",
663628
" )\n",
664629
"\n",
665-
"te_transformer_params = te_transformer.init(key, x, deterministic=False)\n",
666-
"\n",
667-
"te_transformer_params_template = te_transformer_params['params']\n",
668-
"print(f\"TE TransformerLayer parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, te_transformer_params_template)}\")"
630+
"te_transformer_params = te_transformer.init(key, x, deterministic=False)"
669631
]
670632
},
671633
{
672634
"cell_type": "code",
673-
"execution_count": 21,
635+
"execution_count": 39,
674636
"id": "6ec0f60e",
675637
"metadata": {},
676638
"outputs": [
677639
{
678640
"name": "stdout",
679641
"output_type": "stream",
680642
"text": [
681-
"Mean time: 11.84274673461914 ms\n"
643+
"Mean time: 12.37576961517334 ms\n"
682644
]
683645
}
684646
],
685647
"source": [
686-
"import quickstart_jax_utils\n",
687-
"importlib.reload(quickstart_jax_utils)\n",
688-
"\n",
689-
"shared_te_transformer_params = {}\n",
690-
"shared_te_transformer_params['params'] = utils.share_parameters_with_transformerlayer_te_model(basic_params, te_transformer_params_template)\n",
691-
"\n",
692648
"# Test forward pass\n",
693649
"y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)\n",
694650
"\n",
695651
"utils.speedometer(\n",
696652
" model_apply_fn=te_transformer.apply,\n",
697-
" variables=shared_te_transformer_params, # Ensure the correct `params` is passed\n",
653+
" variables=te_transformer_params, # Ensure the correct `params` is passed\n",
698654
" input=x,\n",
699655
" output_grad=dy,\n",
700656
" dropout_key=dropout_key,\n",
@@ -730,18 +686,10 @@
730686
},
731687
{
732688
"cell_type": "code",
733-
"execution_count": 22,
689+
"execution_count": 40,
734690
"id": "b2aaa8ef",
735691
"metadata": {},
736-
"outputs": [
737-
{
738-
"name": "stdout",
739-
"output_type": "stream",
740-
"text": [
741-
"TE TransformerLayer vars: {'fp8_metas': {'attention': {'out': {'grad_amax_history': (16,), 'grad_scale': (1,), 'kernel_amax_history': (16,), 'kernel_scale': (1,), 'x_amax_history': (16,), 'x_scale': (1,)}, 'qkv': {'grad_amax_history': (16,), 'grad_scale': (1,), 'kernel_amax_history': (16,), 'kernel_scale': (1,), 'x_amax_history': (16,), 'x_scale': (1,)}}, 'mlp': {'grad_0_amax_history': (16,), 'grad_0_scale': (1,), 'grad_1_amax_history': (16,), 'grad_1_scale': (1,), 'kernel_0_amax_history': (16,), 'kernel_0_scale': (1,), 'kernel_1_amax_history': (16,), 'kernel_1_scale': (1,), 'x_0_amax_history': (16,), 'x_0_scale': (1,), 'x_1_amax_history': (16,), 'x_1_scale': (1,)}}, 'params': {'attention': {'out': {'bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=('nvte_w_tp', 'nvte_w_fsdp'), mesh=None, rules=None)}, 'qkv': {'bias': LogicallyPartitioned(value=(3, 4096), names=('nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 3, 4096), names=('nvte_w_fsdp', 'nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'ln_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None)}}, 'mlp': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'wi_bias': LogicallyPartitioned(value=(1, 16384), names=('nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'wi_kernel': LogicallyPartitioned(value=(4096, 1, 16384), names=('nvte_w_fsdp', 'nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'wo_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'wo_kernel': LogicallyPartitioned(value=(16384, 4096), names=('nvte_w_tp', 'nvte_w_fsdp'), mesh=None, rules=None)}, 'relpos_bias': {'rel_embedding': LogicallyPartitioned(value=(32, 32), names=('heads', 'relpos_buckets'), mesh=None, rules=None)}}}\n"
742-
]
743-
}
744-
],
692+
"outputs": [],
745693
"source": [
746694
"from transformer_engine.common.recipe import Format, DelayedScaling\n",
747695
"\n",
@@ -760,40 +708,28 @@
760708
"\n",
761709
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
762710
" te_transformer_params = te_transformer.init(key, x, deterministic=False)\n",
763-
" \n",
764-
" # When using FP8, we need to preserve the fp8_metas collection\n",
765-
" # that was created during initialization within the fp8_autocast context.\n",
766-
" # Only the 'params' are shared from basic_params, but fp8_metas must come from\n",
767-
" # the FP8-initialized model.\n",
768-
" shared_te_transformer_params = {}\n",
769-
" shared_te_transformer_params['params'] = utils.share_parameters_with_transformerlayer_te_model(basic_params, te_transformer_params_template)\n",
770-
" print(f\"TE TransformerLayer vars: {jax.tree_util.tree_map(lambda x: x.shape, te_transformer_params)}\")\n",
771-
"\n",
772-
" if 'fp8_metas' in te_transformer_params:\n",
773-
" shared_te_transformer_params['fp8_metas'] = te_transformer_params['fp8_metas']\n",
774-
"\n",
775711
" y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)"
776712
]
777713
},
778714
{
779715
"cell_type": "code",
780-
"execution_count": 23,
716+
"execution_count": 41,
781717
"id": "b9cdbf22",
782718
"metadata": {},
783719
"outputs": [
784720
{
785721
"name": "stdout",
786722
"output_type": "stream",
787723
"text": [
788-
"Mean time: 7.96757698059082 ms\n"
724+
"Mean time: 7.956786155700684 ms\n"
789725
]
790726
}
791727
],
792728
"source": [
793729
"utils.speedometer(\n",
794730
" model_apply_fn=te_transformer.apply,\n",
795731
" model_init_fn=te_transformer.init,\n",
796-
" variables=shared_te_transformer_params, # Includes both params and fp8_metas\n",
732+
" variables=te_transformer_params, # Includes both params and fp8_metas\n",
797733
" input=x,\n",
798734
" output_grad=dy,\n",
799735
" dropout_key=dropout_key,\n",
@@ -808,6 +744,18 @@
808744
"display_name": "Python 3 (ipykernel)",
809745
"language": "python",
810746
"name": "python3"
747+
},
748+
"language_info": {
749+
"codemirror_mode": {
750+
"name": "ipython",
751+
"version": 3
752+
},
753+
"file_extension": ".py",
754+
"mimetype": "text/x-python",
755+
"name": "python",
756+
"nbconvert_exporter": "python",
757+
"pygments_lexer": "ipython3",
758+
"version": "3.12.3"
811759
}
812760
},
813761
"nbformat": 4,

0 commit comments

Comments
 (0)