|
51 | 51 | }, |
52 | 52 | { |
53 | 53 | "cell_type": "code", |
54 | | - "execution_count": null, |
| 54 | + "execution_count": 25, |
55 | 55 | "id": "881fd001", |
56 | 56 | "metadata": {}, |
57 | 57 | "outputs": [], |
|
65 | 65 | }, |
66 | 66 | { |
67 | 67 | "cell_type": "code", |
68 | | - "execution_count": 8, |
| 68 | + "execution_count": 26, |
69 | 69 | "id": "d5284a38", |
70 | 70 | "metadata": {}, |
71 | 71 | "outputs": [], |
|
79 | 79 | }, |
80 | 80 | { |
81 | 81 | "cell_type": "code", |
82 | | - "execution_count": 9, |
| 82 | + "execution_count": 27, |
83 | 83 | "id": "a4d1cfdc", |
84 | 84 | "metadata": {}, |
85 | 85 | "outputs": [], |
|
173 | 173 | }, |
174 | 174 | { |
175 | 175 | "cell_type": "code", |
176 | | - "execution_count": 10, |
| 176 | + "execution_count": 28, |
177 | 177 | "id": "8b44649d", |
178 | 178 | "metadata": {}, |
179 | 179 | "outputs": [], |
|
194 | 194 | }, |
195 | 195 | { |
196 | 196 | "cell_type": "code", |
197 | | - "execution_count": 11, |
| 197 | + "execution_count": 29, |
198 | 198 | "id": "e44ed26d", |
199 | 199 | "metadata": {}, |
200 | 200 | "outputs": [ |
|
224 | 224 | }, |
225 | 225 | { |
226 | 226 | "cell_type": "code", |
227 | | - "execution_count": 12, |
| 227 | + "execution_count": 30, |
228 | 228 | "id": "de91af7a", |
229 | 229 | "metadata": {}, |
230 | 230 | "outputs": [ |
|
250 | 250 | }, |
251 | 251 | { |
252 | 252 | "cell_type": "code", |
253 | | - "execution_count": 13, |
| 253 | + "execution_count": 31, |
254 | 254 | "id": "037bc8d9", |
255 | 255 | "metadata": {}, |
256 | 256 | "outputs": [ |
257 | 257 | { |
258 | 258 | "name": "stdout", |
259 | 259 | "output_type": "stream", |
260 | 260 | "text": [ |
261 | | - "Mean time: 27.269372940063477 ms\n" |
| 261 | + "Mean time: 28.229827880859375 ms\n" |
262 | 262 | ] |
263 | 263 | } |
264 | 264 | ], |
|
308 | 308 | }, |
309 | 309 | { |
310 | 310 | "cell_type": "code", |
311 | | - "execution_count": 14, |
| 311 | + "execution_count": 32, |
312 | 312 | "id": "bed20d6b", |
313 | 313 | "metadata": {}, |
314 | 314 | "outputs": [], |
|
328 | 328 | }, |
329 | 329 | { |
330 | 330 | "cell_type": "code", |
331 | | - "execution_count": 15, |
| 331 | + "execution_count": 33, |
332 | 332 | "id": "56105579", |
333 | 333 | "metadata": {}, |
334 | 334 | "outputs": [], |
|
424 | 424 | }, |
425 | 425 | { |
426 | 426 | "cell_type": "code", |
427 | | - "execution_count": 16, |
| 427 | + "execution_count": 34, |
428 | 428 | "id": "5146cd99", |
429 | 429 | "metadata": {}, |
430 | 430 | "outputs": [ |
431 | 431 | { |
432 | 432 | "name": "stdout", |
433 | 433 | "output_type": "stream", |
434 | 434 | "text": [ |
435 | | - "Basic TE parameter shapes: {'BasicTEMLP_0': {'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(16384,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 16384), names=(), mesh=None, rules=None)}, 'DenseGeneral_1': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(16384, 4096), names=(), mesh=None, rules=None)}}, 'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(12288,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 12288), names=(), mesh=None, rules=None)}, 'DenseGeneral_1': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=(), mesh=None, rules=None)}, 'LayerNorm_0': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}, 'LayerNorm_1': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}}\n", |
436 | | - "Mean time: 17.397570610046387 ms\n" |
| 435 | + "Mean time: 17.390952110290527 ms\n" |
437 | 436 | ] |
438 | 437 | } |
439 | 438 | ], |
|
449 | 448 | "\n", |
450 | 449 | "te_params = basic_te_transformer.init(key, x, attention_mask=None, deterministic=False)\n", |
451 | 450 | "\n", |
452 | | - "# Extract the 'params' pytrees\n", |
453 | | - "basic_params = params['params']\n", |
454 | | - "te_params_template = te_params['params']\n", |
455 | | - "\n", |
456 | | - "print(f\"Basic TE parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, te_params_template)}\")\n", |
457 | | - "\n", |
458 | | - "shared_te_params = {}\n", |
459 | | - "shared_te_params['params'] = utils.share_parameters_with_basic_te_model(basic_params, te_params_template)\n", |
460 | | - "\n", |
461 | 451 | "# Test forward pass\n", |
462 | 452 | "y = basic_te_transformer.apply(te_params, x, attention_mask=None, deterministic=True)\n", |
463 | 453 | "\n", |
464 | 454 | "utils.speedometer(\n", |
465 | 455 | " model_apply_fn=basic_te_transformer.apply,\n", |
466 | | - " variables=shared_te_params, # Ensure the correct `params` is passed\n", |
| 456 | + " variables=te_params, # Ensure the correct `params` is passed\n", |
467 | 457 | " input=x,\n", |
468 | 458 | " output_grad=dy,\n", |
469 | 459 | " dropout_key=dropout_key,\n", |
|
502 | 492 | }, |
503 | 493 | { |
504 | 494 | "cell_type": "code", |
505 | | - "execution_count": 17, |
| 495 | + "execution_count": 35, |
506 | 496 | "id": "11203785", |
507 | 497 | "metadata": {}, |
508 | 498 | "outputs": [], |
|
570 | 560 | }, |
571 | 561 | { |
572 | 562 | "cell_type": "code", |
573 | | - "execution_count": 18, |
| 563 | + "execution_count": 36, |
574 | 564 | "id": "114de14f", |
575 | 565 | "metadata": {}, |
576 | | - "outputs": [ |
577 | | - { |
578 | | - "name": "stdout", |
579 | | - "output_type": "stream", |
580 | | - "text": [ |
581 | | - "Fused TE parameter shapes: {'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=(), mesh=None, rules=None)}, 'LayerNormDenseGeneral_0': {'bias': LogicallyPartitioned(value=(12288,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 12288), names=(), mesh=None, rules=None), 'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}, 'LayerNormMLP_0': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'wi_bias': LogicallyPartitioned(value=(1, 16384), names=('act', 'mlp'), mesh=None, rules=None), 'wi_kernel': LogicallyPartitioned(value=(4096, 1, 16384), names=('embed', 'act', 'mlp'), mesh=None, rules=None), 'wo_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'wo_kernel': LogicallyPartitioned(value=(16384, 4096), names=('mlp', 'embed'), mesh=None, rules=None)}}\n" |
582 | | - ] |
583 | | - } |
584 | | - ], |
| 566 | + "outputs": [], |
585 | 567 | "source": [ |
586 | | - "import quickstart_jax_utils\n", |
587 | | - "importlib.reload(quickstart_jax_utils)\n", |
588 | | - "\n", |
589 | 568 | "fused_te_transformer = FusedTETransformerLayer(\n", |
590 | 569 | " hidden_size, \n", |
591 | 570 | " ffn_hidden_size, \n", |
592 | 571 | " num_attention_heads\n", |
593 | 572 | ")\n", |
594 | 573 | "\n", |
595 | | - "fused_te_params = fused_te_transformer.init(key, x, attention_mask=None, deterministic=False)\n", |
596 | | - "\n", |
597 | | - "fused_te_params_template = fused_te_params['params']\n", |
598 | | - "print(f\"Fused TE parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, fused_te_params_template)}\")" |
| 574 | + "fused_te_params = fused_te_transformer.init(key, x, attention_mask=None, deterministic=False)" |
599 | 575 | ] |
600 | 576 | }, |
601 | 577 | { |
602 | 578 | "cell_type": "code", |
603 | | - "execution_count": 19, |
| 579 | + "execution_count": 37, |
604 | 580 | "id": "6b0c705e", |
605 | 581 | "metadata": {}, |
606 | 582 | "outputs": [ |
607 | 583 | { |
608 | 584 | "name": "stdout", |
609 | 585 | "output_type": "stream", |
610 | 586 | "text": [ |
611 | | - "Mean time: 18.0991792678833 ms\n" |
| 587 | + "Mean time: 18.087706565856934 ms\n" |
612 | 588 | ] |
613 | 589 | } |
614 | 590 | ], |
615 | 591 | "source": [ |
616 | | - "shared_fused_te_params = {}\n", |
617 | | - "shared_fused_te_params['params'] = utils.share_fused_parameters_with_basic_te_model(basic_params, fused_te_params_template)\n", |
618 | | - "\n", |
619 | 592 | "# Test forward pass\n", |
620 | 593 | "y = fused_te_transformer.apply(fused_te_params, x, attention_mask=None, deterministic=True)\n", |
621 | 594 | "\n", |
622 | 595 | "utils.speedometer(\n", |
623 | 596 | " model_apply_fn=fused_te_transformer.apply,\n", |
624 | | - " variables=shared_fused_te_params, # Ensure the correct `params` is passed\n", |
| 597 | + " variables=fused_te_params, # Ensure the correct `params` is passed\n", |
625 | 598 | " input=x,\n", |
626 | 599 | " output_grad=dy,\n", |
627 | 600 | " dropout_key=dropout_key,\n", |
|
639 | 612 | }, |
640 | 613 | { |
641 | 614 | "cell_type": "code", |
642 | | - "execution_count": 20, |
| 615 | + "execution_count": 38, |
643 | 616 | "id": "7496b159", |
644 | 617 | "metadata": {}, |
645 | | - "outputs": [ |
646 | | - { |
647 | | - "name": "stdout", |
648 | | - "output_type": "stream", |
649 | | - "text": [ |
650 | | - "TE TransformerLayer parameter shapes: {'attention': {'out': {'bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=('nvte_w_tp', 'nvte_w_fsdp'), mesh=None, rules=None)}, 'qkv': {'bias': LogicallyPartitioned(value=(3, 4096), names=('nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 3, 4096), names=('nvte_w_fsdp', 'nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'ln_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None)}}, 'mlp': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'wi_bias': LogicallyPartitioned(value=(1, 16384), names=('nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'wi_kernel': LogicallyPartitioned(value=(4096, 1, 16384), names=('nvte_w_fsdp', 'nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'wo_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'wo_kernel': LogicallyPartitioned(value=(16384, 4096), names=('nvte_w_tp', 'nvte_w_fsdp'), mesh=None, rules=None)}, 'relpos_bias': {'rel_embedding': LogicallyPartitioned(value=(32, 32), names=('heads', 'relpos_buckets'), mesh=None, rules=None)}}\n" |
651 | | - ] |
652 | | - } |
653 | | - ], |
| 618 | + "outputs": [], |
654 | 619 | "source": [ |
655 | 620 | "te_transformer = te_flax.TransformerLayer(\n", |
656 | 621 | " hidden_size=hidden_size,\n", |
|
662 | 627 | " use_bias=True\n", |
663 | 628 | " )\n", |
664 | 629 | "\n", |
665 | | - "te_transformer_params = te_transformer.init(key, x, deterministic=False)\n", |
666 | | - "\n", |
667 | | - "te_transformer_params_template = te_transformer_params['params']\n", |
668 | | - "print(f\"TE TransformerLayer parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, te_transformer_params_template)}\")" |
| 630 | + "te_transformer_params = te_transformer.init(key, x, deterministic=False)" |
669 | 631 | ] |
670 | 632 | }, |
671 | 633 | { |
672 | 634 | "cell_type": "code", |
673 | | - "execution_count": 21, |
| 635 | + "execution_count": 39, |
674 | 636 | "id": "6ec0f60e", |
675 | 637 | "metadata": {}, |
676 | 638 | "outputs": [ |
677 | 639 | { |
678 | 640 | "name": "stdout", |
679 | 641 | "output_type": "stream", |
680 | 642 | "text": [ |
681 | | - "Mean time: 11.84274673461914 ms\n" |
| 643 | + "Mean time: 12.37576961517334 ms\n" |
682 | 644 | ] |
683 | 645 | } |
684 | 646 | ], |
685 | 647 | "source": [ |
686 | | - "import quickstart_jax_utils\n", |
687 | | - "importlib.reload(quickstart_jax_utils)\n", |
688 | | - "\n", |
689 | | - "shared_te_transformer_params = {}\n", |
690 | | - "shared_te_transformer_params['params'] = utils.share_parameters_with_transformerlayer_te_model(basic_params, te_transformer_params_template)\n", |
691 | | - "\n", |
692 | 648 | "# Test forward pass\n", |
693 | 649 | "y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)\n", |
694 | 650 | "\n", |
695 | 651 | "utils.speedometer(\n", |
696 | 652 | " model_apply_fn=te_transformer.apply,\n", |
697 | | - " variables=shared_te_transformer_params, # Ensure the correct `params` is passed\n", |
| 653 | + " variables=te_transformer_params, # Ensure the correct `params` is passed\n", |
698 | 654 | " input=x,\n", |
699 | 655 | " output_grad=dy,\n", |
700 | 656 | " dropout_key=dropout_key,\n", |
|
730 | 686 | }, |
731 | 687 | { |
732 | 688 | "cell_type": "code", |
733 | | - "execution_count": 22, |
| 689 | + "execution_count": 40, |
734 | 690 | "id": "b2aaa8ef", |
735 | 691 | "metadata": {}, |
736 | | - "outputs": [ |
737 | | - { |
738 | | - "name": "stdout", |
739 | | - "output_type": "stream", |
740 | | - "text": [ |
741 | | - "TE TransformerLayer vars: {'fp8_metas': {'attention': {'out': {'grad_amax_history': (16,), 'grad_scale': (1,), 'kernel_amax_history': (16,), 'kernel_scale': (1,), 'x_amax_history': (16,), 'x_scale': (1,)}, 'qkv': {'grad_amax_history': (16,), 'grad_scale': (1,), 'kernel_amax_history': (16,), 'kernel_scale': (1,), 'x_amax_history': (16,), 'x_scale': (1,)}}, 'mlp': {'grad_0_amax_history': (16,), 'grad_0_scale': (1,), 'grad_1_amax_history': (16,), 'grad_1_scale': (1,), 'kernel_0_amax_history': (16,), 'kernel_0_scale': (1,), 'kernel_1_amax_history': (16,), 'kernel_1_scale': (1,), 'x_0_amax_history': (16,), 'x_0_scale': (1,), 'x_1_amax_history': (16,), 'x_1_scale': (1,)}}, 'params': {'attention': {'out': {'bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=('nvte_w_tp', 'nvte_w_fsdp'), mesh=None, rules=None)}, 'qkv': {'bias': LogicallyPartitioned(value=(3, 4096), names=('nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 3, 4096), names=('nvte_w_fsdp', 'nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'ln_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None)}}, 'mlp': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'wi_bias': LogicallyPartitioned(value=(1, 16384), names=('nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'wi_kernel': LogicallyPartitioned(value=(4096, 1, 16384), names=('nvte_w_fsdp', 'nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'wo_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'wo_kernel': LogicallyPartitioned(value=(16384, 4096), names=('nvte_w_tp', 'nvte_w_fsdp'), mesh=None, rules=None)}, 'relpos_bias': {'rel_embedding': LogicallyPartitioned(value=(32, 32), names=('heads', 'relpos_buckets'), mesh=None, rules=None)}}}\n" |
742 | | - ] |
743 | | - } |
744 | | - ], |
| 692 | + "outputs": [], |
745 | 693 | "source": [ |
746 | 694 | "from transformer_engine.common.recipe import Format, DelayedScaling\n", |
747 | 695 | "\n", |
|
760 | 708 | "\n", |
761 | 709 | "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", |
762 | 710 | " te_transformer_params = te_transformer.init(key, x, deterministic=False)\n", |
763 | | - " \n", |
764 | | - " # When using FP8, we need to preserve the fp8_metas collection\n", |
765 | | - " # that was created during initialization within the fp8_autocast context.\n", |
766 | | - " # Only the 'params' are shared from basic_params, but fp8_metas must come from\n", |
767 | | - " # the FP8-initialized model.\n", |
768 | | - " shared_te_transformer_params = {}\n", |
769 | | - " shared_te_transformer_params['params'] = utils.share_parameters_with_transformerlayer_te_model(basic_params, te_transformer_params_template)\n", |
770 | | - " print(f\"TE TransformerLayer vars: {jax.tree_util.tree_map(lambda x: x.shape, te_transformer_params)}\")\n", |
771 | | - "\n", |
772 | | - " if 'fp8_metas' in te_transformer_params:\n", |
773 | | - " shared_te_transformer_params['fp8_metas'] = te_transformer_params['fp8_metas']\n", |
774 | | - "\n", |
775 | 711 | " y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)" |
776 | 712 | ] |
777 | 713 | }, |
778 | 714 | { |
779 | 715 | "cell_type": "code", |
780 | | - "execution_count": 23, |
| 716 | + "execution_count": 41, |
781 | 717 | "id": "b9cdbf22", |
782 | 718 | "metadata": {}, |
783 | 719 | "outputs": [ |
784 | 720 | { |
785 | 721 | "name": "stdout", |
786 | 722 | "output_type": "stream", |
787 | 723 | "text": [ |
788 | | - "Mean time: 7.96757698059082 ms\n" |
| 724 | + "Mean time: 7.956786155700684 ms\n" |
789 | 725 | ] |
790 | 726 | } |
791 | 727 | ], |
792 | 728 | "source": [ |
793 | 729 | "utils.speedometer(\n", |
794 | 730 | " model_apply_fn=te_transformer.apply,\n", |
795 | 731 | " model_init_fn=te_transformer.init,\n", |
796 | | - " variables=shared_te_transformer_params, # Includes both params and fp8_metas\n", |
| 732 | + " variables=te_transformer_params, # Includes both params and fp8_metas\n", |
797 | 733 | " input=x,\n", |
798 | 734 | " output_grad=dy,\n", |
799 | 735 | " dropout_key=dropout_key,\n", |
|
808 | 744 | "display_name": "Python 3 (ipykernel)", |
809 | 745 | "language": "python", |
810 | 746 | "name": "python3" |
| 747 | + }, |
| 748 | + "language_info": { |
| 749 | + "codemirror_mode": { |
| 750 | + "name": "ipython", |
| 751 | + "version": 3 |
| 752 | + }, |
| 753 | + "file_extension": ".py", |
| 754 | + "mimetype": "text/x-python", |
| 755 | + "name": "python", |
| 756 | + "nbconvert_exporter": "python", |
| 757 | + "pygments_lexer": "ipython3", |
| 758 | + "version": "3.12.3" |
811 | 759 | } |
812 | 760 | }, |
813 | 761 | "nbformat": 4, |
|
0 commit comments