From 49dbbbb41083fd3e41e7cfbc6568c82125b3850b Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Fri, 19 Dec 2025 19:46:42 +0000 Subject: [PATCH] [WIP][overla] Overlap simulation on 1d, 2d variants of llama3 for 64, 256 gpus stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/289, branch: IvanKobzarev/stack/12 --- .../tools/overlap_simulator/colls32_8.table | 8 + .../tools/overlap_simulator/colls8_8.table | 7 + .../repro_llama3_8b_bw_256_1d_32layers.py | 8954 ++++++++++++ .../repro_llama3_8b_bw_256_2d_32layers.py | 11446 ++++++++++++++++ .../repro_llama3_8b_bw_64_1d_32layers.py | 8953 ++++++++++++ .../repro_llama3_8b_bw_64_2d_32layers.py | 5783 ++++++++ .../repro_llama3_8b_fw_256_1d_32layers.py | 4153 ++++++ .../repro_llama3_8b_fw_256_2d_32layers.py | 5658 ++++++++ .../repro_llama3_8b_fw_64_1d_32layers.py | 4153 ++++++ .../repro_llama3_8b_fw_64_2d_32layers.py | 5657 ++++++++ autoparallel/tools/overlap_simulator/run.py | 809 ++ 11 files changed, 55581 insertions(+) create mode 100644 autoparallel/tools/overlap_simulator/colls32_8.table create mode 100644 autoparallel/tools/overlap_simulator/colls8_8.table create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_1d_32layers.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_2d_32layers.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_1d_32layers.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_2d_32layers.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_1d_32layers.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_2d_32layers.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_1d_32layers.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_2d_32layers.py create mode 100644 autoparallel/tools/overlap_simulator/run.py diff --git a/autoparallel/tools/overlap_simulator/colls32_8.table b/autoparallel/tools/overlap_simulator/colls32_8.table new file mode 100644 index 00000000..50426ef2 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/colls32_8.table @@ -0,0 +1,8 @@ + Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms) +------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- ------------- + 1 8 all_gather_into_tensor 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 10.4714 20.9105 41.7888 + 1 8 reduce_scatter_tensor 0.0173 0.0238 0.0368 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 + 1 8 all_reduce 0.028 0.041 0.0628 0.0849 0.1292 0.2179 0.3822 0.7084 1.3609 2.6658 5.2756 10.4952 + 0 32 all_gather_into_tensor 1.0136 1.7497 3.1512 5.86 11.2777 22.113 43.7835 87.1247 173.807 347.171 693.901 1387.36 + 0 32 reduce_scatter_tensor 0.2114 0.2612 0.3608 0.4615 0.6455 1.0136 1.7497 3.1512 5.86 11.2777 22.113 43.7835 + 0 32 all_gather_into_tensor_out 1.0136 1.7497 3.1512 5.86 11.2777 22.113 43.7835 87.1247 173.807 347.171 693.901 1387.36 diff --git a/autoparallel/tools/overlap_simulator/colls8_8.table b/autoparallel/tools/overlap_simulator/colls8_8.table new file mode 100644 index 00000000..9d75dac7 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/colls8_8.table @@ -0,0 +1,7 @@ + Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms) +------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- ------------- + 1 8 all_reduce 0.028 0.041 0.0628 0.0849 0.1292 0.2179 0.3822 0.7084 1.3609 2.6658 5.2756 10.4952 + 1 8 all_gather_into_tensor 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 10.4714 20.9105 41.7888 + 0 8 reduce_scatter_tensor 0.0866 0.1151 0.1566 0.2397 0.4059 0.7181 1.3297 2.5531 4.9998 9.8931 19.6798 39.2532 + 0 8 all_gather_into_tensor_out 0.2397 0.4059 0.7181 1.3297 2.5531 4.9998 9.8931 19.6798 39.2532 78.4001 156.694 313.281 + 0 8 all_gather_into_tensor 0.2397 0.4059 0.7181 1.3297 2.5531 4.9998 9.8931 19.6798 39.2532 78.4001 156.694 313.281 diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_1d_32layers.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_1d_32layers.py new file mode 100644 index 00000000..ed3b65b8 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_1d_32layers.py @@ -0,0 +1,8954 @@ +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, embedding, mm, mm_2, getitem, getitem_1, getitem_6, getitem_7, mm_4, add_3, mm_7, mm_9, getitem_9, getitem_10, getitem_15, getitem_16, mm_11, add_7, mm_14, mm_16, getitem_18, getitem_19, getitem_24, getitem_25, mm_18, add_11, mm_21, mm_23, getitem_27, getitem_28, getitem_33, getitem_34, mm_25, add_15, mm_28, mm_30, getitem_36, getitem_37, getitem_42, getitem_43, mm_32, add_19, mm_35, mm_37, getitem_45, getitem_46, getitem_51, getitem_52, mm_39, add_23, mm_42, mm_44, getitem_54, getitem_55, getitem_60, getitem_61, mm_46, add_27, mm_49, mm_51, getitem_63, getitem_64, getitem_69, getitem_70, mm_53, add_31, mm_56, mm_58, getitem_72, getitem_73, getitem_78, getitem_79, mm_60, add_35, mm_63, mm_65, getitem_81, getitem_82, getitem_87, getitem_88, mm_67, add_39, mm_70, mm_72, getitem_90, getitem_91, getitem_96, getitem_97, mm_74, add_43, mm_77, mm_79, getitem_99, getitem_100, getitem_105, getitem_106, mm_81, add_47, mm_84, mm_86, getitem_108, getitem_109, getitem_114, getitem_115, mm_88, add_51, mm_91, mm_93, getitem_117, getitem_118, getitem_123, getitem_124, mm_95, add_55, mm_98, mm_100, getitem_126, getitem_127, getitem_132, getitem_133, mm_102, add_59, mm_105, mm_107, getitem_135, getitem_136, getitem_141, getitem_142, mm_109, add_63, mm_112, mm_114, getitem_144, getitem_145, getitem_150, getitem_151, mm_116, add_67, mm_119, mm_121, getitem_153, getitem_154, getitem_159, getitem_160, mm_123, add_71, mm_126, mm_128, getitem_162, getitem_163, getitem_168, getitem_169, mm_130, add_75, mm_133, mm_135, getitem_171, getitem_172, getitem_177, getitem_178, mm_137, add_79, mm_140, mm_142, getitem_180, getitem_181, getitem_186, getitem_187, mm_144, add_83, mm_147, mm_149, getitem_189, getitem_190, getitem_195, getitem_196, mm_151, add_87, mm_154, mm_156, getitem_198, getitem_199, getitem_204, getitem_205, mm_158, add_91, mm_161, mm_163, getitem_207, getitem_208, getitem_213, getitem_214, mm_165, add_95, mm_168, mm_170, getitem_216, getitem_217, getitem_222, getitem_223, mm_172, add_99, mm_175, mm_177, getitem_225, getitem_226, getitem_231, getitem_232, mm_179, add_103, mm_182, mm_184, getitem_234, getitem_235, getitem_240, getitem_241, mm_186, add_107, mm_189, mm_191, getitem_243, getitem_244, getitem_249, getitem_250, mm_193, add_111, mm_196, mm_198, getitem_252, getitem_253, getitem_258, getitem_259, mm_200, add_115, mm_203, mm_205, getitem_261, getitem_262, getitem_267, getitem_268, mm_207, add_119, mm_210, mm_212, getitem_270, getitem_271, getitem_276, getitem_277, mm_214, add_123, mm_217, mm_219, getitem_279, getitem_280, getitem_285, getitem_286, mm_221, mm_223, rsqrt_64, view_1091, tangents_1): + view_1093 = torch.ops.aten.view.default(tangents_1, [16384, 128256]); tangents_1 = None + permute_353 = torch.ops.aten.permute.default(view_1093, [1, 0]) + mm_225 = torch.ops.aten.mm.default(permute_353, view_1091); permute_353 = view_1091 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16); primals_293 = None + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 256, '0'); convert_element_type_1060 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + permute_355 = torch.ops.aten.permute.default(permute_352, [1, 0]); permute_352 = None + mm_226 = torch.ops.aten.mm.default(view_1093, permute_355); view_1093 = permute_355 = None + view_1094 = torch.ops.aten.view.default(mm_226, [2, 8192, 4096]); mm_226 = None + convert_element_type_1067 = torch.ops.prims.convert_element_type.default(mm_225, torch.float32); mm_225 = None + reduce_scatter_tensor = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1067, 'avg', 256, '0'); convert_element_type_1067 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None + convert_element_type_1068 = torch.ops.prims.convert_element_type.default(view_1094, torch.float32); view_1094 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16); primals_292 = None + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 256, '0'); convert_element_type_1057 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + convert_element_type_1070 = torch.ops.prims.convert_element_type.default(wait_tensor_289, torch.float32); wait_tensor_289 = None + mul_258 = torch.ops.aten.mul.Tensor(convert_element_type_1068, convert_element_type_1070); convert_element_type_1070 = None + permute_347 = torch.ops.aten.permute.default(getitem_279, [0, 2, 1, 3]) + view_1075 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16); primals_287 = None + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 256, '0'); convert_element_type_1040 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + view_1077 = torch.ops.aten.view.default(view_1075, [16384, 4096]); view_1075 = None + mm_220 = torch.ops.aten.mm.default(view_1077, permute_348) + view_1078 = torch.ops.aten.view.default(mm_220, [2, 8192, 4096]); mm_220 = None + add_125 = torch.ops.aten.add.Tensor(add_123, view_1078); view_1078 = None + view_1088 = torch.ops.aten.view.default(mm_223, [2, 8192, 4096]); mm_223 = None + add_127 = torch.ops.aten.add.Tensor(add_125, view_1088); view_1088 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_260 = torch.ops.aten.mul.Tensor(mul_256, mul_258) + sum_1 = torch.ops.aten.sum.dim_IntList(mul_260, [2], True); mul_260 = None + div = torch.ops.aten.div.Tensor(mul_256, 4096) + mul_261 = torch.ops.aten.mul.Tensor(div, sum_1); div = sum_1 = None + sub = torch.ops.aten.sub.Tensor(mul_258, mul_261); mul_258 = mul_261 = None + mul_262 = torch.ops.aten.mul.Tensor(sub, rsqrt_64); sub = rsqrt_64 = None + mul_263 = torch.ops.aten.mul.Tensor(convert_element_type_1068, mul_256); convert_element_type_1068 = mul_256 = None + sum_2 = torch.ops.aten.sum.dim_IntList(mul_263, [0, 1]); mul_263 = None + convert_element_type_1071 = torch.ops.prims.convert_element_type.default(mul_262, torch.bfloat16); mul_262 = None + convert_element_type_default_65 = torch.ops.prims.convert_element_type.default(sum_2, torch.float32); sum_2 = None + reduce_scatter_tensor_1 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_65, 'avg', 256, '0'); convert_element_type_default_65 = None + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None + view_1095 = torch.ops.aten.view.default(convert_element_type_1071, [16384, 4096]) + permute_357 = torch.ops.aten.permute.default(view_1095, [1, 0]) + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16); primals_288 = None + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 256, '0'); convert_element_type_1043 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32); add_125 = None + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_285) + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + view_1081 = torch.ops.aten.view.default(convert_element_type_1045, [16384, 4096]); convert_element_type_1045 = None + view_1082 = torch.ops.aten.view.default(mm_221, [2, 8192, 14336]); mm_221 = None + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_1082, torch.float32); view_1082 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16); primals_290 = None + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 256, '0'); convert_element_type_1051 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_287, [1, 0]); wait_tensor_287 = None + mm_222 = torch.ops.aten.mm.default(view_1081, permute_350) + view_1085 = torch.ops.aten.view.default(mm_222, [2, 8192, 14336]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_1085) + view_1087 = torch.ops.aten.view.default(mul_255, [16384, 14336]); mul_255 = None + mm_227 = torch.ops.aten.mm.default(permute_357, view_1087); permute_357 = view_1087 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16); primals_291 = None + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 256, '0'); convert_element_type_1054 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_288, [1, 0]); wait_tensor_288 = None + permute_359 = torch.ops.aten.permute.default(permute_351, [1, 0]); permute_351 = None + mm_228 = torch.ops.aten.mm.default(view_1095, permute_359); view_1095 = permute_359 = None + view_1096 = torch.ops.aten.view.default(mm_228, [2, 8192, 14336]); mm_228 = None + convert_element_type_1078 = torch.ops.prims.convert_element_type.default(mm_227, torch.float32); mm_227 = None + reduce_scatter_tensor_2 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1078, 'avg', 256, '0'); convert_element_type_1078 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None + mul_264 = torch.ops.aten.mul.Tensor(view_1096, convert_element_type_1050); convert_element_type_1050 = None + mul_265 = torch.ops.aten.mul.Tensor(view_1096, view_1085); view_1096 = view_1085 = None + view_1097 = torch.ops.aten.view.default(mul_264, [16384, 14336]); mul_264 = None + permute_361 = torch.ops.aten.permute.default(view_1097, [1, 0]) + mm_229 = torch.ops.aten.mm.default(permute_361, view_1081); permute_361 = None + permute_363 = torch.ops.aten.permute.default(permute_350, [1, 0]); permute_350 = None + mm_230 = torch.ops.aten.mm.default(view_1097, permute_363); view_1097 = permute_363 = None + view_1098 = torch.ops.aten.view.default(mm_230, [2, 8192, 4096]); mm_230 = None + convert_element_type_1083 = torch.ops.prims.convert_element_type.default(mm_229, torch.float32); mm_229 = None + reduce_scatter_tensor_3 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1083, 'avg', 256, '0'); convert_element_type_1083 = None + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3); reduce_scatter_tensor_3 = None + convert_element_type_1084 = torch.ops.prims.convert_element_type.default(mul_265, torch.float32); mul_265 = None + neg = torch.ops.aten.neg.default(convert_element_type_1049) + exp = torch.ops.aten.exp.default(neg); neg = None + add_129 = torch.ops.aten.add.Tensor(exp, 1); exp = None + reciprocal = torch.ops.aten.reciprocal.default(add_129); add_129 = None + mul_266 = torch.ops.aten.mul.Tensor(reciprocal, 1); reciprocal = None + mul_267 = torch.ops.aten.mul.Tensor(convert_element_type_1084, mul_266); convert_element_type_1084 = None + sub_1 = torch.ops.aten.sub.Tensor(1, mul_266); mul_266 = None + mul_268 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sub_1); convert_element_type_1049 = sub_1 = None + add_130 = torch.ops.aten.add.Tensor(mul_268, 1); mul_268 = None + mul_269 = torch.ops.aten.mul.Tensor(mul_267, add_130); mul_267 = add_130 = None + convert_element_type_1086 = torch.ops.prims.convert_element_type.default(mul_269, torch.bfloat16); mul_269 = None + view_1099 = torch.ops.aten.view.default(convert_element_type_1086, [16384, 14336]); convert_element_type_1086 = None + permute_365 = torch.ops.aten.permute.default(view_1099, [1, 0]) + mm_231 = torch.ops.aten.mm.default(permute_365, view_1081); permute_365 = view_1081 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16); primals_289 = None + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 256, '0'); convert_element_type_1046 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + permute_367 = torch.ops.aten.permute.default(permute_349, [1, 0]); permute_349 = None + mm_232 = torch.ops.aten.mm.default(view_1099, permute_367); view_1099 = permute_367 = None + view_1100 = torch.ops.aten.view.default(mm_232, [2, 8192, 4096]); mm_232 = None + add_131 = torch.ops.aten.add.Tensor(view_1098, view_1100); view_1098 = view_1100 = None + convert_element_type_1091 = torch.ops.prims.convert_element_type.default(mm_231, torch.float32); mm_231 = None + reduce_scatter_tensor_4 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1091, 'avg', 256, '0'); convert_element_type_1091 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_4); reduce_scatter_tensor_4 = None + convert_element_type_1092 = torch.ops.prims.convert_element_type.default(add_131, torch.float32); add_131 = None + convert_element_type_1094 = torch.ops.prims.convert_element_type.default(wait_tensor_285, torch.float32); wait_tensor_285 = None + mul_270 = torch.ops.aten.mul.Tensor(convert_element_type_1092, convert_element_type_1094); convert_element_type_1094 = None + mul_272 = torch.ops.aten.mul.Tensor(mul_252, mul_270) + sum_3 = torch.ops.aten.sum.dim_IntList(mul_272, [2], True); mul_272 = None + div_1 = torch.ops.aten.div.Tensor(mul_252, 4096) + mul_273 = torch.ops.aten.mul.Tensor(div_1, sum_3); div_1 = sum_3 = None + sub_2 = torch.ops.aten.sub.Tensor(mul_270, mul_273); mul_270 = mul_273 = None + mul_274 = torch.ops.aten.mul.Tensor(sub_2, rsqrt_63); sub_2 = rsqrt_63 = None + mul_275 = torch.ops.aten.mul.Tensor(convert_element_type_1092, mul_252); convert_element_type_1092 = mul_252 = None + sum_4 = torch.ops.aten.sum.dim_IntList(mul_275, [0, 1]); mul_275 = None + convert_element_type_1095 = torch.ops.prims.convert_element_type.default(mul_274, torch.bfloat16); mul_274 = None + add_132 = torch.ops.aten.add.Tensor(convert_element_type_1071, convert_element_type_1095); convert_element_type_1071 = convert_element_type_1095 = None + convert_element_type_default_64 = torch.ops.prims.convert_element_type.default(sum_4, torch.float32); sum_4 = None + reduce_scatter_tensor_5 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_64, 'avg', 256, '0'); convert_element_type_default_64 = None + wait_tensor_296 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5); reduce_scatter_tensor_5 = None + view_1101 = torch.ops.aten.view.default(add_132, [16384, 4096]) + permute_369 = torch.ops.aten.permute.default(view_1101, [1, 0]) + mm_233 = torch.ops.aten.mm.default(permute_369, view_1077); permute_369 = view_1077 = None + permute_371 = torch.ops.aten.permute.default(permute_348, [1, 0]); permute_348 = None + mm_234 = torch.ops.aten.mm.default(view_1101, permute_371); view_1101 = permute_371 = None + view_1102 = torch.ops.aten.view.default(mm_234, [2, 8192, 4096]); mm_234 = None + convert_element_type_1102 = torch.ops.prims.convert_element_type.default(mm_233, torch.float32); mm_233 = None + reduce_scatter_tensor_6 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1102, 'avg', 256, '0'); convert_element_type_1102 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_6); reduce_scatter_tensor_6 = None + view_1103 = torch.ops.aten.view.default(view_1102, [2, 8192, 32, 128]); view_1102 = None + permute_373 = torch.ops.aten.permute.default(view_1103, [0, 2, 1, 3]); view_1103 = None + view_16 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]); primals_3 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16); primals_283 = None + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 256, '0'); convert_element_type_1024 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32); add_123 = None + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_280) + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + view_1057 = torch.ops.aten.view.default(convert_element_type_1026, [16384, 4096]); convert_element_type_1026 = None + view_1058 = torch.ops.aten.view.default(mm_217, [2, 8192, 4096]); mm_217 = None + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16); primals_285 = None + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 256, '0'); convert_element_type_1030 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_282, [1, 0]); wait_tensor_282 = None + mm_218 = torch.ops.aten.mm.default(view_1057, permute_342) + view_1061 = torch.ops.aten.view.default(mm_218, [2, 8192, 1024]); mm_218 = None + view_1064 = torch.ops.aten.view.default(mm_219, [2, 8192, 1024]); mm_219 = None + view_1065 = torch.ops.aten.view.default(view_1058, [2, 8192, -1, 128]); view_1058 = None + view_1066 = torch.ops.aten.view.default(view_1061, [2, 8192, -1, 128]); view_1061 = None + view_1067 = torch.ops.aten.view.default(view_1064, [2, 8192, -1, 128]); view_1064 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_1065, torch.float32); view_1065 = None + view_1068 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 32, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_1068); view_1068 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_1066, torch.float32); view_1066 = None + view_1069 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 8, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_1069); view_1069 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_16); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_1071 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 32, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_16); view_as_complex_63 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_1072 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 8, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_1071, torch.bfloat16); view_1071 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_1072, torch.bfloat16); view_1072 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 8, 4, 128]); unsqueeze_62 = None + clone_62 = torch.ops.aten.clone.default(expand_62, memory_format = torch.contiguous_format); expand_62 = None + view_1073 = torch.ops.aten.view.default(clone_62, [2, 8192, 32, 128]); clone_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_1067, 3); view_1067 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 8, 4, 128]); unsqueeze_63 = None + clone_63 = torch.ops.aten.clone.default(expand_63, memory_format = torch.contiguous_format); expand_63 = None + view_1074 = torch.ops.aten.view.default(clone_63, [2, 8192, 32, 128]); clone_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_1073, [0, 2, 1, 3]); view_1073 = None + permute_346 = torch.ops.aten.permute.default(view_1074, [0, 2, 1, 3]); view_1074 = None + _scaled_dot_product_cudnn_attention_backward = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_373, permute_344, permute_345, permute_346, getitem_279, getitem_280, getitem_285, getitem_286, None, None, None, 8192, 8192, 0.0, True); permute_373 = permute_344 = permute_345 = permute_346 = getitem_279 = getitem_280 = getitem_285 = getitem_286 = None + getitem_288 = _scaled_dot_product_cudnn_attention_backward[0] + getitem_289 = _scaled_dot_product_cudnn_attention_backward[1] + getitem_290 = _scaled_dot_product_cudnn_attention_backward[2]; _scaled_dot_product_cudnn_attention_backward = None + permute_374 = torch.ops.aten.permute.default(getitem_290, [0, 2, 1, 3]); getitem_290 = None + permute_375 = torch.ops.aten.permute.default(getitem_289, [0, 2, 1, 3]); getitem_289 = None + permute_376 = torch.ops.aten.permute.default(getitem_288, [0, 2, 1, 3]); getitem_288 = None + view_1104 = torch.ops.aten.view.default(permute_374, [2, 8192, 8, 4, 128]); permute_374 = None + sum_5 = torch.ops.aten.sum.dim_IntList(view_1104, [3], True); view_1104 = None + squeeze = torch.ops.aten.squeeze.dim(sum_5, 3); sum_5 = None + view_1105 = torch.ops.aten.view.default(permute_375, [2, 8192, 8, 4, 128]); permute_375 = None + sum_6 = torch.ops.aten.sum.dim_IntList(view_1105, [3], True); view_1105 = None + squeeze_1 = torch.ops.aten.squeeze.dim(sum_6, 3); sum_6 = None + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(squeeze_1, torch.float32); squeeze_1 = None + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(permute_376, torch.float32); permute_376 = None + view_1106 = torch.ops.aten.view.default(convert_element_type_1103, [2, 8192, 8, 64, 2]); convert_element_type_1103 = None + view_as_complex_64 = torch.ops.aten.view_as_complex.default(view_1106); view_1106 = None + _conj = torch.ops.aten._conj.default(view_16) + mul_276 = torch.ops.aten.mul.Tensor(view_as_complex_64, _conj); view_as_complex_64 = None + view_1107 = torch.ops.aten.view.default(convert_element_type_1104, [2, 8192, 32, 64, 2]); convert_element_type_1104 = None + view_as_complex_65 = torch.ops.aten.view_as_complex.default(view_1107); view_1107 = None + mul_277 = torch.ops.aten.mul.Tensor(view_as_complex_65, _conj); view_as_complex_65 = None + view_as_real_64 = torch.ops.aten.view_as_real.default(mul_276); mul_276 = None + view_1108 = torch.ops.aten.view.default(view_as_real_64, [2, 8192, 8, 128]); view_as_real_64 = None + convert_element_type_1105 = torch.ops.prims.convert_element_type.default(view_1108, torch.bfloat16); view_1108 = None + view_as_real_65 = torch.ops.aten.view_as_real.default(mul_277); mul_277 = None + view_1109 = torch.ops.aten.view.default(view_as_real_65, [2, 8192, 32, 128]); view_as_real_65 = None + convert_element_type_1106 = torch.ops.prims.convert_element_type.default(view_1109, torch.bfloat16); view_1109 = None + view_1110 = torch.ops.aten.view.default(squeeze, [2, 8192, 1024]); squeeze = None + view_1111 = torch.ops.aten.view.default(convert_element_type_1105, [2, 8192, 1024]); convert_element_type_1105 = None + view_1112 = torch.ops.aten.view.default(convert_element_type_1106, [2, 8192, 4096]); convert_element_type_1106 = None + view_1113 = torch.ops.aten.view.default(view_1110, [16384, 1024]); view_1110 = None + permute_377 = torch.ops.aten.permute.default(view_1113, [1, 0]) + mm_235 = torch.ops.aten.mm.default(permute_377, view_1057); permute_377 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16); primals_286 = None + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 256, '0'); convert_element_type_1033 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_283, [1, 0]); wait_tensor_283 = None + permute_379 = torch.ops.aten.permute.default(permute_343, [1, 0]); permute_343 = None + mm_236 = torch.ops.aten.mm.default(view_1113, permute_379); view_1113 = permute_379 = None + view_1114 = torch.ops.aten.view.default(mm_236, [2, 8192, 4096]); mm_236 = None + convert_element_type_1111 = torch.ops.prims.convert_element_type.default(mm_235, torch.float32); mm_235 = None + reduce_scatter_tensor_7 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1111, 'avg', 256, '0'); convert_element_type_1111 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7); reduce_scatter_tensor_7 = None + view_1115 = torch.ops.aten.view.default(view_1111, [16384, 1024]); view_1111 = None + permute_381 = torch.ops.aten.permute.default(view_1115, [1, 0]) + mm_237 = torch.ops.aten.mm.default(permute_381, view_1057); permute_381 = None + permute_383 = torch.ops.aten.permute.default(permute_342, [1, 0]); permute_342 = None + mm_238 = torch.ops.aten.mm.default(view_1115, permute_383); view_1115 = permute_383 = None + view_1116 = torch.ops.aten.view.default(mm_238, [2, 8192, 4096]); mm_238 = None + add_133 = torch.ops.aten.add.Tensor(view_1114, view_1116); view_1114 = view_1116 = None + convert_element_type_1116 = torch.ops.prims.convert_element_type.default(mm_237, torch.float32); mm_237 = None + reduce_scatter_tensor_8 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1116, 'avg', 256, '0'); convert_element_type_1116 = None + wait_tensor_299 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_8); reduce_scatter_tensor_8 = None + view_1117 = torch.ops.aten.view.default(view_1112, [16384, 4096]); view_1112 = None + permute_385 = torch.ops.aten.permute.default(view_1117, [1, 0]) + mm_239 = torch.ops.aten.mm.default(permute_385, view_1057); permute_385 = view_1057 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16); primals_284 = None + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 256, '0'); convert_element_type_1027 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_281, [1, 0]); wait_tensor_281 = None + permute_387 = torch.ops.aten.permute.default(permute_341, [1, 0]); permute_341 = None + mm_240 = torch.ops.aten.mm.default(view_1117, permute_387); view_1117 = permute_387 = None + view_1118 = torch.ops.aten.view.default(mm_240, [2, 8192, 4096]); mm_240 = None + add_134 = torch.ops.aten.add.Tensor(add_133, view_1118); add_133 = view_1118 = None + convert_element_type_1121 = torch.ops.prims.convert_element_type.default(mm_239, torch.float32); mm_239 = None + reduce_scatter_tensor_9 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1121, 'avg', 256, '0'); convert_element_type_1121 = None + wait_tensor_300 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9); reduce_scatter_tensor_9 = None + convert_element_type_1122 = torch.ops.prims.convert_element_type.default(add_134, torch.float32); add_134 = None + convert_element_type_1124 = torch.ops.prims.convert_element_type.default(wait_tensor_280, torch.float32); wait_tensor_280 = None + mul_278 = torch.ops.aten.mul.Tensor(convert_element_type_1122, convert_element_type_1124); convert_element_type_1124 = None + mul_280 = torch.ops.aten.mul.Tensor(mul_248, mul_278) + sum_7 = torch.ops.aten.sum.dim_IntList(mul_280, [2], True); mul_280 = None + div_2 = torch.ops.aten.div.Tensor(mul_248, 4096) + mul_281 = torch.ops.aten.mul.Tensor(div_2, sum_7); div_2 = sum_7 = None + sub_3 = torch.ops.aten.sub.Tensor(mul_278, mul_281); mul_278 = mul_281 = None + mul_282 = torch.ops.aten.mul.Tensor(sub_3, rsqrt_62); sub_3 = rsqrt_62 = None + mul_283 = torch.ops.aten.mul.Tensor(convert_element_type_1122, mul_248); convert_element_type_1122 = mul_248 = None + sum_8 = torch.ops.aten.sum.dim_IntList(mul_283, [0, 1]); mul_283 = None + convert_element_type_1125 = torch.ops.prims.convert_element_type.default(mul_282, torch.bfloat16); mul_282 = None + add_135 = torch.ops.aten.add.Tensor(add_132, convert_element_type_1125); add_132 = convert_element_type_1125 = None + convert_element_type_default_63 = torch.ops.prims.convert_element_type.default(sum_8, torch.float32); sum_8 = None + reduce_scatter_tensor_10 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_63, 'avg', 256, '0'); convert_element_type_default_63 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_10); reduce_scatter_tensor_10 = None + view_1119 = torch.ops.aten.view.default(add_135, [16384, 4096]) + permute_389 = torch.ops.aten.permute.default(view_1119, [1, 0]) + permute_336 = torch.ops.aten.permute.default(getitem_270, [0, 2, 1, 3]) + view_1041 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16); primals_278 = None + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 256, '0'); convert_element_type_1007 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_275, [1, 0]); wait_tensor_275 = None + view_1043 = torch.ops.aten.view.default(view_1041, [16384, 4096]); view_1041 = None + mm_213 = torch.ops.aten.mm.default(view_1043, permute_337) + view_1044 = torch.ops.aten.view.default(mm_213, [2, 8192, 4096]); mm_213 = None + add_121 = torch.ops.aten.add.Tensor(add_119, view_1044); view_1044 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16); primals_279 = None + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 256, '0'); convert_element_type_1010 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32); add_121 = None + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_276) + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + view_1047 = torch.ops.aten.view.default(convert_element_type_1012, [16384, 4096]); convert_element_type_1012 = None + view_1048 = torch.ops.aten.view.default(mm_214, [2, 8192, 14336]); mm_214 = None + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_1048, torch.float32); view_1048 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16); primals_281 = None + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 256, '0'); convert_element_type_1018 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_215 = torch.ops.aten.mm.default(view_1047, permute_339) + view_1051 = torch.ops.aten.view.default(mm_215, [2, 8192, 14336]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_1051) + view_1053 = torch.ops.aten.view.default(mul_247, [16384, 14336]); mul_247 = None + mm_241 = torch.ops.aten.mm.default(permute_389, view_1053); permute_389 = view_1053 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16); primals_282 = None + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 256, '0'); convert_element_type_1021 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + permute_391 = torch.ops.aten.permute.default(permute_340, [1, 0]); permute_340 = None + mm_242 = torch.ops.aten.mm.default(view_1119, permute_391); view_1119 = permute_391 = None + view_1120 = torch.ops.aten.view.default(mm_242, [2, 8192, 14336]); mm_242 = None + convert_element_type_1132 = torch.ops.prims.convert_element_type.default(mm_241, torch.float32); mm_241 = None + reduce_scatter_tensor_11 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1132, 'avg', 256, '0'); convert_element_type_1132 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11); reduce_scatter_tensor_11 = None + mul_284 = torch.ops.aten.mul.Tensor(view_1120, convert_element_type_1017); convert_element_type_1017 = None + mul_285 = torch.ops.aten.mul.Tensor(view_1120, view_1051); view_1120 = view_1051 = None + view_1121 = torch.ops.aten.view.default(mul_284, [16384, 14336]); mul_284 = None + permute_393 = torch.ops.aten.permute.default(view_1121, [1, 0]) + mm_243 = torch.ops.aten.mm.default(permute_393, view_1047); permute_393 = None + permute_395 = torch.ops.aten.permute.default(permute_339, [1, 0]); permute_339 = None + mm_244 = torch.ops.aten.mm.default(view_1121, permute_395); view_1121 = permute_395 = None + view_1122 = torch.ops.aten.view.default(mm_244, [2, 8192, 4096]); mm_244 = None + convert_element_type_1137 = torch.ops.prims.convert_element_type.default(mm_243, torch.float32); mm_243 = None + reduce_scatter_tensor_12 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1137, 'avg', 256, '0'); convert_element_type_1137 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_12); reduce_scatter_tensor_12 = None + convert_element_type_1138 = torch.ops.prims.convert_element_type.default(mul_285, torch.float32); mul_285 = None + neg_1 = torch.ops.aten.neg.default(convert_element_type_1016) + exp_1 = torch.ops.aten.exp.default(neg_1); neg_1 = None + add_136 = torch.ops.aten.add.Tensor(exp_1, 1); exp_1 = None + reciprocal_1 = torch.ops.aten.reciprocal.default(add_136); add_136 = None + mul_286 = torch.ops.aten.mul.Tensor(reciprocal_1, 1); reciprocal_1 = None + mul_287 = torch.ops.aten.mul.Tensor(convert_element_type_1138, mul_286); convert_element_type_1138 = None + sub_4 = torch.ops.aten.sub.Tensor(1, mul_286); mul_286 = None + mul_288 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sub_4); convert_element_type_1016 = sub_4 = None + add_137 = torch.ops.aten.add.Tensor(mul_288, 1); mul_288 = None + mul_289 = torch.ops.aten.mul.Tensor(mul_287, add_137); mul_287 = add_137 = None + convert_element_type_1140 = torch.ops.prims.convert_element_type.default(mul_289, torch.bfloat16); mul_289 = None + view_1123 = torch.ops.aten.view.default(convert_element_type_1140, [16384, 14336]); convert_element_type_1140 = None + permute_397 = torch.ops.aten.permute.default(view_1123, [1, 0]) + mm_245 = torch.ops.aten.mm.default(permute_397, view_1047); permute_397 = view_1047 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16); primals_280 = None + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 256, '0'); convert_element_type_1013 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + permute_399 = torch.ops.aten.permute.default(permute_338, [1, 0]); permute_338 = None + mm_246 = torch.ops.aten.mm.default(view_1123, permute_399); view_1123 = permute_399 = None + view_1124 = torch.ops.aten.view.default(mm_246, [2, 8192, 4096]); mm_246 = None + add_138 = torch.ops.aten.add.Tensor(view_1122, view_1124); view_1122 = view_1124 = None + convert_element_type_1145 = torch.ops.prims.convert_element_type.default(mm_245, torch.float32); mm_245 = None + reduce_scatter_tensor_13 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1145, 'avg', 256, '0'); convert_element_type_1145 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13); reduce_scatter_tensor_13 = None + convert_element_type_1146 = torch.ops.prims.convert_element_type.default(add_138, torch.float32); add_138 = None + convert_element_type_1148 = torch.ops.prims.convert_element_type.default(wait_tensor_276, torch.float32); wait_tensor_276 = None + mul_290 = torch.ops.aten.mul.Tensor(convert_element_type_1146, convert_element_type_1148); convert_element_type_1148 = None + mul_292 = torch.ops.aten.mul.Tensor(mul_244, mul_290) + sum_9 = torch.ops.aten.sum.dim_IntList(mul_292, [2], True); mul_292 = None + div_3 = torch.ops.aten.div.Tensor(mul_244, 4096) + mul_293 = torch.ops.aten.mul.Tensor(div_3, sum_9); div_3 = sum_9 = None + sub_5 = torch.ops.aten.sub.Tensor(mul_290, mul_293); mul_290 = mul_293 = None + mul_294 = torch.ops.aten.mul.Tensor(sub_5, rsqrt_61); sub_5 = rsqrt_61 = None + mul_295 = torch.ops.aten.mul.Tensor(convert_element_type_1146, mul_244); convert_element_type_1146 = mul_244 = None + sum_10 = torch.ops.aten.sum.dim_IntList(mul_295, [0, 1]); mul_295 = None + convert_element_type_1149 = torch.ops.prims.convert_element_type.default(mul_294, torch.bfloat16); mul_294 = None + add_139 = torch.ops.aten.add.Tensor(add_135, convert_element_type_1149); add_135 = convert_element_type_1149 = None + convert_element_type_default_62 = torch.ops.prims.convert_element_type.default(sum_10, torch.float32); sum_10 = None + reduce_scatter_tensor_14 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_62, 'avg', 256, '0'); convert_element_type_default_62 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_14); reduce_scatter_tensor_14 = None + view_1125 = torch.ops.aten.view.default(add_139, [16384, 4096]) + permute_401 = torch.ops.aten.permute.default(view_1125, [1, 0]) + mm_247 = torch.ops.aten.mm.default(permute_401, view_1043); permute_401 = view_1043 = None + permute_403 = torch.ops.aten.permute.default(permute_337, [1, 0]); permute_337 = None + mm_248 = torch.ops.aten.mm.default(view_1125, permute_403); view_1125 = permute_403 = None + view_1126 = torch.ops.aten.view.default(mm_248, [2, 8192, 4096]); mm_248 = None + convert_element_type_1156 = torch.ops.prims.convert_element_type.default(mm_247, torch.float32); mm_247 = None + reduce_scatter_tensor_15 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1156, 'avg', 256, '0'); convert_element_type_1156 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15); reduce_scatter_tensor_15 = None + view_1127 = torch.ops.aten.view.default(view_1126, [2, 8192, 32, 128]); view_1126 = None + permute_405 = torch.ops.aten.permute.default(view_1127, [0, 2, 1, 3]); view_1127 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16); primals_274 = None + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 256, '0'); convert_element_type_991 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32); add_119 = None + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_271) + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + view_1023 = torch.ops.aten.view.default(convert_element_type_993, [16384, 4096]); convert_element_type_993 = None + view_1024 = torch.ops.aten.view.default(mm_210, [2, 8192, 4096]); mm_210 = None + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16); primals_276 = None + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 256, '0'); convert_element_type_997 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + mm_211 = torch.ops.aten.mm.default(view_1023, permute_331) + view_1027 = torch.ops.aten.view.default(mm_211, [2, 8192, 1024]); mm_211 = None + view_1030 = torch.ops.aten.view.default(mm_212, [2, 8192, 1024]); mm_212 = None + view_1031 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1032 = torch.ops.aten.view.default(view_1027, [2, 8192, -1, 128]); view_1027 = None + view_1033 = torch.ops.aten.view.default(view_1030, [2, 8192, -1, 128]); view_1030 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_1031, torch.float32); view_1031 = None + view_1034 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 32, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_1034); view_1034 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_1032, torch.float32); view_1032 = None + view_1035 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 8, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_1035); view_1035 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_16); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_1037 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 32, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_16); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_1038 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 8, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_1037, torch.bfloat16); view_1037 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_1038, torch.bfloat16); view_1038 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 8, 4, 128]); unsqueeze_60 = None + clone_60 = torch.ops.aten.clone.default(expand_60, memory_format = torch.contiguous_format); expand_60 = None + view_1039 = torch.ops.aten.view.default(clone_60, [2, 8192, 32, 128]); clone_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_1033, 3); view_1033 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 8, 4, 128]); unsqueeze_61 = None + clone_61 = torch.ops.aten.clone.default(expand_61, memory_format = torch.contiguous_format); expand_61 = None + view_1040 = torch.ops.aten.view.default(clone_61, [2, 8192, 32, 128]); clone_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_1039, [0, 2, 1, 3]); view_1039 = None + permute_335 = torch.ops.aten.permute.default(view_1040, [0, 2, 1, 3]); view_1040 = None + _scaled_dot_product_cudnn_attention_backward_1 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_405, permute_333, permute_334, permute_335, getitem_270, getitem_271, getitem_276, getitem_277, None, None, None, 8192, 8192, 0.0, True); permute_405 = permute_333 = permute_334 = permute_335 = getitem_270 = getitem_271 = getitem_276 = getitem_277 = None + getitem_291 = _scaled_dot_product_cudnn_attention_backward_1[0] + getitem_292 = _scaled_dot_product_cudnn_attention_backward_1[1] + getitem_293 = _scaled_dot_product_cudnn_attention_backward_1[2]; _scaled_dot_product_cudnn_attention_backward_1 = None + permute_406 = torch.ops.aten.permute.default(getitem_293, [0, 2, 1, 3]); getitem_293 = None + permute_407 = torch.ops.aten.permute.default(getitem_292, [0, 2, 1, 3]); getitem_292 = None + permute_408 = torch.ops.aten.permute.default(getitem_291, [0, 2, 1, 3]); getitem_291 = None + view_1128 = torch.ops.aten.view.default(permute_406, [2, 8192, 8, 4, 128]); permute_406 = None + sum_11 = torch.ops.aten.sum.dim_IntList(view_1128, [3], True); view_1128 = None + squeeze_2 = torch.ops.aten.squeeze.dim(sum_11, 3); sum_11 = None + view_1129 = torch.ops.aten.view.default(permute_407, [2, 8192, 8, 4, 128]); permute_407 = None + sum_12 = torch.ops.aten.sum.dim_IntList(view_1129, [3], True); view_1129 = None + squeeze_3 = torch.ops.aten.squeeze.dim(sum_12, 3); sum_12 = None + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(squeeze_3, torch.float32); squeeze_3 = None + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(permute_408, torch.float32); permute_408 = None + view_1130 = torch.ops.aten.view.default(convert_element_type_1157, [2, 8192, 8, 64, 2]); convert_element_type_1157 = None + view_as_complex_66 = torch.ops.aten.view_as_complex.default(view_1130); view_1130 = None + mul_296 = torch.ops.aten.mul.Tensor(view_as_complex_66, _conj); view_as_complex_66 = None + view_1131 = torch.ops.aten.view.default(convert_element_type_1158, [2, 8192, 32, 64, 2]); convert_element_type_1158 = None + view_as_complex_67 = torch.ops.aten.view_as_complex.default(view_1131); view_1131 = None + mul_297 = torch.ops.aten.mul.Tensor(view_as_complex_67, _conj); view_as_complex_67 = None + view_as_real_66 = torch.ops.aten.view_as_real.default(mul_296); mul_296 = None + view_1132 = torch.ops.aten.view.default(view_as_real_66, [2, 8192, 8, 128]); view_as_real_66 = None + convert_element_type_1159 = torch.ops.prims.convert_element_type.default(view_1132, torch.bfloat16); view_1132 = None + view_as_real_67 = torch.ops.aten.view_as_real.default(mul_297); mul_297 = None + view_1133 = torch.ops.aten.view.default(view_as_real_67, [2, 8192, 32, 128]); view_as_real_67 = None + convert_element_type_1160 = torch.ops.prims.convert_element_type.default(view_1133, torch.bfloat16); view_1133 = None + view_1134 = torch.ops.aten.view.default(squeeze_2, [2, 8192, 1024]); squeeze_2 = None + view_1135 = torch.ops.aten.view.default(convert_element_type_1159, [2, 8192, 1024]); convert_element_type_1159 = None + view_1136 = torch.ops.aten.view.default(convert_element_type_1160, [2, 8192, 4096]); convert_element_type_1160 = None + view_1137 = torch.ops.aten.view.default(view_1134, [16384, 1024]); view_1134 = None + permute_409 = torch.ops.aten.permute.default(view_1137, [1, 0]) + mm_249 = torch.ops.aten.mm.default(permute_409, view_1023); permute_409 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16); primals_277 = None + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 256, '0'); convert_element_type_1000 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_274, [1, 0]); wait_tensor_274 = None + permute_411 = torch.ops.aten.permute.default(permute_332, [1, 0]); permute_332 = None + mm_250 = torch.ops.aten.mm.default(view_1137, permute_411); view_1137 = permute_411 = None + view_1138 = torch.ops.aten.view.default(mm_250, [2, 8192, 4096]); mm_250 = None + convert_element_type_1165 = torch.ops.prims.convert_element_type.default(mm_249, torch.float32); mm_249 = None + reduce_scatter_tensor_16 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1165, 'avg', 256, '0'); convert_element_type_1165 = None + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_16); reduce_scatter_tensor_16 = None + view_1139 = torch.ops.aten.view.default(view_1135, [16384, 1024]); view_1135 = None + permute_413 = torch.ops.aten.permute.default(view_1139, [1, 0]) + mm_251 = torch.ops.aten.mm.default(permute_413, view_1023); permute_413 = None + permute_415 = torch.ops.aten.permute.default(permute_331, [1, 0]); permute_331 = None + mm_252 = torch.ops.aten.mm.default(view_1139, permute_415); view_1139 = permute_415 = None + view_1140 = torch.ops.aten.view.default(mm_252, [2, 8192, 4096]); mm_252 = None + add_140 = torch.ops.aten.add.Tensor(view_1138, view_1140); view_1138 = view_1140 = None + convert_element_type_1170 = torch.ops.prims.convert_element_type.default(mm_251, torch.float32); mm_251 = None + reduce_scatter_tensor_17 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1170, 'avg', 256, '0'); convert_element_type_1170 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17); reduce_scatter_tensor_17 = None + view_1141 = torch.ops.aten.view.default(view_1136, [16384, 4096]); view_1136 = None + permute_417 = torch.ops.aten.permute.default(view_1141, [1, 0]) + mm_253 = torch.ops.aten.mm.default(permute_417, view_1023); permute_417 = view_1023 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16); primals_275 = None + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 256, '0'); convert_element_type_994 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + permute_419 = torch.ops.aten.permute.default(permute_330, [1, 0]); permute_330 = None + mm_254 = torch.ops.aten.mm.default(view_1141, permute_419); view_1141 = permute_419 = None + view_1142 = torch.ops.aten.view.default(mm_254, [2, 8192, 4096]); mm_254 = None + add_141 = torch.ops.aten.add.Tensor(add_140, view_1142); add_140 = view_1142 = None + convert_element_type_1175 = torch.ops.prims.convert_element_type.default(mm_253, torch.float32); mm_253 = None + reduce_scatter_tensor_18 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1175, 'avg', 256, '0'); convert_element_type_1175 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_18); reduce_scatter_tensor_18 = None + convert_element_type_1176 = torch.ops.prims.convert_element_type.default(add_141, torch.float32); add_141 = None + convert_element_type_1178 = torch.ops.prims.convert_element_type.default(wait_tensor_271, torch.float32); wait_tensor_271 = None + mul_298 = torch.ops.aten.mul.Tensor(convert_element_type_1176, convert_element_type_1178); convert_element_type_1178 = None + mul_300 = torch.ops.aten.mul.Tensor(mul_240, mul_298) + sum_13 = torch.ops.aten.sum.dim_IntList(mul_300, [2], True); mul_300 = None + div_4 = torch.ops.aten.div.Tensor(mul_240, 4096) + mul_301 = torch.ops.aten.mul.Tensor(div_4, sum_13); div_4 = sum_13 = None + sub_6 = torch.ops.aten.sub.Tensor(mul_298, mul_301); mul_298 = mul_301 = None + mul_302 = torch.ops.aten.mul.Tensor(sub_6, rsqrt_60); sub_6 = rsqrt_60 = None + mul_303 = torch.ops.aten.mul.Tensor(convert_element_type_1176, mul_240); convert_element_type_1176 = mul_240 = None + sum_14 = torch.ops.aten.sum.dim_IntList(mul_303, [0, 1]); mul_303 = None + convert_element_type_1179 = torch.ops.prims.convert_element_type.default(mul_302, torch.bfloat16); mul_302 = None + add_142 = torch.ops.aten.add.Tensor(add_139, convert_element_type_1179); add_139 = convert_element_type_1179 = None + convert_element_type_default_61 = torch.ops.prims.convert_element_type.default(sum_14, torch.float32); sum_14 = None + reduce_scatter_tensor_19 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_61, 'avg', 256, '0'); convert_element_type_default_61 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19); reduce_scatter_tensor_19 = None + view_1143 = torch.ops.aten.view.default(add_142, [16384, 4096]) + permute_421 = torch.ops.aten.permute.default(view_1143, [1, 0]) + permute_325 = torch.ops.aten.permute.default(getitem_261, [0, 2, 1, 3]) + view_1007 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16); primals_269 = None + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 256, '0'); convert_element_type_974 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + view_1009 = torch.ops.aten.view.default(view_1007, [16384, 4096]); view_1007 = None + mm_206 = torch.ops.aten.mm.default(view_1009, permute_326) + view_1010 = torch.ops.aten.view.default(mm_206, [2, 8192, 4096]); mm_206 = None + add_117 = torch.ops.aten.add.Tensor(add_115, view_1010); view_1010 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16); primals_270 = None + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 256, '0'); convert_element_type_977 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32); add_117 = None + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_267) + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + view_1013 = torch.ops.aten.view.default(convert_element_type_979, [16384, 4096]); convert_element_type_979 = None + view_1014 = torch.ops.aten.view.default(mm_207, [2, 8192, 14336]); mm_207 = None + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_1014, torch.float32); view_1014 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16); primals_272 = None + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 256, '0'); convert_element_type_985 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_269, [1, 0]); wait_tensor_269 = None + mm_208 = torch.ops.aten.mm.default(view_1013, permute_328) + view_1017 = torch.ops.aten.view.default(mm_208, [2, 8192, 14336]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_1017) + view_1019 = torch.ops.aten.view.default(mul_239, [16384, 14336]); mul_239 = None + mm_255 = torch.ops.aten.mm.default(permute_421, view_1019); permute_421 = view_1019 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16); primals_273 = None + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 256, '0'); convert_element_type_988 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_270, [1, 0]); wait_tensor_270 = None + permute_423 = torch.ops.aten.permute.default(permute_329, [1, 0]); permute_329 = None + mm_256 = torch.ops.aten.mm.default(view_1143, permute_423); view_1143 = permute_423 = None + view_1144 = torch.ops.aten.view.default(mm_256, [2, 8192, 14336]); mm_256 = None + convert_element_type_1186 = torch.ops.prims.convert_element_type.default(mm_255, torch.float32); mm_255 = None + reduce_scatter_tensor_20 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1186, 'avg', 256, '0'); convert_element_type_1186 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_20); reduce_scatter_tensor_20 = None + mul_304 = torch.ops.aten.mul.Tensor(view_1144, convert_element_type_984); convert_element_type_984 = None + mul_305 = torch.ops.aten.mul.Tensor(view_1144, view_1017); view_1144 = view_1017 = None + view_1145 = torch.ops.aten.view.default(mul_304, [16384, 14336]); mul_304 = None + permute_425 = torch.ops.aten.permute.default(view_1145, [1, 0]) + mm_257 = torch.ops.aten.mm.default(permute_425, view_1013); permute_425 = None + permute_427 = torch.ops.aten.permute.default(permute_328, [1, 0]); permute_328 = None + mm_258 = torch.ops.aten.mm.default(view_1145, permute_427); view_1145 = permute_427 = None + view_1146 = torch.ops.aten.view.default(mm_258, [2, 8192, 4096]); mm_258 = None + convert_element_type_1191 = torch.ops.prims.convert_element_type.default(mm_257, torch.float32); mm_257 = None + reduce_scatter_tensor_21 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1191, 'avg', 256, '0'); convert_element_type_1191 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21); reduce_scatter_tensor_21 = None + convert_element_type_1192 = torch.ops.prims.convert_element_type.default(mul_305, torch.float32); mul_305 = None + neg_2 = torch.ops.aten.neg.default(convert_element_type_983) + exp_2 = torch.ops.aten.exp.default(neg_2); neg_2 = None + add_143 = torch.ops.aten.add.Tensor(exp_2, 1); exp_2 = None + reciprocal_2 = torch.ops.aten.reciprocal.default(add_143); add_143 = None + mul_306 = torch.ops.aten.mul.Tensor(reciprocal_2, 1); reciprocal_2 = None + mul_307 = torch.ops.aten.mul.Tensor(convert_element_type_1192, mul_306); convert_element_type_1192 = None + sub_7 = torch.ops.aten.sub.Tensor(1, mul_306); mul_306 = None + mul_308 = torch.ops.aten.mul.Tensor(convert_element_type_983, sub_7); convert_element_type_983 = sub_7 = None + add_144 = torch.ops.aten.add.Tensor(mul_308, 1); mul_308 = None + mul_309 = torch.ops.aten.mul.Tensor(mul_307, add_144); mul_307 = add_144 = None + convert_element_type_1194 = torch.ops.prims.convert_element_type.default(mul_309, torch.bfloat16); mul_309 = None + view_1147 = torch.ops.aten.view.default(convert_element_type_1194, [16384, 14336]); convert_element_type_1194 = None + permute_429 = torch.ops.aten.permute.default(view_1147, [1, 0]) + mm_259 = torch.ops.aten.mm.default(permute_429, view_1013); permute_429 = view_1013 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16); primals_271 = None + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 256, '0'); convert_element_type_980 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_268, [1, 0]); wait_tensor_268 = None + permute_431 = torch.ops.aten.permute.default(permute_327, [1, 0]); permute_327 = None + mm_260 = torch.ops.aten.mm.default(view_1147, permute_431); view_1147 = permute_431 = None + view_1148 = torch.ops.aten.view.default(mm_260, [2, 8192, 4096]); mm_260 = None + add_145 = torch.ops.aten.add.Tensor(view_1146, view_1148); view_1146 = view_1148 = None + convert_element_type_1199 = torch.ops.prims.convert_element_type.default(mm_259, torch.float32); mm_259 = None + reduce_scatter_tensor_22 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1199, 'avg', 256, '0'); convert_element_type_1199 = None + wait_tensor_313 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_22); reduce_scatter_tensor_22 = None + convert_element_type_1200 = torch.ops.prims.convert_element_type.default(add_145, torch.float32); add_145 = None + convert_element_type_1202 = torch.ops.prims.convert_element_type.default(wait_tensor_267, torch.float32); wait_tensor_267 = None + mul_310 = torch.ops.aten.mul.Tensor(convert_element_type_1200, convert_element_type_1202); convert_element_type_1202 = None + mul_312 = torch.ops.aten.mul.Tensor(mul_236, mul_310) + sum_15 = torch.ops.aten.sum.dim_IntList(mul_312, [2], True); mul_312 = None + div_5 = torch.ops.aten.div.Tensor(mul_236, 4096) + mul_313 = torch.ops.aten.mul.Tensor(div_5, sum_15); div_5 = sum_15 = None + sub_8 = torch.ops.aten.sub.Tensor(mul_310, mul_313); mul_310 = mul_313 = None + mul_314 = torch.ops.aten.mul.Tensor(sub_8, rsqrt_59); sub_8 = rsqrt_59 = None + mul_315 = torch.ops.aten.mul.Tensor(convert_element_type_1200, mul_236); convert_element_type_1200 = mul_236 = None + sum_16 = torch.ops.aten.sum.dim_IntList(mul_315, [0, 1]); mul_315 = None + convert_element_type_1203 = torch.ops.prims.convert_element_type.default(mul_314, torch.bfloat16); mul_314 = None + add_146 = torch.ops.aten.add.Tensor(add_142, convert_element_type_1203); add_142 = convert_element_type_1203 = None + convert_element_type_default_60 = torch.ops.prims.convert_element_type.default(sum_16, torch.float32); sum_16 = None + reduce_scatter_tensor_23 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_60, 'avg', 256, '0'); convert_element_type_default_60 = None + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23); reduce_scatter_tensor_23 = None + view_1149 = torch.ops.aten.view.default(add_146, [16384, 4096]) + permute_433 = torch.ops.aten.permute.default(view_1149, [1, 0]) + mm_261 = torch.ops.aten.mm.default(permute_433, view_1009); permute_433 = view_1009 = None + permute_435 = torch.ops.aten.permute.default(permute_326, [1, 0]); permute_326 = None + mm_262 = torch.ops.aten.mm.default(view_1149, permute_435); view_1149 = permute_435 = None + view_1150 = torch.ops.aten.view.default(mm_262, [2, 8192, 4096]); mm_262 = None + convert_element_type_1210 = torch.ops.prims.convert_element_type.default(mm_261, torch.float32); mm_261 = None + reduce_scatter_tensor_24 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1210, 'avg', 256, '0'); convert_element_type_1210 = None + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_24); reduce_scatter_tensor_24 = None + view_1151 = torch.ops.aten.view.default(view_1150, [2, 8192, 32, 128]); view_1150 = None + permute_437 = torch.ops.aten.permute.default(view_1151, [0, 2, 1, 3]); view_1151 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16); primals_265 = None + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 256, '0'); convert_element_type_958 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32); add_115 = None + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_262) + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + view_989 = torch.ops.aten.view.default(convert_element_type_960, [16384, 4096]); convert_element_type_960 = None + view_990 = torch.ops.aten.view.default(mm_203, [2, 8192, 4096]); mm_203 = None + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16); primals_267 = None + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 256, '0'); convert_element_type_964 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + mm_204 = torch.ops.aten.mm.default(view_989, permute_320) + view_993 = torch.ops.aten.view.default(mm_204, [2, 8192, 1024]); mm_204 = None + view_996 = torch.ops.aten.view.default(mm_205, [2, 8192, 1024]); mm_205 = None + view_997 = torch.ops.aten.view.default(view_990, [2, 8192, -1, 128]); view_990 = None + view_998 = torch.ops.aten.view.default(view_993, [2, 8192, -1, 128]); view_993 = None + view_999 = torch.ops.aten.view.default(view_996, [2, 8192, -1, 128]); view_996 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + view_1000 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 32, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_1000); view_1000 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_998, torch.float32); view_998 = None + view_1001 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 8, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_1001); view_1001 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_16); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_1003 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 32, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_16); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_1004 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 8, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_1003, torch.bfloat16); view_1003 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_1004, torch.bfloat16); view_1004 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 8, 4, 128]); unsqueeze_58 = None + clone_58 = torch.ops.aten.clone.default(expand_58, memory_format = torch.contiguous_format); expand_58 = None + view_1005 = torch.ops.aten.view.default(clone_58, [2, 8192, 32, 128]); clone_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_999, 3); view_999 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 8, 4, 128]); unsqueeze_59 = None + clone_59 = torch.ops.aten.clone.default(expand_59, memory_format = torch.contiguous_format); expand_59 = None + view_1006 = torch.ops.aten.view.default(clone_59, [2, 8192, 32, 128]); clone_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_1005, [0, 2, 1, 3]); view_1005 = None + permute_324 = torch.ops.aten.permute.default(view_1006, [0, 2, 1, 3]); view_1006 = None + _scaled_dot_product_cudnn_attention_backward_2 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_437, permute_322, permute_323, permute_324, getitem_261, getitem_262, getitem_267, getitem_268, None, None, None, 8192, 8192, 0.0, True); permute_437 = permute_322 = permute_323 = permute_324 = getitem_261 = getitem_262 = getitem_267 = getitem_268 = None + getitem_294 = _scaled_dot_product_cudnn_attention_backward_2[0] + getitem_295 = _scaled_dot_product_cudnn_attention_backward_2[1] + getitem_296 = _scaled_dot_product_cudnn_attention_backward_2[2]; _scaled_dot_product_cudnn_attention_backward_2 = None + permute_438 = torch.ops.aten.permute.default(getitem_296, [0, 2, 1, 3]); getitem_296 = None + permute_439 = torch.ops.aten.permute.default(getitem_295, [0, 2, 1, 3]); getitem_295 = None + permute_440 = torch.ops.aten.permute.default(getitem_294, [0, 2, 1, 3]); getitem_294 = None + view_1152 = torch.ops.aten.view.default(permute_438, [2, 8192, 8, 4, 128]); permute_438 = None + sum_17 = torch.ops.aten.sum.dim_IntList(view_1152, [3], True); view_1152 = None + squeeze_4 = torch.ops.aten.squeeze.dim(sum_17, 3); sum_17 = None + view_1153 = torch.ops.aten.view.default(permute_439, [2, 8192, 8, 4, 128]); permute_439 = None + sum_18 = torch.ops.aten.sum.dim_IntList(view_1153, [3], True); view_1153 = None + squeeze_5 = torch.ops.aten.squeeze.dim(sum_18, 3); sum_18 = None + convert_element_type_1211 = torch.ops.prims.convert_element_type.default(squeeze_5, torch.float32); squeeze_5 = None + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(permute_440, torch.float32); permute_440 = None + view_1154 = torch.ops.aten.view.default(convert_element_type_1211, [2, 8192, 8, 64, 2]); convert_element_type_1211 = None + view_as_complex_68 = torch.ops.aten.view_as_complex.default(view_1154); view_1154 = None + mul_316 = torch.ops.aten.mul.Tensor(view_as_complex_68, _conj); view_as_complex_68 = None + view_1155 = torch.ops.aten.view.default(convert_element_type_1212, [2, 8192, 32, 64, 2]); convert_element_type_1212 = None + view_as_complex_69 = torch.ops.aten.view_as_complex.default(view_1155); view_1155 = None + mul_317 = torch.ops.aten.mul.Tensor(view_as_complex_69, _conj); view_as_complex_69 = None + view_as_real_68 = torch.ops.aten.view_as_real.default(mul_316); mul_316 = None + view_1156 = torch.ops.aten.view.default(view_as_real_68, [2, 8192, 8, 128]); view_as_real_68 = None + convert_element_type_1213 = torch.ops.prims.convert_element_type.default(view_1156, torch.bfloat16); view_1156 = None + view_as_real_69 = torch.ops.aten.view_as_real.default(mul_317); mul_317 = None + view_1157 = torch.ops.aten.view.default(view_as_real_69, [2, 8192, 32, 128]); view_as_real_69 = None + convert_element_type_1214 = torch.ops.prims.convert_element_type.default(view_1157, torch.bfloat16); view_1157 = None + view_1158 = torch.ops.aten.view.default(squeeze_4, [2, 8192, 1024]); squeeze_4 = None + view_1159 = torch.ops.aten.view.default(convert_element_type_1213, [2, 8192, 1024]); convert_element_type_1213 = None + view_1160 = torch.ops.aten.view.default(convert_element_type_1214, [2, 8192, 4096]); convert_element_type_1214 = None + view_1161 = torch.ops.aten.view.default(view_1158, [16384, 1024]); view_1158 = None + permute_441 = torch.ops.aten.permute.default(view_1161, [1, 0]) + mm_263 = torch.ops.aten.mm.default(permute_441, view_989); permute_441 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16); primals_268 = None + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 256, '0'); convert_element_type_967 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + permute_443 = torch.ops.aten.permute.default(permute_321, [1, 0]); permute_321 = None + mm_264 = torch.ops.aten.mm.default(view_1161, permute_443); view_1161 = permute_443 = None + view_1162 = torch.ops.aten.view.default(mm_264, [2, 8192, 4096]); mm_264 = None + convert_element_type_1219 = torch.ops.prims.convert_element_type.default(mm_263, torch.float32); mm_263 = None + reduce_scatter_tensor_25 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1219, 'avg', 256, '0'); convert_element_type_1219 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25); reduce_scatter_tensor_25 = None + view_1163 = torch.ops.aten.view.default(view_1159, [16384, 1024]); view_1159 = None + permute_445 = torch.ops.aten.permute.default(view_1163, [1, 0]) + mm_265 = torch.ops.aten.mm.default(permute_445, view_989); permute_445 = None + permute_447 = torch.ops.aten.permute.default(permute_320, [1, 0]); permute_320 = None + mm_266 = torch.ops.aten.mm.default(view_1163, permute_447); view_1163 = permute_447 = None + view_1164 = torch.ops.aten.view.default(mm_266, [2, 8192, 4096]); mm_266 = None + add_147 = torch.ops.aten.add.Tensor(view_1162, view_1164); view_1162 = view_1164 = None + convert_element_type_1224 = torch.ops.prims.convert_element_type.default(mm_265, torch.float32); mm_265 = None + reduce_scatter_tensor_26 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1224, 'avg', 256, '0'); convert_element_type_1224 = None + wait_tensor_317 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_26); reduce_scatter_tensor_26 = None + view_1165 = torch.ops.aten.view.default(view_1160, [16384, 4096]); view_1160 = None + permute_449 = torch.ops.aten.permute.default(view_1165, [1, 0]) + mm_267 = torch.ops.aten.mm.default(permute_449, view_989); permute_449 = view_989 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16); primals_266 = None + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 256, '0'); convert_element_type_961 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_263, [1, 0]); wait_tensor_263 = None + permute_451 = torch.ops.aten.permute.default(permute_319, [1, 0]); permute_319 = None + mm_268 = torch.ops.aten.mm.default(view_1165, permute_451); view_1165 = permute_451 = None + view_1166 = torch.ops.aten.view.default(mm_268, [2, 8192, 4096]); mm_268 = None + add_148 = torch.ops.aten.add.Tensor(add_147, view_1166); add_147 = view_1166 = None + convert_element_type_1229 = torch.ops.prims.convert_element_type.default(mm_267, torch.float32); mm_267 = None + reduce_scatter_tensor_27 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1229, 'avg', 256, '0'); convert_element_type_1229 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27); reduce_scatter_tensor_27 = None + convert_element_type_1230 = torch.ops.prims.convert_element_type.default(add_148, torch.float32); add_148 = None + convert_element_type_1232 = torch.ops.prims.convert_element_type.default(wait_tensor_262, torch.float32); wait_tensor_262 = None + mul_318 = torch.ops.aten.mul.Tensor(convert_element_type_1230, convert_element_type_1232); convert_element_type_1232 = None + mul_320 = torch.ops.aten.mul.Tensor(mul_232, mul_318) + sum_19 = torch.ops.aten.sum.dim_IntList(mul_320, [2], True); mul_320 = None + div_6 = torch.ops.aten.div.Tensor(mul_232, 4096) + mul_321 = torch.ops.aten.mul.Tensor(div_6, sum_19); div_6 = sum_19 = None + sub_9 = torch.ops.aten.sub.Tensor(mul_318, mul_321); mul_318 = mul_321 = None + mul_322 = torch.ops.aten.mul.Tensor(sub_9, rsqrt_58); sub_9 = rsqrt_58 = None + mul_323 = torch.ops.aten.mul.Tensor(convert_element_type_1230, mul_232); convert_element_type_1230 = mul_232 = None + sum_20 = torch.ops.aten.sum.dim_IntList(mul_323, [0, 1]); mul_323 = None + convert_element_type_1233 = torch.ops.prims.convert_element_type.default(mul_322, torch.bfloat16); mul_322 = None + add_149 = torch.ops.aten.add.Tensor(add_146, convert_element_type_1233); add_146 = convert_element_type_1233 = None + convert_element_type_default_59 = torch.ops.prims.convert_element_type.default(sum_20, torch.float32); sum_20 = None + reduce_scatter_tensor_28 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_59, 'avg', 256, '0'); convert_element_type_default_59 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_28); reduce_scatter_tensor_28 = None + view_1167 = torch.ops.aten.view.default(add_149, [16384, 4096]) + permute_453 = torch.ops.aten.permute.default(view_1167, [1, 0]) + permute_314 = torch.ops.aten.permute.default(getitem_252, [0, 2, 1, 3]) + view_973 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16); primals_260 = None + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 256, '0'); convert_element_type_941 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_257, [1, 0]); wait_tensor_257 = None + view_975 = torch.ops.aten.view.default(view_973, [16384, 4096]); view_973 = None + mm_199 = torch.ops.aten.mm.default(view_975, permute_315) + view_976 = torch.ops.aten.view.default(mm_199, [2, 8192, 4096]); mm_199 = None + add_113 = torch.ops.aten.add.Tensor(add_111, view_976); view_976 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16); primals_261 = None + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 256, '0'); convert_element_type_944 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32); add_113 = None + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_258) + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + view_979 = torch.ops.aten.view.default(convert_element_type_946, [16384, 4096]); convert_element_type_946 = None + view_980 = torch.ops.aten.view.default(mm_200, [2, 8192, 14336]); mm_200 = None + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_980, torch.float32); view_980 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16); primals_263 = None + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 256, '0'); convert_element_type_952 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + mm_201 = torch.ops.aten.mm.default(view_979, permute_317) + view_983 = torch.ops.aten.view.default(mm_201, [2, 8192, 14336]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_983) + view_985 = torch.ops.aten.view.default(mul_231, [16384, 14336]); mul_231 = None + mm_269 = torch.ops.aten.mm.default(permute_453, view_985); permute_453 = view_985 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16); primals_264 = None + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 256, '0'); convert_element_type_955 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_261, [1, 0]); wait_tensor_261 = None + permute_455 = torch.ops.aten.permute.default(permute_318, [1, 0]); permute_318 = None + mm_270 = torch.ops.aten.mm.default(view_1167, permute_455); view_1167 = permute_455 = None + view_1168 = torch.ops.aten.view.default(mm_270, [2, 8192, 14336]); mm_270 = None + convert_element_type_1240 = torch.ops.prims.convert_element_type.default(mm_269, torch.float32); mm_269 = None + reduce_scatter_tensor_29 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1240, 'avg', 256, '0'); convert_element_type_1240 = None + wait_tensor_320 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29); reduce_scatter_tensor_29 = None + mul_324 = torch.ops.aten.mul.Tensor(view_1168, convert_element_type_951); convert_element_type_951 = None + mul_325 = torch.ops.aten.mul.Tensor(view_1168, view_983); view_1168 = view_983 = None + view_1169 = torch.ops.aten.view.default(mul_324, [16384, 14336]); mul_324 = None + permute_457 = torch.ops.aten.permute.default(view_1169, [1, 0]) + mm_271 = torch.ops.aten.mm.default(permute_457, view_979); permute_457 = None + permute_459 = torch.ops.aten.permute.default(permute_317, [1, 0]); permute_317 = None + mm_272 = torch.ops.aten.mm.default(view_1169, permute_459); view_1169 = permute_459 = None + view_1170 = torch.ops.aten.view.default(mm_272, [2, 8192, 4096]); mm_272 = None + convert_element_type_1245 = torch.ops.prims.convert_element_type.default(mm_271, torch.float32); mm_271 = None + reduce_scatter_tensor_30 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1245, 'avg', 256, '0'); convert_element_type_1245 = None + wait_tensor_321 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_30); reduce_scatter_tensor_30 = None + convert_element_type_1246 = torch.ops.prims.convert_element_type.default(mul_325, torch.float32); mul_325 = None + neg_3 = torch.ops.aten.neg.default(convert_element_type_950) + exp_3 = torch.ops.aten.exp.default(neg_3); neg_3 = None + add_150 = torch.ops.aten.add.Tensor(exp_3, 1); exp_3 = None + reciprocal_3 = torch.ops.aten.reciprocal.default(add_150); add_150 = None + mul_326 = torch.ops.aten.mul.Tensor(reciprocal_3, 1); reciprocal_3 = None + mul_327 = torch.ops.aten.mul.Tensor(convert_element_type_1246, mul_326); convert_element_type_1246 = None + sub_10 = torch.ops.aten.sub.Tensor(1, mul_326); mul_326 = None + mul_328 = torch.ops.aten.mul.Tensor(convert_element_type_950, sub_10); convert_element_type_950 = sub_10 = None + add_151 = torch.ops.aten.add.Tensor(mul_328, 1); mul_328 = None + mul_329 = torch.ops.aten.mul.Tensor(mul_327, add_151); mul_327 = add_151 = None + convert_element_type_1248 = torch.ops.prims.convert_element_type.default(mul_329, torch.bfloat16); mul_329 = None + view_1171 = torch.ops.aten.view.default(convert_element_type_1248, [16384, 14336]); convert_element_type_1248 = None + permute_461 = torch.ops.aten.permute.default(view_1171, [1, 0]) + mm_273 = torch.ops.aten.mm.default(permute_461, view_979); permute_461 = view_979 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16); primals_262 = None + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 256, '0'); convert_element_type_947 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + permute_463 = torch.ops.aten.permute.default(permute_316, [1, 0]); permute_316 = None + mm_274 = torch.ops.aten.mm.default(view_1171, permute_463); view_1171 = permute_463 = None + view_1172 = torch.ops.aten.view.default(mm_274, [2, 8192, 4096]); mm_274 = None + add_152 = torch.ops.aten.add.Tensor(view_1170, view_1172); view_1170 = view_1172 = None + convert_element_type_1253 = torch.ops.prims.convert_element_type.default(mm_273, torch.float32); mm_273 = None + reduce_scatter_tensor_31 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1253, 'avg', 256, '0'); convert_element_type_1253 = None + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31); reduce_scatter_tensor_31 = None + convert_element_type_1254 = torch.ops.prims.convert_element_type.default(add_152, torch.float32); add_152 = None + convert_element_type_1256 = torch.ops.prims.convert_element_type.default(wait_tensor_258, torch.float32); wait_tensor_258 = None + mul_330 = torch.ops.aten.mul.Tensor(convert_element_type_1254, convert_element_type_1256); convert_element_type_1256 = None + mul_332 = torch.ops.aten.mul.Tensor(mul_228, mul_330) + sum_21 = torch.ops.aten.sum.dim_IntList(mul_332, [2], True); mul_332 = None + div_7 = torch.ops.aten.div.Tensor(mul_228, 4096) + mul_333 = torch.ops.aten.mul.Tensor(div_7, sum_21); div_7 = sum_21 = None + sub_11 = torch.ops.aten.sub.Tensor(mul_330, mul_333); mul_330 = mul_333 = None + mul_334 = torch.ops.aten.mul.Tensor(sub_11, rsqrt_57); sub_11 = rsqrt_57 = None + mul_335 = torch.ops.aten.mul.Tensor(convert_element_type_1254, mul_228); convert_element_type_1254 = mul_228 = None + sum_22 = torch.ops.aten.sum.dim_IntList(mul_335, [0, 1]); mul_335 = None + convert_element_type_1257 = torch.ops.prims.convert_element_type.default(mul_334, torch.bfloat16); mul_334 = None + add_153 = torch.ops.aten.add.Tensor(add_149, convert_element_type_1257); add_149 = convert_element_type_1257 = None + convert_element_type_default_58 = torch.ops.prims.convert_element_type.default(sum_22, torch.float32); sum_22 = None + reduce_scatter_tensor_32 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_58, 'avg', 256, '0'); convert_element_type_default_58 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_32); reduce_scatter_tensor_32 = None + view_1173 = torch.ops.aten.view.default(add_153, [16384, 4096]) + permute_465 = torch.ops.aten.permute.default(view_1173, [1, 0]) + mm_275 = torch.ops.aten.mm.default(permute_465, view_975); permute_465 = view_975 = None + permute_467 = torch.ops.aten.permute.default(permute_315, [1, 0]); permute_315 = None + mm_276 = torch.ops.aten.mm.default(view_1173, permute_467); view_1173 = permute_467 = None + view_1174 = torch.ops.aten.view.default(mm_276, [2, 8192, 4096]); mm_276 = None + convert_element_type_1264 = torch.ops.prims.convert_element_type.default(mm_275, torch.float32); mm_275 = None + reduce_scatter_tensor_33 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1264, 'avg', 256, '0'); convert_element_type_1264 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33); reduce_scatter_tensor_33 = None + view_1175 = torch.ops.aten.view.default(view_1174, [2, 8192, 32, 128]); view_1174 = None + permute_469 = torch.ops.aten.permute.default(view_1175, [0, 2, 1, 3]); view_1175 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16); primals_256 = None + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 256, '0'); convert_element_type_925 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32); add_111 = None + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_253) + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + view_955 = torch.ops.aten.view.default(convert_element_type_927, [16384, 4096]); convert_element_type_927 = None + view_956 = torch.ops.aten.view.default(mm_196, [2, 8192, 4096]); mm_196 = None + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16); primals_258 = None + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 256, '0'); convert_element_type_931 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_255, [1, 0]); wait_tensor_255 = None + mm_197 = torch.ops.aten.mm.default(view_955, permute_309) + view_959 = torch.ops.aten.view.default(mm_197, [2, 8192, 1024]); mm_197 = None + view_962 = torch.ops.aten.view.default(mm_198, [2, 8192, 1024]); mm_198 = None + view_963 = torch.ops.aten.view.default(view_956, [2, 8192, -1, 128]); view_956 = None + view_964 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_965 = torch.ops.aten.view.default(view_962, [2, 8192, -1, 128]); view_962 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_963, torch.float32); view_963 = None + view_966 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 32, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_966); view_966 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_964, torch.float32); view_964 = None + view_967 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 8, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_967); view_967 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_16); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_969 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 32, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_16); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_970 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 8, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_969, torch.bfloat16); view_969 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_970, torch.bfloat16); view_970 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 8, 4, 128]); unsqueeze_56 = None + clone_56 = torch.ops.aten.clone.default(expand_56, memory_format = torch.contiguous_format); expand_56 = None + view_971 = torch.ops.aten.view.default(clone_56, [2, 8192, 32, 128]); clone_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_965, 3); view_965 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 8, 4, 128]); unsqueeze_57 = None + clone_57 = torch.ops.aten.clone.default(expand_57, memory_format = torch.contiguous_format); expand_57 = None + view_972 = torch.ops.aten.view.default(clone_57, [2, 8192, 32, 128]); clone_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_971, [0, 2, 1, 3]); view_971 = None + permute_313 = torch.ops.aten.permute.default(view_972, [0, 2, 1, 3]); view_972 = None + _scaled_dot_product_cudnn_attention_backward_3 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_469, permute_311, permute_312, permute_313, getitem_252, getitem_253, getitem_258, getitem_259, None, None, None, 8192, 8192, 0.0, True); permute_469 = permute_311 = permute_312 = permute_313 = getitem_252 = getitem_253 = getitem_258 = getitem_259 = None + getitem_297 = _scaled_dot_product_cudnn_attention_backward_3[0] + getitem_298 = _scaled_dot_product_cudnn_attention_backward_3[1] + getitem_299 = _scaled_dot_product_cudnn_attention_backward_3[2]; _scaled_dot_product_cudnn_attention_backward_3 = None + permute_470 = torch.ops.aten.permute.default(getitem_299, [0, 2, 1, 3]); getitem_299 = None + permute_471 = torch.ops.aten.permute.default(getitem_298, [0, 2, 1, 3]); getitem_298 = None + permute_472 = torch.ops.aten.permute.default(getitem_297, [0, 2, 1, 3]); getitem_297 = None + view_1176 = torch.ops.aten.view.default(permute_470, [2, 8192, 8, 4, 128]); permute_470 = None + sum_23 = torch.ops.aten.sum.dim_IntList(view_1176, [3], True); view_1176 = None + squeeze_6 = torch.ops.aten.squeeze.dim(sum_23, 3); sum_23 = None + view_1177 = torch.ops.aten.view.default(permute_471, [2, 8192, 8, 4, 128]); permute_471 = None + sum_24 = torch.ops.aten.sum.dim_IntList(view_1177, [3], True); view_1177 = None + squeeze_7 = torch.ops.aten.squeeze.dim(sum_24, 3); sum_24 = None + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(squeeze_7, torch.float32); squeeze_7 = None + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(permute_472, torch.float32); permute_472 = None + view_1178 = torch.ops.aten.view.default(convert_element_type_1265, [2, 8192, 8, 64, 2]); convert_element_type_1265 = None + view_as_complex_70 = torch.ops.aten.view_as_complex.default(view_1178); view_1178 = None + mul_336 = torch.ops.aten.mul.Tensor(view_as_complex_70, _conj); view_as_complex_70 = None + view_1179 = torch.ops.aten.view.default(convert_element_type_1266, [2, 8192, 32, 64, 2]); convert_element_type_1266 = None + view_as_complex_71 = torch.ops.aten.view_as_complex.default(view_1179); view_1179 = None + mul_337 = torch.ops.aten.mul.Tensor(view_as_complex_71, _conj); view_as_complex_71 = None + view_as_real_70 = torch.ops.aten.view_as_real.default(mul_336); mul_336 = None + view_1180 = torch.ops.aten.view.default(view_as_real_70, [2, 8192, 8, 128]); view_as_real_70 = None + convert_element_type_1267 = torch.ops.prims.convert_element_type.default(view_1180, torch.bfloat16); view_1180 = None + view_as_real_71 = torch.ops.aten.view_as_real.default(mul_337); mul_337 = None + view_1181 = torch.ops.aten.view.default(view_as_real_71, [2, 8192, 32, 128]); view_as_real_71 = None + convert_element_type_1268 = torch.ops.prims.convert_element_type.default(view_1181, torch.bfloat16); view_1181 = None + view_1182 = torch.ops.aten.view.default(squeeze_6, [2, 8192, 1024]); squeeze_6 = None + view_1183 = torch.ops.aten.view.default(convert_element_type_1267, [2, 8192, 1024]); convert_element_type_1267 = None + view_1184 = torch.ops.aten.view.default(convert_element_type_1268, [2, 8192, 4096]); convert_element_type_1268 = None + view_1185 = torch.ops.aten.view.default(view_1182, [16384, 1024]); view_1182 = None + permute_473 = torch.ops.aten.permute.default(view_1185, [1, 0]) + mm_277 = torch.ops.aten.mm.default(permute_473, view_955); permute_473 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16); primals_259 = None + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 256, '0'); convert_element_type_934 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_256, [1, 0]); wait_tensor_256 = None + permute_475 = torch.ops.aten.permute.default(permute_310, [1, 0]); permute_310 = None + mm_278 = torch.ops.aten.mm.default(view_1185, permute_475); view_1185 = permute_475 = None + view_1186 = torch.ops.aten.view.default(mm_278, [2, 8192, 4096]); mm_278 = None + convert_element_type_1273 = torch.ops.prims.convert_element_type.default(mm_277, torch.float32); mm_277 = None + reduce_scatter_tensor_34 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1273, 'avg', 256, '0'); convert_element_type_1273 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_34); reduce_scatter_tensor_34 = None + view_1187 = torch.ops.aten.view.default(view_1183, [16384, 1024]); view_1183 = None + permute_477 = torch.ops.aten.permute.default(view_1187, [1, 0]) + mm_279 = torch.ops.aten.mm.default(permute_477, view_955); permute_477 = None + permute_479 = torch.ops.aten.permute.default(permute_309, [1, 0]); permute_309 = None + mm_280 = torch.ops.aten.mm.default(view_1187, permute_479); view_1187 = permute_479 = None + view_1188 = torch.ops.aten.view.default(mm_280, [2, 8192, 4096]); mm_280 = None + add_154 = torch.ops.aten.add.Tensor(view_1186, view_1188); view_1186 = view_1188 = None + convert_element_type_1278 = torch.ops.prims.convert_element_type.default(mm_279, torch.float32); mm_279 = None + reduce_scatter_tensor_35 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1278, 'avg', 256, '0'); convert_element_type_1278 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35); reduce_scatter_tensor_35 = None + view_1189 = torch.ops.aten.view.default(view_1184, [16384, 4096]); view_1184 = None + permute_481 = torch.ops.aten.permute.default(view_1189, [1, 0]) + mm_281 = torch.ops.aten.mm.default(permute_481, view_955); permute_481 = view_955 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16); primals_257 = None + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 256, '0'); convert_element_type_928 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + permute_483 = torch.ops.aten.permute.default(permute_308, [1, 0]); permute_308 = None + mm_282 = torch.ops.aten.mm.default(view_1189, permute_483); view_1189 = permute_483 = None + view_1190 = torch.ops.aten.view.default(mm_282, [2, 8192, 4096]); mm_282 = None + add_155 = torch.ops.aten.add.Tensor(add_154, view_1190); add_154 = view_1190 = None + convert_element_type_1283 = torch.ops.prims.convert_element_type.default(mm_281, torch.float32); mm_281 = None + reduce_scatter_tensor_36 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1283, 'avg', 256, '0'); convert_element_type_1283 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_36); reduce_scatter_tensor_36 = None + convert_element_type_1284 = torch.ops.prims.convert_element_type.default(add_155, torch.float32); add_155 = None + convert_element_type_1286 = torch.ops.prims.convert_element_type.default(wait_tensor_253, torch.float32); wait_tensor_253 = None + mul_338 = torch.ops.aten.mul.Tensor(convert_element_type_1284, convert_element_type_1286); convert_element_type_1286 = None + mul_340 = torch.ops.aten.mul.Tensor(mul_224, mul_338) + sum_25 = torch.ops.aten.sum.dim_IntList(mul_340, [2], True); mul_340 = None + div_8 = torch.ops.aten.div.Tensor(mul_224, 4096) + mul_341 = torch.ops.aten.mul.Tensor(div_8, sum_25); div_8 = sum_25 = None + sub_12 = torch.ops.aten.sub.Tensor(mul_338, mul_341); mul_338 = mul_341 = None + mul_342 = torch.ops.aten.mul.Tensor(sub_12, rsqrt_56); sub_12 = rsqrt_56 = None + mul_343 = torch.ops.aten.mul.Tensor(convert_element_type_1284, mul_224); convert_element_type_1284 = mul_224 = None + sum_26 = torch.ops.aten.sum.dim_IntList(mul_343, [0, 1]); mul_343 = None + convert_element_type_1287 = torch.ops.prims.convert_element_type.default(mul_342, torch.bfloat16); mul_342 = None + add_156 = torch.ops.aten.add.Tensor(add_153, convert_element_type_1287); add_153 = convert_element_type_1287 = None + convert_element_type_default_57 = torch.ops.prims.convert_element_type.default(sum_26, torch.float32); sum_26 = None + reduce_scatter_tensor_37 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_57, 'avg', 256, '0'); convert_element_type_default_57 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37); reduce_scatter_tensor_37 = None + view_1191 = torch.ops.aten.view.default(add_156, [16384, 4096]) + permute_485 = torch.ops.aten.permute.default(view_1191, [1, 0]) + permute_303 = torch.ops.aten.permute.default(getitem_243, [0, 2, 1, 3]) + view_939 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16); primals_251 = None + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 256, '0'); convert_element_type_908 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_248, [1, 0]); wait_tensor_248 = None + view_941 = torch.ops.aten.view.default(view_939, [16384, 4096]); view_939 = None + mm_192 = torch.ops.aten.mm.default(view_941, permute_304) + view_942 = torch.ops.aten.view.default(mm_192, [2, 8192, 4096]); mm_192 = None + add_109 = torch.ops.aten.add.Tensor(add_107, view_942); view_942 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16); primals_252 = None + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 256, '0'); convert_element_type_911 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32); add_109 = None + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_249) + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + view_945 = torch.ops.aten.view.default(convert_element_type_913, [16384, 4096]); convert_element_type_913 = None + view_946 = torch.ops.aten.view.default(mm_193, [2, 8192, 14336]); mm_193 = None + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_946, torch.float32); view_946 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16); primals_254 = None + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 256, '0'); convert_element_type_919 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + mm_194 = torch.ops.aten.mm.default(view_945, permute_306) + view_949 = torch.ops.aten.view.default(mm_194, [2, 8192, 14336]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_949) + view_951 = torch.ops.aten.view.default(mul_223, [16384, 14336]); mul_223 = None + mm_283 = torch.ops.aten.mm.default(permute_485, view_951); permute_485 = view_951 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16); primals_255 = None + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 256, '0'); convert_element_type_922 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + permute_487 = torch.ops.aten.permute.default(permute_307, [1, 0]); permute_307 = None + mm_284 = torch.ops.aten.mm.default(view_1191, permute_487); view_1191 = permute_487 = None + view_1192 = torch.ops.aten.view.default(mm_284, [2, 8192, 14336]); mm_284 = None + convert_element_type_1294 = torch.ops.prims.convert_element_type.default(mm_283, torch.float32); mm_283 = None + reduce_scatter_tensor_38 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1294, 'avg', 256, '0'); convert_element_type_1294 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_38); reduce_scatter_tensor_38 = None + mul_344 = torch.ops.aten.mul.Tensor(view_1192, convert_element_type_918); convert_element_type_918 = None + mul_345 = torch.ops.aten.mul.Tensor(view_1192, view_949); view_1192 = view_949 = None + view_1193 = torch.ops.aten.view.default(mul_344, [16384, 14336]); mul_344 = None + permute_489 = torch.ops.aten.permute.default(view_1193, [1, 0]) + mm_285 = torch.ops.aten.mm.default(permute_489, view_945); permute_489 = None + permute_491 = torch.ops.aten.permute.default(permute_306, [1, 0]); permute_306 = None + mm_286 = torch.ops.aten.mm.default(view_1193, permute_491); view_1193 = permute_491 = None + view_1194 = torch.ops.aten.view.default(mm_286, [2, 8192, 4096]); mm_286 = None + convert_element_type_1299 = torch.ops.prims.convert_element_type.default(mm_285, torch.float32); mm_285 = None + reduce_scatter_tensor_39 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1299, 'avg', 256, '0'); convert_element_type_1299 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39); reduce_scatter_tensor_39 = None + convert_element_type_1300 = torch.ops.prims.convert_element_type.default(mul_345, torch.float32); mul_345 = None + neg_4 = torch.ops.aten.neg.default(convert_element_type_917) + exp_4 = torch.ops.aten.exp.default(neg_4); neg_4 = None + add_157 = torch.ops.aten.add.Tensor(exp_4, 1); exp_4 = None + reciprocal_4 = torch.ops.aten.reciprocal.default(add_157); add_157 = None + mul_346 = torch.ops.aten.mul.Tensor(reciprocal_4, 1); reciprocal_4 = None + mul_347 = torch.ops.aten.mul.Tensor(convert_element_type_1300, mul_346); convert_element_type_1300 = None + sub_13 = torch.ops.aten.sub.Tensor(1, mul_346); mul_346 = None + mul_348 = torch.ops.aten.mul.Tensor(convert_element_type_917, sub_13); convert_element_type_917 = sub_13 = None + add_158 = torch.ops.aten.add.Tensor(mul_348, 1); mul_348 = None + mul_349 = torch.ops.aten.mul.Tensor(mul_347, add_158); mul_347 = add_158 = None + convert_element_type_1302 = torch.ops.prims.convert_element_type.default(mul_349, torch.bfloat16); mul_349 = None + view_1195 = torch.ops.aten.view.default(convert_element_type_1302, [16384, 14336]); convert_element_type_1302 = None + permute_493 = torch.ops.aten.permute.default(view_1195, [1, 0]) + mm_287 = torch.ops.aten.mm.default(permute_493, view_945); permute_493 = view_945 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16); primals_253 = None + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 256, '0'); convert_element_type_914 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_250, [1, 0]); wait_tensor_250 = None + permute_495 = torch.ops.aten.permute.default(permute_305, [1, 0]); permute_305 = None + mm_288 = torch.ops.aten.mm.default(view_1195, permute_495); view_1195 = permute_495 = None + view_1196 = torch.ops.aten.view.default(mm_288, [2, 8192, 4096]); mm_288 = None + add_159 = torch.ops.aten.add.Tensor(view_1194, view_1196); view_1194 = view_1196 = None + convert_element_type_1307 = torch.ops.prims.convert_element_type.default(mm_287, torch.float32); mm_287 = None + reduce_scatter_tensor_40 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1307, 'avg', 256, '0'); convert_element_type_1307 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_40); reduce_scatter_tensor_40 = None + convert_element_type_1308 = torch.ops.prims.convert_element_type.default(add_159, torch.float32); add_159 = None + convert_element_type_1310 = torch.ops.prims.convert_element_type.default(wait_tensor_249, torch.float32); wait_tensor_249 = None + mul_350 = torch.ops.aten.mul.Tensor(convert_element_type_1308, convert_element_type_1310); convert_element_type_1310 = None + mul_352 = torch.ops.aten.mul.Tensor(mul_220, mul_350) + sum_27 = torch.ops.aten.sum.dim_IntList(mul_352, [2], True); mul_352 = None + div_9 = torch.ops.aten.div.Tensor(mul_220, 4096) + mul_353 = torch.ops.aten.mul.Tensor(div_9, sum_27); div_9 = sum_27 = None + sub_14 = torch.ops.aten.sub.Tensor(mul_350, mul_353); mul_350 = mul_353 = None + mul_354 = torch.ops.aten.mul.Tensor(sub_14, rsqrt_55); sub_14 = rsqrt_55 = None + mul_355 = torch.ops.aten.mul.Tensor(convert_element_type_1308, mul_220); convert_element_type_1308 = mul_220 = None + sum_28 = torch.ops.aten.sum.dim_IntList(mul_355, [0, 1]); mul_355 = None + convert_element_type_1311 = torch.ops.prims.convert_element_type.default(mul_354, torch.bfloat16); mul_354 = None + add_160 = torch.ops.aten.add.Tensor(add_156, convert_element_type_1311); add_156 = convert_element_type_1311 = None + convert_element_type_default_56 = torch.ops.prims.convert_element_type.default(sum_28, torch.float32); sum_28 = None + reduce_scatter_tensor_41 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_56, 'avg', 256, '0'); convert_element_type_default_56 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41); reduce_scatter_tensor_41 = None + view_1197 = torch.ops.aten.view.default(add_160, [16384, 4096]) + permute_497 = torch.ops.aten.permute.default(view_1197, [1, 0]) + mm_289 = torch.ops.aten.mm.default(permute_497, view_941); permute_497 = view_941 = None + permute_499 = torch.ops.aten.permute.default(permute_304, [1, 0]); permute_304 = None + mm_290 = torch.ops.aten.mm.default(view_1197, permute_499); view_1197 = permute_499 = None + view_1198 = torch.ops.aten.view.default(mm_290, [2, 8192, 4096]); mm_290 = None + convert_element_type_1318 = torch.ops.prims.convert_element_type.default(mm_289, torch.float32); mm_289 = None + reduce_scatter_tensor_42 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1318, 'avg', 256, '0'); convert_element_type_1318 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_42); reduce_scatter_tensor_42 = None + view_1199 = torch.ops.aten.view.default(view_1198, [2, 8192, 32, 128]); view_1198 = None + permute_501 = torch.ops.aten.permute.default(view_1199, [0, 2, 1, 3]); view_1199 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16); primals_247 = None + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 256, '0'); convert_element_type_892 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32); add_107 = None + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_244) + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + view_921 = torch.ops.aten.view.default(convert_element_type_894, [16384, 4096]); convert_element_type_894 = None + view_922 = torch.ops.aten.view.default(mm_189, [2, 8192, 4096]); mm_189 = None + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16); primals_249 = None + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 256, '0'); convert_element_type_898 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_190 = torch.ops.aten.mm.default(view_921, permute_298) + view_925 = torch.ops.aten.view.default(mm_190, [2, 8192, 1024]); mm_190 = None + view_928 = torch.ops.aten.view.default(mm_191, [2, 8192, 1024]); mm_191 = None + view_929 = torch.ops.aten.view.default(view_922, [2, 8192, -1, 128]); view_922 = None + view_930 = torch.ops.aten.view.default(view_925, [2, 8192, -1, 128]); view_925 = None + view_931 = torch.ops.aten.view.default(view_928, [2, 8192, -1, 128]); view_928 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_929, torch.float32); view_929 = None + view_932 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 32, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_932); view_932 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_930, torch.float32); view_930 = None + view_933 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 8, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_933); view_933 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_16); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_935 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 32, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_16); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_936 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 8, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_935, torch.bfloat16); view_935 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_936, torch.bfloat16); view_936 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 8, 4, 128]); unsqueeze_54 = None + clone_54 = torch.ops.aten.clone.default(expand_54, memory_format = torch.contiguous_format); expand_54 = None + view_937 = torch.ops.aten.view.default(clone_54, [2, 8192, 32, 128]); clone_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_931, 3); view_931 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 8, 4, 128]); unsqueeze_55 = None + clone_55 = torch.ops.aten.clone.default(expand_55, memory_format = torch.contiguous_format); expand_55 = None + view_938 = torch.ops.aten.view.default(clone_55, [2, 8192, 32, 128]); clone_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_937, [0, 2, 1, 3]); view_937 = None + permute_302 = torch.ops.aten.permute.default(view_938, [0, 2, 1, 3]); view_938 = None + _scaled_dot_product_cudnn_attention_backward_4 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_501, permute_300, permute_301, permute_302, getitem_243, getitem_244, getitem_249, getitem_250, None, None, None, 8192, 8192, 0.0, True); permute_501 = permute_300 = permute_301 = permute_302 = getitem_243 = getitem_244 = getitem_249 = getitem_250 = None + getitem_300 = _scaled_dot_product_cudnn_attention_backward_4[0] + getitem_301 = _scaled_dot_product_cudnn_attention_backward_4[1] + getitem_302 = _scaled_dot_product_cudnn_attention_backward_4[2]; _scaled_dot_product_cudnn_attention_backward_4 = None + permute_502 = torch.ops.aten.permute.default(getitem_302, [0, 2, 1, 3]); getitem_302 = None + permute_503 = torch.ops.aten.permute.default(getitem_301, [0, 2, 1, 3]); getitem_301 = None + permute_504 = torch.ops.aten.permute.default(getitem_300, [0, 2, 1, 3]); getitem_300 = None + view_1200 = torch.ops.aten.view.default(permute_502, [2, 8192, 8, 4, 128]); permute_502 = None + sum_29 = torch.ops.aten.sum.dim_IntList(view_1200, [3], True); view_1200 = None + squeeze_8 = torch.ops.aten.squeeze.dim(sum_29, 3); sum_29 = None + view_1201 = torch.ops.aten.view.default(permute_503, [2, 8192, 8, 4, 128]); permute_503 = None + sum_30 = torch.ops.aten.sum.dim_IntList(view_1201, [3], True); view_1201 = None + squeeze_9 = torch.ops.aten.squeeze.dim(sum_30, 3); sum_30 = None + convert_element_type_1319 = torch.ops.prims.convert_element_type.default(squeeze_9, torch.float32); squeeze_9 = None + convert_element_type_1320 = torch.ops.prims.convert_element_type.default(permute_504, torch.float32); permute_504 = None + view_1202 = torch.ops.aten.view.default(convert_element_type_1319, [2, 8192, 8, 64, 2]); convert_element_type_1319 = None + view_as_complex_72 = torch.ops.aten.view_as_complex.default(view_1202); view_1202 = None + mul_356 = torch.ops.aten.mul.Tensor(view_as_complex_72, _conj); view_as_complex_72 = None + view_1203 = torch.ops.aten.view.default(convert_element_type_1320, [2, 8192, 32, 64, 2]); convert_element_type_1320 = None + view_as_complex_73 = torch.ops.aten.view_as_complex.default(view_1203); view_1203 = None + mul_357 = torch.ops.aten.mul.Tensor(view_as_complex_73, _conj); view_as_complex_73 = None + view_as_real_72 = torch.ops.aten.view_as_real.default(mul_356); mul_356 = None + view_1204 = torch.ops.aten.view.default(view_as_real_72, [2, 8192, 8, 128]); view_as_real_72 = None + convert_element_type_1321 = torch.ops.prims.convert_element_type.default(view_1204, torch.bfloat16); view_1204 = None + view_as_real_73 = torch.ops.aten.view_as_real.default(mul_357); mul_357 = None + view_1205 = torch.ops.aten.view.default(view_as_real_73, [2, 8192, 32, 128]); view_as_real_73 = None + convert_element_type_1322 = torch.ops.prims.convert_element_type.default(view_1205, torch.bfloat16); view_1205 = None + view_1206 = torch.ops.aten.view.default(squeeze_8, [2, 8192, 1024]); squeeze_8 = None + view_1207 = torch.ops.aten.view.default(convert_element_type_1321, [2, 8192, 1024]); convert_element_type_1321 = None + view_1208 = torch.ops.aten.view.default(convert_element_type_1322, [2, 8192, 4096]); convert_element_type_1322 = None + view_1209 = torch.ops.aten.view.default(view_1206, [16384, 1024]); view_1206 = None + permute_505 = torch.ops.aten.permute.default(view_1209, [1, 0]) + mm_291 = torch.ops.aten.mm.default(permute_505, view_921); permute_505 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16); primals_250 = None + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 256, '0'); convert_element_type_901 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + permute_507 = torch.ops.aten.permute.default(permute_299, [1, 0]); permute_299 = None + mm_292 = torch.ops.aten.mm.default(view_1209, permute_507); view_1209 = permute_507 = None + view_1210 = torch.ops.aten.view.default(mm_292, [2, 8192, 4096]); mm_292 = None + convert_element_type_1327 = torch.ops.prims.convert_element_type.default(mm_291, torch.float32); mm_291 = None + reduce_scatter_tensor_43 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1327, 'avg', 256, '0'); convert_element_type_1327 = None + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43); reduce_scatter_tensor_43 = None + view_1211 = torch.ops.aten.view.default(view_1207, [16384, 1024]); view_1207 = None + permute_509 = torch.ops.aten.permute.default(view_1211, [1, 0]) + mm_293 = torch.ops.aten.mm.default(permute_509, view_921); permute_509 = None + permute_511 = torch.ops.aten.permute.default(permute_298, [1, 0]); permute_298 = None + mm_294 = torch.ops.aten.mm.default(view_1211, permute_511); view_1211 = permute_511 = None + view_1212 = torch.ops.aten.view.default(mm_294, [2, 8192, 4096]); mm_294 = None + add_161 = torch.ops.aten.add.Tensor(view_1210, view_1212); view_1210 = view_1212 = None + convert_element_type_1332 = torch.ops.prims.convert_element_type.default(mm_293, torch.float32); mm_293 = None + reduce_scatter_tensor_44 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1332, 'avg', 256, '0'); convert_element_type_1332 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_44); reduce_scatter_tensor_44 = None + view_1213 = torch.ops.aten.view.default(view_1208, [16384, 4096]); view_1208 = None + permute_513 = torch.ops.aten.permute.default(view_1213, [1, 0]) + mm_295 = torch.ops.aten.mm.default(permute_513, view_921); permute_513 = view_921 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16); primals_248 = None + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 256, '0'); convert_element_type_895 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + permute_515 = torch.ops.aten.permute.default(permute_297, [1, 0]); permute_297 = None + mm_296 = torch.ops.aten.mm.default(view_1213, permute_515); view_1213 = permute_515 = None + view_1214 = torch.ops.aten.view.default(mm_296, [2, 8192, 4096]); mm_296 = None + add_162 = torch.ops.aten.add.Tensor(add_161, view_1214); add_161 = view_1214 = None + convert_element_type_1337 = torch.ops.prims.convert_element_type.default(mm_295, torch.float32); mm_295 = None + reduce_scatter_tensor_45 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1337, 'avg', 256, '0'); convert_element_type_1337 = None + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45); reduce_scatter_tensor_45 = None + convert_element_type_1338 = torch.ops.prims.convert_element_type.default(add_162, torch.float32); add_162 = None + convert_element_type_1340 = torch.ops.prims.convert_element_type.default(wait_tensor_244, torch.float32); wait_tensor_244 = None + mul_358 = torch.ops.aten.mul.Tensor(convert_element_type_1338, convert_element_type_1340); convert_element_type_1340 = None + mul_360 = torch.ops.aten.mul.Tensor(mul_216, mul_358) + sum_31 = torch.ops.aten.sum.dim_IntList(mul_360, [2], True); mul_360 = None + div_10 = torch.ops.aten.div.Tensor(mul_216, 4096) + mul_361 = torch.ops.aten.mul.Tensor(div_10, sum_31); div_10 = sum_31 = None + sub_15 = torch.ops.aten.sub.Tensor(mul_358, mul_361); mul_358 = mul_361 = None + mul_362 = torch.ops.aten.mul.Tensor(sub_15, rsqrt_54); sub_15 = rsqrt_54 = None + mul_363 = torch.ops.aten.mul.Tensor(convert_element_type_1338, mul_216); convert_element_type_1338 = mul_216 = None + sum_32 = torch.ops.aten.sum.dim_IntList(mul_363, [0, 1]); mul_363 = None + convert_element_type_1341 = torch.ops.prims.convert_element_type.default(mul_362, torch.bfloat16); mul_362 = None + add_163 = torch.ops.aten.add.Tensor(add_160, convert_element_type_1341); add_160 = convert_element_type_1341 = None + convert_element_type_default_55 = torch.ops.prims.convert_element_type.default(sum_32, torch.float32); sum_32 = None + reduce_scatter_tensor_46 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_55, 'avg', 256, '0'); convert_element_type_default_55 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_46); reduce_scatter_tensor_46 = None + view_1215 = torch.ops.aten.view.default(add_163, [16384, 4096]) + permute_517 = torch.ops.aten.permute.default(view_1215, [1, 0]) + permute_292 = torch.ops.aten.permute.default(getitem_234, [0, 2, 1, 3]) + view_905 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16); primals_242 = None + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 256, '0'); convert_element_type_875 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + view_907 = torch.ops.aten.view.default(view_905, [16384, 4096]); view_905 = None + mm_185 = torch.ops.aten.mm.default(view_907, permute_293) + view_908 = torch.ops.aten.view.default(mm_185, [2, 8192, 4096]); mm_185 = None + add_105 = torch.ops.aten.add.Tensor(add_103, view_908); view_908 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16); primals_243 = None + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 256, '0'); convert_element_type_878 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32); add_105 = None + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_240) + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + view_911 = torch.ops.aten.view.default(convert_element_type_880, [16384, 4096]); convert_element_type_880 = None + view_912 = torch.ops.aten.view.default(mm_186, [2, 8192, 14336]); mm_186 = None + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_912, torch.float32); view_912 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16); primals_245 = None + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 256, '0'); convert_element_type_886 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_242, [1, 0]); wait_tensor_242 = None + mm_187 = torch.ops.aten.mm.default(view_911, permute_295) + view_915 = torch.ops.aten.view.default(mm_187, [2, 8192, 14336]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_915) + view_917 = torch.ops.aten.view.default(mul_215, [16384, 14336]); mul_215 = None + mm_297 = torch.ops.aten.mm.default(permute_517, view_917); permute_517 = view_917 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16); primals_246 = None + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 256, '0'); convert_element_type_889 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_243, [1, 0]); wait_tensor_243 = None + permute_519 = torch.ops.aten.permute.default(permute_296, [1, 0]); permute_296 = None + mm_298 = torch.ops.aten.mm.default(view_1215, permute_519); view_1215 = permute_519 = None + view_1216 = torch.ops.aten.view.default(mm_298, [2, 8192, 14336]); mm_298 = None + convert_element_type_1348 = torch.ops.prims.convert_element_type.default(mm_297, torch.float32); mm_297 = None + reduce_scatter_tensor_47 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1348, 'avg', 256, '0'); convert_element_type_1348 = None + wait_tensor_338 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47); reduce_scatter_tensor_47 = None + mul_364 = torch.ops.aten.mul.Tensor(view_1216, convert_element_type_885); convert_element_type_885 = None + mul_365 = torch.ops.aten.mul.Tensor(view_1216, view_915); view_1216 = view_915 = None + view_1217 = torch.ops.aten.view.default(mul_364, [16384, 14336]); mul_364 = None + permute_521 = torch.ops.aten.permute.default(view_1217, [1, 0]) + mm_299 = torch.ops.aten.mm.default(permute_521, view_911); permute_521 = None + permute_523 = torch.ops.aten.permute.default(permute_295, [1, 0]); permute_295 = None + mm_300 = torch.ops.aten.mm.default(view_1217, permute_523); view_1217 = permute_523 = None + view_1218 = torch.ops.aten.view.default(mm_300, [2, 8192, 4096]); mm_300 = None + convert_element_type_1353 = torch.ops.prims.convert_element_type.default(mm_299, torch.float32); mm_299 = None + reduce_scatter_tensor_48 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1353, 'avg', 256, '0'); convert_element_type_1353 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_48); reduce_scatter_tensor_48 = None + convert_element_type_1354 = torch.ops.prims.convert_element_type.default(mul_365, torch.float32); mul_365 = None + neg_5 = torch.ops.aten.neg.default(convert_element_type_884) + exp_5 = torch.ops.aten.exp.default(neg_5); neg_5 = None + add_164 = torch.ops.aten.add.Tensor(exp_5, 1); exp_5 = None + reciprocal_5 = torch.ops.aten.reciprocal.default(add_164); add_164 = None + mul_366 = torch.ops.aten.mul.Tensor(reciprocal_5, 1); reciprocal_5 = None + mul_367 = torch.ops.aten.mul.Tensor(convert_element_type_1354, mul_366); convert_element_type_1354 = None + sub_16 = torch.ops.aten.sub.Tensor(1, mul_366); mul_366 = None + mul_368 = torch.ops.aten.mul.Tensor(convert_element_type_884, sub_16); convert_element_type_884 = sub_16 = None + add_165 = torch.ops.aten.add.Tensor(mul_368, 1); mul_368 = None + mul_369 = torch.ops.aten.mul.Tensor(mul_367, add_165); mul_367 = add_165 = None + convert_element_type_1356 = torch.ops.prims.convert_element_type.default(mul_369, torch.bfloat16); mul_369 = None + view_1219 = torch.ops.aten.view.default(convert_element_type_1356, [16384, 14336]); convert_element_type_1356 = None + permute_525 = torch.ops.aten.permute.default(view_1219, [1, 0]) + mm_301 = torch.ops.aten.mm.default(permute_525, view_911); permute_525 = view_911 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16); primals_244 = None + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 256, '0'); convert_element_type_881 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + permute_527 = torch.ops.aten.permute.default(permute_294, [1, 0]); permute_294 = None + mm_302 = torch.ops.aten.mm.default(view_1219, permute_527); view_1219 = permute_527 = None + view_1220 = torch.ops.aten.view.default(mm_302, [2, 8192, 4096]); mm_302 = None + add_166 = torch.ops.aten.add.Tensor(view_1218, view_1220); view_1218 = view_1220 = None + convert_element_type_1361 = torch.ops.prims.convert_element_type.default(mm_301, torch.float32); mm_301 = None + reduce_scatter_tensor_49 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1361, 'avg', 256, '0'); convert_element_type_1361 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49); reduce_scatter_tensor_49 = None + convert_element_type_1362 = torch.ops.prims.convert_element_type.default(add_166, torch.float32); add_166 = None + convert_element_type_1364 = torch.ops.prims.convert_element_type.default(wait_tensor_240, torch.float32); wait_tensor_240 = None + mul_370 = torch.ops.aten.mul.Tensor(convert_element_type_1362, convert_element_type_1364); convert_element_type_1364 = None + mul_372 = torch.ops.aten.mul.Tensor(mul_212, mul_370) + sum_33 = torch.ops.aten.sum.dim_IntList(mul_372, [2], True); mul_372 = None + div_11 = torch.ops.aten.div.Tensor(mul_212, 4096) + mul_373 = torch.ops.aten.mul.Tensor(div_11, sum_33); div_11 = sum_33 = None + sub_17 = torch.ops.aten.sub.Tensor(mul_370, mul_373); mul_370 = mul_373 = None + mul_374 = torch.ops.aten.mul.Tensor(sub_17, rsqrt_53); sub_17 = rsqrt_53 = None + mul_375 = torch.ops.aten.mul.Tensor(convert_element_type_1362, mul_212); convert_element_type_1362 = mul_212 = None + sum_34 = torch.ops.aten.sum.dim_IntList(mul_375, [0, 1]); mul_375 = None + convert_element_type_1365 = torch.ops.prims.convert_element_type.default(mul_374, torch.bfloat16); mul_374 = None + add_167 = torch.ops.aten.add.Tensor(add_163, convert_element_type_1365); add_163 = convert_element_type_1365 = None + convert_element_type_default_54 = torch.ops.prims.convert_element_type.default(sum_34, torch.float32); sum_34 = None + reduce_scatter_tensor_50 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_54, 'avg', 256, '0'); convert_element_type_default_54 = None + wait_tensor_341 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_50); reduce_scatter_tensor_50 = None + view_1221 = torch.ops.aten.view.default(add_167, [16384, 4096]) + permute_529 = torch.ops.aten.permute.default(view_1221, [1, 0]) + mm_303 = torch.ops.aten.mm.default(permute_529, view_907); permute_529 = view_907 = None + permute_531 = torch.ops.aten.permute.default(permute_293, [1, 0]); permute_293 = None + mm_304 = torch.ops.aten.mm.default(view_1221, permute_531); view_1221 = permute_531 = None + view_1222 = torch.ops.aten.view.default(mm_304, [2, 8192, 4096]); mm_304 = None + convert_element_type_1372 = torch.ops.prims.convert_element_type.default(mm_303, torch.float32); mm_303 = None + reduce_scatter_tensor_51 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1372, 'avg', 256, '0'); convert_element_type_1372 = None + wait_tensor_342 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51); reduce_scatter_tensor_51 = None + view_1223 = torch.ops.aten.view.default(view_1222, [2, 8192, 32, 128]); view_1222 = None + permute_533 = torch.ops.aten.permute.default(view_1223, [0, 2, 1, 3]); view_1223 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16); primals_238 = None + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 256, '0'); convert_element_type_859 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32); add_103 = None + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_235) + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + view_887 = torch.ops.aten.view.default(convert_element_type_861, [16384, 4096]); convert_element_type_861 = None + view_888 = torch.ops.aten.view.default(mm_182, [2, 8192, 4096]); mm_182 = None + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16); primals_240 = None + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 256, '0'); convert_element_type_865 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_237, [1, 0]); wait_tensor_237 = None + mm_183 = torch.ops.aten.mm.default(view_887, permute_287) + view_891 = torch.ops.aten.view.default(mm_183, [2, 8192, 1024]); mm_183 = None + view_894 = torch.ops.aten.view.default(mm_184, [2, 8192, 1024]); mm_184 = None + view_895 = torch.ops.aten.view.default(view_888, [2, 8192, -1, 128]); view_888 = None + view_896 = torch.ops.aten.view.default(view_891, [2, 8192, -1, 128]); view_891 = None + view_897 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_895, torch.float32); view_895 = None + view_898 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 32, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_898); view_898 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 8, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_16); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_901 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 32, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_16); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_902 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 8, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_901, torch.bfloat16); view_901 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 8, 4, 128]); unsqueeze_52 = None + clone_52 = torch.ops.aten.clone.default(expand_52, memory_format = torch.contiguous_format); expand_52 = None + view_903 = torch.ops.aten.view.default(clone_52, [2, 8192, 32, 128]); clone_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_897, 3); view_897 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 8, 4, 128]); unsqueeze_53 = None + clone_53 = torch.ops.aten.clone.default(expand_53, memory_format = torch.contiguous_format); expand_53 = None + view_904 = torch.ops.aten.view.default(clone_53, [2, 8192, 32, 128]); clone_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_903, [0, 2, 1, 3]); view_903 = None + permute_291 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + _scaled_dot_product_cudnn_attention_backward_5 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_533, permute_289, permute_290, permute_291, getitem_234, getitem_235, getitem_240, getitem_241, None, None, None, 8192, 8192, 0.0, True); permute_533 = permute_289 = permute_290 = permute_291 = getitem_234 = getitem_235 = getitem_240 = getitem_241 = None + getitem_303 = _scaled_dot_product_cudnn_attention_backward_5[0] + getitem_304 = _scaled_dot_product_cudnn_attention_backward_5[1] + getitem_305 = _scaled_dot_product_cudnn_attention_backward_5[2]; _scaled_dot_product_cudnn_attention_backward_5 = None + permute_534 = torch.ops.aten.permute.default(getitem_305, [0, 2, 1, 3]); getitem_305 = None + permute_535 = torch.ops.aten.permute.default(getitem_304, [0, 2, 1, 3]); getitem_304 = None + permute_536 = torch.ops.aten.permute.default(getitem_303, [0, 2, 1, 3]); getitem_303 = None + view_1224 = torch.ops.aten.view.default(permute_534, [2, 8192, 8, 4, 128]); permute_534 = None + sum_35 = torch.ops.aten.sum.dim_IntList(view_1224, [3], True); view_1224 = None + squeeze_10 = torch.ops.aten.squeeze.dim(sum_35, 3); sum_35 = None + view_1225 = torch.ops.aten.view.default(permute_535, [2, 8192, 8, 4, 128]); permute_535 = None + sum_36 = torch.ops.aten.sum.dim_IntList(view_1225, [3], True); view_1225 = None + squeeze_11 = torch.ops.aten.squeeze.dim(sum_36, 3); sum_36 = None + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(squeeze_11, torch.float32); squeeze_11 = None + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(permute_536, torch.float32); permute_536 = None + view_1226 = torch.ops.aten.view.default(convert_element_type_1373, [2, 8192, 8, 64, 2]); convert_element_type_1373 = None + view_as_complex_74 = torch.ops.aten.view_as_complex.default(view_1226); view_1226 = None + mul_376 = torch.ops.aten.mul.Tensor(view_as_complex_74, _conj); view_as_complex_74 = None + view_1227 = torch.ops.aten.view.default(convert_element_type_1374, [2, 8192, 32, 64, 2]); convert_element_type_1374 = None + view_as_complex_75 = torch.ops.aten.view_as_complex.default(view_1227); view_1227 = None + mul_377 = torch.ops.aten.mul.Tensor(view_as_complex_75, _conj); view_as_complex_75 = None + view_as_real_74 = torch.ops.aten.view_as_real.default(mul_376); mul_376 = None + view_1228 = torch.ops.aten.view.default(view_as_real_74, [2, 8192, 8, 128]); view_as_real_74 = None + convert_element_type_1375 = torch.ops.prims.convert_element_type.default(view_1228, torch.bfloat16); view_1228 = None + view_as_real_75 = torch.ops.aten.view_as_real.default(mul_377); mul_377 = None + view_1229 = torch.ops.aten.view.default(view_as_real_75, [2, 8192, 32, 128]); view_as_real_75 = None + convert_element_type_1376 = torch.ops.prims.convert_element_type.default(view_1229, torch.bfloat16); view_1229 = None + view_1230 = torch.ops.aten.view.default(squeeze_10, [2, 8192, 1024]); squeeze_10 = None + view_1231 = torch.ops.aten.view.default(convert_element_type_1375, [2, 8192, 1024]); convert_element_type_1375 = None + view_1232 = torch.ops.aten.view.default(convert_element_type_1376, [2, 8192, 4096]); convert_element_type_1376 = None + view_1233 = torch.ops.aten.view.default(view_1230, [16384, 1024]); view_1230 = None + permute_537 = torch.ops.aten.permute.default(view_1233, [1, 0]) + mm_305 = torch.ops.aten.mm.default(permute_537, view_887); permute_537 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16); primals_241 = None + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 256, '0'); convert_element_type_868 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + permute_539 = torch.ops.aten.permute.default(permute_288, [1, 0]); permute_288 = None + mm_306 = torch.ops.aten.mm.default(view_1233, permute_539); view_1233 = permute_539 = None + view_1234 = torch.ops.aten.view.default(mm_306, [2, 8192, 4096]); mm_306 = None + convert_element_type_1381 = torch.ops.prims.convert_element_type.default(mm_305, torch.float32); mm_305 = None + reduce_scatter_tensor_52 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1381, 'avg', 256, '0'); convert_element_type_1381 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_52); reduce_scatter_tensor_52 = None + view_1235 = torch.ops.aten.view.default(view_1231, [16384, 1024]); view_1231 = None + permute_541 = torch.ops.aten.permute.default(view_1235, [1, 0]) + mm_307 = torch.ops.aten.mm.default(permute_541, view_887); permute_541 = None + permute_543 = torch.ops.aten.permute.default(permute_287, [1, 0]); permute_287 = None + mm_308 = torch.ops.aten.mm.default(view_1235, permute_543); view_1235 = permute_543 = None + view_1236 = torch.ops.aten.view.default(mm_308, [2, 8192, 4096]); mm_308 = None + add_168 = torch.ops.aten.add.Tensor(view_1234, view_1236); view_1234 = view_1236 = None + convert_element_type_1386 = torch.ops.prims.convert_element_type.default(mm_307, torch.float32); mm_307 = None + reduce_scatter_tensor_53 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1386, 'avg', 256, '0'); convert_element_type_1386 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53); reduce_scatter_tensor_53 = None + view_1237 = torch.ops.aten.view.default(view_1232, [16384, 4096]); view_1232 = None + permute_545 = torch.ops.aten.permute.default(view_1237, [1, 0]) + mm_309 = torch.ops.aten.mm.default(permute_545, view_887); permute_545 = view_887 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16); primals_239 = None + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 256, '0'); convert_element_type_862 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_236, [1, 0]); wait_tensor_236 = None + permute_547 = torch.ops.aten.permute.default(permute_286, [1, 0]); permute_286 = None + mm_310 = torch.ops.aten.mm.default(view_1237, permute_547); view_1237 = permute_547 = None + view_1238 = torch.ops.aten.view.default(mm_310, [2, 8192, 4096]); mm_310 = None + add_169 = torch.ops.aten.add.Tensor(add_168, view_1238); add_168 = view_1238 = None + convert_element_type_1391 = torch.ops.prims.convert_element_type.default(mm_309, torch.float32); mm_309 = None + reduce_scatter_tensor_54 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1391, 'avg', 256, '0'); convert_element_type_1391 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_54); reduce_scatter_tensor_54 = None + convert_element_type_1392 = torch.ops.prims.convert_element_type.default(add_169, torch.float32); add_169 = None + convert_element_type_1394 = torch.ops.prims.convert_element_type.default(wait_tensor_235, torch.float32); wait_tensor_235 = None + mul_378 = torch.ops.aten.mul.Tensor(convert_element_type_1392, convert_element_type_1394); convert_element_type_1394 = None + mul_380 = torch.ops.aten.mul.Tensor(mul_208, mul_378) + sum_37 = torch.ops.aten.sum.dim_IntList(mul_380, [2], True); mul_380 = None + div_12 = torch.ops.aten.div.Tensor(mul_208, 4096) + mul_381 = torch.ops.aten.mul.Tensor(div_12, sum_37); div_12 = sum_37 = None + sub_18 = torch.ops.aten.sub.Tensor(mul_378, mul_381); mul_378 = mul_381 = None + mul_382 = torch.ops.aten.mul.Tensor(sub_18, rsqrt_52); sub_18 = rsqrt_52 = None + mul_383 = torch.ops.aten.mul.Tensor(convert_element_type_1392, mul_208); convert_element_type_1392 = mul_208 = None + sum_38 = torch.ops.aten.sum.dim_IntList(mul_383, [0, 1]); mul_383 = None + convert_element_type_1395 = torch.ops.prims.convert_element_type.default(mul_382, torch.bfloat16); mul_382 = None + add_170 = torch.ops.aten.add.Tensor(add_167, convert_element_type_1395); add_167 = convert_element_type_1395 = None + convert_element_type_default_53 = torch.ops.prims.convert_element_type.default(sum_38, torch.float32); sum_38 = None + reduce_scatter_tensor_55 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_53, 'avg', 256, '0'); convert_element_type_default_53 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55); reduce_scatter_tensor_55 = None + view_1239 = torch.ops.aten.view.default(add_170, [16384, 4096]) + permute_549 = torch.ops.aten.permute.default(view_1239, [1, 0]) + permute_281 = torch.ops.aten.permute.default(getitem_225, [0, 2, 1, 3]) + view_871 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16); primals_233 = None + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 256, '0'); convert_element_type_842 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_230, [1, 0]); wait_tensor_230 = None + view_873 = torch.ops.aten.view.default(view_871, [16384, 4096]); view_871 = None + mm_178 = torch.ops.aten.mm.default(view_873, permute_282) + view_874 = torch.ops.aten.view.default(mm_178, [2, 8192, 4096]); mm_178 = None + add_101 = torch.ops.aten.add.Tensor(add_99, view_874); view_874 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16); primals_234 = None + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 256, '0'); convert_element_type_845 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32); add_101 = None + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_231) + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + view_877 = torch.ops.aten.view.default(convert_element_type_847, [16384, 4096]); convert_element_type_847 = None + view_878 = torch.ops.aten.view.default(mm_179, [2, 8192, 14336]); mm_179 = None + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_878, torch.float32); view_878 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16); primals_236 = None + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 256, '0'); convert_element_type_853 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_180 = torch.ops.aten.mm.default(view_877, permute_284) + view_881 = torch.ops.aten.view.default(mm_180, [2, 8192, 14336]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_881) + view_883 = torch.ops.aten.view.default(mul_207, [16384, 14336]); mul_207 = None + mm_311 = torch.ops.aten.mm.default(permute_549, view_883); permute_549 = view_883 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16); primals_237 = None + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 256, '0'); convert_element_type_856 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + permute_551 = torch.ops.aten.permute.default(permute_285, [1, 0]); permute_285 = None + mm_312 = torch.ops.aten.mm.default(view_1239, permute_551); view_1239 = permute_551 = None + view_1240 = torch.ops.aten.view.default(mm_312, [2, 8192, 14336]); mm_312 = None + convert_element_type_1402 = torch.ops.prims.convert_element_type.default(mm_311, torch.float32); mm_311 = None + reduce_scatter_tensor_56 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1402, 'avg', 256, '0'); convert_element_type_1402 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_56); reduce_scatter_tensor_56 = None + mul_384 = torch.ops.aten.mul.Tensor(view_1240, convert_element_type_852); convert_element_type_852 = None + mul_385 = torch.ops.aten.mul.Tensor(view_1240, view_881); view_1240 = view_881 = None + view_1241 = torch.ops.aten.view.default(mul_384, [16384, 14336]); mul_384 = None + permute_553 = torch.ops.aten.permute.default(view_1241, [1, 0]) + mm_313 = torch.ops.aten.mm.default(permute_553, view_877); permute_553 = None + permute_555 = torch.ops.aten.permute.default(permute_284, [1, 0]); permute_284 = None + mm_314 = torch.ops.aten.mm.default(view_1241, permute_555); view_1241 = permute_555 = None + view_1242 = torch.ops.aten.view.default(mm_314, [2, 8192, 4096]); mm_314 = None + convert_element_type_1407 = torch.ops.prims.convert_element_type.default(mm_313, torch.float32); mm_313 = None + reduce_scatter_tensor_57 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1407, 'avg', 256, '0'); convert_element_type_1407 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57); reduce_scatter_tensor_57 = None + convert_element_type_1408 = torch.ops.prims.convert_element_type.default(mul_385, torch.float32); mul_385 = None + neg_6 = torch.ops.aten.neg.default(convert_element_type_851) + exp_6 = torch.ops.aten.exp.default(neg_6); neg_6 = None + add_171 = torch.ops.aten.add.Tensor(exp_6, 1); exp_6 = None + reciprocal_6 = torch.ops.aten.reciprocal.default(add_171); add_171 = None + mul_386 = torch.ops.aten.mul.Tensor(reciprocal_6, 1); reciprocal_6 = None + mul_387 = torch.ops.aten.mul.Tensor(convert_element_type_1408, mul_386); convert_element_type_1408 = None + sub_19 = torch.ops.aten.sub.Tensor(1, mul_386); mul_386 = None + mul_388 = torch.ops.aten.mul.Tensor(convert_element_type_851, sub_19); convert_element_type_851 = sub_19 = None + add_172 = torch.ops.aten.add.Tensor(mul_388, 1); mul_388 = None + mul_389 = torch.ops.aten.mul.Tensor(mul_387, add_172); mul_387 = add_172 = None + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(mul_389, torch.bfloat16); mul_389 = None + view_1243 = torch.ops.aten.view.default(convert_element_type_1410, [16384, 14336]); convert_element_type_1410 = None + permute_557 = torch.ops.aten.permute.default(view_1243, [1, 0]) + mm_315 = torch.ops.aten.mm.default(permute_557, view_877); permute_557 = view_877 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16); primals_235 = None + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 256, '0'); convert_element_type_848 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + permute_559 = torch.ops.aten.permute.default(permute_283, [1, 0]); permute_283 = None + mm_316 = torch.ops.aten.mm.default(view_1243, permute_559); view_1243 = permute_559 = None + view_1244 = torch.ops.aten.view.default(mm_316, [2, 8192, 4096]); mm_316 = None + add_173 = torch.ops.aten.add.Tensor(view_1242, view_1244); view_1242 = view_1244 = None + convert_element_type_1415 = torch.ops.prims.convert_element_type.default(mm_315, torch.float32); mm_315 = None + reduce_scatter_tensor_58 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1415, 'avg', 256, '0'); convert_element_type_1415 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_58); reduce_scatter_tensor_58 = None + convert_element_type_1416 = torch.ops.prims.convert_element_type.default(add_173, torch.float32); add_173 = None + convert_element_type_1418 = torch.ops.prims.convert_element_type.default(wait_tensor_231, torch.float32); wait_tensor_231 = None + mul_390 = torch.ops.aten.mul.Tensor(convert_element_type_1416, convert_element_type_1418); convert_element_type_1418 = None + mul_392 = torch.ops.aten.mul.Tensor(mul_204, mul_390) + sum_39 = torch.ops.aten.sum.dim_IntList(mul_392, [2], True); mul_392 = None + div_13 = torch.ops.aten.div.Tensor(mul_204, 4096) + mul_393 = torch.ops.aten.mul.Tensor(div_13, sum_39); div_13 = sum_39 = None + sub_20 = torch.ops.aten.sub.Tensor(mul_390, mul_393); mul_390 = mul_393 = None + mul_394 = torch.ops.aten.mul.Tensor(sub_20, rsqrt_51); sub_20 = rsqrt_51 = None + mul_395 = torch.ops.aten.mul.Tensor(convert_element_type_1416, mul_204); convert_element_type_1416 = mul_204 = None + sum_40 = torch.ops.aten.sum.dim_IntList(mul_395, [0, 1]); mul_395 = None + convert_element_type_1419 = torch.ops.prims.convert_element_type.default(mul_394, torch.bfloat16); mul_394 = None + add_174 = torch.ops.aten.add.Tensor(add_170, convert_element_type_1419); add_170 = convert_element_type_1419 = None + convert_element_type_default_52 = torch.ops.prims.convert_element_type.default(sum_40, torch.float32); sum_40 = None + reduce_scatter_tensor_59 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_52, 'avg', 256, '0'); convert_element_type_default_52 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59); reduce_scatter_tensor_59 = None + view_1245 = torch.ops.aten.view.default(add_174, [16384, 4096]) + permute_561 = torch.ops.aten.permute.default(view_1245, [1, 0]) + mm_317 = torch.ops.aten.mm.default(permute_561, view_873); permute_561 = view_873 = None + permute_563 = torch.ops.aten.permute.default(permute_282, [1, 0]); permute_282 = None + mm_318 = torch.ops.aten.mm.default(view_1245, permute_563); view_1245 = permute_563 = None + view_1246 = torch.ops.aten.view.default(mm_318, [2, 8192, 4096]); mm_318 = None + convert_element_type_1426 = torch.ops.prims.convert_element_type.default(mm_317, torch.float32); mm_317 = None + reduce_scatter_tensor_60 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1426, 'avg', 256, '0'); convert_element_type_1426 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_60); reduce_scatter_tensor_60 = None + view_1247 = torch.ops.aten.view.default(view_1246, [2, 8192, 32, 128]); view_1246 = None + permute_565 = torch.ops.aten.permute.default(view_1247, [0, 2, 1, 3]); view_1247 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16); primals_229 = None + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 256, '0'); convert_element_type_826 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32); add_99 = None + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_226) + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + view_853 = torch.ops.aten.view.default(convert_element_type_828, [16384, 4096]); convert_element_type_828 = None + view_854 = torch.ops.aten.view.default(mm_175, [2, 8192, 4096]); mm_175 = None + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16); primals_231 = None + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 256, '0'); convert_element_type_832 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + mm_176 = torch.ops.aten.mm.default(view_853, permute_276) + view_857 = torch.ops.aten.view.default(mm_176, [2, 8192, 1024]); mm_176 = None + view_860 = torch.ops.aten.view.default(mm_177, [2, 8192, 1024]); mm_177 = None + view_861 = torch.ops.aten.view.default(view_854, [2, 8192, -1, 128]); view_854 = None + view_862 = torch.ops.aten.view.default(view_857, [2, 8192, -1, 128]); view_857 = None + view_863 = torch.ops.aten.view.default(view_860, [2, 8192, -1, 128]); view_860 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_861, torch.float32); view_861 = None + view_864 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 32, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_864); view_864 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_862, torch.float32); view_862 = None + view_865 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 8, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_865); view_865 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_16); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_867 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 32, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_16); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_868 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 8, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_867, torch.bfloat16); view_867 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_868, torch.bfloat16); view_868 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 8, 4, 128]); unsqueeze_50 = None + clone_50 = torch.ops.aten.clone.default(expand_50, memory_format = torch.contiguous_format); expand_50 = None + view_869 = torch.ops.aten.view.default(clone_50, [2, 8192, 32, 128]); clone_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_863, 3); view_863 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 8, 4, 128]); unsqueeze_51 = None + clone_51 = torch.ops.aten.clone.default(expand_51, memory_format = torch.contiguous_format); expand_51 = None + view_870 = torch.ops.aten.view.default(clone_51, [2, 8192, 32, 128]); clone_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_869, [0, 2, 1, 3]); view_869 = None + permute_280 = torch.ops.aten.permute.default(view_870, [0, 2, 1, 3]); view_870 = None + _scaled_dot_product_cudnn_attention_backward_6 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_565, permute_278, permute_279, permute_280, getitem_225, getitem_226, getitem_231, getitem_232, None, None, None, 8192, 8192, 0.0, True); permute_565 = permute_278 = permute_279 = permute_280 = getitem_225 = getitem_226 = getitem_231 = getitem_232 = None + getitem_306 = _scaled_dot_product_cudnn_attention_backward_6[0] + getitem_307 = _scaled_dot_product_cudnn_attention_backward_6[1] + getitem_308 = _scaled_dot_product_cudnn_attention_backward_6[2]; _scaled_dot_product_cudnn_attention_backward_6 = None + permute_566 = torch.ops.aten.permute.default(getitem_308, [0, 2, 1, 3]); getitem_308 = None + permute_567 = torch.ops.aten.permute.default(getitem_307, [0, 2, 1, 3]); getitem_307 = None + permute_568 = torch.ops.aten.permute.default(getitem_306, [0, 2, 1, 3]); getitem_306 = None + view_1248 = torch.ops.aten.view.default(permute_566, [2, 8192, 8, 4, 128]); permute_566 = None + sum_41 = torch.ops.aten.sum.dim_IntList(view_1248, [3], True); view_1248 = None + squeeze_12 = torch.ops.aten.squeeze.dim(sum_41, 3); sum_41 = None + view_1249 = torch.ops.aten.view.default(permute_567, [2, 8192, 8, 4, 128]); permute_567 = None + sum_42 = torch.ops.aten.sum.dim_IntList(view_1249, [3], True); view_1249 = None + squeeze_13 = torch.ops.aten.squeeze.dim(sum_42, 3); sum_42 = None + convert_element_type_1427 = torch.ops.prims.convert_element_type.default(squeeze_13, torch.float32); squeeze_13 = None + convert_element_type_1428 = torch.ops.prims.convert_element_type.default(permute_568, torch.float32); permute_568 = None + view_1250 = torch.ops.aten.view.default(convert_element_type_1427, [2, 8192, 8, 64, 2]); convert_element_type_1427 = None + view_as_complex_76 = torch.ops.aten.view_as_complex.default(view_1250); view_1250 = None + mul_396 = torch.ops.aten.mul.Tensor(view_as_complex_76, _conj); view_as_complex_76 = None + view_1251 = torch.ops.aten.view.default(convert_element_type_1428, [2, 8192, 32, 64, 2]); convert_element_type_1428 = None + view_as_complex_77 = torch.ops.aten.view_as_complex.default(view_1251); view_1251 = None + mul_397 = torch.ops.aten.mul.Tensor(view_as_complex_77, _conj); view_as_complex_77 = None + view_as_real_76 = torch.ops.aten.view_as_real.default(mul_396); mul_396 = None + view_1252 = torch.ops.aten.view.default(view_as_real_76, [2, 8192, 8, 128]); view_as_real_76 = None + convert_element_type_1429 = torch.ops.prims.convert_element_type.default(view_1252, torch.bfloat16); view_1252 = None + view_as_real_77 = torch.ops.aten.view_as_real.default(mul_397); mul_397 = None + view_1253 = torch.ops.aten.view.default(view_as_real_77, [2, 8192, 32, 128]); view_as_real_77 = None + convert_element_type_1430 = torch.ops.prims.convert_element_type.default(view_1253, torch.bfloat16); view_1253 = None + view_1254 = torch.ops.aten.view.default(squeeze_12, [2, 8192, 1024]); squeeze_12 = None + view_1255 = torch.ops.aten.view.default(convert_element_type_1429, [2, 8192, 1024]); convert_element_type_1429 = None + view_1256 = torch.ops.aten.view.default(convert_element_type_1430, [2, 8192, 4096]); convert_element_type_1430 = None + view_1257 = torch.ops.aten.view.default(view_1254, [16384, 1024]); view_1254 = None + permute_569 = torch.ops.aten.permute.default(view_1257, [1, 0]) + mm_319 = torch.ops.aten.mm.default(permute_569, view_853); permute_569 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16); primals_232 = None + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 256, '0'); convert_element_type_835 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_229, [1, 0]); wait_tensor_229 = None + permute_571 = torch.ops.aten.permute.default(permute_277, [1, 0]); permute_277 = None + mm_320 = torch.ops.aten.mm.default(view_1257, permute_571); view_1257 = permute_571 = None + view_1258 = torch.ops.aten.view.default(mm_320, [2, 8192, 4096]); mm_320 = None + convert_element_type_1435 = torch.ops.prims.convert_element_type.default(mm_319, torch.float32); mm_319 = None + reduce_scatter_tensor_61 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1435, 'avg', 256, '0'); convert_element_type_1435 = None + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61); reduce_scatter_tensor_61 = None + view_1259 = torch.ops.aten.view.default(view_1255, [16384, 1024]); view_1255 = None + permute_573 = torch.ops.aten.permute.default(view_1259, [1, 0]) + mm_321 = torch.ops.aten.mm.default(permute_573, view_853); permute_573 = None + permute_575 = torch.ops.aten.permute.default(permute_276, [1, 0]); permute_276 = None + mm_322 = torch.ops.aten.mm.default(view_1259, permute_575); view_1259 = permute_575 = None + view_1260 = torch.ops.aten.view.default(mm_322, [2, 8192, 4096]); mm_322 = None + add_175 = torch.ops.aten.add.Tensor(view_1258, view_1260); view_1258 = view_1260 = None + convert_element_type_1440 = torch.ops.prims.convert_element_type.default(mm_321, torch.float32); mm_321 = None + reduce_scatter_tensor_62 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1440, 'avg', 256, '0'); convert_element_type_1440 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_62); reduce_scatter_tensor_62 = None + view_1261 = torch.ops.aten.view.default(view_1256, [16384, 4096]); view_1256 = None + permute_577 = torch.ops.aten.permute.default(view_1261, [1, 0]) + mm_323 = torch.ops.aten.mm.default(permute_577, view_853); permute_577 = view_853 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16); primals_230 = None + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 256, '0'); convert_element_type_829 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + permute_579 = torch.ops.aten.permute.default(permute_275, [1, 0]); permute_275 = None + mm_324 = torch.ops.aten.mm.default(view_1261, permute_579); view_1261 = permute_579 = None + view_1262 = torch.ops.aten.view.default(mm_324, [2, 8192, 4096]); mm_324 = None + add_176 = torch.ops.aten.add.Tensor(add_175, view_1262); add_175 = view_1262 = None + convert_element_type_1445 = torch.ops.prims.convert_element_type.default(mm_323, torch.float32); mm_323 = None + reduce_scatter_tensor_63 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1445, 'avg', 256, '0'); convert_element_type_1445 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63); reduce_scatter_tensor_63 = None + convert_element_type_1446 = torch.ops.prims.convert_element_type.default(add_176, torch.float32); add_176 = None + convert_element_type_1448 = torch.ops.prims.convert_element_type.default(wait_tensor_226, torch.float32); wait_tensor_226 = None + mul_398 = torch.ops.aten.mul.Tensor(convert_element_type_1446, convert_element_type_1448); convert_element_type_1448 = None + mul_400 = torch.ops.aten.mul.Tensor(mul_200, mul_398) + sum_43 = torch.ops.aten.sum.dim_IntList(mul_400, [2], True); mul_400 = None + div_14 = torch.ops.aten.div.Tensor(mul_200, 4096) + mul_401 = torch.ops.aten.mul.Tensor(div_14, sum_43); div_14 = sum_43 = None + sub_21 = torch.ops.aten.sub.Tensor(mul_398, mul_401); mul_398 = mul_401 = None + mul_402 = torch.ops.aten.mul.Tensor(sub_21, rsqrt_50); sub_21 = rsqrt_50 = None + mul_403 = torch.ops.aten.mul.Tensor(convert_element_type_1446, mul_200); convert_element_type_1446 = mul_200 = None + sum_44 = torch.ops.aten.sum.dim_IntList(mul_403, [0, 1]); mul_403 = None + convert_element_type_1449 = torch.ops.prims.convert_element_type.default(mul_402, torch.bfloat16); mul_402 = None + add_177 = torch.ops.aten.add.Tensor(add_174, convert_element_type_1449); add_174 = convert_element_type_1449 = None + convert_element_type_default_51 = torch.ops.prims.convert_element_type.default(sum_44, torch.float32); sum_44 = None + reduce_scatter_tensor_64 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_51, 'avg', 256, '0'); convert_element_type_default_51 = None + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64); reduce_scatter_tensor_64 = None + view_1263 = torch.ops.aten.view.default(add_177, [16384, 4096]) + permute_581 = torch.ops.aten.permute.default(view_1263, [1, 0]) + permute_270 = torch.ops.aten.permute.default(getitem_216, [0, 2, 1, 3]) + view_837 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16); primals_224 = None + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 256, '0'); convert_element_type_809 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + view_839 = torch.ops.aten.view.default(view_837, [16384, 4096]); view_837 = None + mm_171 = torch.ops.aten.mm.default(view_839, permute_271) + view_840 = torch.ops.aten.view.default(mm_171, [2, 8192, 4096]); mm_171 = None + add_97 = torch.ops.aten.add.Tensor(add_95, view_840); view_840 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16); primals_225 = None + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 256, '0'); convert_element_type_812 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32); add_97 = None + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_222) + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + view_843 = torch.ops.aten.view.default(convert_element_type_814, [16384, 4096]); convert_element_type_814 = None + view_844 = torch.ops.aten.view.default(mm_172, [2, 8192, 14336]); mm_172 = None + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_844, torch.float32); view_844 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16); primals_227 = None + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 256, '0'); convert_element_type_820 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_224, [1, 0]); wait_tensor_224 = None + mm_173 = torch.ops.aten.mm.default(view_843, permute_273) + view_847 = torch.ops.aten.view.default(mm_173, [2, 8192, 14336]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_847) + view_849 = torch.ops.aten.view.default(mul_199, [16384, 14336]); mul_199 = None + mm_325 = torch.ops.aten.mm.default(permute_581, view_849); permute_581 = view_849 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16); primals_228 = None + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 256, '0'); convert_element_type_823 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + permute_583 = torch.ops.aten.permute.default(permute_274, [1, 0]); permute_274 = None + mm_326 = torch.ops.aten.mm.default(view_1263, permute_583); view_1263 = permute_583 = None + view_1264 = torch.ops.aten.view.default(mm_326, [2, 8192, 14336]); mm_326 = None + convert_element_type_1456 = torch.ops.prims.convert_element_type.default(mm_325, torch.float32); mm_325 = None + reduce_scatter_tensor_65 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1456, 'avg', 256, '0'); convert_element_type_1456 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_65); reduce_scatter_tensor_65 = None + mul_404 = torch.ops.aten.mul.Tensor(view_1264, convert_element_type_819); convert_element_type_819 = None + mul_405 = torch.ops.aten.mul.Tensor(view_1264, view_847); view_1264 = view_847 = None + view_1265 = torch.ops.aten.view.default(mul_404, [16384, 14336]); mul_404 = None + permute_585 = torch.ops.aten.permute.default(view_1265, [1, 0]) + mm_327 = torch.ops.aten.mm.default(permute_585, view_843); permute_585 = None + permute_587 = torch.ops.aten.permute.default(permute_273, [1, 0]); permute_273 = None + mm_328 = torch.ops.aten.mm.default(view_1265, permute_587); view_1265 = permute_587 = None + view_1266 = torch.ops.aten.view.default(mm_328, [2, 8192, 4096]); mm_328 = None + convert_element_type_1461 = torch.ops.prims.convert_element_type.default(mm_327, torch.float32); mm_327 = None + reduce_scatter_tensor_66 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1461, 'avg', 256, '0'); convert_element_type_1461 = None + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_66); reduce_scatter_tensor_66 = None + convert_element_type_1462 = torch.ops.prims.convert_element_type.default(mul_405, torch.float32); mul_405 = None + neg_7 = torch.ops.aten.neg.default(convert_element_type_818) + exp_7 = torch.ops.aten.exp.default(neg_7); neg_7 = None + add_178 = torch.ops.aten.add.Tensor(exp_7, 1); exp_7 = None + reciprocal_7 = torch.ops.aten.reciprocal.default(add_178); add_178 = None + mul_406 = torch.ops.aten.mul.Tensor(reciprocal_7, 1); reciprocal_7 = None + mul_407 = torch.ops.aten.mul.Tensor(convert_element_type_1462, mul_406); convert_element_type_1462 = None + sub_22 = torch.ops.aten.sub.Tensor(1, mul_406); mul_406 = None + mul_408 = torch.ops.aten.mul.Tensor(convert_element_type_818, sub_22); convert_element_type_818 = sub_22 = None + add_179 = torch.ops.aten.add.Tensor(mul_408, 1); mul_408 = None + mul_409 = torch.ops.aten.mul.Tensor(mul_407, add_179); mul_407 = add_179 = None + convert_element_type_1464 = torch.ops.prims.convert_element_type.default(mul_409, torch.bfloat16); mul_409 = None + view_1267 = torch.ops.aten.view.default(convert_element_type_1464, [16384, 14336]); convert_element_type_1464 = None + permute_589 = torch.ops.aten.permute.default(view_1267, [1, 0]) + mm_329 = torch.ops.aten.mm.default(permute_589, view_843); permute_589 = view_843 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16); primals_226 = None + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 256, '0'); convert_element_type_815 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_223, [1, 0]); wait_tensor_223 = None + permute_591 = torch.ops.aten.permute.default(permute_272, [1, 0]); permute_272 = None + mm_330 = torch.ops.aten.mm.default(view_1267, permute_591); view_1267 = permute_591 = None + view_1268 = torch.ops.aten.view.default(mm_330, [2, 8192, 4096]); mm_330 = None + add_180 = torch.ops.aten.add.Tensor(view_1266, view_1268); view_1266 = view_1268 = None + convert_element_type_1469 = torch.ops.prims.convert_element_type.default(mm_329, torch.float32); mm_329 = None + reduce_scatter_tensor_67 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1469, 'avg', 256, '0'); convert_element_type_1469 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_67); reduce_scatter_tensor_67 = None + convert_element_type_1470 = torch.ops.prims.convert_element_type.default(add_180, torch.float32); add_180 = None + convert_element_type_1472 = torch.ops.prims.convert_element_type.default(wait_tensor_222, torch.float32); wait_tensor_222 = None + mul_410 = torch.ops.aten.mul.Tensor(convert_element_type_1470, convert_element_type_1472); convert_element_type_1472 = None + mul_412 = torch.ops.aten.mul.Tensor(mul_196, mul_410) + sum_45 = torch.ops.aten.sum.dim_IntList(mul_412, [2], True); mul_412 = None + div_15 = torch.ops.aten.div.Tensor(mul_196, 4096) + mul_413 = torch.ops.aten.mul.Tensor(div_15, sum_45); div_15 = sum_45 = None + sub_23 = torch.ops.aten.sub.Tensor(mul_410, mul_413); mul_410 = mul_413 = None + mul_414 = torch.ops.aten.mul.Tensor(sub_23, rsqrt_49); sub_23 = rsqrt_49 = None + mul_415 = torch.ops.aten.mul.Tensor(convert_element_type_1470, mul_196); convert_element_type_1470 = mul_196 = None + sum_46 = torch.ops.aten.sum.dim_IntList(mul_415, [0, 1]); mul_415 = None + convert_element_type_1473 = torch.ops.prims.convert_element_type.default(mul_414, torch.bfloat16); mul_414 = None + add_181 = torch.ops.aten.add.Tensor(add_177, convert_element_type_1473); add_177 = convert_element_type_1473 = None + convert_element_type_default_50 = torch.ops.prims.convert_element_type.default(sum_46, torch.float32); sum_46 = None + reduce_scatter_tensor_68 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_50, 'avg', 256, '0'); convert_element_type_default_50 = None + wait_tensor_359 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_68); reduce_scatter_tensor_68 = None + view_1269 = torch.ops.aten.view.default(add_181, [16384, 4096]) + permute_593 = torch.ops.aten.permute.default(view_1269, [1, 0]) + mm_331 = torch.ops.aten.mm.default(permute_593, view_839); permute_593 = view_839 = None + permute_595 = torch.ops.aten.permute.default(permute_271, [1, 0]); permute_271 = None + mm_332 = torch.ops.aten.mm.default(view_1269, permute_595); view_1269 = permute_595 = None + view_1270 = torch.ops.aten.view.default(mm_332, [2, 8192, 4096]); mm_332 = None + convert_element_type_1480 = torch.ops.prims.convert_element_type.default(mm_331, torch.float32); mm_331 = None + reduce_scatter_tensor_69 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1480, 'avg', 256, '0'); convert_element_type_1480 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_69); reduce_scatter_tensor_69 = None + view_1271 = torch.ops.aten.view.default(view_1270, [2, 8192, 32, 128]); view_1270 = None + permute_597 = torch.ops.aten.permute.default(view_1271, [0, 2, 1, 3]); view_1271 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16); primals_220 = None + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 256, '0'); convert_element_type_793 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32); add_95 = None + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_217) + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + view_819 = torch.ops.aten.view.default(convert_element_type_795, [16384, 4096]); convert_element_type_795 = None + view_820 = torch.ops.aten.view.default(mm_168, [2, 8192, 4096]); mm_168 = None + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16); primals_222 = None + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 256, '0'); convert_element_type_799 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + mm_169 = torch.ops.aten.mm.default(view_819, permute_265) + view_823 = torch.ops.aten.view.default(mm_169, [2, 8192, 1024]); mm_169 = None + view_826 = torch.ops.aten.view.default(mm_170, [2, 8192, 1024]); mm_170 = None + view_827 = torch.ops.aten.view.default(view_820, [2, 8192, -1, 128]); view_820 = None + view_828 = torch.ops.aten.view.default(view_823, [2, 8192, -1, 128]); view_823 = None + view_829 = torch.ops.aten.view.default(view_826, [2, 8192, -1, 128]); view_826 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_827, torch.float32); view_827 = None + view_830 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 32, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_830); view_830 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_828, torch.float32); view_828 = None + view_831 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 8, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_831); view_831 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_16); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_833 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 32, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_16); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_834 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 8, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_833, torch.bfloat16); view_833 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_834, torch.bfloat16); view_834 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 8, 4, 128]); unsqueeze_48 = None + clone_48 = torch.ops.aten.clone.default(expand_48, memory_format = torch.contiguous_format); expand_48 = None + view_835 = torch.ops.aten.view.default(clone_48, [2, 8192, 32, 128]); clone_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_829, 3); view_829 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 8, 4, 128]); unsqueeze_49 = None + clone_49 = torch.ops.aten.clone.default(expand_49, memory_format = torch.contiguous_format); expand_49 = None + view_836 = torch.ops.aten.view.default(clone_49, [2, 8192, 32, 128]); clone_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_835, [0, 2, 1, 3]); view_835 = None + permute_269 = torch.ops.aten.permute.default(view_836, [0, 2, 1, 3]); view_836 = None + _scaled_dot_product_cudnn_attention_backward_7 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_597, permute_267, permute_268, permute_269, getitem_216, getitem_217, getitem_222, getitem_223, None, None, None, 8192, 8192, 0.0, True); permute_597 = permute_267 = permute_268 = permute_269 = getitem_216 = getitem_217 = getitem_222 = getitem_223 = None + getitem_309 = _scaled_dot_product_cudnn_attention_backward_7[0] + getitem_310 = _scaled_dot_product_cudnn_attention_backward_7[1] + getitem_311 = _scaled_dot_product_cudnn_attention_backward_7[2]; _scaled_dot_product_cudnn_attention_backward_7 = None + permute_598 = torch.ops.aten.permute.default(getitem_311, [0, 2, 1, 3]); getitem_311 = None + permute_599 = torch.ops.aten.permute.default(getitem_310, [0, 2, 1, 3]); getitem_310 = None + permute_600 = torch.ops.aten.permute.default(getitem_309, [0, 2, 1, 3]); getitem_309 = None + view_1272 = torch.ops.aten.view.default(permute_598, [2, 8192, 8, 4, 128]); permute_598 = None + sum_47 = torch.ops.aten.sum.dim_IntList(view_1272, [3], True); view_1272 = None + squeeze_14 = torch.ops.aten.squeeze.dim(sum_47, 3); sum_47 = None + view_1273 = torch.ops.aten.view.default(permute_599, [2, 8192, 8, 4, 128]); permute_599 = None + sum_48 = torch.ops.aten.sum.dim_IntList(view_1273, [3], True); view_1273 = None + squeeze_15 = torch.ops.aten.squeeze.dim(sum_48, 3); sum_48 = None + convert_element_type_1481 = torch.ops.prims.convert_element_type.default(squeeze_15, torch.float32); squeeze_15 = None + convert_element_type_1482 = torch.ops.prims.convert_element_type.default(permute_600, torch.float32); permute_600 = None + view_1274 = torch.ops.aten.view.default(convert_element_type_1481, [2, 8192, 8, 64, 2]); convert_element_type_1481 = None + view_as_complex_78 = torch.ops.aten.view_as_complex.default(view_1274); view_1274 = None + mul_416 = torch.ops.aten.mul.Tensor(view_as_complex_78, _conj); view_as_complex_78 = None + view_1275 = torch.ops.aten.view.default(convert_element_type_1482, [2, 8192, 32, 64, 2]); convert_element_type_1482 = None + view_as_complex_79 = torch.ops.aten.view_as_complex.default(view_1275); view_1275 = None + mul_417 = torch.ops.aten.mul.Tensor(view_as_complex_79, _conj); view_as_complex_79 = None + view_as_real_78 = torch.ops.aten.view_as_real.default(mul_416); mul_416 = None + view_1276 = torch.ops.aten.view.default(view_as_real_78, [2, 8192, 8, 128]); view_as_real_78 = None + convert_element_type_1483 = torch.ops.prims.convert_element_type.default(view_1276, torch.bfloat16); view_1276 = None + view_as_real_79 = torch.ops.aten.view_as_real.default(mul_417); mul_417 = None + view_1277 = torch.ops.aten.view.default(view_as_real_79, [2, 8192, 32, 128]); view_as_real_79 = None + convert_element_type_1484 = torch.ops.prims.convert_element_type.default(view_1277, torch.bfloat16); view_1277 = None + view_1278 = torch.ops.aten.view.default(squeeze_14, [2, 8192, 1024]); squeeze_14 = None + view_1279 = torch.ops.aten.view.default(convert_element_type_1483, [2, 8192, 1024]); convert_element_type_1483 = None + view_1280 = torch.ops.aten.view.default(convert_element_type_1484, [2, 8192, 4096]); convert_element_type_1484 = None + view_1281 = torch.ops.aten.view.default(view_1278, [16384, 1024]); view_1278 = None + permute_601 = torch.ops.aten.permute.default(view_1281, [1, 0]) + mm_333 = torch.ops.aten.mm.default(permute_601, view_819); permute_601 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16); primals_223 = None + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 256, '0'); convert_element_type_802 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + permute_603 = torch.ops.aten.permute.default(permute_266, [1, 0]); permute_266 = None + mm_334 = torch.ops.aten.mm.default(view_1281, permute_603); view_1281 = permute_603 = None + view_1282 = torch.ops.aten.view.default(mm_334, [2, 8192, 4096]); mm_334 = None + convert_element_type_1489 = torch.ops.prims.convert_element_type.default(mm_333, torch.float32); mm_333 = None + reduce_scatter_tensor_70 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1489, 'avg', 256, '0'); convert_element_type_1489 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_70); reduce_scatter_tensor_70 = None + view_1283 = torch.ops.aten.view.default(view_1279, [16384, 1024]); view_1279 = None + permute_605 = torch.ops.aten.permute.default(view_1283, [1, 0]) + mm_335 = torch.ops.aten.mm.default(permute_605, view_819); permute_605 = None + permute_607 = torch.ops.aten.permute.default(permute_265, [1, 0]); permute_265 = None + mm_336 = torch.ops.aten.mm.default(view_1283, permute_607); view_1283 = permute_607 = None + view_1284 = torch.ops.aten.view.default(mm_336, [2, 8192, 4096]); mm_336 = None + add_182 = torch.ops.aten.add.Tensor(view_1282, view_1284); view_1282 = view_1284 = None + convert_element_type_1494 = torch.ops.prims.convert_element_type.default(mm_335, torch.float32); mm_335 = None + reduce_scatter_tensor_71 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1494, 'avg', 256, '0'); convert_element_type_1494 = None + wait_tensor_362 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_71); reduce_scatter_tensor_71 = None + view_1285 = torch.ops.aten.view.default(view_1280, [16384, 4096]); view_1280 = None + permute_609 = torch.ops.aten.permute.default(view_1285, [1, 0]) + mm_337 = torch.ops.aten.mm.default(permute_609, view_819); permute_609 = view_819 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16); primals_221 = None + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 256, '0'); convert_element_type_796 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_218, [1, 0]); wait_tensor_218 = None + permute_611 = torch.ops.aten.permute.default(permute_264, [1, 0]); permute_264 = None + mm_338 = torch.ops.aten.mm.default(view_1285, permute_611); view_1285 = permute_611 = None + view_1286 = torch.ops.aten.view.default(mm_338, [2, 8192, 4096]); mm_338 = None + add_183 = torch.ops.aten.add.Tensor(add_182, view_1286); add_182 = view_1286 = None + convert_element_type_1499 = torch.ops.prims.convert_element_type.default(mm_337, torch.float32); mm_337 = None + reduce_scatter_tensor_72 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1499, 'avg', 256, '0'); convert_element_type_1499 = None + wait_tensor_363 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_72); reduce_scatter_tensor_72 = None + convert_element_type_1500 = torch.ops.prims.convert_element_type.default(add_183, torch.float32); add_183 = None + convert_element_type_1502 = torch.ops.prims.convert_element_type.default(wait_tensor_217, torch.float32); wait_tensor_217 = None + mul_418 = torch.ops.aten.mul.Tensor(convert_element_type_1500, convert_element_type_1502); convert_element_type_1502 = None + mul_420 = torch.ops.aten.mul.Tensor(mul_192, mul_418) + sum_49 = torch.ops.aten.sum.dim_IntList(mul_420, [2], True); mul_420 = None + div_16 = torch.ops.aten.div.Tensor(mul_192, 4096) + mul_421 = torch.ops.aten.mul.Tensor(div_16, sum_49); div_16 = sum_49 = None + sub_24 = torch.ops.aten.sub.Tensor(mul_418, mul_421); mul_418 = mul_421 = None + mul_422 = torch.ops.aten.mul.Tensor(sub_24, rsqrt_48); sub_24 = rsqrt_48 = None + mul_423 = torch.ops.aten.mul.Tensor(convert_element_type_1500, mul_192); convert_element_type_1500 = mul_192 = None + sum_50 = torch.ops.aten.sum.dim_IntList(mul_423, [0, 1]); mul_423 = None + convert_element_type_1503 = torch.ops.prims.convert_element_type.default(mul_422, torch.bfloat16); mul_422 = None + add_184 = torch.ops.aten.add.Tensor(add_181, convert_element_type_1503); add_181 = convert_element_type_1503 = None + convert_element_type_default_49 = torch.ops.prims.convert_element_type.default(sum_50, torch.float32); sum_50 = None + reduce_scatter_tensor_73 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_49, 'avg', 256, '0'); convert_element_type_default_49 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_73); reduce_scatter_tensor_73 = None + view_1287 = torch.ops.aten.view.default(add_184, [16384, 4096]) + permute_613 = torch.ops.aten.permute.default(view_1287, [1, 0]) + permute_259 = torch.ops.aten.permute.default(getitem_207, [0, 2, 1, 3]) + view_803 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16); primals_215 = None + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 256, '0'); convert_element_type_776 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + view_805 = torch.ops.aten.view.default(view_803, [16384, 4096]); view_803 = None + mm_164 = torch.ops.aten.mm.default(view_805, permute_260) + view_806 = torch.ops.aten.view.default(mm_164, [2, 8192, 4096]); mm_164 = None + add_93 = torch.ops.aten.add.Tensor(add_91, view_806); view_806 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16); primals_216 = None + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 256, '0'); convert_element_type_779 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32); add_93 = None + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_213) + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + view_809 = torch.ops.aten.view.default(convert_element_type_781, [16384, 4096]); convert_element_type_781 = None + view_810 = torch.ops.aten.view.default(mm_165, [2, 8192, 14336]); mm_165 = None + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_810, torch.float32); view_810 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16); primals_218 = None + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 256, '0'); convert_element_type_787 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + mm_166 = torch.ops.aten.mm.default(view_809, permute_262) + view_813 = torch.ops.aten.view.default(mm_166, [2, 8192, 14336]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_813) + view_815 = torch.ops.aten.view.default(mul_191, [16384, 14336]); mul_191 = None + mm_339 = torch.ops.aten.mm.default(permute_613, view_815); permute_613 = view_815 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16); primals_219 = None + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 256, '0'); convert_element_type_790 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_216, [1, 0]); wait_tensor_216 = None + permute_615 = torch.ops.aten.permute.default(permute_263, [1, 0]); permute_263 = None + mm_340 = torch.ops.aten.mm.default(view_1287, permute_615); view_1287 = permute_615 = None + view_1288 = torch.ops.aten.view.default(mm_340, [2, 8192, 14336]); mm_340 = None + convert_element_type_1510 = torch.ops.prims.convert_element_type.default(mm_339, torch.float32); mm_339 = None + reduce_scatter_tensor_74 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1510, 'avg', 256, '0'); convert_element_type_1510 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_74); reduce_scatter_tensor_74 = None + mul_424 = torch.ops.aten.mul.Tensor(view_1288, convert_element_type_786); convert_element_type_786 = None + mul_425 = torch.ops.aten.mul.Tensor(view_1288, view_813); view_1288 = view_813 = None + view_1289 = torch.ops.aten.view.default(mul_424, [16384, 14336]); mul_424 = None + permute_617 = torch.ops.aten.permute.default(view_1289, [1, 0]) + mm_341 = torch.ops.aten.mm.default(permute_617, view_809); permute_617 = None + permute_619 = torch.ops.aten.permute.default(permute_262, [1, 0]); permute_262 = None + mm_342 = torch.ops.aten.mm.default(view_1289, permute_619); view_1289 = permute_619 = None + view_1290 = torch.ops.aten.view.default(mm_342, [2, 8192, 4096]); mm_342 = None + convert_element_type_1515 = torch.ops.prims.convert_element_type.default(mm_341, torch.float32); mm_341 = None + reduce_scatter_tensor_75 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1515, 'avg', 256, '0'); convert_element_type_1515 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_75); reduce_scatter_tensor_75 = None + convert_element_type_1516 = torch.ops.prims.convert_element_type.default(mul_425, torch.float32); mul_425 = None + neg_8 = torch.ops.aten.neg.default(convert_element_type_785) + exp_8 = torch.ops.aten.exp.default(neg_8); neg_8 = None + add_185 = torch.ops.aten.add.Tensor(exp_8, 1); exp_8 = None + reciprocal_8 = torch.ops.aten.reciprocal.default(add_185); add_185 = None + mul_426 = torch.ops.aten.mul.Tensor(reciprocal_8, 1); reciprocal_8 = None + mul_427 = torch.ops.aten.mul.Tensor(convert_element_type_1516, mul_426); convert_element_type_1516 = None + sub_25 = torch.ops.aten.sub.Tensor(1, mul_426); mul_426 = None + mul_428 = torch.ops.aten.mul.Tensor(convert_element_type_785, sub_25); convert_element_type_785 = sub_25 = None + add_186 = torch.ops.aten.add.Tensor(mul_428, 1); mul_428 = None + mul_429 = torch.ops.aten.mul.Tensor(mul_427, add_186); mul_427 = add_186 = None + convert_element_type_1518 = torch.ops.prims.convert_element_type.default(mul_429, torch.bfloat16); mul_429 = None + view_1291 = torch.ops.aten.view.default(convert_element_type_1518, [16384, 14336]); convert_element_type_1518 = None + permute_621 = torch.ops.aten.permute.default(view_1291, [1, 0]) + mm_343 = torch.ops.aten.mm.default(permute_621, view_809); permute_621 = view_809 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16); primals_217 = None + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 256, '0'); convert_element_type_782 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + permute_623 = torch.ops.aten.permute.default(permute_261, [1, 0]); permute_261 = None + mm_344 = torch.ops.aten.mm.default(view_1291, permute_623); view_1291 = permute_623 = None + view_1292 = torch.ops.aten.view.default(mm_344, [2, 8192, 4096]); mm_344 = None + add_187 = torch.ops.aten.add.Tensor(view_1290, view_1292); view_1290 = view_1292 = None + convert_element_type_1523 = torch.ops.prims.convert_element_type.default(mm_343, torch.float32); mm_343 = None + reduce_scatter_tensor_76 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1523, 'avg', 256, '0'); convert_element_type_1523 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_76); reduce_scatter_tensor_76 = None + convert_element_type_1524 = torch.ops.prims.convert_element_type.default(add_187, torch.float32); add_187 = None + convert_element_type_1526 = torch.ops.prims.convert_element_type.default(wait_tensor_213, torch.float32); wait_tensor_213 = None + mul_430 = torch.ops.aten.mul.Tensor(convert_element_type_1524, convert_element_type_1526); convert_element_type_1526 = None + mul_432 = torch.ops.aten.mul.Tensor(mul_188, mul_430) + sum_51 = torch.ops.aten.sum.dim_IntList(mul_432, [2], True); mul_432 = None + div_17 = torch.ops.aten.div.Tensor(mul_188, 4096) + mul_433 = torch.ops.aten.mul.Tensor(div_17, sum_51); div_17 = sum_51 = None + sub_26 = torch.ops.aten.sub.Tensor(mul_430, mul_433); mul_430 = mul_433 = None + mul_434 = torch.ops.aten.mul.Tensor(sub_26, rsqrt_47); sub_26 = rsqrt_47 = None + mul_435 = torch.ops.aten.mul.Tensor(convert_element_type_1524, mul_188); convert_element_type_1524 = mul_188 = None + sum_52 = torch.ops.aten.sum.dim_IntList(mul_435, [0, 1]); mul_435 = None + convert_element_type_1527 = torch.ops.prims.convert_element_type.default(mul_434, torch.bfloat16); mul_434 = None + add_188 = torch.ops.aten.add.Tensor(add_184, convert_element_type_1527); add_184 = convert_element_type_1527 = None + convert_element_type_default_48 = torch.ops.prims.convert_element_type.default(sum_52, torch.float32); sum_52 = None + reduce_scatter_tensor_77 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_48, 'avg', 256, '0'); convert_element_type_default_48 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_77); reduce_scatter_tensor_77 = None + view_1293 = torch.ops.aten.view.default(add_188, [16384, 4096]) + permute_625 = torch.ops.aten.permute.default(view_1293, [1, 0]) + mm_345 = torch.ops.aten.mm.default(permute_625, view_805); permute_625 = view_805 = None + permute_627 = torch.ops.aten.permute.default(permute_260, [1, 0]); permute_260 = None + mm_346 = torch.ops.aten.mm.default(view_1293, permute_627); view_1293 = permute_627 = None + view_1294 = torch.ops.aten.view.default(mm_346, [2, 8192, 4096]); mm_346 = None + convert_element_type_1534 = torch.ops.prims.convert_element_type.default(mm_345, torch.float32); mm_345 = None + reduce_scatter_tensor_78 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1534, 'avg', 256, '0'); convert_element_type_1534 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_78); reduce_scatter_tensor_78 = None + view_1295 = torch.ops.aten.view.default(view_1294, [2, 8192, 32, 128]); view_1294 = None + permute_629 = torch.ops.aten.permute.default(view_1295, [0, 2, 1, 3]); view_1295 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16); primals_211 = None + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 256, '0'); convert_element_type_760 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32); add_91 = None + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_208) + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + view_785 = torch.ops.aten.view.default(convert_element_type_762, [16384, 4096]); convert_element_type_762 = None + view_786 = torch.ops.aten.view.default(mm_161, [2, 8192, 4096]); mm_161 = None + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16); primals_213 = None + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 256, '0'); convert_element_type_766 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_210, [1, 0]); wait_tensor_210 = None + mm_162 = torch.ops.aten.mm.default(view_785, permute_254) + view_789 = torch.ops.aten.view.default(mm_162, [2, 8192, 1024]); mm_162 = None + view_792 = torch.ops.aten.view.default(mm_163, [2, 8192, 1024]); mm_163 = None + view_793 = torch.ops.aten.view.default(view_786, [2, 8192, -1, 128]); view_786 = None + view_794 = torch.ops.aten.view.default(view_789, [2, 8192, -1, 128]); view_789 = None + view_795 = torch.ops.aten.view.default(view_792, [2, 8192, -1, 128]); view_792 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_793, torch.float32); view_793 = None + view_796 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 32, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_796); view_796 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_794, torch.float32); view_794 = None + view_797 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 8, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_797); view_797 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_16); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_799 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 32, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_16); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_800 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 8, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_799, torch.bfloat16); view_799 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_800, torch.bfloat16); view_800 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 8, 4, 128]); unsqueeze_46 = None + clone_46 = torch.ops.aten.clone.default(expand_46, memory_format = torch.contiguous_format); expand_46 = None + view_801 = torch.ops.aten.view.default(clone_46, [2, 8192, 32, 128]); clone_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_795, 3); view_795 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 8, 4, 128]); unsqueeze_47 = None + clone_47 = torch.ops.aten.clone.default(expand_47, memory_format = torch.contiguous_format); expand_47 = None + view_802 = torch.ops.aten.view.default(clone_47, [2, 8192, 32, 128]); clone_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_801, [0, 2, 1, 3]); view_801 = None + permute_258 = torch.ops.aten.permute.default(view_802, [0, 2, 1, 3]); view_802 = None + _scaled_dot_product_cudnn_attention_backward_8 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_629, permute_256, permute_257, permute_258, getitem_207, getitem_208, getitem_213, getitem_214, None, None, None, 8192, 8192, 0.0, True); permute_629 = permute_256 = permute_257 = permute_258 = getitem_207 = getitem_208 = getitem_213 = getitem_214 = None + getitem_312 = _scaled_dot_product_cudnn_attention_backward_8[0] + getitem_313 = _scaled_dot_product_cudnn_attention_backward_8[1] + getitem_314 = _scaled_dot_product_cudnn_attention_backward_8[2]; _scaled_dot_product_cudnn_attention_backward_8 = None + permute_630 = torch.ops.aten.permute.default(getitem_314, [0, 2, 1, 3]); getitem_314 = None + permute_631 = torch.ops.aten.permute.default(getitem_313, [0, 2, 1, 3]); getitem_313 = None + permute_632 = torch.ops.aten.permute.default(getitem_312, [0, 2, 1, 3]); getitem_312 = None + view_1296 = torch.ops.aten.view.default(permute_630, [2, 8192, 8, 4, 128]); permute_630 = None + sum_53 = torch.ops.aten.sum.dim_IntList(view_1296, [3], True); view_1296 = None + squeeze_16 = torch.ops.aten.squeeze.dim(sum_53, 3); sum_53 = None + view_1297 = torch.ops.aten.view.default(permute_631, [2, 8192, 8, 4, 128]); permute_631 = None + sum_54 = torch.ops.aten.sum.dim_IntList(view_1297, [3], True); view_1297 = None + squeeze_17 = torch.ops.aten.squeeze.dim(sum_54, 3); sum_54 = None + convert_element_type_1535 = torch.ops.prims.convert_element_type.default(squeeze_17, torch.float32); squeeze_17 = None + convert_element_type_1536 = torch.ops.prims.convert_element_type.default(permute_632, torch.float32); permute_632 = None + view_1298 = torch.ops.aten.view.default(convert_element_type_1535, [2, 8192, 8, 64, 2]); convert_element_type_1535 = None + view_as_complex_80 = torch.ops.aten.view_as_complex.default(view_1298); view_1298 = None + mul_436 = torch.ops.aten.mul.Tensor(view_as_complex_80, _conj); view_as_complex_80 = None + view_1299 = torch.ops.aten.view.default(convert_element_type_1536, [2, 8192, 32, 64, 2]); convert_element_type_1536 = None + view_as_complex_81 = torch.ops.aten.view_as_complex.default(view_1299); view_1299 = None + mul_437 = torch.ops.aten.mul.Tensor(view_as_complex_81, _conj); view_as_complex_81 = None + view_as_real_80 = torch.ops.aten.view_as_real.default(mul_436); mul_436 = None + view_1300 = torch.ops.aten.view.default(view_as_real_80, [2, 8192, 8, 128]); view_as_real_80 = None + convert_element_type_1537 = torch.ops.prims.convert_element_type.default(view_1300, torch.bfloat16); view_1300 = None + view_as_real_81 = torch.ops.aten.view_as_real.default(mul_437); mul_437 = None + view_1301 = torch.ops.aten.view.default(view_as_real_81, [2, 8192, 32, 128]); view_as_real_81 = None + convert_element_type_1538 = torch.ops.prims.convert_element_type.default(view_1301, torch.bfloat16); view_1301 = None + view_1302 = torch.ops.aten.view.default(squeeze_16, [2, 8192, 1024]); squeeze_16 = None + view_1303 = torch.ops.aten.view.default(convert_element_type_1537, [2, 8192, 1024]); convert_element_type_1537 = None + view_1304 = torch.ops.aten.view.default(convert_element_type_1538, [2, 8192, 4096]); convert_element_type_1538 = None + view_1305 = torch.ops.aten.view.default(view_1302, [16384, 1024]); view_1302 = None + permute_633 = torch.ops.aten.permute.default(view_1305, [1, 0]) + mm_347 = torch.ops.aten.mm.default(permute_633, view_785); permute_633 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16); primals_214 = None + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 256, '0'); convert_element_type_769 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_211, [1, 0]); wait_tensor_211 = None + permute_635 = torch.ops.aten.permute.default(permute_255, [1, 0]); permute_255 = None + mm_348 = torch.ops.aten.mm.default(view_1305, permute_635); view_1305 = permute_635 = None + view_1306 = torch.ops.aten.view.default(mm_348, [2, 8192, 4096]); mm_348 = None + convert_element_type_1543 = torch.ops.prims.convert_element_type.default(mm_347, torch.float32); mm_347 = None + reduce_scatter_tensor_79 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1543, 'avg', 256, '0'); convert_element_type_1543 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_79); reduce_scatter_tensor_79 = None + view_1307 = torch.ops.aten.view.default(view_1303, [16384, 1024]); view_1303 = None + permute_637 = torch.ops.aten.permute.default(view_1307, [1, 0]) + mm_349 = torch.ops.aten.mm.default(permute_637, view_785); permute_637 = None + permute_639 = torch.ops.aten.permute.default(permute_254, [1, 0]); permute_254 = None + mm_350 = torch.ops.aten.mm.default(view_1307, permute_639); view_1307 = permute_639 = None + view_1308 = torch.ops.aten.view.default(mm_350, [2, 8192, 4096]); mm_350 = None + add_189 = torch.ops.aten.add.Tensor(view_1306, view_1308); view_1306 = view_1308 = None + convert_element_type_1548 = torch.ops.prims.convert_element_type.default(mm_349, torch.float32); mm_349 = None + reduce_scatter_tensor_80 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1548, 'avg', 256, '0'); convert_element_type_1548 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_80); reduce_scatter_tensor_80 = None + view_1309 = torch.ops.aten.view.default(view_1304, [16384, 4096]); view_1304 = None + permute_641 = torch.ops.aten.permute.default(view_1309, [1, 0]) + mm_351 = torch.ops.aten.mm.default(permute_641, view_785); permute_641 = view_785 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16); primals_212 = None + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 256, '0'); convert_element_type_763 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_209, [1, 0]); wait_tensor_209 = None + permute_643 = torch.ops.aten.permute.default(permute_253, [1, 0]); permute_253 = None + mm_352 = torch.ops.aten.mm.default(view_1309, permute_643); view_1309 = permute_643 = None + view_1310 = torch.ops.aten.view.default(mm_352, [2, 8192, 4096]); mm_352 = None + add_190 = torch.ops.aten.add.Tensor(add_189, view_1310); add_189 = view_1310 = None + convert_element_type_1553 = torch.ops.prims.convert_element_type.default(mm_351, torch.float32); mm_351 = None + reduce_scatter_tensor_81 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1553, 'avg', 256, '0'); convert_element_type_1553 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_81); reduce_scatter_tensor_81 = None + convert_element_type_1554 = torch.ops.prims.convert_element_type.default(add_190, torch.float32); add_190 = None + convert_element_type_1556 = torch.ops.prims.convert_element_type.default(wait_tensor_208, torch.float32); wait_tensor_208 = None + mul_438 = torch.ops.aten.mul.Tensor(convert_element_type_1554, convert_element_type_1556); convert_element_type_1556 = None + mul_440 = torch.ops.aten.mul.Tensor(mul_184, mul_438) + sum_55 = torch.ops.aten.sum.dim_IntList(mul_440, [2], True); mul_440 = None + div_18 = torch.ops.aten.div.Tensor(mul_184, 4096) + mul_441 = torch.ops.aten.mul.Tensor(div_18, sum_55); div_18 = sum_55 = None + sub_27 = torch.ops.aten.sub.Tensor(mul_438, mul_441); mul_438 = mul_441 = None + mul_442 = torch.ops.aten.mul.Tensor(sub_27, rsqrt_46); sub_27 = rsqrt_46 = None + mul_443 = torch.ops.aten.mul.Tensor(convert_element_type_1554, mul_184); convert_element_type_1554 = mul_184 = None + sum_56 = torch.ops.aten.sum.dim_IntList(mul_443, [0, 1]); mul_443 = None + convert_element_type_1557 = torch.ops.prims.convert_element_type.default(mul_442, torch.bfloat16); mul_442 = None + add_191 = torch.ops.aten.add.Tensor(add_188, convert_element_type_1557); add_188 = convert_element_type_1557 = None + convert_element_type_default_47 = torch.ops.prims.convert_element_type.default(sum_56, torch.float32); sum_56 = None + reduce_scatter_tensor_82 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_47, 'avg', 256, '0'); convert_element_type_default_47 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_82); reduce_scatter_tensor_82 = None + view_1311 = torch.ops.aten.view.default(add_191, [16384, 4096]) + permute_645 = torch.ops.aten.permute.default(view_1311, [1, 0]) + permute_248 = torch.ops.aten.permute.default(getitem_198, [0, 2, 1, 3]) + view_769 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16); primals_206 = None + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 256, '0'); convert_element_type_743 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_203, [1, 0]); wait_tensor_203 = None + view_771 = torch.ops.aten.view.default(view_769, [16384, 4096]); view_769 = None + mm_157 = torch.ops.aten.mm.default(view_771, permute_249) + view_772 = torch.ops.aten.view.default(mm_157, [2, 8192, 4096]); mm_157 = None + add_89 = torch.ops.aten.add.Tensor(add_87, view_772); view_772 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16); primals_207 = None + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 256, '0'); convert_element_type_746 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32); add_89 = None + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_204) + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + view_775 = torch.ops.aten.view.default(convert_element_type_748, [16384, 4096]); convert_element_type_748 = None + view_776 = torch.ops.aten.view.default(mm_158, [2, 8192, 14336]); mm_158 = None + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_776, torch.float32); view_776 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16); primals_209 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 256, '0'); convert_element_type_754 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + mm_159 = torch.ops.aten.mm.default(view_775, permute_251) + view_779 = torch.ops.aten.view.default(mm_159, [2, 8192, 14336]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_779) + view_781 = torch.ops.aten.view.default(mul_183, [16384, 14336]); mul_183 = None + mm_353 = torch.ops.aten.mm.default(permute_645, view_781); permute_645 = view_781 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16); primals_210 = None + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 256, '0'); convert_element_type_757 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + permute_647 = torch.ops.aten.permute.default(permute_252, [1, 0]); permute_252 = None + mm_354 = torch.ops.aten.mm.default(view_1311, permute_647); view_1311 = permute_647 = None + view_1312 = torch.ops.aten.view.default(mm_354, [2, 8192, 14336]); mm_354 = None + convert_element_type_1564 = torch.ops.prims.convert_element_type.default(mm_353, torch.float32); mm_353 = None + reduce_scatter_tensor_83 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1564, 'avg', 256, '0'); convert_element_type_1564 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_83); reduce_scatter_tensor_83 = None + mul_444 = torch.ops.aten.mul.Tensor(view_1312, convert_element_type_753); convert_element_type_753 = None + mul_445 = torch.ops.aten.mul.Tensor(view_1312, view_779); view_1312 = view_779 = None + view_1313 = torch.ops.aten.view.default(mul_444, [16384, 14336]); mul_444 = None + permute_649 = torch.ops.aten.permute.default(view_1313, [1, 0]) + mm_355 = torch.ops.aten.mm.default(permute_649, view_775); permute_649 = None + permute_651 = torch.ops.aten.permute.default(permute_251, [1, 0]); permute_251 = None + mm_356 = torch.ops.aten.mm.default(view_1313, permute_651); view_1313 = permute_651 = None + view_1314 = torch.ops.aten.view.default(mm_356, [2, 8192, 4096]); mm_356 = None + convert_element_type_1569 = torch.ops.prims.convert_element_type.default(mm_355, torch.float32); mm_355 = None + reduce_scatter_tensor_84 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1569, 'avg', 256, '0'); convert_element_type_1569 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_84); reduce_scatter_tensor_84 = None + convert_element_type_1570 = torch.ops.prims.convert_element_type.default(mul_445, torch.float32); mul_445 = None + neg_9 = torch.ops.aten.neg.default(convert_element_type_752) + exp_9 = torch.ops.aten.exp.default(neg_9); neg_9 = None + add_192 = torch.ops.aten.add.Tensor(exp_9, 1); exp_9 = None + reciprocal_9 = torch.ops.aten.reciprocal.default(add_192); add_192 = None + mul_446 = torch.ops.aten.mul.Tensor(reciprocal_9, 1); reciprocal_9 = None + mul_447 = torch.ops.aten.mul.Tensor(convert_element_type_1570, mul_446); convert_element_type_1570 = None + sub_28 = torch.ops.aten.sub.Tensor(1, mul_446); mul_446 = None + mul_448 = torch.ops.aten.mul.Tensor(convert_element_type_752, sub_28); convert_element_type_752 = sub_28 = None + add_193 = torch.ops.aten.add.Tensor(mul_448, 1); mul_448 = None + mul_449 = torch.ops.aten.mul.Tensor(mul_447, add_193); mul_447 = add_193 = None + convert_element_type_1572 = torch.ops.prims.convert_element_type.default(mul_449, torch.bfloat16); mul_449 = None + view_1315 = torch.ops.aten.view.default(convert_element_type_1572, [16384, 14336]); convert_element_type_1572 = None + permute_653 = torch.ops.aten.permute.default(view_1315, [1, 0]) + mm_357 = torch.ops.aten.mm.default(permute_653, view_775); permute_653 = view_775 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16); primals_208 = None + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 256, '0'); convert_element_type_749 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_205, [1, 0]); wait_tensor_205 = None + permute_655 = torch.ops.aten.permute.default(permute_250, [1, 0]); permute_250 = None + mm_358 = torch.ops.aten.mm.default(view_1315, permute_655); view_1315 = permute_655 = None + view_1316 = torch.ops.aten.view.default(mm_358, [2, 8192, 4096]); mm_358 = None + add_194 = torch.ops.aten.add.Tensor(view_1314, view_1316); view_1314 = view_1316 = None + convert_element_type_1577 = torch.ops.prims.convert_element_type.default(mm_357, torch.float32); mm_357 = None + reduce_scatter_tensor_85 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1577, 'avg', 256, '0'); convert_element_type_1577 = None + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_85); reduce_scatter_tensor_85 = None + convert_element_type_1578 = torch.ops.prims.convert_element_type.default(add_194, torch.float32); add_194 = None + convert_element_type_1580 = torch.ops.prims.convert_element_type.default(wait_tensor_204, torch.float32); wait_tensor_204 = None + mul_450 = torch.ops.aten.mul.Tensor(convert_element_type_1578, convert_element_type_1580); convert_element_type_1580 = None + mul_452 = torch.ops.aten.mul.Tensor(mul_180, mul_450) + sum_57 = torch.ops.aten.sum.dim_IntList(mul_452, [2], True); mul_452 = None + div_19 = torch.ops.aten.div.Tensor(mul_180, 4096) + mul_453 = torch.ops.aten.mul.Tensor(div_19, sum_57); div_19 = sum_57 = None + sub_29 = torch.ops.aten.sub.Tensor(mul_450, mul_453); mul_450 = mul_453 = None + mul_454 = torch.ops.aten.mul.Tensor(sub_29, rsqrt_45); sub_29 = rsqrt_45 = None + mul_455 = torch.ops.aten.mul.Tensor(convert_element_type_1578, mul_180); convert_element_type_1578 = mul_180 = None + sum_58 = torch.ops.aten.sum.dim_IntList(mul_455, [0, 1]); mul_455 = None + convert_element_type_1581 = torch.ops.prims.convert_element_type.default(mul_454, torch.bfloat16); mul_454 = None + add_195 = torch.ops.aten.add.Tensor(add_191, convert_element_type_1581); add_191 = convert_element_type_1581 = None + convert_element_type_default_46 = torch.ops.prims.convert_element_type.default(sum_58, torch.float32); sum_58 = None + reduce_scatter_tensor_86 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_46, 'avg', 256, '0'); convert_element_type_default_46 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_86); reduce_scatter_tensor_86 = None + view_1317 = torch.ops.aten.view.default(add_195, [16384, 4096]) + permute_657 = torch.ops.aten.permute.default(view_1317, [1, 0]) + mm_359 = torch.ops.aten.mm.default(permute_657, view_771); permute_657 = view_771 = None + permute_659 = torch.ops.aten.permute.default(permute_249, [1, 0]); permute_249 = None + mm_360 = torch.ops.aten.mm.default(view_1317, permute_659); view_1317 = permute_659 = None + view_1318 = torch.ops.aten.view.default(mm_360, [2, 8192, 4096]); mm_360 = None + convert_element_type_1588 = torch.ops.prims.convert_element_type.default(mm_359, torch.float32); mm_359 = None + reduce_scatter_tensor_87 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1588, 'avg', 256, '0'); convert_element_type_1588 = None + wait_tensor_378 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_87); reduce_scatter_tensor_87 = None + view_1319 = torch.ops.aten.view.default(view_1318, [2, 8192, 32, 128]); view_1318 = None + permute_661 = torch.ops.aten.permute.default(view_1319, [0, 2, 1, 3]); view_1319 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16); primals_202 = None + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 256, '0'); convert_element_type_727 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32); add_87 = None + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_199) + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + view_751 = torch.ops.aten.view.default(convert_element_type_729, [16384, 4096]); convert_element_type_729 = None + view_752 = torch.ops.aten.view.default(mm_154, [2, 8192, 4096]); mm_154 = None + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16); primals_204 = None + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 256, '0'); convert_element_type_733 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + mm_155 = torch.ops.aten.mm.default(view_751, permute_243) + view_755 = torch.ops.aten.view.default(mm_155, [2, 8192, 1024]); mm_155 = None + view_758 = torch.ops.aten.view.default(mm_156, [2, 8192, 1024]); mm_156 = None + view_759 = torch.ops.aten.view.default(view_752, [2, 8192, -1, 128]); view_752 = None + view_760 = torch.ops.aten.view.default(view_755, [2, 8192, -1, 128]); view_755 = None + view_761 = torch.ops.aten.view.default(view_758, [2, 8192, -1, 128]); view_758 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_759, torch.float32); view_759 = None + view_762 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 32, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_762); view_762 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_760, torch.float32); view_760 = None + view_763 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 8, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_763); view_763 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_16); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_765 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 32, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_16); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_766 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 8, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_765, torch.bfloat16); view_765 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_766, torch.bfloat16); view_766 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 8, 4, 128]); unsqueeze_44 = None + clone_44 = torch.ops.aten.clone.default(expand_44, memory_format = torch.contiguous_format); expand_44 = None + view_767 = torch.ops.aten.view.default(clone_44, [2, 8192, 32, 128]); clone_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_761, 3); view_761 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 8, 4, 128]); unsqueeze_45 = None + clone_45 = torch.ops.aten.clone.default(expand_45, memory_format = torch.contiguous_format); expand_45 = None + view_768 = torch.ops.aten.view.default(clone_45, [2, 8192, 32, 128]); clone_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_767, [0, 2, 1, 3]); view_767 = None + permute_247 = torch.ops.aten.permute.default(view_768, [0, 2, 1, 3]); view_768 = None + _scaled_dot_product_cudnn_attention_backward_9 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_661, permute_245, permute_246, permute_247, getitem_198, getitem_199, getitem_204, getitem_205, None, None, None, 8192, 8192, 0.0, True); permute_661 = permute_245 = permute_246 = permute_247 = getitem_198 = getitem_199 = getitem_204 = getitem_205 = None + getitem_315 = _scaled_dot_product_cudnn_attention_backward_9[0] + getitem_316 = _scaled_dot_product_cudnn_attention_backward_9[1] + getitem_317 = _scaled_dot_product_cudnn_attention_backward_9[2]; _scaled_dot_product_cudnn_attention_backward_9 = None + permute_662 = torch.ops.aten.permute.default(getitem_317, [0, 2, 1, 3]); getitem_317 = None + permute_663 = torch.ops.aten.permute.default(getitem_316, [0, 2, 1, 3]); getitem_316 = None + permute_664 = torch.ops.aten.permute.default(getitem_315, [0, 2, 1, 3]); getitem_315 = None + view_1320 = torch.ops.aten.view.default(permute_662, [2, 8192, 8, 4, 128]); permute_662 = None + sum_59 = torch.ops.aten.sum.dim_IntList(view_1320, [3], True); view_1320 = None + squeeze_18 = torch.ops.aten.squeeze.dim(sum_59, 3); sum_59 = None + view_1321 = torch.ops.aten.view.default(permute_663, [2, 8192, 8, 4, 128]); permute_663 = None + sum_60 = torch.ops.aten.sum.dim_IntList(view_1321, [3], True); view_1321 = None + squeeze_19 = torch.ops.aten.squeeze.dim(sum_60, 3); sum_60 = None + convert_element_type_1589 = torch.ops.prims.convert_element_type.default(squeeze_19, torch.float32); squeeze_19 = None + convert_element_type_1590 = torch.ops.prims.convert_element_type.default(permute_664, torch.float32); permute_664 = None + view_1322 = torch.ops.aten.view.default(convert_element_type_1589, [2, 8192, 8, 64, 2]); convert_element_type_1589 = None + view_as_complex_82 = torch.ops.aten.view_as_complex.default(view_1322); view_1322 = None + mul_456 = torch.ops.aten.mul.Tensor(view_as_complex_82, _conj); view_as_complex_82 = None + view_1323 = torch.ops.aten.view.default(convert_element_type_1590, [2, 8192, 32, 64, 2]); convert_element_type_1590 = None + view_as_complex_83 = torch.ops.aten.view_as_complex.default(view_1323); view_1323 = None + mul_457 = torch.ops.aten.mul.Tensor(view_as_complex_83, _conj); view_as_complex_83 = None + view_as_real_82 = torch.ops.aten.view_as_real.default(mul_456); mul_456 = None + view_1324 = torch.ops.aten.view.default(view_as_real_82, [2, 8192, 8, 128]); view_as_real_82 = None + convert_element_type_1591 = torch.ops.prims.convert_element_type.default(view_1324, torch.bfloat16); view_1324 = None + view_as_real_83 = torch.ops.aten.view_as_real.default(mul_457); mul_457 = None + view_1325 = torch.ops.aten.view.default(view_as_real_83, [2, 8192, 32, 128]); view_as_real_83 = None + convert_element_type_1592 = torch.ops.prims.convert_element_type.default(view_1325, torch.bfloat16); view_1325 = None + view_1326 = torch.ops.aten.view.default(squeeze_18, [2, 8192, 1024]); squeeze_18 = None + view_1327 = torch.ops.aten.view.default(convert_element_type_1591, [2, 8192, 1024]); convert_element_type_1591 = None + view_1328 = torch.ops.aten.view.default(convert_element_type_1592, [2, 8192, 4096]); convert_element_type_1592 = None + view_1329 = torch.ops.aten.view.default(view_1326, [16384, 1024]); view_1326 = None + permute_665 = torch.ops.aten.permute.default(view_1329, [1, 0]) + mm_361 = torch.ops.aten.mm.default(permute_665, view_751); permute_665 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16); primals_205 = None + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 256, '0'); convert_element_type_736 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + permute_667 = torch.ops.aten.permute.default(permute_244, [1, 0]); permute_244 = None + mm_362 = torch.ops.aten.mm.default(view_1329, permute_667); view_1329 = permute_667 = None + view_1330 = torch.ops.aten.view.default(mm_362, [2, 8192, 4096]); mm_362 = None + convert_element_type_1597 = torch.ops.prims.convert_element_type.default(mm_361, torch.float32); mm_361 = None + reduce_scatter_tensor_88 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1597, 'avg', 256, '0'); convert_element_type_1597 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_88); reduce_scatter_tensor_88 = None + view_1331 = torch.ops.aten.view.default(view_1327, [16384, 1024]); view_1327 = None + permute_669 = torch.ops.aten.permute.default(view_1331, [1, 0]) + mm_363 = torch.ops.aten.mm.default(permute_669, view_751); permute_669 = None + permute_671 = torch.ops.aten.permute.default(permute_243, [1, 0]); permute_243 = None + mm_364 = torch.ops.aten.mm.default(view_1331, permute_671); view_1331 = permute_671 = None + view_1332 = torch.ops.aten.view.default(mm_364, [2, 8192, 4096]); mm_364 = None + add_196 = torch.ops.aten.add.Tensor(view_1330, view_1332); view_1330 = view_1332 = None + convert_element_type_1602 = torch.ops.prims.convert_element_type.default(mm_363, torch.float32); mm_363 = None + reduce_scatter_tensor_89 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1602, 'avg', 256, '0'); convert_element_type_1602 = None + wait_tensor_380 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_89); reduce_scatter_tensor_89 = None + view_1333 = torch.ops.aten.view.default(view_1328, [16384, 4096]); view_1328 = None + permute_673 = torch.ops.aten.permute.default(view_1333, [1, 0]) + mm_365 = torch.ops.aten.mm.default(permute_673, view_751); permute_673 = view_751 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16); primals_203 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 256, '0'); convert_element_type_730 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + permute_675 = torch.ops.aten.permute.default(permute_242, [1, 0]); permute_242 = None + mm_366 = torch.ops.aten.mm.default(view_1333, permute_675); view_1333 = permute_675 = None + view_1334 = torch.ops.aten.view.default(mm_366, [2, 8192, 4096]); mm_366 = None + add_197 = torch.ops.aten.add.Tensor(add_196, view_1334); add_196 = view_1334 = None + convert_element_type_1607 = torch.ops.prims.convert_element_type.default(mm_365, torch.float32); mm_365 = None + reduce_scatter_tensor_90 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1607, 'avg', 256, '0'); convert_element_type_1607 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_90); reduce_scatter_tensor_90 = None + convert_element_type_1608 = torch.ops.prims.convert_element_type.default(add_197, torch.float32); add_197 = None + convert_element_type_1610 = torch.ops.prims.convert_element_type.default(wait_tensor_199, torch.float32); wait_tensor_199 = None + mul_458 = torch.ops.aten.mul.Tensor(convert_element_type_1608, convert_element_type_1610); convert_element_type_1610 = None + mul_460 = torch.ops.aten.mul.Tensor(mul_176, mul_458) + sum_61 = torch.ops.aten.sum.dim_IntList(mul_460, [2], True); mul_460 = None + div_20 = torch.ops.aten.div.Tensor(mul_176, 4096) + mul_461 = torch.ops.aten.mul.Tensor(div_20, sum_61); div_20 = sum_61 = None + sub_30 = torch.ops.aten.sub.Tensor(mul_458, mul_461); mul_458 = mul_461 = None + mul_462 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_44); sub_30 = rsqrt_44 = None + mul_463 = torch.ops.aten.mul.Tensor(convert_element_type_1608, mul_176); convert_element_type_1608 = mul_176 = None + sum_62 = torch.ops.aten.sum.dim_IntList(mul_463, [0, 1]); mul_463 = None + convert_element_type_1611 = torch.ops.prims.convert_element_type.default(mul_462, torch.bfloat16); mul_462 = None + add_198 = torch.ops.aten.add.Tensor(add_195, convert_element_type_1611); add_195 = convert_element_type_1611 = None + convert_element_type_default_45 = torch.ops.prims.convert_element_type.default(sum_62, torch.float32); sum_62 = None + reduce_scatter_tensor_91 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_45, 'avg', 256, '0'); convert_element_type_default_45 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_91); reduce_scatter_tensor_91 = None + view_1335 = torch.ops.aten.view.default(add_198, [16384, 4096]) + permute_677 = torch.ops.aten.permute.default(view_1335, [1, 0]) + permute_237 = torch.ops.aten.permute.default(getitem_189, [0, 2, 1, 3]) + view_735 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16); primals_197 = None + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 256, '0'); convert_element_type_710 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + view_737 = torch.ops.aten.view.default(view_735, [16384, 4096]); view_735 = None + mm_150 = torch.ops.aten.mm.default(view_737, permute_238) + view_738 = torch.ops.aten.view.default(mm_150, [2, 8192, 4096]); mm_150 = None + add_85 = torch.ops.aten.add.Tensor(add_83, view_738); view_738 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16); primals_198 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 256, '0'); convert_element_type_713 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32); add_85 = None + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_195) + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + view_741 = torch.ops.aten.view.default(convert_element_type_715, [16384, 4096]); convert_element_type_715 = None + view_742 = torch.ops.aten.view.default(mm_151, [2, 8192, 14336]); mm_151 = None + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_742, torch.float32); view_742 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16); primals_200 = None + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 256, '0'); convert_element_type_721 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_197, [1, 0]); wait_tensor_197 = None + mm_152 = torch.ops.aten.mm.default(view_741, permute_240) + view_745 = torch.ops.aten.view.default(mm_152, [2, 8192, 14336]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_745) + view_747 = torch.ops.aten.view.default(mul_175, [16384, 14336]); mul_175 = None + mm_367 = torch.ops.aten.mm.default(permute_677, view_747); permute_677 = view_747 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16); primals_201 = None + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 256, '0'); convert_element_type_724 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_198, [1, 0]); wait_tensor_198 = None + permute_679 = torch.ops.aten.permute.default(permute_241, [1, 0]); permute_241 = None + mm_368 = torch.ops.aten.mm.default(view_1335, permute_679); view_1335 = permute_679 = None + view_1336 = torch.ops.aten.view.default(mm_368, [2, 8192, 14336]); mm_368 = None + convert_element_type_1618 = torch.ops.prims.convert_element_type.default(mm_367, torch.float32); mm_367 = None + reduce_scatter_tensor_92 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1618, 'avg', 256, '0'); convert_element_type_1618 = None + wait_tensor_383 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_92); reduce_scatter_tensor_92 = None + mul_464 = torch.ops.aten.mul.Tensor(view_1336, convert_element_type_720); convert_element_type_720 = None + mul_465 = torch.ops.aten.mul.Tensor(view_1336, view_745); view_1336 = view_745 = None + view_1337 = torch.ops.aten.view.default(mul_464, [16384, 14336]); mul_464 = None + permute_681 = torch.ops.aten.permute.default(view_1337, [1, 0]) + mm_369 = torch.ops.aten.mm.default(permute_681, view_741); permute_681 = None + permute_683 = torch.ops.aten.permute.default(permute_240, [1, 0]); permute_240 = None + mm_370 = torch.ops.aten.mm.default(view_1337, permute_683); view_1337 = permute_683 = None + view_1338 = torch.ops.aten.view.default(mm_370, [2, 8192, 4096]); mm_370 = None + convert_element_type_1623 = torch.ops.prims.convert_element_type.default(mm_369, torch.float32); mm_369 = None + reduce_scatter_tensor_93 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1623, 'avg', 256, '0'); convert_element_type_1623 = None + wait_tensor_384 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_93); reduce_scatter_tensor_93 = None + convert_element_type_1624 = torch.ops.prims.convert_element_type.default(mul_465, torch.float32); mul_465 = None + neg_10 = torch.ops.aten.neg.default(convert_element_type_719) + exp_10 = torch.ops.aten.exp.default(neg_10); neg_10 = None + add_199 = torch.ops.aten.add.Tensor(exp_10, 1); exp_10 = None + reciprocal_10 = torch.ops.aten.reciprocal.default(add_199); add_199 = None + mul_466 = torch.ops.aten.mul.Tensor(reciprocal_10, 1); reciprocal_10 = None + mul_467 = torch.ops.aten.mul.Tensor(convert_element_type_1624, mul_466); convert_element_type_1624 = None + sub_31 = torch.ops.aten.sub.Tensor(1, mul_466); mul_466 = None + mul_468 = torch.ops.aten.mul.Tensor(convert_element_type_719, sub_31); convert_element_type_719 = sub_31 = None + add_200 = torch.ops.aten.add.Tensor(mul_468, 1); mul_468 = None + mul_469 = torch.ops.aten.mul.Tensor(mul_467, add_200); mul_467 = add_200 = None + convert_element_type_1626 = torch.ops.prims.convert_element_type.default(mul_469, torch.bfloat16); mul_469 = None + view_1339 = torch.ops.aten.view.default(convert_element_type_1626, [16384, 14336]); convert_element_type_1626 = None + permute_685 = torch.ops.aten.permute.default(view_1339, [1, 0]) + mm_371 = torch.ops.aten.mm.default(permute_685, view_741); permute_685 = view_741 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16); primals_199 = None + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 256, '0'); convert_element_type_716 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_196, [1, 0]); wait_tensor_196 = None + permute_687 = torch.ops.aten.permute.default(permute_239, [1, 0]); permute_239 = None + mm_372 = torch.ops.aten.mm.default(view_1339, permute_687); view_1339 = permute_687 = None + view_1340 = torch.ops.aten.view.default(mm_372, [2, 8192, 4096]); mm_372 = None + add_201 = torch.ops.aten.add.Tensor(view_1338, view_1340); view_1338 = view_1340 = None + convert_element_type_1631 = torch.ops.prims.convert_element_type.default(mm_371, torch.float32); mm_371 = None + reduce_scatter_tensor_94 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1631, 'avg', 256, '0'); convert_element_type_1631 = None + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_94); reduce_scatter_tensor_94 = None + convert_element_type_1632 = torch.ops.prims.convert_element_type.default(add_201, torch.float32); add_201 = None + convert_element_type_1634 = torch.ops.prims.convert_element_type.default(wait_tensor_195, torch.float32); wait_tensor_195 = None + mul_470 = torch.ops.aten.mul.Tensor(convert_element_type_1632, convert_element_type_1634); convert_element_type_1634 = None + mul_472 = torch.ops.aten.mul.Tensor(mul_172, mul_470) + sum_63 = torch.ops.aten.sum.dim_IntList(mul_472, [2], True); mul_472 = None + div_21 = torch.ops.aten.div.Tensor(mul_172, 4096) + mul_473 = torch.ops.aten.mul.Tensor(div_21, sum_63); div_21 = sum_63 = None + sub_32 = torch.ops.aten.sub.Tensor(mul_470, mul_473); mul_470 = mul_473 = None + mul_474 = torch.ops.aten.mul.Tensor(sub_32, rsqrt_43); sub_32 = rsqrt_43 = None + mul_475 = torch.ops.aten.mul.Tensor(convert_element_type_1632, mul_172); convert_element_type_1632 = mul_172 = None + sum_64 = torch.ops.aten.sum.dim_IntList(mul_475, [0, 1]); mul_475 = None + convert_element_type_1635 = torch.ops.prims.convert_element_type.default(mul_474, torch.bfloat16); mul_474 = None + add_202 = torch.ops.aten.add.Tensor(add_198, convert_element_type_1635); add_198 = convert_element_type_1635 = None + convert_element_type_default_44 = torch.ops.prims.convert_element_type.default(sum_64, torch.float32); sum_64 = None + reduce_scatter_tensor_95 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_44, 'avg', 256, '0'); convert_element_type_default_44 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_95); reduce_scatter_tensor_95 = None + view_1341 = torch.ops.aten.view.default(add_202, [16384, 4096]) + permute_689 = torch.ops.aten.permute.default(view_1341, [1, 0]) + mm_373 = torch.ops.aten.mm.default(permute_689, view_737); permute_689 = view_737 = None + permute_691 = torch.ops.aten.permute.default(permute_238, [1, 0]); permute_238 = None + mm_374 = torch.ops.aten.mm.default(view_1341, permute_691); view_1341 = permute_691 = None + view_1342 = torch.ops.aten.view.default(mm_374, [2, 8192, 4096]); mm_374 = None + convert_element_type_1642 = torch.ops.prims.convert_element_type.default(mm_373, torch.float32); mm_373 = None + reduce_scatter_tensor_96 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1642, 'avg', 256, '0'); convert_element_type_1642 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_96); reduce_scatter_tensor_96 = None + view_1343 = torch.ops.aten.view.default(view_1342, [2, 8192, 32, 128]); view_1342 = None + permute_693 = torch.ops.aten.permute.default(view_1343, [0, 2, 1, 3]); view_1343 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16); primals_193 = None + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 256, '0'); convert_element_type_694 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32); add_83 = None + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_190) + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + view_717 = torch.ops.aten.view.default(convert_element_type_696, [16384, 4096]); convert_element_type_696 = None + view_718 = torch.ops.aten.view.default(mm_147, [2, 8192, 4096]); mm_147 = None + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16); primals_195 = None + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 256, '0'); convert_element_type_700 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_192, [1, 0]); wait_tensor_192 = None + mm_148 = torch.ops.aten.mm.default(view_717, permute_232) + view_721 = torch.ops.aten.view.default(mm_148, [2, 8192, 1024]); mm_148 = None + view_724 = torch.ops.aten.view.default(mm_149, [2, 8192, 1024]); mm_149 = None + view_725 = torch.ops.aten.view.default(view_718, [2, 8192, -1, 128]); view_718 = None + view_726 = torch.ops.aten.view.default(view_721, [2, 8192, -1, 128]); view_721 = None + view_727 = torch.ops.aten.view.default(view_724, [2, 8192, -1, 128]); view_724 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_725, torch.float32); view_725 = None + view_728 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 32, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_728); view_728 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_726, torch.float32); view_726 = None + view_729 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 8, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_729); view_729 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_16); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_731 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 32, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_16); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_732 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 8, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_731, torch.bfloat16); view_731 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_732, torch.bfloat16); view_732 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 8, 4, 128]); unsqueeze_42 = None + clone_42 = torch.ops.aten.clone.default(expand_42, memory_format = torch.contiguous_format); expand_42 = None + view_733 = torch.ops.aten.view.default(clone_42, [2, 8192, 32, 128]); clone_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_727, 3); view_727 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 8, 4, 128]); unsqueeze_43 = None + clone_43 = torch.ops.aten.clone.default(expand_43, memory_format = torch.contiguous_format); expand_43 = None + view_734 = torch.ops.aten.view.default(clone_43, [2, 8192, 32, 128]); clone_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_733, [0, 2, 1, 3]); view_733 = None + permute_236 = torch.ops.aten.permute.default(view_734, [0, 2, 1, 3]); view_734 = None + _scaled_dot_product_cudnn_attention_backward_10 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_693, permute_234, permute_235, permute_236, getitem_189, getitem_190, getitem_195, getitem_196, None, None, None, 8192, 8192, 0.0, True); permute_693 = permute_234 = permute_235 = permute_236 = getitem_189 = getitem_190 = getitem_195 = getitem_196 = None + getitem_318 = _scaled_dot_product_cudnn_attention_backward_10[0] + getitem_319 = _scaled_dot_product_cudnn_attention_backward_10[1] + getitem_320 = _scaled_dot_product_cudnn_attention_backward_10[2]; _scaled_dot_product_cudnn_attention_backward_10 = None + permute_694 = torch.ops.aten.permute.default(getitem_320, [0, 2, 1, 3]); getitem_320 = None + permute_695 = torch.ops.aten.permute.default(getitem_319, [0, 2, 1, 3]); getitem_319 = None + permute_696 = torch.ops.aten.permute.default(getitem_318, [0, 2, 1, 3]); getitem_318 = None + view_1344 = torch.ops.aten.view.default(permute_694, [2, 8192, 8, 4, 128]); permute_694 = None + sum_65 = torch.ops.aten.sum.dim_IntList(view_1344, [3], True); view_1344 = None + squeeze_20 = torch.ops.aten.squeeze.dim(sum_65, 3); sum_65 = None + view_1345 = torch.ops.aten.view.default(permute_695, [2, 8192, 8, 4, 128]); permute_695 = None + sum_66 = torch.ops.aten.sum.dim_IntList(view_1345, [3], True); view_1345 = None + squeeze_21 = torch.ops.aten.squeeze.dim(sum_66, 3); sum_66 = None + convert_element_type_1643 = torch.ops.prims.convert_element_type.default(squeeze_21, torch.float32); squeeze_21 = None + convert_element_type_1644 = torch.ops.prims.convert_element_type.default(permute_696, torch.float32); permute_696 = None + view_1346 = torch.ops.aten.view.default(convert_element_type_1643, [2, 8192, 8, 64, 2]); convert_element_type_1643 = None + view_as_complex_84 = torch.ops.aten.view_as_complex.default(view_1346); view_1346 = None + mul_476 = torch.ops.aten.mul.Tensor(view_as_complex_84, _conj); view_as_complex_84 = None + view_1347 = torch.ops.aten.view.default(convert_element_type_1644, [2, 8192, 32, 64, 2]); convert_element_type_1644 = None + view_as_complex_85 = torch.ops.aten.view_as_complex.default(view_1347); view_1347 = None + mul_477 = torch.ops.aten.mul.Tensor(view_as_complex_85, _conj); view_as_complex_85 = None + view_as_real_84 = torch.ops.aten.view_as_real.default(mul_476); mul_476 = None + view_1348 = torch.ops.aten.view.default(view_as_real_84, [2, 8192, 8, 128]); view_as_real_84 = None + convert_element_type_1645 = torch.ops.prims.convert_element_type.default(view_1348, torch.bfloat16); view_1348 = None + view_as_real_85 = torch.ops.aten.view_as_real.default(mul_477); mul_477 = None + view_1349 = torch.ops.aten.view.default(view_as_real_85, [2, 8192, 32, 128]); view_as_real_85 = None + convert_element_type_1646 = torch.ops.prims.convert_element_type.default(view_1349, torch.bfloat16); view_1349 = None + view_1350 = torch.ops.aten.view.default(squeeze_20, [2, 8192, 1024]); squeeze_20 = None + view_1351 = torch.ops.aten.view.default(convert_element_type_1645, [2, 8192, 1024]); convert_element_type_1645 = None + view_1352 = torch.ops.aten.view.default(convert_element_type_1646, [2, 8192, 4096]); convert_element_type_1646 = None + view_1353 = torch.ops.aten.view.default(view_1350, [16384, 1024]); view_1350 = None + permute_697 = torch.ops.aten.permute.default(view_1353, [1, 0]) + mm_375 = torch.ops.aten.mm.default(permute_697, view_717); permute_697 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16); primals_196 = None + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 256, '0'); convert_element_type_703 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + permute_699 = torch.ops.aten.permute.default(permute_233, [1, 0]); permute_233 = None + mm_376 = torch.ops.aten.mm.default(view_1353, permute_699); view_1353 = permute_699 = None + view_1354 = torch.ops.aten.view.default(mm_376, [2, 8192, 4096]); mm_376 = None + convert_element_type_1651 = torch.ops.prims.convert_element_type.default(mm_375, torch.float32); mm_375 = None + reduce_scatter_tensor_97 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1651, 'avg', 256, '0'); convert_element_type_1651 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_97); reduce_scatter_tensor_97 = None + view_1355 = torch.ops.aten.view.default(view_1351, [16384, 1024]); view_1351 = None + permute_701 = torch.ops.aten.permute.default(view_1355, [1, 0]) + mm_377 = torch.ops.aten.mm.default(permute_701, view_717); permute_701 = None + permute_703 = torch.ops.aten.permute.default(permute_232, [1, 0]); permute_232 = None + mm_378 = torch.ops.aten.mm.default(view_1355, permute_703); view_1355 = permute_703 = None + view_1356 = torch.ops.aten.view.default(mm_378, [2, 8192, 4096]); mm_378 = None + add_203 = torch.ops.aten.add.Tensor(view_1354, view_1356); view_1354 = view_1356 = None + convert_element_type_1656 = torch.ops.prims.convert_element_type.default(mm_377, torch.float32); mm_377 = None + reduce_scatter_tensor_98 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1656, 'avg', 256, '0'); convert_element_type_1656 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_98); reduce_scatter_tensor_98 = None + view_1357 = torch.ops.aten.view.default(view_1352, [16384, 4096]); view_1352 = None + permute_705 = torch.ops.aten.permute.default(view_1357, [1, 0]) + mm_379 = torch.ops.aten.mm.default(permute_705, view_717); permute_705 = view_717 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16); primals_194 = None + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 256, '0'); convert_element_type_697 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_191, [1, 0]); wait_tensor_191 = None + permute_707 = torch.ops.aten.permute.default(permute_231, [1, 0]); permute_231 = None + mm_380 = torch.ops.aten.mm.default(view_1357, permute_707); view_1357 = permute_707 = None + view_1358 = torch.ops.aten.view.default(mm_380, [2, 8192, 4096]); mm_380 = None + add_204 = torch.ops.aten.add.Tensor(add_203, view_1358); add_203 = view_1358 = None + convert_element_type_1661 = torch.ops.prims.convert_element_type.default(mm_379, torch.float32); mm_379 = None + reduce_scatter_tensor_99 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1661, 'avg', 256, '0'); convert_element_type_1661 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_99); reduce_scatter_tensor_99 = None + convert_element_type_1662 = torch.ops.prims.convert_element_type.default(add_204, torch.float32); add_204 = None + convert_element_type_1664 = torch.ops.prims.convert_element_type.default(wait_tensor_190, torch.float32); wait_tensor_190 = None + mul_478 = torch.ops.aten.mul.Tensor(convert_element_type_1662, convert_element_type_1664); convert_element_type_1664 = None + mul_480 = torch.ops.aten.mul.Tensor(mul_168, mul_478) + sum_67 = torch.ops.aten.sum.dim_IntList(mul_480, [2], True); mul_480 = None + div_22 = torch.ops.aten.div.Tensor(mul_168, 4096) + mul_481 = torch.ops.aten.mul.Tensor(div_22, sum_67); div_22 = sum_67 = None + sub_33 = torch.ops.aten.sub.Tensor(mul_478, mul_481); mul_478 = mul_481 = None + mul_482 = torch.ops.aten.mul.Tensor(sub_33, rsqrt_42); sub_33 = rsqrt_42 = None + mul_483 = torch.ops.aten.mul.Tensor(convert_element_type_1662, mul_168); convert_element_type_1662 = mul_168 = None + sum_68 = torch.ops.aten.sum.dim_IntList(mul_483, [0, 1]); mul_483 = None + convert_element_type_1665 = torch.ops.prims.convert_element_type.default(mul_482, torch.bfloat16); mul_482 = None + add_205 = torch.ops.aten.add.Tensor(add_202, convert_element_type_1665); add_202 = convert_element_type_1665 = None + convert_element_type_default_43 = torch.ops.prims.convert_element_type.default(sum_68, torch.float32); sum_68 = None + reduce_scatter_tensor_100 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_43, 'avg', 256, '0'); convert_element_type_default_43 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_100); reduce_scatter_tensor_100 = None + view_1359 = torch.ops.aten.view.default(add_205, [16384, 4096]) + permute_709 = torch.ops.aten.permute.default(view_1359, [1, 0]) + permute_226 = torch.ops.aten.permute.default(getitem_180, [0, 2, 1, 3]) + view_701 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16); primals_188 = None + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 256, '0'); convert_element_type_677 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_185, [1, 0]); wait_tensor_185 = None + view_703 = torch.ops.aten.view.default(view_701, [16384, 4096]); view_701 = None + mm_143 = torch.ops.aten.mm.default(view_703, permute_227) + view_704 = torch.ops.aten.view.default(mm_143, [2, 8192, 4096]); mm_143 = None + add_81 = torch.ops.aten.add.Tensor(add_79, view_704); view_704 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16); primals_189 = None + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 256, '0'); convert_element_type_680 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32); add_81 = None + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_186) + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + view_707 = torch.ops.aten.view.default(convert_element_type_682, [16384, 4096]); convert_element_type_682 = None + view_708 = torch.ops.aten.view.default(mm_144, [2, 8192, 14336]); mm_144 = None + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_708, torch.float32); view_708 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16); primals_191 = None + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 256, '0'); convert_element_type_688 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + mm_145 = torch.ops.aten.mm.default(view_707, permute_229) + view_711 = torch.ops.aten.view.default(mm_145, [2, 8192, 14336]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_711) + view_713 = torch.ops.aten.view.default(mul_167, [16384, 14336]); mul_167 = None + mm_381 = torch.ops.aten.mm.default(permute_709, view_713); permute_709 = view_713 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16); primals_192 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 256, '0'); convert_element_type_691 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + permute_711 = torch.ops.aten.permute.default(permute_230, [1, 0]); permute_230 = None + mm_382 = torch.ops.aten.mm.default(view_1359, permute_711); view_1359 = permute_711 = None + view_1360 = torch.ops.aten.view.default(mm_382, [2, 8192, 14336]); mm_382 = None + convert_element_type_1672 = torch.ops.prims.convert_element_type.default(mm_381, torch.float32); mm_381 = None + reduce_scatter_tensor_101 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1672, 'avg', 256, '0'); convert_element_type_1672 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_101); reduce_scatter_tensor_101 = None + mul_484 = torch.ops.aten.mul.Tensor(view_1360, convert_element_type_687); convert_element_type_687 = None + mul_485 = torch.ops.aten.mul.Tensor(view_1360, view_711); view_1360 = view_711 = None + view_1361 = torch.ops.aten.view.default(mul_484, [16384, 14336]); mul_484 = None + permute_713 = torch.ops.aten.permute.default(view_1361, [1, 0]) + mm_383 = torch.ops.aten.mm.default(permute_713, view_707); permute_713 = None + permute_715 = torch.ops.aten.permute.default(permute_229, [1, 0]); permute_229 = None + mm_384 = torch.ops.aten.mm.default(view_1361, permute_715); view_1361 = permute_715 = None + view_1362 = torch.ops.aten.view.default(mm_384, [2, 8192, 4096]); mm_384 = None + convert_element_type_1677 = torch.ops.prims.convert_element_type.default(mm_383, torch.float32); mm_383 = None + reduce_scatter_tensor_102 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1677, 'avg', 256, '0'); convert_element_type_1677 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_102); reduce_scatter_tensor_102 = None + convert_element_type_1678 = torch.ops.prims.convert_element_type.default(mul_485, torch.float32); mul_485 = None + neg_11 = torch.ops.aten.neg.default(convert_element_type_686) + exp_11 = torch.ops.aten.exp.default(neg_11); neg_11 = None + add_206 = torch.ops.aten.add.Tensor(exp_11, 1); exp_11 = None + reciprocal_11 = torch.ops.aten.reciprocal.default(add_206); add_206 = None + mul_486 = torch.ops.aten.mul.Tensor(reciprocal_11, 1); reciprocal_11 = None + mul_487 = torch.ops.aten.mul.Tensor(convert_element_type_1678, mul_486); convert_element_type_1678 = None + sub_34 = torch.ops.aten.sub.Tensor(1, mul_486); mul_486 = None + mul_488 = torch.ops.aten.mul.Tensor(convert_element_type_686, sub_34); convert_element_type_686 = sub_34 = None + add_207 = torch.ops.aten.add.Tensor(mul_488, 1); mul_488 = None + mul_489 = torch.ops.aten.mul.Tensor(mul_487, add_207); mul_487 = add_207 = None + convert_element_type_1680 = torch.ops.prims.convert_element_type.default(mul_489, torch.bfloat16); mul_489 = None + view_1363 = torch.ops.aten.view.default(convert_element_type_1680, [16384, 14336]); convert_element_type_1680 = None + permute_717 = torch.ops.aten.permute.default(view_1363, [1, 0]) + mm_385 = torch.ops.aten.mm.default(permute_717, view_707); permute_717 = view_707 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16); primals_190 = None + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 256, '0'); convert_element_type_683 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + permute_719 = torch.ops.aten.permute.default(permute_228, [1, 0]); permute_228 = None + mm_386 = torch.ops.aten.mm.default(view_1363, permute_719); view_1363 = permute_719 = None + view_1364 = torch.ops.aten.view.default(mm_386, [2, 8192, 4096]); mm_386 = None + add_208 = torch.ops.aten.add.Tensor(view_1362, view_1364); view_1362 = view_1364 = None + convert_element_type_1685 = torch.ops.prims.convert_element_type.default(mm_385, torch.float32); mm_385 = None + reduce_scatter_tensor_103 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1685, 'avg', 256, '0'); convert_element_type_1685 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_103); reduce_scatter_tensor_103 = None + convert_element_type_1686 = torch.ops.prims.convert_element_type.default(add_208, torch.float32); add_208 = None + convert_element_type_1688 = torch.ops.prims.convert_element_type.default(wait_tensor_186, torch.float32); wait_tensor_186 = None + mul_490 = torch.ops.aten.mul.Tensor(convert_element_type_1686, convert_element_type_1688); convert_element_type_1688 = None + mul_492 = torch.ops.aten.mul.Tensor(mul_164, mul_490) + sum_69 = torch.ops.aten.sum.dim_IntList(mul_492, [2], True); mul_492 = None + div_23 = torch.ops.aten.div.Tensor(mul_164, 4096) + mul_493 = torch.ops.aten.mul.Tensor(div_23, sum_69); div_23 = sum_69 = None + sub_35 = torch.ops.aten.sub.Tensor(mul_490, mul_493); mul_490 = mul_493 = None + mul_494 = torch.ops.aten.mul.Tensor(sub_35, rsqrt_41); sub_35 = rsqrt_41 = None + mul_495 = torch.ops.aten.mul.Tensor(convert_element_type_1686, mul_164); convert_element_type_1686 = mul_164 = None + sum_70 = torch.ops.aten.sum.dim_IntList(mul_495, [0, 1]); mul_495 = None + convert_element_type_1689 = torch.ops.prims.convert_element_type.default(mul_494, torch.bfloat16); mul_494 = None + add_209 = torch.ops.aten.add.Tensor(add_205, convert_element_type_1689); add_205 = convert_element_type_1689 = None + convert_element_type_default_42 = torch.ops.prims.convert_element_type.default(sum_70, torch.float32); sum_70 = None + reduce_scatter_tensor_104 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_42, 'avg', 256, '0'); convert_element_type_default_42 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_104); reduce_scatter_tensor_104 = None + view_1365 = torch.ops.aten.view.default(add_209, [16384, 4096]) + permute_721 = torch.ops.aten.permute.default(view_1365, [1, 0]) + mm_387 = torch.ops.aten.mm.default(permute_721, view_703); permute_721 = view_703 = None + permute_723 = torch.ops.aten.permute.default(permute_227, [1, 0]); permute_227 = None + mm_388 = torch.ops.aten.mm.default(view_1365, permute_723); view_1365 = permute_723 = None + view_1366 = torch.ops.aten.view.default(mm_388, [2, 8192, 4096]); mm_388 = None + convert_element_type_1696 = torch.ops.prims.convert_element_type.default(mm_387, torch.float32); mm_387 = None + reduce_scatter_tensor_105 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1696, 'avg', 256, '0'); convert_element_type_1696 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_105); reduce_scatter_tensor_105 = None + view_1367 = torch.ops.aten.view.default(view_1366, [2, 8192, 32, 128]); view_1366 = None + permute_725 = torch.ops.aten.permute.default(view_1367, [0, 2, 1, 3]); view_1367 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16); primals_184 = None + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 256, '0'); convert_element_type_661 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32); add_79 = None + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_181) + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + view_683 = torch.ops.aten.view.default(convert_element_type_663, [16384, 4096]); convert_element_type_663 = None + view_684 = torch.ops.aten.view.default(mm_140, [2, 8192, 4096]); mm_140 = None + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16); primals_186 = None + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 256, '0'); convert_element_type_667 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_183, [1, 0]); wait_tensor_183 = None + mm_141 = torch.ops.aten.mm.default(view_683, permute_221) + view_687 = torch.ops.aten.view.default(mm_141, [2, 8192, 1024]); mm_141 = None + view_690 = torch.ops.aten.view.default(mm_142, [2, 8192, 1024]); mm_142 = None + view_691 = torch.ops.aten.view.default(view_684, [2, 8192, -1, 128]); view_684 = None + view_692 = torch.ops.aten.view.default(view_687, [2, 8192, -1, 128]); view_687 = None + view_693 = torch.ops.aten.view.default(view_690, [2, 8192, -1, 128]); view_690 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_691, torch.float32); view_691 = None + view_694 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 32, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_694); view_694 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_692, torch.float32); view_692 = None + view_695 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 8, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_695); view_695 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_16); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_697 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 32, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_16); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_698 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 8, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_697, torch.bfloat16); view_697 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_698, torch.bfloat16); view_698 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 8, 4, 128]); unsqueeze_40 = None + clone_40 = torch.ops.aten.clone.default(expand_40, memory_format = torch.contiguous_format); expand_40 = None + view_699 = torch.ops.aten.view.default(clone_40, [2, 8192, 32, 128]); clone_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_693, 3); view_693 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 8, 4, 128]); unsqueeze_41 = None + clone_41 = torch.ops.aten.clone.default(expand_41, memory_format = torch.contiguous_format); expand_41 = None + view_700 = torch.ops.aten.view.default(clone_41, [2, 8192, 32, 128]); clone_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_699, [0, 2, 1, 3]); view_699 = None + permute_225 = torch.ops.aten.permute.default(view_700, [0, 2, 1, 3]); view_700 = None + _scaled_dot_product_cudnn_attention_backward_11 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_725, permute_223, permute_224, permute_225, getitem_180, getitem_181, getitem_186, getitem_187, None, None, None, 8192, 8192, 0.0, True); permute_725 = permute_223 = permute_224 = permute_225 = getitem_180 = getitem_181 = getitem_186 = getitem_187 = None + getitem_321 = _scaled_dot_product_cudnn_attention_backward_11[0] + getitem_322 = _scaled_dot_product_cudnn_attention_backward_11[1] + getitem_323 = _scaled_dot_product_cudnn_attention_backward_11[2]; _scaled_dot_product_cudnn_attention_backward_11 = None + permute_726 = torch.ops.aten.permute.default(getitem_323, [0, 2, 1, 3]); getitem_323 = None + permute_727 = torch.ops.aten.permute.default(getitem_322, [0, 2, 1, 3]); getitem_322 = None + permute_728 = torch.ops.aten.permute.default(getitem_321, [0, 2, 1, 3]); getitem_321 = None + view_1368 = torch.ops.aten.view.default(permute_726, [2, 8192, 8, 4, 128]); permute_726 = None + sum_71 = torch.ops.aten.sum.dim_IntList(view_1368, [3], True); view_1368 = None + squeeze_22 = torch.ops.aten.squeeze.dim(sum_71, 3); sum_71 = None + view_1369 = torch.ops.aten.view.default(permute_727, [2, 8192, 8, 4, 128]); permute_727 = None + sum_72 = torch.ops.aten.sum.dim_IntList(view_1369, [3], True); view_1369 = None + squeeze_23 = torch.ops.aten.squeeze.dim(sum_72, 3); sum_72 = None + convert_element_type_1697 = torch.ops.prims.convert_element_type.default(squeeze_23, torch.float32); squeeze_23 = None + convert_element_type_1698 = torch.ops.prims.convert_element_type.default(permute_728, torch.float32); permute_728 = None + view_1370 = torch.ops.aten.view.default(convert_element_type_1697, [2, 8192, 8, 64, 2]); convert_element_type_1697 = None + view_as_complex_86 = torch.ops.aten.view_as_complex.default(view_1370); view_1370 = None + mul_496 = torch.ops.aten.mul.Tensor(view_as_complex_86, _conj); view_as_complex_86 = None + view_1371 = torch.ops.aten.view.default(convert_element_type_1698, [2, 8192, 32, 64, 2]); convert_element_type_1698 = None + view_as_complex_87 = torch.ops.aten.view_as_complex.default(view_1371); view_1371 = None + mul_497 = torch.ops.aten.mul.Tensor(view_as_complex_87, _conj); view_as_complex_87 = None + view_as_real_86 = torch.ops.aten.view_as_real.default(mul_496); mul_496 = None + view_1372 = torch.ops.aten.view.default(view_as_real_86, [2, 8192, 8, 128]); view_as_real_86 = None + convert_element_type_1699 = torch.ops.prims.convert_element_type.default(view_1372, torch.bfloat16); view_1372 = None + view_as_real_87 = torch.ops.aten.view_as_real.default(mul_497); mul_497 = None + view_1373 = torch.ops.aten.view.default(view_as_real_87, [2, 8192, 32, 128]); view_as_real_87 = None + convert_element_type_1700 = torch.ops.prims.convert_element_type.default(view_1373, torch.bfloat16); view_1373 = None + view_1374 = torch.ops.aten.view.default(squeeze_22, [2, 8192, 1024]); squeeze_22 = None + view_1375 = torch.ops.aten.view.default(convert_element_type_1699, [2, 8192, 1024]); convert_element_type_1699 = None + view_1376 = torch.ops.aten.view.default(convert_element_type_1700, [2, 8192, 4096]); convert_element_type_1700 = None + view_1377 = torch.ops.aten.view.default(view_1374, [16384, 1024]); view_1374 = None + permute_729 = torch.ops.aten.permute.default(view_1377, [1, 0]) + mm_389 = torch.ops.aten.mm.default(permute_729, view_683); permute_729 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16); primals_187 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 256, '0'); convert_element_type_670 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_184, [1, 0]); wait_tensor_184 = None + permute_731 = torch.ops.aten.permute.default(permute_222, [1, 0]); permute_222 = None + mm_390 = torch.ops.aten.mm.default(view_1377, permute_731); view_1377 = permute_731 = None + view_1378 = torch.ops.aten.view.default(mm_390, [2, 8192, 4096]); mm_390 = None + convert_element_type_1705 = torch.ops.prims.convert_element_type.default(mm_389, torch.float32); mm_389 = None + reduce_scatter_tensor_106 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1705, 'avg', 256, '0'); convert_element_type_1705 = None + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_106); reduce_scatter_tensor_106 = None + view_1379 = torch.ops.aten.view.default(view_1375, [16384, 1024]); view_1375 = None + permute_733 = torch.ops.aten.permute.default(view_1379, [1, 0]) + mm_391 = torch.ops.aten.mm.default(permute_733, view_683); permute_733 = None + permute_735 = torch.ops.aten.permute.default(permute_221, [1, 0]); permute_221 = None + mm_392 = torch.ops.aten.mm.default(view_1379, permute_735); view_1379 = permute_735 = None + view_1380 = torch.ops.aten.view.default(mm_392, [2, 8192, 4096]); mm_392 = None + add_210 = torch.ops.aten.add.Tensor(view_1378, view_1380); view_1378 = view_1380 = None + convert_element_type_1710 = torch.ops.prims.convert_element_type.default(mm_391, torch.float32); mm_391 = None + reduce_scatter_tensor_107 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1710, 'avg', 256, '0'); convert_element_type_1710 = None + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_107); reduce_scatter_tensor_107 = None + view_1381 = torch.ops.aten.view.default(view_1376, [16384, 4096]); view_1376 = None + permute_737 = torch.ops.aten.permute.default(view_1381, [1, 0]) + mm_393 = torch.ops.aten.mm.default(permute_737, view_683); permute_737 = view_683 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16); primals_185 = None + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 256, '0'); convert_element_type_664 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + permute_739 = torch.ops.aten.permute.default(permute_220, [1, 0]); permute_220 = None + mm_394 = torch.ops.aten.mm.default(view_1381, permute_739); view_1381 = permute_739 = None + view_1382 = torch.ops.aten.view.default(mm_394, [2, 8192, 4096]); mm_394 = None + add_211 = torch.ops.aten.add.Tensor(add_210, view_1382); add_210 = view_1382 = None + convert_element_type_1715 = torch.ops.prims.convert_element_type.default(mm_393, torch.float32); mm_393 = None + reduce_scatter_tensor_108 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1715, 'avg', 256, '0'); convert_element_type_1715 = None + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_108); reduce_scatter_tensor_108 = None + convert_element_type_1716 = torch.ops.prims.convert_element_type.default(add_211, torch.float32); add_211 = None + convert_element_type_1718 = torch.ops.prims.convert_element_type.default(wait_tensor_181, torch.float32); wait_tensor_181 = None + mul_498 = torch.ops.aten.mul.Tensor(convert_element_type_1716, convert_element_type_1718); convert_element_type_1718 = None + mul_500 = torch.ops.aten.mul.Tensor(mul_160, mul_498) + sum_73 = torch.ops.aten.sum.dim_IntList(mul_500, [2], True); mul_500 = None + div_24 = torch.ops.aten.div.Tensor(mul_160, 4096) + mul_501 = torch.ops.aten.mul.Tensor(div_24, sum_73); div_24 = sum_73 = None + sub_36 = torch.ops.aten.sub.Tensor(mul_498, mul_501); mul_498 = mul_501 = None + mul_502 = torch.ops.aten.mul.Tensor(sub_36, rsqrt_40); sub_36 = rsqrt_40 = None + mul_503 = torch.ops.aten.mul.Tensor(convert_element_type_1716, mul_160); convert_element_type_1716 = mul_160 = None + sum_74 = torch.ops.aten.sum.dim_IntList(mul_503, [0, 1]); mul_503 = None + convert_element_type_1719 = torch.ops.prims.convert_element_type.default(mul_502, torch.bfloat16); mul_502 = None + add_212 = torch.ops.aten.add.Tensor(add_209, convert_element_type_1719); add_209 = convert_element_type_1719 = None + convert_element_type_default_41 = torch.ops.prims.convert_element_type.default(sum_74, torch.float32); sum_74 = None + reduce_scatter_tensor_109 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_41, 'avg', 256, '0'); convert_element_type_default_41 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_109); reduce_scatter_tensor_109 = None + view_1383 = torch.ops.aten.view.default(add_212, [16384, 4096]) + permute_741 = torch.ops.aten.permute.default(view_1383, [1, 0]) + permute_215 = torch.ops.aten.permute.default(getitem_171, [0, 2, 1, 3]) + view_667 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16); primals_179 = None + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 256, '0'); convert_element_type_644 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + view_669 = torch.ops.aten.view.default(view_667, [16384, 4096]); view_667 = None + mm_136 = torch.ops.aten.mm.default(view_669, permute_216) + view_670 = torch.ops.aten.view.default(mm_136, [2, 8192, 4096]); mm_136 = None + add_77 = torch.ops.aten.add.Tensor(add_75, view_670); view_670 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16); primals_180 = None + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 256, '0'); convert_element_type_647 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32); add_77 = None + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_177) + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + view_673 = torch.ops.aten.view.default(convert_element_type_649, [16384, 4096]); convert_element_type_649 = None + view_674 = torch.ops.aten.view.default(mm_137, [2, 8192, 14336]); mm_137 = None + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_674, torch.float32); view_674 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16); primals_182 = None + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 256, '0'); convert_element_type_655 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_179, [1, 0]); wait_tensor_179 = None + mm_138 = torch.ops.aten.mm.default(view_673, permute_218) + view_677 = torch.ops.aten.view.default(mm_138, [2, 8192, 14336]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_677) + view_679 = torch.ops.aten.view.default(mul_159, [16384, 14336]); mul_159 = None + mm_395 = torch.ops.aten.mm.default(permute_741, view_679); permute_741 = view_679 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16); primals_183 = None + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 256, '0'); convert_element_type_658 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + permute_743 = torch.ops.aten.permute.default(permute_219, [1, 0]); permute_219 = None + mm_396 = torch.ops.aten.mm.default(view_1383, permute_743); view_1383 = permute_743 = None + view_1384 = torch.ops.aten.view.default(mm_396, [2, 8192, 14336]); mm_396 = None + convert_element_type_1726 = torch.ops.prims.convert_element_type.default(mm_395, torch.float32); mm_395 = None + reduce_scatter_tensor_110 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1726, 'avg', 256, '0'); convert_element_type_1726 = None + wait_tensor_401 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_110); reduce_scatter_tensor_110 = None + mul_504 = torch.ops.aten.mul.Tensor(view_1384, convert_element_type_654); convert_element_type_654 = None + mul_505 = torch.ops.aten.mul.Tensor(view_1384, view_677); view_1384 = view_677 = None + view_1385 = torch.ops.aten.view.default(mul_504, [16384, 14336]); mul_504 = None + permute_745 = torch.ops.aten.permute.default(view_1385, [1, 0]) + mm_397 = torch.ops.aten.mm.default(permute_745, view_673); permute_745 = None + permute_747 = torch.ops.aten.permute.default(permute_218, [1, 0]); permute_218 = None + mm_398 = torch.ops.aten.mm.default(view_1385, permute_747); view_1385 = permute_747 = None + view_1386 = torch.ops.aten.view.default(mm_398, [2, 8192, 4096]); mm_398 = None + convert_element_type_1731 = torch.ops.prims.convert_element_type.default(mm_397, torch.float32); mm_397 = None + reduce_scatter_tensor_111 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1731, 'avg', 256, '0'); convert_element_type_1731 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_111); reduce_scatter_tensor_111 = None + convert_element_type_1732 = torch.ops.prims.convert_element_type.default(mul_505, torch.float32); mul_505 = None + neg_12 = torch.ops.aten.neg.default(convert_element_type_653) + exp_12 = torch.ops.aten.exp.default(neg_12); neg_12 = None + add_213 = torch.ops.aten.add.Tensor(exp_12, 1); exp_12 = None + reciprocal_12 = torch.ops.aten.reciprocal.default(add_213); add_213 = None + mul_506 = torch.ops.aten.mul.Tensor(reciprocal_12, 1); reciprocal_12 = None + mul_507 = torch.ops.aten.mul.Tensor(convert_element_type_1732, mul_506); convert_element_type_1732 = None + sub_37 = torch.ops.aten.sub.Tensor(1, mul_506); mul_506 = None + mul_508 = torch.ops.aten.mul.Tensor(convert_element_type_653, sub_37); convert_element_type_653 = sub_37 = None + add_214 = torch.ops.aten.add.Tensor(mul_508, 1); mul_508 = None + mul_509 = torch.ops.aten.mul.Tensor(mul_507, add_214); mul_507 = add_214 = None + convert_element_type_1734 = torch.ops.prims.convert_element_type.default(mul_509, torch.bfloat16); mul_509 = None + view_1387 = torch.ops.aten.view.default(convert_element_type_1734, [16384, 14336]); convert_element_type_1734 = None + permute_749 = torch.ops.aten.permute.default(view_1387, [1, 0]) + mm_399 = torch.ops.aten.mm.default(permute_749, view_673); permute_749 = view_673 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16); primals_181 = None + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 256, '0'); convert_element_type_650 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_178, [1, 0]); wait_tensor_178 = None + permute_751 = torch.ops.aten.permute.default(permute_217, [1, 0]); permute_217 = None + mm_400 = torch.ops.aten.mm.default(view_1387, permute_751); view_1387 = permute_751 = None + view_1388 = torch.ops.aten.view.default(mm_400, [2, 8192, 4096]); mm_400 = None + add_215 = torch.ops.aten.add.Tensor(view_1386, view_1388); view_1386 = view_1388 = None + convert_element_type_1739 = torch.ops.prims.convert_element_type.default(mm_399, torch.float32); mm_399 = None + reduce_scatter_tensor_112 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1739, 'avg', 256, '0'); convert_element_type_1739 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_112); reduce_scatter_tensor_112 = None + convert_element_type_1740 = torch.ops.prims.convert_element_type.default(add_215, torch.float32); add_215 = None + convert_element_type_1742 = torch.ops.prims.convert_element_type.default(wait_tensor_177, torch.float32); wait_tensor_177 = None + mul_510 = torch.ops.aten.mul.Tensor(convert_element_type_1740, convert_element_type_1742); convert_element_type_1742 = None + mul_512 = torch.ops.aten.mul.Tensor(mul_156, mul_510) + sum_75 = torch.ops.aten.sum.dim_IntList(mul_512, [2], True); mul_512 = None + div_25 = torch.ops.aten.div.Tensor(mul_156, 4096) + mul_513 = torch.ops.aten.mul.Tensor(div_25, sum_75); div_25 = sum_75 = None + sub_38 = torch.ops.aten.sub.Tensor(mul_510, mul_513); mul_510 = mul_513 = None + mul_514 = torch.ops.aten.mul.Tensor(sub_38, rsqrt_39); sub_38 = rsqrt_39 = None + mul_515 = torch.ops.aten.mul.Tensor(convert_element_type_1740, mul_156); convert_element_type_1740 = mul_156 = None + sum_76 = torch.ops.aten.sum.dim_IntList(mul_515, [0, 1]); mul_515 = None + convert_element_type_1743 = torch.ops.prims.convert_element_type.default(mul_514, torch.bfloat16); mul_514 = None + add_216 = torch.ops.aten.add.Tensor(add_212, convert_element_type_1743); add_212 = convert_element_type_1743 = None + convert_element_type_default_40 = torch.ops.prims.convert_element_type.default(sum_76, torch.float32); sum_76 = None + reduce_scatter_tensor_113 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_40, 'avg', 256, '0'); convert_element_type_default_40 = None + wait_tensor_404 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_113); reduce_scatter_tensor_113 = None + view_1389 = torch.ops.aten.view.default(add_216, [16384, 4096]) + permute_753 = torch.ops.aten.permute.default(view_1389, [1, 0]) + mm_401 = torch.ops.aten.mm.default(permute_753, view_669); permute_753 = view_669 = None + permute_755 = torch.ops.aten.permute.default(permute_216, [1, 0]); permute_216 = None + mm_402 = torch.ops.aten.mm.default(view_1389, permute_755); view_1389 = permute_755 = None + view_1390 = torch.ops.aten.view.default(mm_402, [2, 8192, 4096]); mm_402 = None + convert_element_type_1750 = torch.ops.prims.convert_element_type.default(mm_401, torch.float32); mm_401 = None + reduce_scatter_tensor_114 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1750, 'avg', 256, '0'); convert_element_type_1750 = None + wait_tensor_405 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_114); reduce_scatter_tensor_114 = None + view_1391 = torch.ops.aten.view.default(view_1390, [2, 8192, 32, 128]); view_1390 = None + permute_757 = torch.ops.aten.permute.default(view_1391, [0, 2, 1, 3]); view_1391 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16); primals_175 = None + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 256, '0'); convert_element_type_628 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32); add_75 = None + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_172) + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + view_649 = torch.ops.aten.view.default(convert_element_type_630, [16384, 4096]); convert_element_type_630 = None + view_650 = torch.ops.aten.view.default(mm_133, [2, 8192, 4096]); mm_133 = None + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16); primals_177 = None + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 256, '0'); convert_element_type_634 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_134 = torch.ops.aten.mm.default(view_649, permute_210) + view_653 = torch.ops.aten.view.default(mm_134, [2, 8192, 1024]); mm_134 = None + view_656 = torch.ops.aten.view.default(mm_135, [2, 8192, 1024]); mm_135 = None + view_657 = torch.ops.aten.view.default(view_650, [2, 8192, -1, 128]); view_650 = None + view_658 = torch.ops.aten.view.default(view_653, [2, 8192, -1, 128]); view_653 = None + view_659 = torch.ops.aten.view.default(view_656, [2, 8192, -1, 128]); view_656 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_657, torch.float32); view_657 = None + view_660 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 32, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_660); view_660 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_658, torch.float32); view_658 = None + view_661 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 8, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_661); view_661 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_16); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_663 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 32, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_16); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_664 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 8, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_663, torch.bfloat16); view_663 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_664, torch.bfloat16); view_664 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 8, 4, 128]); unsqueeze_38 = None + clone_38 = torch.ops.aten.clone.default(expand_38, memory_format = torch.contiguous_format); expand_38 = None + view_665 = torch.ops.aten.view.default(clone_38, [2, 8192, 32, 128]); clone_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_659, 3); view_659 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 8, 4, 128]); unsqueeze_39 = None + clone_39 = torch.ops.aten.clone.default(expand_39, memory_format = torch.contiguous_format); expand_39 = None + view_666 = torch.ops.aten.view.default(clone_39, [2, 8192, 32, 128]); clone_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_665, [0, 2, 1, 3]); view_665 = None + permute_214 = torch.ops.aten.permute.default(view_666, [0, 2, 1, 3]); view_666 = None + _scaled_dot_product_cudnn_attention_backward_12 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_757, permute_212, permute_213, permute_214, getitem_171, getitem_172, getitem_177, getitem_178, None, None, None, 8192, 8192, 0.0, True); permute_757 = permute_212 = permute_213 = permute_214 = getitem_171 = getitem_172 = getitem_177 = getitem_178 = None + getitem_324 = _scaled_dot_product_cudnn_attention_backward_12[0] + getitem_325 = _scaled_dot_product_cudnn_attention_backward_12[1] + getitem_326 = _scaled_dot_product_cudnn_attention_backward_12[2]; _scaled_dot_product_cudnn_attention_backward_12 = None + permute_758 = torch.ops.aten.permute.default(getitem_326, [0, 2, 1, 3]); getitem_326 = None + permute_759 = torch.ops.aten.permute.default(getitem_325, [0, 2, 1, 3]); getitem_325 = None + permute_760 = torch.ops.aten.permute.default(getitem_324, [0, 2, 1, 3]); getitem_324 = None + view_1392 = torch.ops.aten.view.default(permute_758, [2, 8192, 8, 4, 128]); permute_758 = None + sum_77 = torch.ops.aten.sum.dim_IntList(view_1392, [3], True); view_1392 = None + squeeze_24 = torch.ops.aten.squeeze.dim(sum_77, 3); sum_77 = None + view_1393 = torch.ops.aten.view.default(permute_759, [2, 8192, 8, 4, 128]); permute_759 = None + sum_78 = torch.ops.aten.sum.dim_IntList(view_1393, [3], True); view_1393 = None + squeeze_25 = torch.ops.aten.squeeze.dim(sum_78, 3); sum_78 = None + convert_element_type_1751 = torch.ops.prims.convert_element_type.default(squeeze_25, torch.float32); squeeze_25 = None + convert_element_type_1752 = torch.ops.prims.convert_element_type.default(permute_760, torch.float32); permute_760 = None + view_1394 = torch.ops.aten.view.default(convert_element_type_1751, [2, 8192, 8, 64, 2]); convert_element_type_1751 = None + view_as_complex_88 = torch.ops.aten.view_as_complex.default(view_1394); view_1394 = None + mul_516 = torch.ops.aten.mul.Tensor(view_as_complex_88, _conj); view_as_complex_88 = None + view_1395 = torch.ops.aten.view.default(convert_element_type_1752, [2, 8192, 32, 64, 2]); convert_element_type_1752 = None + view_as_complex_89 = torch.ops.aten.view_as_complex.default(view_1395); view_1395 = None + mul_517 = torch.ops.aten.mul.Tensor(view_as_complex_89, _conj); view_as_complex_89 = None + view_as_real_88 = torch.ops.aten.view_as_real.default(mul_516); mul_516 = None + view_1396 = torch.ops.aten.view.default(view_as_real_88, [2, 8192, 8, 128]); view_as_real_88 = None + convert_element_type_1753 = torch.ops.prims.convert_element_type.default(view_1396, torch.bfloat16); view_1396 = None + view_as_real_89 = torch.ops.aten.view_as_real.default(mul_517); mul_517 = None + view_1397 = torch.ops.aten.view.default(view_as_real_89, [2, 8192, 32, 128]); view_as_real_89 = None + convert_element_type_1754 = torch.ops.prims.convert_element_type.default(view_1397, torch.bfloat16); view_1397 = None + view_1398 = torch.ops.aten.view.default(squeeze_24, [2, 8192, 1024]); squeeze_24 = None + view_1399 = torch.ops.aten.view.default(convert_element_type_1753, [2, 8192, 1024]); convert_element_type_1753 = None + view_1400 = torch.ops.aten.view.default(convert_element_type_1754, [2, 8192, 4096]); convert_element_type_1754 = None + view_1401 = torch.ops.aten.view.default(view_1398, [16384, 1024]); view_1398 = None + permute_761 = torch.ops.aten.permute.default(view_1401, [1, 0]) + mm_403 = torch.ops.aten.mm.default(permute_761, view_649); permute_761 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16); primals_178 = None + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 256, '0'); convert_element_type_637 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + permute_763 = torch.ops.aten.permute.default(permute_211, [1, 0]); permute_211 = None + mm_404 = torch.ops.aten.mm.default(view_1401, permute_763); view_1401 = permute_763 = None + view_1402 = torch.ops.aten.view.default(mm_404, [2, 8192, 4096]); mm_404 = None + convert_element_type_1759 = torch.ops.prims.convert_element_type.default(mm_403, torch.float32); mm_403 = None + reduce_scatter_tensor_115 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1759, 'avg', 256, '0'); convert_element_type_1759 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_115); reduce_scatter_tensor_115 = None + view_1403 = torch.ops.aten.view.default(view_1399, [16384, 1024]); view_1399 = None + permute_765 = torch.ops.aten.permute.default(view_1403, [1, 0]) + mm_405 = torch.ops.aten.mm.default(permute_765, view_649); permute_765 = None + permute_767 = torch.ops.aten.permute.default(permute_210, [1, 0]); permute_210 = None + mm_406 = torch.ops.aten.mm.default(view_1403, permute_767); view_1403 = permute_767 = None + view_1404 = torch.ops.aten.view.default(mm_406, [2, 8192, 4096]); mm_406 = None + add_217 = torch.ops.aten.add.Tensor(view_1402, view_1404); view_1402 = view_1404 = None + convert_element_type_1764 = torch.ops.prims.convert_element_type.default(mm_405, torch.float32); mm_405 = None + reduce_scatter_tensor_116 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1764, 'avg', 256, '0'); convert_element_type_1764 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_116); reduce_scatter_tensor_116 = None + view_1405 = torch.ops.aten.view.default(view_1400, [16384, 4096]); view_1400 = None + permute_769 = torch.ops.aten.permute.default(view_1405, [1, 0]) + mm_407 = torch.ops.aten.mm.default(permute_769, view_649); permute_769 = view_649 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16); primals_176 = None + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 256, '0'); convert_element_type_631 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + permute_771 = torch.ops.aten.permute.default(permute_209, [1, 0]); permute_209 = None + mm_408 = torch.ops.aten.mm.default(view_1405, permute_771); view_1405 = permute_771 = None + view_1406 = torch.ops.aten.view.default(mm_408, [2, 8192, 4096]); mm_408 = None + add_218 = torch.ops.aten.add.Tensor(add_217, view_1406); add_217 = view_1406 = None + convert_element_type_1769 = torch.ops.prims.convert_element_type.default(mm_407, torch.float32); mm_407 = None + reduce_scatter_tensor_117 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1769, 'avg', 256, '0'); convert_element_type_1769 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_117); reduce_scatter_tensor_117 = None + convert_element_type_1770 = torch.ops.prims.convert_element_type.default(add_218, torch.float32); add_218 = None + convert_element_type_1772 = torch.ops.prims.convert_element_type.default(wait_tensor_172, torch.float32); wait_tensor_172 = None + mul_518 = torch.ops.aten.mul.Tensor(convert_element_type_1770, convert_element_type_1772); convert_element_type_1772 = None + mul_520 = torch.ops.aten.mul.Tensor(mul_152, mul_518) + sum_79 = torch.ops.aten.sum.dim_IntList(mul_520, [2], True); mul_520 = None + div_26 = torch.ops.aten.div.Tensor(mul_152, 4096) + mul_521 = torch.ops.aten.mul.Tensor(div_26, sum_79); div_26 = sum_79 = None + sub_39 = torch.ops.aten.sub.Tensor(mul_518, mul_521); mul_518 = mul_521 = None + mul_522 = torch.ops.aten.mul.Tensor(sub_39, rsqrt_38); sub_39 = rsqrt_38 = None + mul_523 = torch.ops.aten.mul.Tensor(convert_element_type_1770, mul_152); convert_element_type_1770 = mul_152 = None + sum_80 = torch.ops.aten.sum.dim_IntList(mul_523, [0, 1]); mul_523 = None + convert_element_type_1773 = torch.ops.prims.convert_element_type.default(mul_522, torch.bfloat16); mul_522 = None + add_219 = torch.ops.aten.add.Tensor(add_216, convert_element_type_1773); add_216 = convert_element_type_1773 = None + convert_element_type_default_39 = torch.ops.prims.convert_element_type.default(sum_80, torch.float32); sum_80 = None + reduce_scatter_tensor_118 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_39, 'avg', 256, '0'); convert_element_type_default_39 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_118); reduce_scatter_tensor_118 = None + view_1407 = torch.ops.aten.view.default(add_219, [16384, 4096]) + permute_773 = torch.ops.aten.permute.default(view_1407, [1, 0]) + permute_204 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_633 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16); primals_170 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 256, '0'); convert_element_type_611 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + view_635 = torch.ops.aten.view.default(view_633, [16384, 4096]); view_633 = None + mm_129 = torch.ops.aten.mm.default(view_635, permute_205) + view_636 = torch.ops.aten.view.default(mm_129, [2, 8192, 4096]); mm_129 = None + add_73 = torch.ops.aten.add.Tensor(add_71, view_636); view_636 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16); primals_171 = None + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 256, '0'); convert_element_type_614 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32); add_73 = None + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_168) + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + view_639 = torch.ops.aten.view.default(convert_element_type_616, [16384, 4096]); convert_element_type_616 = None + view_640 = torch.ops.aten.view.default(mm_130, [2, 8192, 14336]); mm_130 = None + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_640, torch.float32); view_640 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16); primals_173 = None + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 256, '0'); convert_element_type_622 = None + wait_tensor_170 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_170, [1, 0]); wait_tensor_170 = None + mm_131 = torch.ops.aten.mm.default(view_639, permute_207) + view_643 = torch.ops.aten.view.default(mm_131, [2, 8192, 14336]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_643) + view_645 = torch.ops.aten.view.default(mul_151, [16384, 14336]); mul_151 = None + mm_409 = torch.ops.aten.mm.default(permute_773, view_645); permute_773 = view_645 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16); primals_174 = None + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 256, '0'); convert_element_type_625 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_171, [1, 0]); wait_tensor_171 = None + permute_775 = torch.ops.aten.permute.default(permute_208, [1, 0]); permute_208 = None + mm_410 = torch.ops.aten.mm.default(view_1407, permute_775); view_1407 = permute_775 = None + view_1408 = torch.ops.aten.view.default(mm_410, [2, 8192, 14336]); mm_410 = None + convert_element_type_1780 = torch.ops.prims.convert_element_type.default(mm_409, torch.float32); mm_409 = None + reduce_scatter_tensor_119 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1780, 'avg', 256, '0'); convert_element_type_1780 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_119); reduce_scatter_tensor_119 = None + mul_524 = torch.ops.aten.mul.Tensor(view_1408, convert_element_type_621); convert_element_type_621 = None + mul_525 = torch.ops.aten.mul.Tensor(view_1408, view_643); view_1408 = view_643 = None + view_1409 = torch.ops.aten.view.default(mul_524, [16384, 14336]); mul_524 = None + permute_777 = torch.ops.aten.permute.default(view_1409, [1, 0]) + mm_411 = torch.ops.aten.mm.default(permute_777, view_639); permute_777 = None + permute_779 = torch.ops.aten.permute.default(permute_207, [1, 0]); permute_207 = None + mm_412 = torch.ops.aten.mm.default(view_1409, permute_779); view_1409 = permute_779 = None + view_1410 = torch.ops.aten.view.default(mm_412, [2, 8192, 4096]); mm_412 = None + convert_element_type_1785 = torch.ops.prims.convert_element_type.default(mm_411, torch.float32); mm_411 = None + reduce_scatter_tensor_120 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1785, 'avg', 256, '0'); convert_element_type_1785 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_120); reduce_scatter_tensor_120 = None + convert_element_type_1786 = torch.ops.prims.convert_element_type.default(mul_525, torch.float32); mul_525 = None + neg_13 = torch.ops.aten.neg.default(convert_element_type_620) + exp_13 = torch.ops.aten.exp.default(neg_13); neg_13 = None + add_220 = torch.ops.aten.add.Tensor(exp_13, 1); exp_13 = None + reciprocal_13 = torch.ops.aten.reciprocal.default(add_220); add_220 = None + mul_526 = torch.ops.aten.mul.Tensor(reciprocal_13, 1); reciprocal_13 = None + mul_527 = torch.ops.aten.mul.Tensor(convert_element_type_1786, mul_526); convert_element_type_1786 = None + sub_40 = torch.ops.aten.sub.Tensor(1, mul_526); mul_526 = None + mul_528 = torch.ops.aten.mul.Tensor(convert_element_type_620, sub_40); convert_element_type_620 = sub_40 = None + add_221 = torch.ops.aten.add.Tensor(mul_528, 1); mul_528 = None + mul_529 = torch.ops.aten.mul.Tensor(mul_527, add_221); mul_527 = add_221 = None + convert_element_type_1788 = torch.ops.prims.convert_element_type.default(mul_529, torch.bfloat16); mul_529 = None + view_1411 = torch.ops.aten.view.default(convert_element_type_1788, [16384, 14336]); convert_element_type_1788 = None + permute_781 = torch.ops.aten.permute.default(view_1411, [1, 0]) + mm_413 = torch.ops.aten.mm.default(permute_781, view_639); permute_781 = view_639 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16); primals_172 = None + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 256, '0'); convert_element_type_617 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + permute_783 = torch.ops.aten.permute.default(permute_206, [1, 0]); permute_206 = None + mm_414 = torch.ops.aten.mm.default(view_1411, permute_783); view_1411 = permute_783 = None + view_1412 = torch.ops.aten.view.default(mm_414, [2, 8192, 4096]); mm_414 = None + add_222 = torch.ops.aten.add.Tensor(view_1410, view_1412); view_1410 = view_1412 = None + convert_element_type_1793 = torch.ops.prims.convert_element_type.default(mm_413, torch.float32); mm_413 = None + reduce_scatter_tensor_121 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1793, 'avg', 256, '0'); convert_element_type_1793 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_121); reduce_scatter_tensor_121 = None + convert_element_type_1794 = torch.ops.prims.convert_element_type.default(add_222, torch.float32); add_222 = None + convert_element_type_1796 = torch.ops.prims.convert_element_type.default(wait_tensor_168, torch.float32); wait_tensor_168 = None + mul_530 = torch.ops.aten.mul.Tensor(convert_element_type_1794, convert_element_type_1796); convert_element_type_1796 = None + mul_532 = torch.ops.aten.mul.Tensor(mul_148, mul_530) + sum_81 = torch.ops.aten.sum.dim_IntList(mul_532, [2], True); mul_532 = None + div_27 = torch.ops.aten.div.Tensor(mul_148, 4096) + mul_533 = torch.ops.aten.mul.Tensor(div_27, sum_81); div_27 = sum_81 = None + sub_41 = torch.ops.aten.sub.Tensor(mul_530, mul_533); mul_530 = mul_533 = None + mul_534 = torch.ops.aten.mul.Tensor(sub_41, rsqrt_37); sub_41 = rsqrt_37 = None + mul_535 = torch.ops.aten.mul.Tensor(convert_element_type_1794, mul_148); convert_element_type_1794 = mul_148 = None + sum_82 = torch.ops.aten.sum.dim_IntList(mul_535, [0, 1]); mul_535 = None + convert_element_type_1797 = torch.ops.prims.convert_element_type.default(mul_534, torch.bfloat16); mul_534 = None + add_223 = torch.ops.aten.add.Tensor(add_219, convert_element_type_1797); add_219 = convert_element_type_1797 = None + convert_element_type_default_38 = torch.ops.prims.convert_element_type.default(sum_82, torch.float32); sum_82 = None + reduce_scatter_tensor_122 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_38, 'avg', 256, '0'); convert_element_type_default_38 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_122); reduce_scatter_tensor_122 = None + view_1413 = torch.ops.aten.view.default(add_223, [16384, 4096]) + permute_785 = torch.ops.aten.permute.default(view_1413, [1, 0]) + mm_415 = torch.ops.aten.mm.default(permute_785, view_635); permute_785 = view_635 = None + permute_787 = torch.ops.aten.permute.default(permute_205, [1, 0]); permute_205 = None + mm_416 = torch.ops.aten.mm.default(view_1413, permute_787); view_1413 = permute_787 = None + view_1414 = torch.ops.aten.view.default(mm_416, [2, 8192, 4096]); mm_416 = None + convert_element_type_1804 = torch.ops.prims.convert_element_type.default(mm_415, torch.float32); mm_415 = None + reduce_scatter_tensor_123 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1804, 'avg', 256, '0'); convert_element_type_1804 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_123); reduce_scatter_tensor_123 = None + view_1415 = torch.ops.aten.view.default(view_1414, [2, 8192, 32, 128]); view_1414 = None + permute_789 = torch.ops.aten.permute.default(view_1415, [0, 2, 1, 3]); view_1415 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16); primals_166 = None + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 256, '0'); convert_element_type_595 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32); add_71 = None + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_163) + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + view_615 = torch.ops.aten.view.default(convert_element_type_597, [16384, 4096]); convert_element_type_597 = None + view_616 = torch.ops.aten.view.default(mm_126, [2, 8192, 4096]); mm_126 = None + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16); primals_168 = None + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 256, '0'); convert_element_type_601 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_165, [1, 0]); wait_tensor_165 = None + mm_127 = torch.ops.aten.mm.default(view_615, permute_199) + view_619 = torch.ops.aten.view.default(mm_127, [2, 8192, 1024]); mm_127 = None + view_622 = torch.ops.aten.view.default(mm_128, [2, 8192, 1024]); mm_128 = None + view_623 = torch.ops.aten.view.default(view_616, [2, 8192, -1, 128]); view_616 = None + view_624 = torch.ops.aten.view.default(view_619, [2, 8192, -1, 128]); view_619 = None + view_625 = torch.ops.aten.view.default(view_622, [2, 8192, -1, 128]); view_622 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_623, torch.float32); view_623 = None + view_626 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 32, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_626); view_626 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_624, torch.float32); view_624 = None + view_627 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 8, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_627); view_627 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_16); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_629 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 32, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_16); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_630 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 8, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_629, torch.bfloat16); view_629 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_630, torch.bfloat16); view_630 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 8, 4, 128]); unsqueeze_36 = None + clone_36 = torch.ops.aten.clone.default(expand_36, memory_format = torch.contiguous_format); expand_36 = None + view_631 = torch.ops.aten.view.default(clone_36, [2, 8192, 32, 128]); clone_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_625, 3); view_625 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 8, 4, 128]); unsqueeze_37 = None + clone_37 = torch.ops.aten.clone.default(expand_37, memory_format = torch.contiguous_format); expand_37 = None + view_632 = torch.ops.aten.view.default(clone_37, [2, 8192, 32, 128]); clone_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_631, [0, 2, 1, 3]); view_631 = None + permute_203 = torch.ops.aten.permute.default(view_632, [0, 2, 1, 3]); view_632 = None + _scaled_dot_product_cudnn_attention_backward_13 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_789, permute_201, permute_202, permute_203, getitem_162, getitem_163, getitem_168, getitem_169, None, None, None, 8192, 8192, 0.0, True); permute_789 = permute_201 = permute_202 = permute_203 = getitem_162 = getitem_163 = getitem_168 = getitem_169 = None + getitem_327 = _scaled_dot_product_cudnn_attention_backward_13[0] + getitem_328 = _scaled_dot_product_cudnn_attention_backward_13[1] + getitem_329 = _scaled_dot_product_cudnn_attention_backward_13[2]; _scaled_dot_product_cudnn_attention_backward_13 = None + permute_790 = torch.ops.aten.permute.default(getitem_329, [0, 2, 1, 3]); getitem_329 = None + permute_791 = torch.ops.aten.permute.default(getitem_328, [0, 2, 1, 3]); getitem_328 = None + permute_792 = torch.ops.aten.permute.default(getitem_327, [0, 2, 1, 3]); getitem_327 = None + view_1416 = torch.ops.aten.view.default(permute_790, [2, 8192, 8, 4, 128]); permute_790 = None + sum_83 = torch.ops.aten.sum.dim_IntList(view_1416, [3], True); view_1416 = None + squeeze_26 = torch.ops.aten.squeeze.dim(sum_83, 3); sum_83 = None + view_1417 = torch.ops.aten.view.default(permute_791, [2, 8192, 8, 4, 128]); permute_791 = None + sum_84 = torch.ops.aten.sum.dim_IntList(view_1417, [3], True); view_1417 = None + squeeze_27 = torch.ops.aten.squeeze.dim(sum_84, 3); sum_84 = None + convert_element_type_1805 = torch.ops.prims.convert_element_type.default(squeeze_27, torch.float32); squeeze_27 = None + convert_element_type_1806 = torch.ops.prims.convert_element_type.default(permute_792, torch.float32); permute_792 = None + view_1418 = torch.ops.aten.view.default(convert_element_type_1805, [2, 8192, 8, 64, 2]); convert_element_type_1805 = None + view_as_complex_90 = torch.ops.aten.view_as_complex.default(view_1418); view_1418 = None + mul_536 = torch.ops.aten.mul.Tensor(view_as_complex_90, _conj); view_as_complex_90 = None + view_1419 = torch.ops.aten.view.default(convert_element_type_1806, [2, 8192, 32, 64, 2]); convert_element_type_1806 = None + view_as_complex_91 = torch.ops.aten.view_as_complex.default(view_1419); view_1419 = None + mul_537 = torch.ops.aten.mul.Tensor(view_as_complex_91, _conj); view_as_complex_91 = None + view_as_real_90 = torch.ops.aten.view_as_real.default(mul_536); mul_536 = None + view_1420 = torch.ops.aten.view.default(view_as_real_90, [2, 8192, 8, 128]); view_as_real_90 = None + convert_element_type_1807 = torch.ops.prims.convert_element_type.default(view_1420, torch.bfloat16); view_1420 = None + view_as_real_91 = torch.ops.aten.view_as_real.default(mul_537); mul_537 = None + view_1421 = torch.ops.aten.view.default(view_as_real_91, [2, 8192, 32, 128]); view_as_real_91 = None + convert_element_type_1808 = torch.ops.prims.convert_element_type.default(view_1421, torch.bfloat16); view_1421 = None + view_1422 = torch.ops.aten.view.default(squeeze_26, [2, 8192, 1024]); squeeze_26 = None + view_1423 = torch.ops.aten.view.default(convert_element_type_1807, [2, 8192, 1024]); convert_element_type_1807 = None + view_1424 = torch.ops.aten.view.default(convert_element_type_1808, [2, 8192, 4096]); convert_element_type_1808 = None + view_1425 = torch.ops.aten.view.default(view_1422, [16384, 1024]); view_1422 = None + permute_793 = torch.ops.aten.permute.default(view_1425, [1, 0]) + mm_417 = torch.ops.aten.mm.default(permute_793, view_615); permute_793 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16); primals_169 = None + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 256, '0'); convert_element_type_604 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_166, [1, 0]); wait_tensor_166 = None + permute_795 = torch.ops.aten.permute.default(permute_200, [1, 0]); permute_200 = None + mm_418 = torch.ops.aten.mm.default(view_1425, permute_795); view_1425 = permute_795 = None + view_1426 = torch.ops.aten.view.default(mm_418, [2, 8192, 4096]); mm_418 = None + convert_element_type_1813 = torch.ops.prims.convert_element_type.default(mm_417, torch.float32); mm_417 = None + reduce_scatter_tensor_124 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1813, 'avg', 256, '0'); convert_element_type_1813 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_124); reduce_scatter_tensor_124 = None + view_1427 = torch.ops.aten.view.default(view_1423, [16384, 1024]); view_1423 = None + permute_797 = torch.ops.aten.permute.default(view_1427, [1, 0]) + mm_419 = torch.ops.aten.mm.default(permute_797, view_615); permute_797 = None + permute_799 = torch.ops.aten.permute.default(permute_199, [1, 0]); permute_199 = None + mm_420 = torch.ops.aten.mm.default(view_1427, permute_799); view_1427 = permute_799 = None + view_1428 = torch.ops.aten.view.default(mm_420, [2, 8192, 4096]); mm_420 = None + add_224 = torch.ops.aten.add.Tensor(view_1426, view_1428); view_1426 = view_1428 = None + convert_element_type_1818 = torch.ops.prims.convert_element_type.default(mm_419, torch.float32); mm_419 = None + reduce_scatter_tensor_125 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1818, 'avg', 256, '0'); convert_element_type_1818 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_125); reduce_scatter_tensor_125 = None + view_1429 = torch.ops.aten.view.default(view_1424, [16384, 4096]); view_1424 = None + permute_801 = torch.ops.aten.permute.default(view_1429, [1, 0]) + mm_421 = torch.ops.aten.mm.default(permute_801, view_615); permute_801 = view_615 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16); primals_167 = None + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 256, '0'); convert_element_type_598 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_164, [1, 0]); wait_tensor_164 = None + permute_803 = torch.ops.aten.permute.default(permute_198, [1, 0]); permute_198 = None + mm_422 = torch.ops.aten.mm.default(view_1429, permute_803); view_1429 = permute_803 = None + view_1430 = torch.ops.aten.view.default(mm_422, [2, 8192, 4096]); mm_422 = None + add_225 = torch.ops.aten.add.Tensor(add_224, view_1430); add_224 = view_1430 = None + convert_element_type_1823 = torch.ops.prims.convert_element_type.default(mm_421, torch.float32); mm_421 = None + reduce_scatter_tensor_126 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1823, 'avg', 256, '0'); convert_element_type_1823 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_126); reduce_scatter_tensor_126 = None + convert_element_type_1824 = torch.ops.prims.convert_element_type.default(add_225, torch.float32); add_225 = None + convert_element_type_1826 = torch.ops.prims.convert_element_type.default(wait_tensor_163, torch.float32); wait_tensor_163 = None + mul_538 = torch.ops.aten.mul.Tensor(convert_element_type_1824, convert_element_type_1826); convert_element_type_1826 = None + mul_540 = torch.ops.aten.mul.Tensor(mul_144, mul_538) + sum_85 = torch.ops.aten.sum.dim_IntList(mul_540, [2], True); mul_540 = None + div_28 = torch.ops.aten.div.Tensor(mul_144, 4096) + mul_541 = torch.ops.aten.mul.Tensor(div_28, sum_85); div_28 = sum_85 = None + sub_42 = torch.ops.aten.sub.Tensor(mul_538, mul_541); mul_538 = mul_541 = None + mul_542 = torch.ops.aten.mul.Tensor(sub_42, rsqrt_36); sub_42 = rsqrt_36 = None + mul_543 = torch.ops.aten.mul.Tensor(convert_element_type_1824, mul_144); convert_element_type_1824 = mul_144 = None + sum_86 = torch.ops.aten.sum.dim_IntList(mul_543, [0, 1]); mul_543 = None + convert_element_type_1827 = torch.ops.prims.convert_element_type.default(mul_542, torch.bfloat16); mul_542 = None + add_226 = torch.ops.aten.add.Tensor(add_223, convert_element_type_1827); add_223 = convert_element_type_1827 = None + convert_element_type_default_37 = torch.ops.prims.convert_element_type.default(sum_86, torch.float32); sum_86 = None + reduce_scatter_tensor_127 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_37, 'avg', 256, '0'); convert_element_type_default_37 = None + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_127); reduce_scatter_tensor_127 = None + view_1431 = torch.ops.aten.view.default(add_226, [16384, 4096]) + permute_805 = torch.ops.aten.permute.default(view_1431, [1, 0]) + permute_193 = torch.ops.aten.permute.default(getitem_153, [0, 2, 1, 3]) + view_599 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16); primals_161 = None + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 256, '0'); convert_element_type_578 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_158, [1, 0]); wait_tensor_158 = None + view_601 = torch.ops.aten.view.default(view_599, [16384, 4096]); view_599 = None + mm_122 = torch.ops.aten.mm.default(view_601, permute_194) + view_602 = torch.ops.aten.view.default(mm_122, [2, 8192, 4096]); mm_122 = None + add_69 = torch.ops.aten.add.Tensor(add_67, view_602); view_602 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16); primals_162 = None + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 256, '0'); convert_element_type_581 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32); add_69 = None + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_159) + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + view_605 = torch.ops.aten.view.default(convert_element_type_583, [16384, 4096]); convert_element_type_583 = None + view_606 = torch.ops.aten.view.default(mm_123, [2, 8192, 14336]); mm_123 = None + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_606, torch.float32); view_606 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16); primals_164 = None + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 256, '0'); convert_element_type_589 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_124 = torch.ops.aten.mm.default(view_605, permute_196) + view_609 = torch.ops.aten.view.default(mm_124, [2, 8192, 14336]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_609) + view_611 = torch.ops.aten.view.default(mul_143, [16384, 14336]); mul_143 = None + mm_423 = torch.ops.aten.mm.default(permute_805, view_611); permute_805 = view_611 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16); primals_165 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 256, '0'); convert_element_type_592 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + permute_807 = torch.ops.aten.permute.default(permute_197, [1, 0]); permute_197 = None + mm_424 = torch.ops.aten.mm.default(view_1431, permute_807); view_1431 = permute_807 = None + view_1432 = torch.ops.aten.view.default(mm_424, [2, 8192, 14336]); mm_424 = None + convert_element_type_1834 = torch.ops.prims.convert_element_type.default(mm_423, torch.float32); mm_423 = None + reduce_scatter_tensor_128 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1834, 'avg', 256, '0'); convert_element_type_1834 = None + wait_tensor_419 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_128); reduce_scatter_tensor_128 = None + mul_544 = torch.ops.aten.mul.Tensor(view_1432, convert_element_type_588); convert_element_type_588 = None + mul_545 = torch.ops.aten.mul.Tensor(view_1432, view_609); view_1432 = view_609 = None + view_1433 = torch.ops.aten.view.default(mul_544, [16384, 14336]); mul_544 = None + permute_809 = torch.ops.aten.permute.default(view_1433, [1, 0]) + mm_425 = torch.ops.aten.mm.default(permute_809, view_605); permute_809 = None + permute_811 = torch.ops.aten.permute.default(permute_196, [1, 0]); permute_196 = None + mm_426 = torch.ops.aten.mm.default(view_1433, permute_811); view_1433 = permute_811 = None + view_1434 = torch.ops.aten.view.default(mm_426, [2, 8192, 4096]); mm_426 = None + convert_element_type_1839 = torch.ops.prims.convert_element_type.default(mm_425, torch.float32); mm_425 = None + reduce_scatter_tensor_129 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1839, 'avg', 256, '0'); convert_element_type_1839 = None + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_129); reduce_scatter_tensor_129 = None + convert_element_type_1840 = torch.ops.prims.convert_element_type.default(mul_545, torch.float32); mul_545 = None + neg_14 = torch.ops.aten.neg.default(convert_element_type_587) + exp_14 = torch.ops.aten.exp.default(neg_14); neg_14 = None + add_227 = torch.ops.aten.add.Tensor(exp_14, 1); exp_14 = None + reciprocal_14 = torch.ops.aten.reciprocal.default(add_227); add_227 = None + mul_546 = torch.ops.aten.mul.Tensor(reciprocal_14, 1); reciprocal_14 = None + mul_547 = torch.ops.aten.mul.Tensor(convert_element_type_1840, mul_546); convert_element_type_1840 = None + sub_43 = torch.ops.aten.sub.Tensor(1, mul_546); mul_546 = None + mul_548 = torch.ops.aten.mul.Tensor(convert_element_type_587, sub_43); convert_element_type_587 = sub_43 = None + add_228 = torch.ops.aten.add.Tensor(mul_548, 1); mul_548 = None + mul_549 = torch.ops.aten.mul.Tensor(mul_547, add_228); mul_547 = add_228 = None + convert_element_type_1842 = torch.ops.prims.convert_element_type.default(mul_549, torch.bfloat16); mul_549 = None + view_1435 = torch.ops.aten.view.default(convert_element_type_1842, [16384, 14336]); convert_element_type_1842 = None + permute_813 = torch.ops.aten.permute.default(view_1435, [1, 0]) + mm_427 = torch.ops.aten.mm.default(permute_813, view_605); permute_813 = view_605 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16); primals_163 = None + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 256, '0'); convert_element_type_584 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + permute_815 = torch.ops.aten.permute.default(permute_195, [1, 0]); permute_195 = None + mm_428 = torch.ops.aten.mm.default(view_1435, permute_815); view_1435 = permute_815 = None + view_1436 = torch.ops.aten.view.default(mm_428, [2, 8192, 4096]); mm_428 = None + add_229 = torch.ops.aten.add.Tensor(view_1434, view_1436); view_1434 = view_1436 = None + convert_element_type_1847 = torch.ops.prims.convert_element_type.default(mm_427, torch.float32); mm_427 = None + reduce_scatter_tensor_130 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1847, 'avg', 256, '0'); convert_element_type_1847 = None + wait_tensor_421 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_130); reduce_scatter_tensor_130 = None + convert_element_type_1848 = torch.ops.prims.convert_element_type.default(add_229, torch.float32); add_229 = None + convert_element_type_1850 = torch.ops.prims.convert_element_type.default(wait_tensor_159, torch.float32); wait_tensor_159 = None + mul_550 = torch.ops.aten.mul.Tensor(convert_element_type_1848, convert_element_type_1850); convert_element_type_1850 = None + mul_552 = torch.ops.aten.mul.Tensor(mul_140, mul_550) + sum_87 = torch.ops.aten.sum.dim_IntList(mul_552, [2], True); mul_552 = None + div_29 = torch.ops.aten.div.Tensor(mul_140, 4096) + mul_553 = torch.ops.aten.mul.Tensor(div_29, sum_87); div_29 = sum_87 = None + sub_44 = torch.ops.aten.sub.Tensor(mul_550, mul_553); mul_550 = mul_553 = None + mul_554 = torch.ops.aten.mul.Tensor(sub_44, rsqrt_35); sub_44 = rsqrt_35 = None + mul_555 = torch.ops.aten.mul.Tensor(convert_element_type_1848, mul_140); convert_element_type_1848 = mul_140 = None + sum_88 = torch.ops.aten.sum.dim_IntList(mul_555, [0, 1]); mul_555 = None + convert_element_type_1851 = torch.ops.prims.convert_element_type.default(mul_554, torch.bfloat16); mul_554 = None + add_230 = torch.ops.aten.add.Tensor(add_226, convert_element_type_1851); add_226 = convert_element_type_1851 = None + convert_element_type_default_36 = torch.ops.prims.convert_element_type.default(sum_88, torch.float32); sum_88 = None + reduce_scatter_tensor_131 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_36, 'avg', 256, '0'); convert_element_type_default_36 = None + wait_tensor_422 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_131); reduce_scatter_tensor_131 = None + view_1437 = torch.ops.aten.view.default(add_230, [16384, 4096]) + permute_817 = torch.ops.aten.permute.default(view_1437, [1, 0]) + mm_429 = torch.ops.aten.mm.default(permute_817, view_601); permute_817 = view_601 = None + permute_819 = torch.ops.aten.permute.default(permute_194, [1, 0]); permute_194 = None + mm_430 = torch.ops.aten.mm.default(view_1437, permute_819); view_1437 = permute_819 = None + view_1438 = torch.ops.aten.view.default(mm_430, [2, 8192, 4096]); mm_430 = None + convert_element_type_1858 = torch.ops.prims.convert_element_type.default(mm_429, torch.float32); mm_429 = None + reduce_scatter_tensor_132 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1858, 'avg', 256, '0'); convert_element_type_1858 = None + wait_tensor_423 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_132); reduce_scatter_tensor_132 = None + view_1439 = torch.ops.aten.view.default(view_1438, [2, 8192, 32, 128]); view_1438 = None + permute_821 = torch.ops.aten.permute.default(view_1439, [0, 2, 1, 3]); view_1439 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16); primals_157 = None + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 256, '0'); convert_element_type_562 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32); add_67 = None + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_154) + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + view_581 = torch.ops.aten.view.default(convert_element_type_564, [16384, 4096]); convert_element_type_564 = None + view_582 = torch.ops.aten.view.default(mm_119, [2, 8192, 4096]); mm_119 = None + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16); primals_159 = None + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 256, '0'); convert_element_type_568 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + mm_120 = torch.ops.aten.mm.default(view_581, permute_188) + view_585 = torch.ops.aten.view.default(mm_120, [2, 8192, 1024]); mm_120 = None + view_588 = torch.ops.aten.view.default(mm_121, [2, 8192, 1024]); mm_121 = None + view_589 = torch.ops.aten.view.default(view_582, [2, 8192, -1, 128]); view_582 = None + view_590 = torch.ops.aten.view.default(view_585, [2, 8192, -1, 128]); view_585 = None + view_591 = torch.ops.aten.view.default(view_588, [2, 8192, -1, 128]); view_588 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_589, torch.float32); view_589 = None + view_592 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 32, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_592); view_592 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_590, torch.float32); view_590 = None + view_593 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 8, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_593); view_593 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_16); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_595 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 32, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_16); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_596 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 8, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_595, torch.bfloat16); view_595 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_596, torch.bfloat16); view_596 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 8, 4, 128]); unsqueeze_34 = None + clone_34 = torch.ops.aten.clone.default(expand_34, memory_format = torch.contiguous_format); expand_34 = None + view_597 = torch.ops.aten.view.default(clone_34, [2, 8192, 32, 128]); clone_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_591, 3); view_591 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 8, 4, 128]); unsqueeze_35 = None + clone_35 = torch.ops.aten.clone.default(expand_35, memory_format = torch.contiguous_format); expand_35 = None + view_598 = torch.ops.aten.view.default(clone_35, [2, 8192, 32, 128]); clone_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_597, [0, 2, 1, 3]); view_597 = None + permute_192 = torch.ops.aten.permute.default(view_598, [0, 2, 1, 3]); view_598 = None + _scaled_dot_product_cudnn_attention_backward_14 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_821, permute_190, permute_191, permute_192, getitem_153, getitem_154, getitem_159, getitem_160, None, None, None, 8192, 8192, 0.0, True); permute_821 = permute_190 = permute_191 = permute_192 = getitem_153 = getitem_154 = getitem_159 = getitem_160 = None + getitem_330 = _scaled_dot_product_cudnn_attention_backward_14[0] + getitem_331 = _scaled_dot_product_cudnn_attention_backward_14[1] + getitem_332 = _scaled_dot_product_cudnn_attention_backward_14[2]; _scaled_dot_product_cudnn_attention_backward_14 = None + permute_822 = torch.ops.aten.permute.default(getitem_332, [0, 2, 1, 3]); getitem_332 = None + permute_823 = torch.ops.aten.permute.default(getitem_331, [0, 2, 1, 3]); getitem_331 = None + permute_824 = torch.ops.aten.permute.default(getitem_330, [0, 2, 1, 3]); getitem_330 = None + view_1440 = torch.ops.aten.view.default(permute_822, [2, 8192, 8, 4, 128]); permute_822 = None + sum_89 = torch.ops.aten.sum.dim_IntList(view_1440, [3], True); view_1440 = None + squeeze_28 = torch.ops.aten.squeeze.dim(sum_89, 3); sum_89 = None + view_1441 = torch.ops.aten.view.default(permute_823, [2, 8192, 8, 4, 128]); permute_823 = None + sum_90 = torch.ops.aten.sum.dim_IntList(view_1441, [3], True); view_1441 = None + squeeze_29 = torch.ops.aten.squeeze.dim(sum_90, 3); sum_90 = None + convert_element_type_1859 = torch.ops.prims.convert_element_type.default(squeeze_29, torch.float32); squeeze_29 = None + convert_element_type_1860 = torch.ops.prims.convert_element_type.default(permute_824, torch.float32); permute_824 = None + view_1442 = torch.ops.aten.view.default(convert_element_type_1859, [2, 8192, 8, 64, 2]); convert_element_type_1859 = None + view_as_complex_92 = torch.ops.aten.view_as_complex.default(view_1442); view_1442 = None + mul_556 = torch.ops.aten.mul.Tensor(view_as_complex_92, _conj); view_as_complex_92 = None + view_1443 = torch.ops.aten.view.default(convert_element_type_1860, [2, 8192, 32, 64, 2]); convert_element_type_1860 = None + view_as_complex_93 = torch.ops.aten.view_as_complex.default(view_1443); view_1443 = None + mul_557 = torch.ops.aten.mul.Tensor(view_as_complex_93, _conj); view_as_complex_93 = None + view_as_real_92 = torch.ops.aten.view_as_real.default(mul_556); mul_556 = None + view_1444 = torch.ops.aten.view.default(view_as_real_92, [2, 8192, 8, 128]); view_as_real_92 = None + convert_element_type_1861 = torch.ops.prims.convert_element_type.default(view_1444, torch.bfloat16); view_1444 = None + view_as_real_93 = torch.ops.aten.view_as_real.default(mul_557); mul_557 = None + view_1445 = torch.ops.aten.view.default(view_as_real_93, [2, 8192, 32, 128]); view_as_real_93 = None + convert_element_type_1862 = torch.ops.prims.convert_element_type.default(view_1445, torch.bfloat16); view_1445 = None + view_1446 = torch.ops.aten.view.default(squeeze_28, [2, 8192, 1024]); squeeze_28 = None + view_1447 = torch.ops.aten.view.default(convert_element_type_1861, [2, 8192, 1024]); convert_element_type_1861 = None + view_1448 = torch.ops.aten.view.default(convert_element_type_1862, [2, 8192, 4096]); convert_element_type_1862 = None + view_1449 = torch.ops.aten.view.default(view_1446, [16384, 1024]); view_1446 = None + permute_825 = torch.ops.aten.permute.default(view_1449, [1, 0]) + mm_431 = torch.ops.aten.mm.default(permute_825, view_581); permute_825 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16); primals_160 = None + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 256, '0'); convert_element_type_571 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_157, [1, 0]); wait_tensor_157 = None + permute_827 = torch.ops.aten.permute.default(permute_189, [1, 0]); permute_189 = None + mm_432 = torch.ops.aten.mm.default(view_1449, permute_827); view_1449 = permute_827 = None + view_1450 = torch.ops.aten.view.default(mm_432, [2, 8192, 4096]); mm_432 = None + convert_element_type_1867 = torch.ops.prims.convert_element_type.default(mm_431, torch.float32); mm_431 = None + reduce_scatter_tensor_133 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1867, 'avg', 256, '0'); convert_element_type_1867 = None + wait_tensor_424 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_133); reduce_scatter_tensor_133 = None + view_1451 = torch.ops.aten.view.default(view_1447, [16384, 1024]); view_1447 = None + permute_829 = torch.ops.aten.permute.default(view_1451, [1, 0]) + mm_433 = torch.ops.aten.mm.default(permute_829, view_581); permute_829 = None + permute_831 = torch.ops.aten.permute.default(permute_188, [1, 0]); permute_188 = None + mm_434 = torch.ops.aten.mm.default(view_1451, permute_831); view_1451 = permute_831 = None + view_1452 = torch.ops.aten.view.default(mm_434, [2, 8192, 4096]); mm_434 = None + add_231 = torch.ops.aten.add.Tensor(view_1450, view_1452); view_1450 = view_1452 = None + convert_element_type_1872 = torch.ops.prims.convert_element_type.default(mm_433, torch.float32); mm_433 = None + reduce_scatter_tensor_134 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1872, 'avg', 256, '0'); convert_element_type_1872 = None + wait_tensor_425 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_134); reduce_scatter_tensor_134 = None + view_1453 = torch.ops.aten.view.default(view_1448, [16384, 4096]); view_1448 = None + permute_833 = torch.ops.aten.permute.default(view_1453, [1, 0]) + mm_435 = torch.ops.aten.mm.default(permute_833, view_581); permute_833 = view_581 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16); primals_158 = None + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 256, '0'); convert_element_type_565 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + permute_835 = torch.ops.aten.permute.default(permute_187, [1, 0]); permute_187 = None + mm_436 = torch.ops.aten.mm.default(view_1453, permute_835); view_1453 = permute_835 = None + view_1454 = torch.ops.aten.view.default(mm_436, [2, 8192, 4096]); mm_436 = None + add_232 = torch.ops.aten.add.Tensor(add_231, view_1454); add_231 = view_1454 = None + convert_element_type_1877 = torch.ops.prims.convert_element_type.default(mm_435, torch.float32); mm_435 = None + reduce_scatter_tensor_135 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1877, 'avg', 256, '0'); convert_element_type_1877 = None + wait_tensor_426 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_135); reduce_scatter_tensor_135 = None + convert_element_type_1878 = torch.ops.prims.convert_element_type.default(add_232, torch.float32); add_232 = None + convert_element_type_1880 = torch.ops.prims.convert_element_type.default(wait_tensor_154, torch.float32); wait_tensor_154 = None + mul_558 = torch.ops.aten.mul.Tensor(convert_element_type_1878, convert_element_type_1880); convert_element_type_1880 = None + mul_560 = torch.ops.aten.mul.Tensor(mul_136, mul_558) + sum_91 = torch.ops.aten.sum.dim_IntList(mul_560, [2], True); mul_560 = None + div_30 = torch.ops.aten.div.Tensor(mul_136, 4096) + mul_561 = torch.ops.aten.mul.Tensor(div_30, sum_91); div_30 = sum_91 = None + sub_45 = torch.ops.aten.sub.Tensor(mul_558, mul_561); mul_558 = mul_561 = None + mul_562 = torch.ops.aten.mul.Tensor(sub_45, rsqrt_34); sub_45 = rsqrt_34 = None + mul_563 = torch.ops.aten.mul.Tensor(convert_element_type_1878, mul_136); convert_element_type_1878 = mul_136 = None + sum_92 = torch.ops.aten.sum.dim_IntList(mul_563, [0, 1]); mul_563 = None + convert_element_type_1881 = torch.ops.prims.convert_element_type.default(mul_562, torch.bfloat16); mul_562 = None + add_233 = torch.ops.aten.add.Tensor(add_230, convert_element_type_1881); add_230 = convert_element_type_1881 = None + convert_element_type_default_35 = torch.ops.prims.convert_element_type.default(sum_92, torch.float32); sum_92 = None + reduce_scatter_tensor_136 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_35, 'avg', 256, '0'); convert_element_type_default_35 = None + wait_tensor_427 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_136); reduce_scatter_tensor_136 = None + view_1455 = torch.ops.aten.view.default(add_233, [16384, 4096]) + permute_837 = torch.ops.aten.permute.default(view_1455, [1, 0]) + permute_182 = torch.ops.aten.permute.default(getitem_144, [0, 2, 1, 3]) + view_565 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16); primals_152 = None + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 256, '0'); convert_element_type_545 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + view_567 = torch.ops.aten.view.default(view_565, [16384, 4096]); view_565 = None + mm_115 = torch.ops.aten.mm.default(view_567, permute_183) + view_568 = torch.ops.aten.view.default(mm_115, [2, 8192, 4096]); mm_115 = None + add_65 = torch.ops.aten.add.Tensor(add_63, view_568); view_568 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16); primals_153 = None + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 256, '0'); convert_element_type_548 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32); add_65 = None + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_150) + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + view_571 = torch.ops.aten.view.default(convert_element_type_550, [16384, 4096]); convert_element_type_550 = None + view_572 = torch.ops.aten.view.default(mm_116, [2, 8192, 14336]); mm_116 = None + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_572, torch.float32); view_572 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16); primals_155 = None + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 256, '0'); convert_element_type_556 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_152, [1, 0]); wait_tensor_152 = None + mm_117 = torch.ops.aten.mm.default(view_571, permute_185) + view_575 = torch.ops.aten.view.default(mm_117, [2, 8192, 14336]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_575) + view_577 = torch.ops.aten.view.default(mul_135, [16384, 14336]); mul_135 = None + mm_437 = torch.ops.aten.mm.default(permute_837, view_577); permute_837 = view_577 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16); primals_156 = None + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 256, '0'); convert_element_type_559 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_153, [1, 0]); wait_tensor_153 = None + permute_839 = torch.ops.aten.permute.default(permute_186, [1, 0]); permute_186 = None + mm_438 = torch.ops.aten.mm.default(view_1455, permute_839); view_1455 = permute_839 = None + view_1456 = torch.ops.aten.view.default(mm_438, [2, 8192, 14336]); mm_438 = None + convert_element_type_1888 = torch.ops.prims.convert_element_type.default(mm_437, torch.float32); mm_437 = None + reduce_scatter_tensor_137 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1888, 'avg', 256, '0'); convert_element_type_1888 = None + wait_tensor_428 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_137); reduce_scatter_tensor_137 = None + mul_564 = torch.ops.aten.mul.Tensor(view_1456, convert_element_type_555); convert_element_type_555 = None + mul_565 = torch.ops.aten.mul.Tensor(view_1456, view_575); view_1456 = view_575 = None + view_1457 = torch.ops.aten.view.default(mul_564, [16384, 14336]); mul_564 = None + permute_841 = torch.ops.aten.permute.default(view_1457, [1, 0]) + mm_439 = torch.ops.aten.mm.default(permute_841, view_571); permute_841 = None + permute_843 = torch.ops.aten.permute.default(permute_185, [1, 0]); permute_185 = None + mm_440 = torch.ops.aten.mm.default(view_1457, permute_843); view_1457 = permute_843 = None + view_1458 = torch.ops.aten.view.default(mm_440, [2, 8192, 4096]); mm_440 = None + convert_element_type_1893 = torch.ops.prims.convert_element_type.default(mm_439, torch.float32); mm_439 = None + reduce_scatter_tensor_138 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1893, 'avg', 256, '0'); convert_element_type_1893 = None + wait_tensor_429 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_138); reduce_scatter_tensor_138 = None + convert_element_type_1894 = torch.ops.prims.convert_element_type.default(mul_565, torch.float32); mul_565 = None + neg_15 = torch.ops.aten.neg.default(convert_element_type_554) + exp_15 = torch.ops.aten.exp.default(neg_15); neg_15 = None + add_234 = torch.ops.aten.add.Tensor(exp_15, 1); exp_15 = None + reciprocal_15 = torch.ops.aten.reciprocal.default(add_234); add_234 = None + mul_566 = torch.ops.aten.mul.Tensor(reciprocal_15, 1); reciprocal_15 = None + mul_567 = torch.ops.aten.mul.Tensor(convert_element_type_1894, mul_566); convert_element_type_1894 = None + sub_46 = torch.ops.aten.sub.Tensor(1, mul_566); mul_566 = None + mul_568 = torch.ops.aten.mul.Tensor(convert_element_type_554, sub_46); convert_element_type_554 = sub_46 = None + add_235 = torch.ops.aten.add.Tensor(mul_568, 1); mul_568 = None + mul_569 = torch.ops.aten.mul.Tensor(mul_567, add_235); mul_567 = add_235 = None + convert_element_type_1896 = torch.ops.prims.convert_element_type.default(mul_569, torch.bfloat16); mul_569 = None + view_1459 = torch.ops.aten.view.default(convert_element_type_1896, [16384, 14336]); convert_element_type_1896 = None + permute_845 = torch.ops.aten.permute.default(view_1459, [1, 0]) + mm_441 = torch.ops.aten.mm.default(permute_845, view_571); permute_845 = view_571 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16); primals_154 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 256, '0'); convert_element_type_551 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_151, [1, 0]); wait_tensor_151 = None + permute_847 = torch.ops.aten.permute.default(permute_184, [1, 0]); permute_184 = None + mm_442 = torch.ops.aten.mm.default(view_1459, permute_847); view_1459 = permute_847 = None + view_1460 = torch.ops.aten.view.default(mm_442, [2, 8192, 4096]); mm_442 = None + add_236 = torch.ops.aten.add.Tensor(view_1458, view_1460); view_1458 = view_1460 = None + convert_element_type_1901 = torch.ops.prims.convert_element_type.default(mm_441, torch.float32); mm_441 = None + reduce_scatter_tensor_139 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1901, 'avg', 256, '0'); convert_element_type_1901 = None + wait_tensor_430 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_139); reduce_scatter_tensor_139 = None + convert_element_type_1902 = torch.ops.prims.convert_element_type.default(add_236, torch.float32); add_236 = None + convert_element_type_1904 = torch.ops.prims.convert_element_type.default(wait_tensor_150, torch.float32); wait_tensor_150 = None + mul_570 = torch.ops.aten.mul.Tensor(convert_element_type_1902, convert_element_type_1904); convert_element_type_1904 = None + mul_572 = torch.ops.aten.mul.Tensor(mul_132, mul_570) + sum_93 = torch.ops.aten.sum.dim_IntList(mul_572, [2], True); mul_572 = None + div_31 = torch.ops.aten.div.Tensor(mul_132, 4096) + mul_573 = torch.ops.aten.mul.Tensor(div_31, sum_93); div_31 = sum_93 = None + sub_47 = torch.ops.aten.sub.Tensor(mul_570, mul_573); mul_570 = mul_573 = None + mul_574 = torch.ops.aten.mul.Tensor(sub_47, rsqrt_33); sub_47 = rsqrt_33 = None + mul_575 = torch.ops.aten.mul.Tensor(convert_element_type_1902, mul_132); convert_element_type_1902 = mul_132 = None + sum_94 = torch.ops.aten.sum.dim_IntList(mul_575, [0, 1]); mul_575 = None + convert_element_type_1905 = torch.ops.prims.convert_element_type.default(mul_574, torch.bfloat16); mul_574 = None + add_237 = torch.ops.aten.add.Tensor(add_233, convert_element_type_1905); add_233 = convert_element_type_1905 = None + convert_element_type_default_34 = torch.ops.prims.convert_element_type.default(sum_94, torch.float32); sum_94 = None + reduce_scatter_tensor_140 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_34, 'avg', 256, '0'); convert_element_type_default_34 = None + wait_tensor_431 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_140); reduce_scatter_tensor_140 = None + view_1461 = torch.ops.aten.view.default(add_237, [16384, 4096]) + permute_849 = torch.ops.aten.permute.default(view_1461, [1, 0]) + mm_443 = torch.ops.aten.mm.default(permute_849, view_567); permute_849 = view_567 = None + permute_851 = torch.ops.aten.permute.default(permute_183, [1, 0]); permute_183 = None + mm_444 = torch.ops.aten.mm.default(view_1461, permute_851); view_1461 = permute_851 = None + view_1462 = torch.ops.aten.view.default(mm_444, [2, 8192, 4096]); mm_444 = None + convert_element_type_1912 = torch.ops.prims.convert_element_type.default(mm_443, torch.float32); mm_443 = None + reduce_scatter_tensor_141 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1912, 'avg', 256, '0'); convert_element_type_1912 = None + wait_tensor_432 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_141); reduce_scatter_tensor_141 = None + view_1463 = torch.ops.aten.view.default(view_1462, [2, 8192, 32, 128]); view_1462 = None + permute_853 = torch.ops.aten.permute.default(view_1463, [0, 2, 1, 3]); view_1463 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16); primals_148 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 256, '0'); convert_element_type_529 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32); add_63 = None + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_145) + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + view_547 = torch.ops.aten.view.default(convert_element_type_531, [16384, 4096]); convert_element_type_531 = None + view_548 = torch.ops.aten.view.default(mm_112, [2, 8192, 4096]); mm_112 = None + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16); primals_150 = None + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 256, '0'); convert_element_type_535 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + mm_113 = torch.ops.aten.mm.default(view_547, permute_177) + view_551 = torch.ops.aten.view.default(mm_113, [2, 8192, 1024]); mm_113 = None + view_554 = torch.ops.aten.view.default(mm_114, [2, 8192, 1024]); mm_114 = None + view_555 = torch.ops.aten.view.default(view_548, [2, 8192, -1, 128]); view_548 = None + view_556 = torch.ops.aten.view.default(view_551, [2, 8192, -1, 128]); view_551 = None + view_557 = torch.ops.aten.view.default(view_554, [2, 8192, -1, 128]); view_554 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_555, torch.float32); view_555 = None + view_558 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 32, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_558); view_558 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_556, torch.float32); view_556 = None + view_559 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 8, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_559); view_559 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_16); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_561 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 32, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_16); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_562 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 8, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_561, torch.bfloat16); view_561 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_562, torch.bfloat16); view_562 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 8, 4, 128]); unsqueeze_32 = None + clone_32 = torch.ops.aten.clone.default(expand_32, memory_format = torch.contiguous_format); expand_32 = None + view_563 = torch.ops.aten.view.default(clone_32, [2, 8192, 32, 128]); clone_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_557, 3); view_557 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 8, 4, 128]); unsqueeze_33 = None + clone_33 = torch.ops.aten.clone.default(expand_33, memory_format = torch.contiguous_format); expand_33 = None + view_564 = torch.ops.aten.view.default(clone_33, [2, 8192, 32, 128]); clone_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_563, [0, 2, 1, 3]); view_563 = None + permute_181 = torch.ops.aten.permute.default(view_564, [0, 2, 1, 3]); view_564 = None + _scaled_dot_product_cudnn_attention_backward_15 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_853, permute_179, permute_180, permute_181, getitem_144, getitem_145, getitem_150, getitem_151, None, None, None, 8192, 8192, 0.0, True); permute_853 = permute_179 = permute_180 = permute_181 = getitem_144 = getitem_145 = getitem_150 = getitem_151 = None + getitem_333 = _scaled_dot_product_cudnn_attention_backward_15[0] + getitem_334 = _scaled_dot_product_cudnn_attention_backward_15[1] + getitem_335 = _scaled_dot_product_cudnn_attention_backward_15[2]; _scaled_dot_product_cudnn_attention_backward_15 = None + permute_854 = torch.ops.aten.permute.default(getitem_335, [0, 2, 1, 3]); getitem_335 = None + permute_855 = torch.ops.aten.permute.default(getitem_334, [0, 2, 1, 3]); getitem_334 = None + permute_856 = torch.ops.aten.permute.default(getitem_333, [0, 2, 1, 3]); getitem_333 = None + view_1464 = torch.ops.aten.view.default(permute_854, [2, 8192, 8, 4, 128]); permute_854 = None + sum_95 = torch.ops.aten.sum.dim_IntList(view_1464, [3], True); view_1464 = None + squeeze_30 = torch.ops.aten.squeeze.dim(sum_95, 3); sum_95 = None + view_1465 = torch.ops.aten.view.default(permute_855, [2, 8192, 8, 4, 128]); permute_855 = None + sum_96 = torch.ops.aten.sum.dim_IntList(view_1465, [3], True); view_1465 = None + squeeze_31 = torch.ops.aten.squeeze.dim(sum_96, 3); sum_96 = None + convert_element_type_1913 = torch.ops.prims.convert_element_type.default(squeeze_31, torch.float32); squeeze_31 = None + convert_element_type_1914 = torch.ops.prims.convert_element_type.default(permute_856, torch.float32); permute_856 = None + view_1466 = torch.ops.aten.view.default(convert_element_type_1913, [2, 8192, 8, 64, 2]); convert_element_type_1913 = None + view_as_complex_94 = torch.ops.aten.view_as_complex.default(view_1466); view_1466 = None + mul_576 = torch.ops.aten.mul.Tensor(view_as_complex_94, _conj); view_as_complex_94 = None + view_1467 = torch.ops.aten.view.default(convert_element_type_1914, [2, 8192, 32, 64, 2]); convert_element_type_1914 = None + view_as_complex_95 = torch.ops.aten.view_as_complex.default(view_1467); view_1467 = None + mul_577 = torch.ops.aten.mul.Tensor(view_as_complex_95, _conj); view_as_complex_95 = None + view_as_real_94 = torch.ops.aten.view_as_real.default(mul_576); mul_576 = None + view_1468 = torch.ops.aten.view.default(view_as_real_94, [2, 8192, 8, 128]); view_as_real_94 = None + convert_element_type_1915 = torch.ops.prims.convert_element_type.default(view_1468, torch.bfloat16); view_1468 = None + view_as_real_95 = torch.ops.aten.view_as_real.default(mul_577); mul_577 = None + view_1469 = torch.ops.aten.view.default(view_as_real_95, [2, 8192, 32, 128]); view_as_real_95 = None + convert_element_type_1916 = torch.ops.prims.convert_element_type.default(view_1469, torch.bfloat16); view_1469 = None + view_1470 = torch.ops.aten.view.default(squeeze_30, [2, 8192, 1024]); squeeze_30 = None + view_1471 = torch.ops.aten.view.default(convert_element_type_1915, [2, 8192, 1024]); convert_element_type_1915 = None + view_1472 = torch.ops.aten.view.default(convert_element_type_1916, [2, 8192, 4096]); convert_element_type_1916 = None + view_1473 = torch.ops.aten.view.default(view_1470, [16384, 1024]); view_1470 = None + permute_857 = torch.ops.aten.permute.default(view_1473, [1, 0]) + mm_445 = torch.ops.aten.mm.default(permute_857, view_547); permute_857 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16); primals_151 = None + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 256, '0'); convert_element_type_538 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + permute_859 = torch.ops.aten.permute.default(permute_178, [1, 0]); permute_178 = None + mm_446 = torch.ops.aten.mm.default(view_1473, permute_859); view_1473 = permute_859 = None + view_1474 = torch.ops.aten.view.default(mm_446, [2, 8192, 4096]); mm_446 = None + convert_element_type_1921 = torch.ops.prims.convert_element_type.default(mm_445, torch.float32); mm_445 = None + reduce_scatter_tensor_142 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1921, 'avg', 256, '0'); convert_element_type_1921 = None + wait_tensor_433 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_142); reduce_scatter_tensor_142 = None + view_1475 = torch.ops.aten.view.default(view_1471, [16384, 1024]); view_1471 = None + permute_861 = torch.ops.aten.permute.default(view_1475, [1, 0]) + mm_447 = torch.ops.aten.mm.default(permute_861, view_547); permute_861 = None + permute_863 = torch.ops.aten.permute.default(permute_177, [1, 0]); permute_177 = None + mm_448 = torch.ops.aten.mm.default(view_1475, permute_863); view_1475 = permute_863 = None + view_1476 = torch.ops.aten.view.default(mm_448, [2, 8192, 4096]); mm_448 = None + add_238 = torch.ops.aten.add.Tensor(view_1474, view_1476); view_1474 = view_1476 = None + convert_element_type_1926 = torch.ops.prims.convert_element_type.default(mm_447, torch.float32); mm_447 = None + reduce_scatter_tensor_143 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1926, 'avg', 256, '0'); convert_element_type_1926 = None + wait_tensor_434 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_143); reduce_scatter_tensor_143 = None + view_1477 = torch.ops.aten.view.default(view_1472, [16384, 4096]); view_1472 = None + permute_865 = torch.ops.aten.permute.default(view_1477, [1, 0]) + mm_449 = torch.ops.aten.mm.default(permute_865, view_547); permute_865 = view_547 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16); primals_149 = None + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 256, '0'); convert_element_type_532 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_146, [1, 0]); wait_tensor_146 = None + permute_867 = torch.ops.aten.permute.default(permute_176, [1, 0]); permute_176 = None + mm_450 = torch.ops.aten.mm.default(view_1477, permute_867); view_1477 = permute_867 = None + view_1478 = torch.ops.aten.view.default(mm_450, [2, 8192, 4096]); mm_450 = None + add_239 = torch.ops.aten.add.Tensor(add_238, view_1478); add_238 = view_1478 = None + convert_element_type_1931 = torch.ops.prims.convert_element_type.default(mm_449, torch.float32); mm_449 = None + reduce_scatter_tensor_144 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1931, 'avg', 256, '0'); convert_element_type_1931 = None + wait_tensor_435 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_144); reduce_scatter_tensor_144 = None + convert_element_type_1932 = torch.ops.prims.convert_element_type.default(add_239, torch.float32); add_239 = None + convert_element_type_1934 = torch.ops.prims.convert_element_type.default(wait_tensor_145, torch.float32); wait_tensor_145 = None + mul_578 = torch.ops.aten.mul.Tensor(convert_element_type_1932, convert_element_type_1934); convert_element_type_1934 = None + mul_580 = torch.ops.aten.mul.Tensor(mul_128, mul_578) + sum_97 = torch.ops.aten.sum.dim_IntList(mul_580, [2], True); mul_580 = None + div_32 = torch.ops.aten.div.Tensor(mul_128, 4096) + mul_581 = torch.ops.aten.mul.Tensor(div_32, sum_97); div_32 = sum_97 = None + sub_48 = torch.ops.aten.sub.Tensor(mul_578, mul_581); mul_578 = mul_581 = None + mul_582 = torch.ops.aten.mul.Tensor(sub_48, rsqrt_32); sub_48 = rsqrt_32 = None + mul_583 = torch.ops.aten.mul.Tensor(convert_element_type_1932, mul_128); convert_element_type_1932 = mul_128 = None + sum_98 = torch.ops.aten.sum.dim_IntList(mul_583, [0, 1]); mul_583 = None + convert_element_type_1935 = torch.ops.prims.convert_element_type.default(mul_582, torch.bfloat16); mul_582 = None + add_240 = torch.ops.aten.add.Tensor(add_237, convert_element_type_1935); add_237 = convert_element_type_1935 = None + convert_element_type_default_33 = torch.ops.prims.convert_element_type.default(sum_98, torch.float32); sum_98 = None + reduce_scatter_tensor_145 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_33, 'avg', 256, '0'); convert_element_type_default_33 = None + wait_tensor_436 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_145); reduce_scatter_tensor_145 = None + view_1479 = torch.ops.aten.view.default(add_240, [16384, 4096]) + permute_869 = torch.ops.aten.permute.default(view_1479, [1, 0]) + permute_171 = torch.ops.aten.permute.default(getitem_135, [0, 2, 1, 3]) + view_531 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16); primals_143 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 256, '0'); convert_element_type_512 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_140, [1, 0]); wait_tensor_140 = None + view_533 = torch.ops.aten.view.default(view_531, [16384, 4096]); view_531 = None + mm_108 = torch.ops.aten.mm.default(view_533, permute_172) + view_534 = torch.ops.aten.view.default(mm_108, [2, 8192, 4096]); mm_108 = None + add_61 = torch.ops.aten.add.Tensor(add_59, view_534); view_534 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16); primals_144 = None + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 256, '0'); convert_element_type_515 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32); add_61 = None + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_141) + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + view_537 = torch.ops.aten.view.default(convert_element_type_517, [16384, 4096]); convert_element_type_517 = None + view_538 = torch.ops.aten.view.default(mm_109, [2, 8192, 14336]); mm_109 = None + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_538, torch.float32); view_538 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16); primals_146 = None + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 256, '0'); convert_element_type_523 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + mm_110 = torch.ops.aten.mm.default(view_537, permute_174) + view_541 = torch.ops.aten.view.default(mm_110, [2, 8192, 14336]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_541) + view_543 = torch.ops.aten.view.default(mul_127, [16384, 14336]); mul_127 = None + mm_451 = torch.ops.aten.mm.default(permute_869, view_543); permute_869 = view_543 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16); primals_147 = None + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 256, '0'); convert_element_type_526 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_144, [1, 0]); wait_tensor_144 = None + permute_871 = torch.ops.aten.permute.default(permute_175, [1, 0]); permute_175 = None + mm_452 = torch.ops.aten.mm.default(view_1479, permute_871); view_1479 = permute_871 = None + view_1480 = torch.ops.aten.view.default(mm_452, [2, 8192, 14336]); mm_452 = None + convert_element_type_1942 = torch.ops.prims.convert_element_type.default(mm_451, torch.float32); mm_451 = None + reduce_scatter_tensor_146 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1942, 'avg', 256, '0'); convert_element_type_1942 = None + wait_tensor_437 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_146); reduce_scatter_tensor_146 = None + mul_584 = torch.ops.aten.mul.Tensor(view_1480, convert_element_type_522); convert_element_type_522 = None + mul_585 = torch.ops.aten.mul.Tensor(view_1480, view_541); view_1480 = view_541 = None + view_1481 = torch.ops.aten.view.default(mul_584, [16384, 14336]); mul_584 = None + permute_873 = torch.ops.aten.permute.default(view_1481, [1, 0]) + mm_453 = torch.ops.aten.mm.default(permute_873, view_537); permute_873 = None + permute_875 = torch.ops.aten.permute.default(permute_174, [1, 0]); permute_174 = None + mm_454 = torch.ops.aten.mm.default(view_1481, permute_875); view_1481 = permute_875 = None + view_1482 = torch.ops.aten.view.default(mm_454, [2, 8192, 4096]); mm_454 = None + convert_element_type_1947 = torch.ops.prims.convert_element_type.default(mm_453, torch.float32); mm_453 = None + reduce_scatter_tensor_147 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1947, 'avg', 256, '0'); convert_element_type_1947 = None + wait_tensor_438 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_147); reduce_scatter_tensor_147 = None + convert_element_type_1948 = torch.ops.prims.convert_element_type.default(mul_585, torch.float32); mul_585 = None + neg_16 = torch.ops.aten.neg.default(convert_element_type_521) + exp_16 = torch.ops.aten.exp.default(neg_16); neg_16 = None + add_241 = torch.ops.aten.add.Tensor(exp_16, 1); exp_16 = None + reciprocal_16 = torch.ops.aten.reciprocal.default(add_241); add_241 = None + mul_586 = torch.ops.aten.mul.Tensor(reciprocal_16, 1); reciprocal_16 = None + mul_587 = torch.ops.aten.mul.Tensor(convert_element_type_1948, mul_586); convert_element_type_1948 = None + sub_49 = torch.ops.aten.sub.Tensor(1, mul_586); mul_586 = None + mul_588 = torch.ops.aten.mul.Tensor(convert_element_type_521, sub_49); convert_element_type_521 = sub_49 = None + add_242 = torch.ops.aten.add.Tensor(mul_588, 1); mul_588 = None + mul_589 = torch.ops.aten.mul.Tensor(mul_587, add_242); mul_587 = add_242 = None + convert_element_type_1950 = torch.ops.prims.convert_element_type.default(mul_589, torch.bfloat16); mul_589 = None + view_1483 = torch.ops.aten.view.default(convert_element_type_1950, [16384, 14336]); convert_element_type_1950 = None + permute_877 = torch.ops.aten.permute.default(view_1483, [1, 0]) + mm_455 = torch.ops.aten.mm.default(permute_877, view_537); permute_877 = view_537 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16); primals_145 = None + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 256, '0'); convert_element_type_518 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + permute_879 = torch.ops.aten.permute.default(permute_173, [1, 0]); permute_173 = None + mm_456 = torch.ops.aten.mm.default(view_1483, permute_879); view_1483 = permute_879 = None + view_1484 = torch.ops.aten.view.default(mm_456, [2, 8192, 4096]); mm_456 = None + add_243 = torch.ops.aten.add.Tensor(view_1482, view_1484); view_1482 = view_1484 = None + convert_element_type_1955 = torch.ops.prims.convert_element_type.default(mm_455, torch.float32); mm_455 = None + reduce_scatter_tensor_148 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1955, 'avg', 256, '0'); convert_element_type_1955 = None + wait_tensor_439 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_148); reduce_scatter_tensor_148 = None + convert_element_type_1956 = torch.ops.prims.convert_element_type.default(add_243, torch.float32); add_243 = None + convert_element_type_1958 = torch.ops.prims.convert_element_type.default(wait_tensor_141, torch.float32); wait_tensor_141 = None + mul_590 = torch.ops.aten.mul.Tensor(convert_element_type_1956, convert_element_type_1958); convert_element_type_1958 = None + mul_592 = torch.ops.aten.mul.Tensor(mul_124, mul_590) + sum_99 = torch.ops.aten.sum.dim_IntList(mul_592, [2], True); mul_592 = None + div_33 = torch.ops.aten.div.Tensor(mul_124, 4096) + mul_593 = torch.ops.aten.mul.Tensor(div_33, sum_99); div_33 = sum_99 = None + sub_50 = torch.ops.aten.sub.Tensor(mul_590, mul_593); mul_590 = mul_593 = None + mul_594 = torch.ops.aten.mul.Tensor(sub_50, rsqrt_31); sub_50 = rsqrt_31 = None + mul_595 = torch.ops.aten.mul.Tensor(convert_element_type_1956, mul_124); convert_element_type_1956 = mul_124 = None + sum_100 = torch.ops.aten.sum.dim_IntList(mul_595, [0, 1]); mul_595 = None + convert_element_type_1959 = torch.ops.prims.convert_element_type.default(mul_594, torch.bfloat16); mul_594 = None + add_244 = torch.ops.aten.add.Tensor(add_240, convert_element_type_1959); add_240 = convert_element_type_1959 = None + convert_element_type_default_32 = torch.ops.prims.convert_element_type.default(sum_100, torch.float32); sum_100 = None + reduce_scatter_tensor_149 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_32, 'avg', 256, '0'); convert_element_type_default_32 = None + wait_tensor_440 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_149); reduce_scatter_tensor_149 = None + view_1485 = torch.ops.aten.view.default(add_244, [16384, 4096]) + permute_881 = torch.ops.aten.permute.default(view_1485, [1, 0]) + mm_457 = torch.ops.aten.mm.default(permute_881, view_533); permute_881 = view_533 = None + permute_883 = torch.ops.aten.permute.default(permute_172, [1, 0]); permute_172 = None + mm_458 = torch.ops.aten.mm.default(view_1485, permute_883); view_1485 = permute_883 = None + view_1486 = torch.ops.aten.view.default(mm_458, [2, 8192, 4096]); mm_458 = None + convert_element_type_1966 = torch.ops.prims.convert_element_type.default(mm_457, torch.float32); mm_457 = None + reduce_scatter_tensor_150 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1966, 'avg', 256, '0'); convert_element_type_1966 = None + wait_tensor_441 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_150); reduce_scatter_tensor_150 = None + view_1487 = torch.ops.aten.view.default(view_1486, [2, 8192, 32, 128]); view_1486 = None + permute_885 = torch.ops.aten.permute.default(view_1487, [0, 2, 1, 3]); view_1487 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16); primals_139 = None + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 256, '0'); convert_element_type_496 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32); add_59 = None + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_136) + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + view_513 = torch.ops.aten.view.default(convert_element_type_498, [16384, 4096]); convert_element_type_498 = None + view_514 = torch.ops.aten.view.default(mm_105, [2, 8192, 4096]); mm_105 = None + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16); primals_141 = None + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 256, '0'); convert_element_type_502 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_138, [1, 0]); wait_tensor_138 = None + mm_106 = torch.ops.aten.mm.default(view_513, permute_166) + view_517 = torch.ops.aten.view.default(mm_106, [2, 8192, 1024]); mm_106 = None + view_520 = torch.ops.aten.view.default(mm_107, [2, 8192, 1024]); mm_107 = None + view_521 = torch.ops.aten.view.default(view_514, [2, 8192, -1, 128]); view_514 = None + view_522 = torch.ops.aten.view.default(view_517, [2, 8192, -1, 128]); view_517 = None + view_523 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_521, torch.float32); view_521 = None + view_524 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 32, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_524); view_524 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_522, torch.float32); view_522 = None + view_525 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 8, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_525); view_525 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_16); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_527 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 32, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_16); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_528 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 8, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_527, torch.bfloat16); view_527 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_528, torch.bfloat16); view_528 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 8, 4, 128]); unsqueeze_30 = None + clone_30 = torch.ops.aten.clone.default(expand_30, memory_format = torch.contiguous_format); expand_30 = None + view_529 = torch.ops.aten.view.default(clone_30, [2, 8192, 32, 128]); clone_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_523, 3); view_523 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 8, 4, 128]); unsqueeze_31 = None + clone_31 = torch.ops.aten.clone.default(expand_31, memory_format = torch.contiguous_format); expand_31 = None + view_530 = torch.ops.aten.view.default(clone_31, [2, 8192, 32, 128]); clone_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_529, [0, 2, 1, 3]); view_529 = None + permute_170 = torch.ops.aten.permute.default(view_530, [0, 2, 1, 3]); view_530 = None + _scaled_dot_product_cudnn_attention_backward_16 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_885, permute_168, permute_169, permute_170, getitem_135, getitem_136, getitem_141, getitem_142, None, None, None, 8192, 8192, 0.0, True); permute_885 = permute_168 = permute_169 = permute_170 = getitem_135 = getitem_136 = getitem_141 = getitem_142 = None + getitem_336 = _scaled_dot_product_cudnn_attention_backward_16[0] + getitem_337 = _scaled_dot_product_cudnn_attention_backward_16[1] + getitem_338 = _scaled_dot_product_cudnn_attention_backward_16[2]; _scaled_dot_product_cudnn_attention_backward_16 = None + permute_886 = torch.ops.aten.permute.default(getitem_338, [0, 2, 1, 3]); getitem_338 = None + permute_887 = torch.ops.aten.permute.default(getitem_337, [0, 2, 1, 3]); getitem_337 = None + permute_888 = torch.ops.aten.permute.default(getitem_336, [0, 2, 1, 3]); getitem_336 = None + view_1488 = torch.ops.aten.view.default(permute_886, [2, 8192, 8, 4, 128]); permute_886 = None + sum_101 = torch.ops.aten.sum.dim_IntList(view_1488, [3], True); view_1488 = None + squeeze_32 = torch.ops.aten.squeeze.dim(sum_101, 3); sum_101 = None + view_1489 = torch.ops.aten.view.default(permute_887, [2, 8192, 8, 4, 128]); permute_887 = None + sum_102 = torch.ops.aten.sum.dim_IntList(view_1489, [3], True); view_1489 = None + squeeze_33 = torch.ops.aten.squeeze.dim(sum_102, 3); sum_102 = None + convert_element_type_1967 = torch.ops.prims.convert_element_type.default(squeeze_33, torch.float32); squeeze_33 = None + convert_element_type_1968 = torch.ops.prims.convert_element_type.default(permute_888, torch.float32); permute_888 = None + view_1490 = torch.ops.aten.view.default(convert_element_type_1967, [2, 8192, 8, 64, 2]); convert_element_type_1967 = None + view_as_complex_96 = torch.ops.aten.view_as_complex.default(view_1490); view_1490 = None + mul_596 = torch.ops.aten.mul.Tensor(view_as_complex_96, _conj); view_as_complex_96 = None + view_1491 = torch.ops.aten.view.default(convert_element_type_1968, [2, 8192, 32, 64, 2]); convert_element_type_1968 = None + view_as_complex_97 = torch.ops.aten.view_as_complex.default(view_1491); view_1491 = None + mul_597 = torch.ops.aten.mul.Tensor(view_as_complex_97, _conj); view_as_complex_97 = None + view_as_real_96 = torch.ops.aten.view_as_real.default(mul_596); mul_596 = None + view_1492 = torch.ops.aten.view.default(view_as_real_96, [2, 8192, 8, 128]); view_as_real_96 = None + convert_element_type_1969 = torch.ops.prims.convert_element_type.default(view_1492, torch.bfloat16); view_1492 = None + view_as_real_97 = torch.ops.aten.view_as_real.default(mul_597); mul_597 = None + view_1493 = torch.ops.aten.view.default(view_as_real_97, [2, 8192, 32, 128]); view_as_real_97 = None + convert_element_type_1970 = torch.ops.prims.convert_element_type.default(view_1493, torch.bfloat16); view_1493 = None + view_1494 = torch.ops.aten.view.default(squeeze_32, [2, 8192, 1024]); squeeze_32 = None + view_1495 = torch.ops.aten.view.default(convert_element_type_1969, [2, 8192, 1024]); convert_element_type_1969 = None + view_1496 = torch.ops.aten.view.default(convert_element_type_1970, [2, 8192, 4096]); convert_element_type_1970 = None + view_1497 = torch.ops.aten.view.default(view_1494, [16384, 1024]); view_1494 = None + permute_889 = torch.ops.aten.permute.default(view_1497, [1, 0]) + mm_459 = torch.ops.aten.mm.default(permute_889, view_513); permute_889 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16); primals_142 = None + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 256, '0'); convert_element_type_505 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_139, [1, 0]); wait_tensor_139 = None + permute_891 = torch.ops.aten.permute.default(permute_167, [1, 0]); permute_167 = None + mm_460 = torch.ops.aten.mm.default(view_1497, permute_891); view_1497 = permute_891 = None + view_1498 = torch.ops.aten.view.default(mm_460, [2, 8192, 4096]); mm_460 = None + convert_element_type_1975 = torch.ops.prims.convert_element_type.default(mm_459, torch.float32); mm_459 = None + reduce_scatter_tensor_151 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1975, 'avg', 256, '0'); convert_element_type_1975 = None + wait_tensor_442 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_151); reduce_scatter_tensor_151 = None + view_1499 = torch.ops.aten.view.default(view_1495, [16384, 1024]); view_1495 = None + permute_893 = torch.ops.aten.permute.default(view_1499, [1, 0]) + mm_461 = torch.ops.aten.mm.default(permute_893, view_513); permute_893 = None + permute_895 = torch.ops.aten.permute.default(permute_166, [1, 0]); permute_166 = None + mm_462 = torch.ops.aten.mm.default(view_1499, permute_895); view_1499 = permute_895 = None + view_1500 = torch.ops.aten.view.default(mm_462, [2, 8192, 4096]); mm_462 = None + add_245 = torch.ops.aten.add.Tensor(view_1498, view_1500); view_1498 = view_1500 = None + convert_element_type_1980 = torch.ops.prims.convert_element_type.default(mm_461, torch.float32); mm_461 = None + reduce_scatter_tensor_152 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1980, 'avg', 256, '0'); convert_element_type_1980 = None + wait_tensor_443 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_152); reduce_scatter_tensor_152 = None + view_1501 = torch.ops.aten.view.default(view_1496, [16384, 4096]); view_1496 = None + permute_897 = torch.ops.aten.permute.default(view_1501, [1, 0]) + mm_463 = torch.ops.aten.mm.default(permute_897, view_513); permute_897 = view_513 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16); primals_140 = None + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 256, '0'); convert_element_type_499 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + permute_899 = torch.ops.aten.permute.default(permute_165, [1, 0]); permute_165 = None + mm_464 = torch.ops.aten.mm.default(view_1501, permute_899); view_1501 = permute_899 = None + view_1502 = torch.ops.aten.view.default(mm_464, [2, 8192, 4096]); mm_464 = None + add_246 = torch.ops.aten.add.Tensor(add_245, view_1502); add_245 = view_1502 = None + convert_element_type_1985 = torch.ops.prims.convert_element_type.default(mm_463, torch.float32); mm_463 = None + reduce_scatter_tensor_153 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1985, 'avg', 256, '0'); convert_element_type_1985 = None + wait_tensor_444 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_153); reduce_scatter_tensor_153 = None + convert_element_type_1986 = torch.ops.prims.convert_element_type.default(add_246, torch.float32); add_246 = None + convert_element_type_1988 = torch.ops.prims.convert_element_type.default(wait_tensor_136, torch.float32); wait_tensor_136 = None + mul_598 = torch.ops.aten.mul.Tensor(convert_element_type_1986, convert_element_type_1988); convert_element_type_1988 = None + mul_600 = torch.ops.aten.mul.Tensor(mul_120, mul_598) + sum_103 = torch.ops.aten.sum.dim_IntList(mul_600, [2], True); mul_600 = None + div_34 = torch.ops.aten.div.Tensor(mul_120, 4096) + mul_601 = torch.ops.aten.mul.Tensor(div_34, sum_103); div_34 = sum_103 = None + sub_51 = torch.ops.aten.sub.Tensor(mul_598, mul_601); mul_598 = mul_601 = None + mul_602 = torch.ops.aten.mul.Tensor(sub_51, rsqrt_30); sub_51 = rsqrt_30 = None + mul_603 = torch.ops.aten.mul.Tensor(convert_element_type_1986, mul_120); convert_element_type_1986 = mul_120 = None + sum_104 = torch.ops.aten.sum.dim_IntList(mul_603, [0, 1]); mul_603 = None + convert_element_type_1989 = torch.ops.prims.convert_element_type.default(mul_602, torch.bfloat16); mul_602 = None + add_247 = torch.ops.aten.add.Tensor(add_244, convert_element_type_1989); add_244 = convert_element_type_1989 = None + convert_element_type_default_31 = torch.ops.prims.convert_element_type.default(sum_104, torch.float32); sum_104 = None + reduce_scatter_tensor_154 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_31, 'avg', 256, '0'); convert_element_type_default_31 = None + wait_tensor_445 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_154); reduce_scatter_tensor_154 = None + view_1503 = torch.ops.aten.view.default(add_247, [16384, 4096]) + permute_901 = torch.ops.aten.permute.default(view_1503, [1, 0]) + permute_160 = torch.ops.aten.permute.default(getitem_126, [0, 2, 1, 3]) + view_497 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16); primals_134 = None + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 256, '0'); convert_element_type_479 = None + wait_tensor_131 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_131, [1, 0]); wait_tensor_131 = None + view_499 = torch.ops.aten.view.default(view_497, [16384, 4096]); view_497 = None + mm_101 = torch.ops.aten.mm.default(view_499, permute_161) + view_500 = torch.ops.aten.view.default(mm_101, [2, 8192, 4096]); mm_101 = None + add_57 = torch.ops.aten.add.Tensor(add_55, view_500); view_500 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16); primals_135 = None + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 256, '0'); convert_element_type_482 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32); add_57 = None + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_132) + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + view_503 = torch.ops.aten.view.default(convert_element_type_484, [16384, 4096]); convert_element_type_484 = None + view_504 = torch.ops.aten.view.default(mm_102, [2, 8192, 14336]); mm_102 = None + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_504, torch.float32); view_504 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16); primals_137 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 256, '0'); convert_element_type_490 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + mm_103 = torch.ops.aten.mm.default(view_503, permute_163) + view_507 = torch.ops.aten.view.default(mm_103, [2, 8192, 14336]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_507) + view_509 = torch.ops.aten.view.default(mul_119, [16384, 14336]); mul_119 = None + mm_465 = torch.ops.aten.mm.default(permute_901, view_509); permute_901 = view_509 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16); primals_138 = None + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 256, '0'); convert_element_type_493 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + permute_903 = torch.ops.aten.permute.default(permute_164, [1, 0]); permute_164 = None + mm_466 = torch.ops.aten.mm.default(view_1503, permute_903); view_1503 = permute_903 = None + view_1504 = torch.ops.aten.view.default(mm_466, [2, 8192, 14336]); mm_466 = None + convert_element_type_1996 = torch.ops.prims.convert_element_type.default(mm_465, torch.float32); mm_465 = None + reduce_scatter_tensor_155 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1996, 'avg', 256, '0'); convert_element_type_1996 = None + wait_tensor_446 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_155); reduce_scatter_tensor_155 = None + mul_604 = torch.ops.aten.mul.Tensor(view_1504, convert_element_type_489); convert_element_type_489 = None + mul_605 = torch.ops.aten.mul.Tensor(view_1504, view_507); view_1504 = view_507 = None + view_1505 = torch.ops.aten.view.default(mul_604, [16384, 14336]); mul_604 = None + permute_905 = torch.ops.aten.permute.default(view_1505, [1, 0]) + mm_467 = torch.ops.aten.mm.default(permute_905, view_503); permute_905 = None + permute_907 = torch.ops.aten.permute.default(permute_163, [1, 0]); permute_163 = None + mm_468 = torch.ops.aten.mm.default(view_1505, permute_907); view_1505 = permute_907 = None + view_1506 = torch.ops.aten.view.default(mm_468, [2, 8192, 4096]); mm_468 = None + convert_element_type_2001 = torch.ops.prims.convert_element_type.default(mm_467, torch.float32); mm_467 = None + reduce_scatter_tensor_156 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2001, 'avg', 256, '0'); convert_element_type_2001 = None + wait_tensor_447 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_156); reduce_scatter_tensor_156 = None + convert_element_type_2002 = torch.ops.prims.convert_element_type.default(mul_605, torch.float32); mul_605 = None + neg_17 = torch.ops.aten.neg.default(convert_element_type_488) + exp_17 = torch.ops.aten.exp.default(neg_17); neg_17 = None + add_248 = torch.ops.aten.add.Tensor(exp_17, 1); exp_17 = None + reciprocal_17 = torch.ops.aten.reciprocal.default(add_248); add_248 = None + mul_606 = torch.ops.aten.mul.Tensor(reciprocal_17, 1); reciprocal_17 = None + mul_607 = torch.ops.aten.mul.Tensor(convert_element_type_2002, mul_606); convert_element_type_2002 = None + sub_52 = torch.ops.aten.sub.Tensor(1, mul_606); mul_606 = None + mul_608 = torch.ops.aten.mul.Tensor(convert_element_type_488, sub_52); convert_element_type_488 = sub_52 = None + add_249 = torch.ops.aten.add.Tensor(mul_608, 1); mul_608 = None + mul_609 = torch.ops.aten.mul.Tensor(mul_607, add_249); mul_607 = add_249 = None + convert_element_type_2004 = torch.ops.prims.convert_element_type.default(mul_609, torch.bfloat16); mul_609 = None + view_1507 = torch.ops.aten.view.default(convert_element_type_2004, [16384, 14336]); convert_element_type_2004 = None + permute_909 = torch.ops.aten.permute.default(view_1507, [1, 0]) + mm_469 = torch.ops.aten.mm.default(permute_909, view_503); permute_909 = view_503 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16); primals_136 = None + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 256, '0'); convert_element_type_485 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_133, [1, 0]); wait_tensor_133 = None + permute_911 = torch.ops.aten.permute.default(permute_162, [1, 0]); permute_162 = None + mm_470 = torch.ops.aten.mm.default(view_1507, permute_911); view_1507 = permute_911 = None + view_1508 = torch.ops.aten.view.default(mm_470, [2, 8192, 4096]); mm_470 = None + add_250 = torch.ops.aten.add.Tensor(view_1506, view_1508); view_1506 = view_1508 = None + convert_element_type_2009 = torch.ops.prims.convert_element_type.default(mm_469, torch.float32); mm_469 = None + reduce_scatter_tensor_157 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2009, 'avg', 256, '0'); convert_element_type_2009 = None + wait_tensor_448 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_157); reduce_scatter_tensor_157 = None + convert_element_type_2010 = torch.ops.prims.convert_element_type.default(add_250, torch.float32); add_250 = None + convert_element_type_2012 = torch.ops.prims.convert_element_type.default(wait_tensor_132, torch.float32); wait_tensor_132 = None + mul_610 = torch.ops.aten.mul.Tensor(convert_element_type_2010, convert_element_type_2012); convert_element_type_2012 = None + mul_612 = torch.ops.aten.mul.Tensor(mul_116, mul_610) + sum_105 = torch.ops.aten.sum.dim_IntList(mul_612, [2], True); mul_612 = None + div_35 = torch.ops.aten.div.Tensor(mul_116, 4096) + mul_613 = torch.ops.aten.mul.Tensor(div_35, sum_105); div_35 = sum_105 = None + sub_53 = torch.ops.aten.sub.Tensor(mul_610, mul_613); mul_610 = mul_613 = None + mul_614 = torch.ops.aten.mul.Tensor(sub_53, rsqrt_29); sub_53 = rsqrt_29 = None + mul_615 = torch.ops.aten.mul.Tensor(convert_element_type_2010, mul_116); convert_element_type_2010 = mul_116 = None + sum_106 = torch.ops.aten.sum.dim_IntList(mul_615, [0, 1]); mul_615 = None + convert_element_type_2013 = torch.ops.prims.convert_element_type.default(mul_614, torch.bfloat16); mul_614 = None + add_251 = torch.ops.aten.add.Tensor(add_247, convert_element_type_2013); add_247 = convert_element_type_2013 = None + convert_element_type_default_30 = torch.ops.prims.convert_element_type.default(sum_106, torch.float32); sum_106 = None + reduce_scatter_tensor_158 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_30, 'avg', 256, '0'); convert_element_type_default_30 = None + wait_tensor_449 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_158); reduce_scatter_tensor_158 = None + view_1509 = torch.ops.aten.view.default(add_251, [16384, 4096]) + permute_913 = torch.ops.aten.permute.default(view_1509, [1, 0]) + mm_471 = torch.ops.aten.mm.default(permute_913, view_499); permute_913 = view_499 = None + permute_915 = torch.ops.aten.permute.default(permute_161, [1, 0]); permute_161 = None + mm_472 = torch.ops.aten.mm.default(view_1509, permute_915); view_1509 = permute_915 = None + view_1510 = torch.ops.aten.view.default(mm_472, [2, 8192, 4096]); mm_472 = None + convert_element_type_2020 = torch.ops.prims.convert_element_type.default(mm_471, torch.float32); mm_471 = None + reduce_scatter_tensor_159 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2020, 'avg', 256, '0'); convert_element_type_2020 = None + wait_tensor_450 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_159); reduce_scatter_tensor_159 = None + view_1511 = torch.ops.aten.view.default(view_1510, [2, 8192, 32, 128]); view_1510 = None + permute_917 = torch.ops.aten.permute.default(view_1511, [0, 2, 1, 3]); view_1511 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16); primals_130 = None + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 256, '0'); convert_element_type_463 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32); add_55 = None + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_127) + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + view_479 = torch.ops.aten.view.default(convert_element_type_465, [16384, 4096]); convert_element_type_465 = None + view_480 = torch.ops.aten.view.default(mm_98, [2, 8192, 4096]); mm_98 = None + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16); primals_132 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 256, '0'); convert_element_type_469 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_99 = torch.ops.aten.mm.default(view_479, permute_155) + view_483 = torch.ops.aten.view.default(mm_99, [2, 8192, 1024]); mm_99 = None + view_486 = torch.ops.aten.view.default(mm_100, [2, 8192, 1024]); mm_100 = None + view_487 = torch.ops.aten.view.default(view_480, [2, 8192, -1, 128]); view_480 = None + view_488 = torch.ops.aten.view.default(view_483, [2, 8192, -1, 128]); view_483 = None + view_489 = torch.ops.aten.view.default(view_486, [2, 8192, -1, 128]); view_486 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_487, torch.float32); view_487 = None + view_490 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 32, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_490); view_490 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_488, torch.float32); view_488 = None + view_491 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 8, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_491); view_491 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_16); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_493 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 32, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_16); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_494 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 8, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_493, torch.bfloat16); view_493 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_494, torch.bfloat16); view_494 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 8, 4, 128]); unsqueeze_28 = None + clone_28 = torch.ops.aten.clone.default(expand_28, memory_format = torch.contiguous_format); expand_28 = None + view_495 = torch.ops.aten.view.default(clone_28, [2, 8192, 32, 128]); clone_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_489, 3); view_489 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 8, 4, 128]); unsqueeze_29 = None + clone_29 = torch.ops.aten.clone.default(expand_29, memory_format = torch.contiguous_format); expand_29 = None + view_496 = torch.ops.aten.view.default(clone_29, [2, 8192, 32, 128]); clone_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_495, [0, 2, 1, 3]); view_495 = None + permute_159 = torch.ops.aten.permute.default(view_496, [0, 2, 1, 3]); view_496 = None + _scaled_dot_product_cudnn_attention_backward_17 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_917, permute_157, permute_158, permute_159, getitem_126, getitem_127, getitem_132, getitem_133, None, None, None, 8192, 8192, 0.0, True); permute_917 = permute_157 = permute_158 = permute_159 = getitem_126 = getitem_127 = getitem_132 = getitem_133 = None + getitem_339 = _scaled_dot_product_cudnn_attention_backward_17[0] + getitem_340 = _scaled_dot_product_cudnn_attention_backward_17[1] + getitem_341 = _scaled_dot_product_cudnn_attention_backward_17[2]; _scaled_dot_product_cudnn_attention_backward_17 = None + permute_918 = torch.ops.aten.permute.default(getitem_341, [0, 2, 1, 3]); getitem_341 = None + permute_919 = torch.ops.aten.permute.default(getitem_340, [0, 2, 1, 3]); getitem_340 = None + permute_920 = torch.ops.aten.permute.default(getitem_339, [0, 2, 1, 3]); getitem_339 = None + view_1512 = torch.ops.aten.view.default(permute_918, [2, 8192, 8, 4, 128]); permute_918 = None + sum_107 = torch.ops.aten.sum.dim_IntList(view_1512, [3], True); view_1512 = None + squeeze_34 = torch.ops.aten.squeeze.dim(sum_107, 3); sum_107 = None + view_1513 = torch.ops.aten.view.default(permute_919, [2, 8192, 8, 4, 128]); permute_919 = None + sum_108 = torch.ops.aten.sum.dim_IntList(view_1513, [3], True); view_1513 = None + squeeze_35 = torch.ops.aten.squeeze.dim(sum_108, 3); sum_108 = None + convert_element_type_2021 = torch.ops.prims.convert_element_type.default(squeeze_35, torch.float32); squeeze_35 = None + convert_element_type_2022 = torch.ops.prims.convert_element_type.default(permute_920, torch.float32); permute_920 = None + view_1514 = torch.ops.aten.view.default(convert_element_type_2021, [2, 8192, 8, 64, 2]); convert_element_type_2021 = None + view_as_complex_98 = torch.ops.aten.view_as_complex.default(view_1514); view_1514 = None + mul_616 = torch.ops.aten.mul.Tensor(view_as_complex_98, _conj); view_as_complex_98 = None + view_1515 = torch.ops.aten.view.default(convert_element_type_2022, [2, 8192, 32, 64, 2]); convert_element_type_2022 = None + view_as_complex_99 = torch.ops.aten.view_as_complex.default(view_1515); view_1515 = None + mul_617 = torch.ops.aten.mul.Tensor(view_as_complex_99, _conj); view_as_complex_99 = None + view_as_real_98 = torch.ops.aten.view_as_real.default(mul_616); mul_616 = None + view_1516 = torch.ops.aten.view.default(view_as_real_98, [2, 8192, 8, 128]); view_as_real_98 = None + convert_element_type_2023 = torch.ops.prims.convert_element_type.default(view_1516, torch.bfloat16); view_1516 = None + view_as_real_99 = torch.ops.aten.view_as_real.default(mul_617); mul_617 = None + view_1517 = torch.ops.aten.view.default(view_as_real_99, [2, 8192, 32, 128]); view_as_real_99 = None + convert_element_type_2024 = torch.ops.prims.convert_element_type.default(view_1517, torch.bfloat16); view_1517 = None + view_1518 = torch.ops.aten.view.default(squeeze_34, [2, 8192, 1024]); squeeze_34 = None + view_1519 = torch.ops.aten.view.default(convert_element_type_2023, [2, 8192, 1024]); convert_element_type_2023 = None + view_1520 = torch.ops.aten.view.default(convert_element_type_2024, [2, 8192, 4096]); convert_element_type_2024 = None + view_1521 = torch.ops.aten.view.default(view_1518, [16384, 1024]); view_1518 = None + permute_921 = torch.ops.aten.permute.default(view_1521, [1, 0]) + mm_473 = torch.ops.aten.mm.default(permute_921, view_479); permute_921 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16); primals_133 = None + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 256, '0'); convert_element_type_472 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + permute_923 = torch.ops.aten.permute.default(permute_156, [1, 0]); permute_156 = None + mm_474 = torch.ops.aten.mm.default(view_1521, permute_923); view_1521 = permute_923 = None + view_1522 = torch.ops.aten.view.default(mm_474, [2, 8192, 4096]); mm_474 = None + convert_element_type_2029 = torch.ops.prims.convert_element_type.default(mm_473, torch.float32); mm_473 = None + reduce_scatter_tensor_160 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2029, 'avg', 256, '0'); convert_element_type_2029 = None + wait_tensor_451 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_160); reduce_scatter_tensor_160 = None + view_1523 = torch.ops.aten.view.default(view_1519, [16384, 1024]); view_1519 = None + permute_925 = torch.ops.aten.permute.default(view_1523, [1, 0]) + mm_475 = torch.ops.aten.mm.default(permute_925, view_479); permute_925 = None + permute_927 = torch.ops.aten.permute.default(permute_155, [1, 0]); permute_155 = None + mm_476 = torch.ops.aten.mm.default(view_1523, permute_927); view_1523 = permute_927 = None + view_1524 = torch.ops.aten.view.default(mm_476, [2, 8192, 4096]); mm_476 = None + add_252 = torch.ops.aten.add.Tensor(view_1522, view_1524); view_1522 = view_1524 = None + convert_element_type_2034 = torch.ops.prims.convert_element_type.default(mm_475, torch.float32); mm_475 = None + reduce_scatter_tensor_161 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2034, 'avg', 256, '0'); convert_element_type_2034 = None + wait_tensor_452 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_161); reduce_scatter_tensor_161 = None + view_1525 = torch.ops.aten.view.default(view_1520, [16384, 4096]); view_1520 = None + permute_929 = torch.ops.aten.permute.default(view_1525, [1, 0]) + mm_477 = torch.ops.aten.mm.default(permute_929, view_479); permute_929 = view_479 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16); primals_131 = None + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 256, '0'); convert_element_type_466 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + permute_931 = torch.ops.aten.permute.default(permute_154, [1, 0]); permute_154 = None + mm_478 = torch.ops.aten.mm.default(view_1525, permute_931); view_1525 = permute_931 = None + view_1526 = torch.ops.aten.view.default(mm_478, [2, 8192, 4096]); mm_478 = None + add_253 = torch.ops.aten.add.Tensor(add_252, view_1526); add_252 = view_1526 = None + convert_element_type_2039 = torch.ops.prims.convert_element_type.default(mm_477, torch.float32); mm_477 = None + reduce_scatter_tensor_162 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2039, 'avg', 256, '0'); convert_element_type_2039 = None + wait_tensor_453 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_162); reduce_scatter_tensor_162 = None + convert_element_type_2040 = torch.ops.prims.convert_element_type.default(add_253, torch.float32); add_253 = None + convert_element_type_2042 = torch.ops.prims.convert_element_type.default(wait_tensor_127, torch.float32); wait_tensor_127 = None + mul_618 = torch.ops.aten.mul.Tensor(convert_element_type_2040, convert_element_type_2042); convert_element_type_2042 = None + mul_620 = torch.ops.aten.mul.Tensor(mul_112, mul_618) + sum_109 = torch.ops.aten.sum.dim_IntList(mul_620, [2], True); mul_620 = None + div_36 = torch.ops.aten.div.Tensor(mul_112, 4096) + mul_621 = torch.ops.aten.mul.Tensor(div_36, sum_109); div_36 = sum_109 = None + sub_54 = torch.ops.aten.sub.Tensor(mul_618, mul_621); mul_618 = mul_621 = None + mul_622 = torch.ops.aten.mul.Tensor(sub_54, rsqrt_28); sub_54 = rsqrt_28 = None + mul_623 = torch.ops.aten.mul.Tensor(convert_element_type_2040, mul_112); convert_element_type_2040 = mul_112 = None + sum_110 = torch.ops.aten.sum.dim_IntList(mul_623, [0, 1]); mul_623 = None + convert_element_type_2043 = torch.ops.prims.convert_element_type.default(mul_622, torch.bfloat16); mul_622 = None + add_254 = torch.ops.aten.add.Tensor(add_251, convert_element_type_2043); add_251 = convert_element_type_2043 = None + convert_element_type_default_29 = torch.ops.prims.convert_element_type.default(sum_110, torch.float32); sum_110 = None + reduce_scatter_tensor_163 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_29, 'avg', 256, '0'); convert_element_type_default_29 = None + wait_tensor_454 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_163); reduce_scatter_tensor_163 = None + view_1527 = torch.ops.aten.view.default(add_254, [16384, 4096]) + permute_933 = torch.ops.aten.permute.default(view_1527, [1, 0]) + permute_149 = torch.ops.aten.permute.default(getitem_117, [0, 2, 1, 3]) + view_463 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16); primals_125 = None + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 256, '0'); convert_element_type_446 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + view_465 = torch.ops.aten.view.default(view_463, [16384, 4096]); view_463 = None + mm_94 = torch.ops.aten.mm.default(view_465, permute_150) + view_466 = torch.ops.aten.view.default(mm_94, [2, 8192, 4096]); mm_94 = None + add_53 = torch.ops.aten.add.Tensor(add_51, view_466); view_466 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16); primals_126 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 256, '0'); convert_element_type_449 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32); add_53 = None + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_123) + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + view_469 = torch.ops.aten.view.default(convert_element_type_451, [16384, 4096]); convert_element_type_451 = None + view_470 = torch.ops.aten.view.default(mm_95, [2, 8192, 14336]); mm_95 = None + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_470, torch.float32); view_470 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16); primals_128 = None + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 256, '0'); convert_element_type_457 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_125, [1, 0]); wait_tensor_125 = None + mm_96 = torch.ops.aten.mm.default(view_469, permute_152) + view_473 = torch.ops.aten.view.default(mm_96, [2, 8192, 14336]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_473) + view_475 = torch.ops.aten.view.default(mul_111, [16384, 14336]); mul_111 = None + mm_479 = torch.ops.aten.mm.default(permute_933, view_475); permute_933 = view_475 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16); primals_129 = None + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 256, '0'); convert_element_type_460 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_126, [1, 0]); wait_tensor_126 = None + permute_935 = torch.ops.aten.permute.default(permute_153, [1, 0]); permute_153 = None + mm_480 = torch.ops.aten.mm.default(view_1527, permute_935); view_1527 = permute_935 = None + view_1528 = torch.ops.aten.view.default(mm_480, [2, 8192, 14336]); mm_480 = None + convert_element_type_2050 = torch.ops.prims.convert_element_type.default(mm_479, torch.float32); mm_479 = None + reduce_scatter_tensor_164 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2050, 'avg', 256, '0'); convert_element_type_2050 = None + wait_tensor_455 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_164); reduce_scatter_tensor_164 = None + mul_624 = torch.ops.aten.mul.Tensor(view_1528, convert_element_type_456); convert_element_type_456 = None + mul_625 = torch.ops.aten.mul.Tensor(view_1528, view_473); view_1528 = view_473 = None + view_1529 = torch.ops.aten.view.default(mul_624, [16384, 14336]); mul_624 = None + permute_937 = torch.ops.aten.permute.default(view_1529, [1, 0]) + mm_481 = torch.ops.aten.mm.default(permute_937, view_469); permute_937 = None + permute_939 = torch.ops.aten.permute.default(permute_152, [1, 0]); permute_152 = None + mm_482 = torch.ops.aten.mm.default(view_1529, permute_939); view_1529 = permute_939 = None + view_1530 = torch.ops.aten.view.default(mm_482, [2, 8192, 4096]); mm_482 = None + convert_element_type_2055 = torch.ops.prims.convert_element_type.default(mm_481, torch.float32); mm_481 = None + reduce_scatter_tensor_165 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2055, 'avg', 256, '0'); convert_element_type_2055 = None + wait_tensor_456 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_165); reduce_scatter_tensor_165 = None + convert_element_type_2056 = torch.ops.prims.convert_element_type.default(mul_625, torch.float32); mul_625 = None + neg_18 = torch.ops.aten.neg.default(convert_element_type_455) + exp_18 = torch.ops.aten.exp.default(neg_18); neg_18 = None + add_255 = torch.ops.aten.add.Tensor(exp_18, 1); exp_18 = None + reciprocal_18 = torch.ops.aten.reciprocal.default(add_255); add_255 = None + mul_626 = torch.ops.aten.mul.Tensor(reciprocal_18, 1); reciprocal_18 = None + mul_627 = torch.ops.aten.mul.Tensor(convert_element_type_2056, mul_626); convert_element_type_2056 = None + sub_55 = torch.ops.aten.sub.Tensor(1, mul_626); mul_626 = None + mul_628 = torch.ops.aten.mul.Tensor(convert_element_type_455, sub_55); convert_element_type_455 = sub_55 = None + add_256 = torch.ops.aten.add.Tensor(mul_628, 1); mul_628 = None + mul_629 = torch.ops.aten.mul.Tensor(mul_627, add_256); mul_627 = add_256 = None + convert_element_type_2058 = torch.ops.prims.convert_element_type.default(mul_629, torch.bfloat16); mul_629 = None + view_1531 = torch.ops.aten.view.default(convert_element_type_2058, [16384, 14336]); convert_element_type_2058 = None + permute_941 = torch.ops.aten.permute.default(view_1531, [1, 0]) + mm_483 = torch.ops.aten.mm.default(permute_941, view_469); permute_941 = view_469 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16); primals_127 = None + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 256, '0'); convert_element_type_452 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + permute_943 = torch.ops.aten.permute.default(permute_151, [1, 0]); permute_151 = None + mm_484 = torch.ops.aten.mm.default(view_1531, permute_943); view_1531 = permute_943 = None + view_1532 = torch.ops.aten.view.default(mm_484, [2, 8192, 4096]); mm_484 = None + add_257 = torch.ops.aten.add.Tensor(view_1530, view_1532); view_1530 = view_1532 = None + convert_element_type_2063 = torch.ops.prims.convert_element_type.default(mm_483, torch.float32); mm_483 = None + reduce_scatter_tensor_166 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2063, 'avg', 256, '0'); convert_element_type_2063 = None + wait_tensor_457 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_166); reduce_scatter_tensor_166 = None + convert_element_type_2064 = torch.ops.prims.convert_element_type.default(add_257, torch.float32); add_257 = None + convert_element_type_2066 = torch.ops.prims.convert_element_type.default(wait_tensor_123, torch.float32); wait_tensor_123 = None + mul_630 = torch.ops.aten.mul.Tensor(convert_element_type_2064, convert_element_type_2066); convert_element_type_2066 = None + mul_632 = torch.ops.aten.mul.Tensor(mul_108, mul_630) + sum_111 = torch.ops.aten.sum.dim_IntList(mul_632, [2], True); mul_632 = None + div_37 = torch.ops.aten.div.Tensor(mul_108, 4096) + mul_633 = torch.ops.aten.mul.Tensor(div_37, sum_111); div_37 = sum_111 = None + sub_56 = torch.ops.aten.sub.Tensor(mul_630, mul_633); mul_630 = mul_633 = None + mul_634 = torch.ops.aten.mul.Tensor(sub_56, rsqrt_27); sub_56 = rsqrt_27 = None + mul_635 = torch.ops.aten.mul.Tensor(convert_element_type_2064, mul_108); convert_element_type_2064 = mul_108 = None + sum_112 = torch.ops.aten.sum.dim_IntList(mul_635, [0, 1]); mul_635 = None + convert_element_type_2067 = torch.ops.prims.convert_element_type.default(mul_634, torch.bfloat16); mul_634 = None + add_258 = torch.ops.aten.add.Tensor(add_254, convert_element_type_2067); add_254 = convert_element_type_2067 = None + convert_element_type_default_28 = torch.ops.prims.convert_element_type.default(sum_112, torch.float32); sum_112 = None + reduce_scatter_tensor_167 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_28, 'avg', 256, '0'); convert_element_type_default_28 = None + wait_tensor_458 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_167); reduce_scatter_tensor_167 = None + view_1533 = torch.ops.aten.view.default(add_258, [16384, 4096]) + permute_945 = torch.ops.aten.permute.default(view_1533, [1, 0]) + mm_485 = torch.ops.aten.mm.default(permute_945, view_465); permute_945 = view_465 = None + permute_947 = torch.ops.aten.permute.default(permute_150, [1, 0]); permute_150 = None + mm_486 = torch.ops.aten.mm.default(view_1533, permute_947); view_1533 = permute_947 = None + view_1534 = torch.ops.aten.view.default(mm_486, [2, 8192, 4096]); mm_486 = None + convert_element_type_2074 = torch.ops.prims.convert_element_type.default(mm_485, torch.float32); mm_485 = None + reduce_scatter_tensor_168 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2074, 'avg', 256, '0'); convert_element_type_2074 = None + wait_tensor_459 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_168); reduce_scatter_tensor_168 = None + view_1535 = torch.ops.aten.view.default(view_1534, [2, 8192, 32, 128]); view_1534 = None + permute_949 = torch.ops.aten.permute.default(view_1535, [0, 2, 1, 3]); view_1535 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16); primals_121 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 256, '0'); convert_element_type_430 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32); add_51 = None + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_118) + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + view_445 = torch.ops.aten.view.default(convert_element_type_432, [16384, 4096]); convert_element_type_432 = None + view_446 = torch.ops.aten.view.default(mm_91, [2, 8192, 4096]); mm_91 = None + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16); primals_123 = None + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 256, '0'); convert_element_type_436 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_120, [1, 0]); wait_tensor_120 = None + mm_92 = torch.ops.aten.mm.default(view_445, permute_144) + view_449 = torch.ops.aten.view.default(mm_92, [2, 8192, 1024]); mm_92 = None + view_452 = torch.ops.aten.view.default(mm_93, [2, 8192, 1024]); mm_93 = None + view_453 = torch.ops.aten.view.default(view_446, [2, 8192, -1, 128]); view_446 = None + view_454 = torch.ops.aten.view.default(view_449, [2, 8192, -1, 128]); view_449 = None + view_455 = torch.ops.aten.view.default(view_452, [2, 8192, -1, 128]); view_452 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_453, torch.float32); view_453 = None + view_456 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 32, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_456); view_456 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_454, torch.float32); view_454 = None + view_457 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 8, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_457); view_457 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_16); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_459 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 32, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_16); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_460 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 8, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_459, torch.bfloat16); view_459 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_460, torch.bfloat16); view_460 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 8, 4, 128]); unsqueeze_26 = None + clone_26 = torch.ops.aten.clone.default(expand_26, memory_format = torch.contiguous_format); expand_26 = None + view_461 = torch.ops.aten.view.default(clone_26, [2, 8192, 32, 128]); clone_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_455, 3); view_455 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 8, 4, 128]); unsqueeze_27 = None + clone_27 = torch.ops.aten.clone.default(expand_27, memory_format = torch.contiguous_format); expand_27 = None + view_462 = torch.ops.aten.view.default(clone_27, [2, 8192, 32, 128]); clone_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_461, [0, 2, 1, 3]); view_461 = None + permute_148 = torch.ops.aten.permute.default(view_462, [0, 2, 1, 3]); view_462 = None + _scaled_dot_product_cudnn_attention_backward_18 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_949, permute_146, permute_147, permute_148, getitem_117, getitem_118, getitem_123, getitem_124, None, None, None, 8192, 8192, 0.0, True); permute_949 = permute_146 = permute_147 = permute_148 = getitem_117 = getitem_118 = getitem_123 = getitem_124 = None + getitem_342 = _scaled_dot_product_cudnn_attention_backward_18[0] + getitem_343 = _scaled_dot_product_cudnn_attention_backward_18[1] + getitem_344 = _scaled_dot_product_cudnn_attention_backward_18[2]; _scaled_dot_product_cudnn_attention_backward_18 = None + permute_950 = torch.ops.aten.permute.default(getitem_344, [0, 2, 1, 3]); getitem_344 = None + permute_951 = torch.ops.aten.permute.default(getitem_343, [0, 2, 1, 3]); getitem_343 = None + permute_952 = torch.ops.aten.permute.default(getitem_342, [0, 2, 1, 3]); getitem_342 = None + view_1536 = torch.ops.aten.view.default(permute_950, [2, 8192, 8, 4, 128]); permute_950 = None + sum_113 = torch.ops.aten.sum.dim_IntList(view_1536, [3], True); view_1536 = None + squeeze_36 = torch.ops.aten.squeeze.dim(sum_113, 3); sum_113 = None + view_1537 = torch.ops.aten.view.default(permute_951, [2, 8192, 8, 4, 128]); permute_951 = None + sum_114 = torch.ops.aten.sum.dim_IntList(view_1537, [3], True); view_1537 = None + squeeze_37 = torch.ops.aten.squeeze.dim(sum_114, 3); sum_114 = None + convert_element_type_2075 = torch.ops.prims.convert_element_type.default(squeeze_37, torch.float32); squeeze_37 = None + convert_element_type_2076 = torch.ops.prims.convert_element_type.default(permute_952, torch.float32); permute_952 = None + view_1538 = torch.ops.aten.view.default(convert_element_type_2075, [2, 8192, 8, 64, 2]); convert_element_type_2075 = None + view_as_complex_100 = torch.ops.aten.view_as_complex.default(view_1538); view_1538 = None + mul_636 = torch.ops.aten.mul.Tensor(view_as_complex_100, _conj); view_as_complex_100 = None + view_1539 = torch.ops.aten.view.default(convert_element_type_2076, [2, 8192, 32, 64, 2]); convert_element_type_2076 = None + view_as_complex_101 = torch.ops.aten.view_as_complex.default(view_1539); view_1539 = None + mul_637 = torch.ops.aten.mul.Tensor(view_as_complex_101, _conj); view_as_complex_101 = None + view_as_real_100 = torch.ops.aten.view_as_real.default(mul_636); mul_636 = None + view_1540 = torch.ops.aten.view.default(view_as_real_100, [2, 8192, 8, 128]); view_as_real_100 = None + convert_element_type_2077 = torch.ops.prims.convert_element_type.default(view_1540, torch.bfloat16); view_1540 = None + view_as_real_101 = torch.ops.aten.view_as_real.default(mul_637); mul_637 = None + view_1541 = torch.ops.aten.view.default(view_as_real_101, [2, 8192, 32, 128]); view_as_real_101 = None + convert_element_type_2078 = torch.ops.prims.convert_element_type.default(view_1541, torch.bfloat16); view_1541 = None + view_1542 = torch.ops.aten.view.default(squeeze_36, [2, 8192, 1024]); squeeze_36 = None + view_1543 = torch.ops.aten.view.default(convert_element_type_2077, [2, 8192, 1024]); convert_element_type_2077 = None + view_1544 = torch.ops.aten.view.default(convert_element_type_2078, [2, 8192, 4096]); convert_element_type_2078 = None + view_1545 = torch.ops.aten.view.default(view_1542, [16384, 1024]); view_1542 = None + permute_953 = torch.ops.aten.permute.default(view_1545, [1, 0]) + mm_487 = torch.ops.aten.mm.default(permute_953, view_445); permute_953 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16); primals_124 = None + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 256, '0'); convert_element_type_439 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + permute_955 = torch.ops.aten.permute.default(permute_145, [1, 0]); permute_145 = None + mm_488 = torch.ops.aten.mm.default(view_1545, permute_955); view_1545 = permute_955 = None + view_1546 = torch.ops.aten.view.default(mm_488, [2, 8192, 4096]); mm_488 = None + convert_element_type_2083 = torch.ops.prims.convert_element_type.default(mm_487, torch.float32); mm_487 = None + reduce_scatter_tensor_169 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2083, 'avg', 256, '0'); convert_element_type_2083 = None + wait_tensor_460 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_169); reduce_scatter_tensor_169 = None + view_1547 = torch.ops.aten.view.default(view_1543, [16384, 1024]); view_1543 = None + permute_957 = torch.ops.aten.permute.default(view_1547, [1, 0]) + mm_489 = torch.ops.aten.mm.default(permute_957, view_445); permute_957 = None + permute_959 = torch.ops.aten.permute.default(permute_144, [1, 0]); permute_144 = None + mm_490 = torch.ops.aten.mm.default(view_1547, permute_959); view_1547 = permute_959 = None + view_1548 = torch.ops.aten.view.default(mm_490, [2, 8192, 4096]); mm_490 = None + add_259 = torch.ops.aten.add.Tensor(view_1546, view_1548); view_1546 = view_1548 = None + convert_element_type_2088 = torch.ops.prims.convert_element_type.default(mm_489, torch.float32); mm_489 = None + reduce_scatter_tensor_170 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2088, 'avg', 256, '0'); convert_element_type_2088 = None + wait_tensor_461 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_170); reduce_scatter_tensor_170 = None + view_1549 = torch.ops.aten.view.default(view_1544, [16384, 4096]); view_1544 = None + permute_961 = torch.ops.aten.permute.default(view_1549, [1, 0]) + mm_491 = torch.ops.aten.mm.default(permute_961, view_445); permute_961 = view_445 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16); primals_122 = None + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 256, '0'); convert_element_type_433 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_119, [1, 0]); wait_tensor_119 = None + permute_963 = torch.ops.aten.permute.default(permute_143, [1, 0]); permute_143 = None + mm_492 = torch.ops.aten.mm.default(view_1549, permute_963); view_1549 = permute_963 = None + view_1550 = torch.ops.aten.view.default(mm_492, [2, 8192, 4096]); mm_492 = None + add_260 = torch.ops.aten.add.Tensor(add_259, view_1550); add_259 = view_1550 = None + convert_element_type_2093 = torch.ops.prims.convert_element_type.default(mm_491, torch.float32); mm_491 = None + reduce_scatter_tensor_171 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2093, 'avg', 256, '0'); convert_element_type_2093 = None + wait_tensor_462 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_171); reduce_scatter_tensor_171 = None + convert_element_type_2094 = torch.ops.prims.convert_element_type.default(add_260, torch.float32); add_260 = None + convert_element_type_2096 = torch.ops.prims.convert_element_type.default(wait_tensor_118, torch.float32); wait_tensor_118 = None + mul_638 = torch.ops.aten.mul.Tensor(convert_element_type_2094, convert_element_type_2096); convert_element_type_2096 = None + mul_640 = torch.ops.aten.mul.Tensor(mul_104, mul_638) + sum_115 = torch.ops.aten.sum.dim_IntList(mul_640, [2], True); mul_640 = None + div_38 = torch.ops.aten.div.Tensor(mul_104, 4096) + mul_641 = torch.ops.aten.mul.Tensor(div_38, sum_115); div_38 = sum_115 = None + sub_57 = torch.ops.aten.sub.Tensor(mul_638, mul_641); mul_638 = mul_641 = None + mul_642 = torch.ops.aten.mul.Tensor(sub_57, rsqrt_26); sub_57 = rsqrt_26 = None + mul_643 = torch.ops.aten.mul.Tensor(convert_element_type_2094, mul_104); convert_element_type_2094 = mul_104 = None + sum_116 = torch.ops.aten.sum.dim_IntList(mul_643, [0, 1]); mul_643 = None + convert_element_type_2097 = torch.ops.prims.convert_element_type.default(mul_642, torch.bfloat16); mul_642 = None + add_261 = torch.ops.aten.add.Tensor(add_258, convert_element_type_2097); add_258 = convert_element_type_2097 = None + convert_element_type_default_27 = torch.ops.prims.convert_element_type.default(sum_116, torch.float32); sum_116 = None + reduce_scatter_tensor_172 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_27, 'avg', 256, '0'); convert_element_type_default_27 = None + wait_tensor_463 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_172); reduce_scatter_tensor_172 = None + view_1551 = torch.ops.aten.view.default(add_261, [16384, 4096]) + permute_965 = torch.ops.aten.permute.default(view_1551, [1, 0]) + permute_138 = torch.ops.aten.permute.default(getitem_108, [0, 2, 1, 3]) + view_429 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16); primals_116 = None + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 256, '0'); convert_element_type_413 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_113, [1, 0]); wait_tensor_113 = None + view_431 = torch.ops.aten.view.default(view_429, [16384, 4096]); view_429 = None + mm_87 = torch.ops.aten.mm.default(view_431, permute_139) + view_432 = torch.ops.aten.view.default(mm_87, [2, 8192, 4096]); mm_87 = None + add_49 = torch.ops.aten.add.Tensor(add_47, view_432); view_432 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16); primals_117 = None + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 256, '0'); convert_element_type_416 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32); add_49 = None + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_114) + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + view_435 = torch.ops.aten.view.default(convert_element_type_418, [16384, 4096]); convert_element_type_418 = None + view_436 = torch.ops.aten.view.default(mm_88, [2, 8192, 14336]); mm_88 = None + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_436, torch.float32); view_436 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16); primals_119 = None + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 256, '0'); convert_element_type_424 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_89 = torch.ops.aten.mm.default(view_435, permute_141) + view_439 = torch.ops.aten.view.default(mm_89, [2, 8192, 14336]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_439) + view_441 = torch.ops.aten.view.default(mul_103, [16384, 14336]); mul_103 = None + mm_493 = torch.ops.aten.mm.default(permute_965, view_441); permute_965 = view_441 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16); primals_120 = None + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 256, '0'); convert_element_type_427 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + permute_967 = torch.ops.aten.permute.default(permute_142, [1, 0]); permute_142 = None + mm_494 = torch.ops.aten.mm.default(view_1551, permute_967); view_1551 = permute_967 = None + view_1552 = torch.ops.aten.view.default(mm_494, [2, 8192, 14336]); mm_494 = None + convert_element_type_2104 = torch.ops.prims.convert_element_type.default(mm_493, torch.float32); mm_493 = None + reduce_scatter_tensor_173 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2104, 'avg', 256, '0'); convert_element_type_2104 = None + wait_tensor_464 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_173); reduce_scatter_tensor_173 = None + mul_644 = torch.ops.aten.mul.Tensor(view_1552, convert_element_type_423); convert_element_type_423 = None + mul_645 = torch.ops.aten.mul.Tensor(view_1552, view_439); view_1552 = view_439 = None + view_1553 = torch.ops.aten.view.default(mul_644, [16384, 14336]); mul_644 = None + permute_969 = torch.ops.aten.permute.default(view_1553, [1, 0]) + mm_495 = torch.ops.aten.mm.default(permute_969, view_435); permute_969 = None + permute_971 = torch.ops.aten.permute.default(permute_141, [1, 0]); permute_141 = None + mm_496 = torch.ops.aten.mm.default(view_1553, permute_971); view_1553 = permute_971 = None + view_1554 = torch.ops.aten.view.default(mm_496, [2, 8192, 4096]); mm_496 = None + convert_element_type_2109 = torch.ops.prims.convert_element_type.default(mm_495, torch.float32); mm_495 = None + reduce_scatter_tensor_174 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2109, 'avg', 256, '0'); convert_element_type_2109 = None + wait_tensor_465 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_174); reduce_scatter_tensor_174 = None + convert_element_type_2110 = torch.ops.prims.convert_element_type.default(mul_645, torch.float32); mul_645 = None + neg_19 = torch.ops.aten.neg.default(convert_element_type_422) + exp_19 = torch.ops.aten.exp.default(neg_19); neg_19 = None + add_262 = torch.ops.aten.add.Tensor(exp_19, 1); exp_19 = None + reciprocal_19 = torch.ops.aten.reciprocal.default(add_262); add_262 = None + mul_646 = torch.ops.aten.mul.Tensor(reciprocal_19, 1); reciprocal_19 = None + mul_647 = torch.ops.aten.mul.Tensor(convert_element_type_2110, mul_646); convert_element_type_2110 = None + sub_58 = torch.ops.aten.sub.Tensor(1, mul_646); mul_646 = None + mul_648 = torch.ops.aten.mul.Tensor(convert_element_type_422, sub_58); convert_element_type_422 = sub_58 = None + add_263 = torch.ops.aten.add.Tensor(mul_648, 1); mul_648 = None + mul_649 = torch.ops.aten.mul.Tensor(mul_647, add_263); mul_647 = add_263 = None + convert_element_type_2112 = torch.ops.prims.convert_element_type.default(mul_649, torch.bfloat16); mul_649 = None + view_1555 = torch.ops.aten.view.default(convert_element_type_2112, [16384, 14336]); convert_element_type_2112 = None + permute_973 = torch.ops.aten.permute.default(view_1555, [1, 0]) + mm_497 = torch.ops.aten.mm.default(permute_973, view_435); permute_973 = view_435 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16); primals_118 = None + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 256, '0'); convert_element_type_419 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + permute_975 = torch.ops.aten.permute.default(permute_140, [1, 0]); permute_140 = None + mm_498 = torch.ops.aten.mm.default(view_1555, permute_975); view_1555 = permute_975 = None + view_1556 = torch.ops.aten.view.default(mm_498, [2, 8192, 4096]); mm_498 = None + add_264 = torch.ops.aten.add.Tensor(view_1554, view_1556); view_1554 = view_1556 = None + convert_element_type_2117 = torch.ops.prims.convert_element_type.default(mm_497, torch.float32); mm_497 = None + reduce_scatter_tensor_175 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2117, 'avg', 256, '0'); convert_element_type_2117 = None + wait_tensor_466 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_175); reduce_scatter_tensor_175 = None + convert_element_type_2118 = torch.ops.prims.convert_element_type.default(add_264, torch.float32); add_264 = None + convert_element_type_2120 = torch.ops.prims.convert_element_type.default(wait_tensor_114, torch.float32); wait_tensor_114 = None + mul_650 = torch.ops.aten.mul.Tensor(convert_element_type_2118, convert_element_type_2120); convert_element_type_2120 = None + mul_652 = torch.ops.aten.mul.Tensor(mul_100, mul_650) + sum_117 = torch.ops.aten.sum.dim_IntList(mul_652, [2], True); mul_652 = None + div_39 = torch.ops.aten.div.Tensor(mul_100, 4096) + mul_653 = torch.ops.aten.mul.Tensor(div_39, sum_117); div_39 = sum_117 = None + sub_59 = torch.ops.aten.sub.Tensor(mul_650, mul_653); mul_650 = mul_653 = None + mul_654 = torch.ops.aten.mul.Tensor(sub_59, rsqrt_25); sub_59 = rsqrt_25 = None + mul_655 = torch.ops.aten.mul.Tensor(convert_element_type_2118, mul_100); convert_element_type_2118 = mul_100 = None + sum_118 = torch.ops.aten.sum.dim_IntList(mul_655, [0, 1]); mul_655 = None + convert_element_type_2121 = torch.ops.prims.convert_element_type.default(mul_654, torch.bfloat16); mul_654 = None + add_265 = torch.ops.aten.add.Tensor(add_261, convert_element_type_2121); add_261 = convert_element_type_2121 = None + convert_element_type_default_26 = torch.ops.prims.convert_element_type.default(sum_118, torch.float32); sum_118 = None + reduce_scatter_tensor_176 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_26, 'avg', 256, '0'); convert_element_type_default_26 = None + wait_tensor_467 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_176); reduce_scatter_tensor_176 = None + view_1557 = torch.ops.aten.view.default(add_265, [16384, 4096]) + permute_977 = torch.ops.aten.permute.default(view_1557, [1, 0]) + mm_499 = torch.ops.aten.mm.default(permute_977, view_431); permute_977 = view_431 = None + permute_979 = torch.ops.aten.permute.default(permute_139, [1, 0]); permute_139 = None + mm_500 = torch.ops.aten.mm.default(view_1557, permute_979); view_1557 = permute_979 = None + view_1558 = torch.ops.aten.view.default(mm_500, [2, 8192, 4096]); mm_500 = None + convert_element_type_2128 = torch.ops.prims.convert_element_type.default(mm_499, torch.float32); mm_499 = None + reduce_scatter_tensor_177 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2128, 'avg', 256, '0'); convert_element_type_2128 = None + wait_tensor_468 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_177); reduce_scatter_tensor_177 = None + view_1559 = torch.ops.aten.view.default(view_1558, [2, 8192, 32, 128]); view_1558 = None + permute_981 = torch.ops.aten.permute.default(view_1559, [0, 2, 1, 3]); view_1559 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16); primals_112 = None + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 256, '0'); convert_element_type_397 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32); add_47 = None + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_109) + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + view_411 = torch.ops.aten.view.default(convert_element_type_399, [16384, 4096]); convert_element_type_399 = None + view_412 = torch.ops.aten.view.default(mm_84, [2, 8192, 4096]); mm_84 = None + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16); primals_114 = None + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 256, '0'); convert_element_type_403 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + mm_85 = torch.ops.aten.mm.default(view_411, permute_133) + view_415 = torch.ops.aten.view.default(mm_85, [2, 8192, 1024]); mm_85 = None + view_418 = torch.ops.aten.view.default(mm_86, [2, 8192, 1024]); mm_86 = None + view_419 = torch.ops.aten.view.default(view_412, [2, 8192, -1, 128]); view_412 = None + view_420 = torch.ops.aten.view.default(view_415, [2, 8192, -1, 128]); view_415 = None + view_421 = torch.ops.aten.view.default(view_418, [2, 8192, -1, 128]); view_418 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_419, torch.float32); view_419 = None + view_422 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 32, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_422); view_422 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_420, torch.float32); view_420 = None + view_423 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 8, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_423); view_423 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_16); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_425 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 32, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_16); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_426 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 8, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_425, torch.bfloat16); view_425 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_426, torch.bfloat16); view_426 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 8, 4, 128]); unsqueeze_24 = None + clone_24 = torch.ops.aten.clone.default(expand_24, memory_format = torch.contiguous_format); expand_24 = None + view_427 = torch.ops.aten.view.default(clone_24, [2, 8192, 32, 128]); clone_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_421, 3); view_421 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 8, 4, 128]); unsqueeze_25 = None + clone_25 = torch.ops.aten.clone.default(expand_25, memory_format = torch.contiguous_format); expand_25 = None + view_428 = torch.ops.aten.view.default(clone_25, [2, 8192, 32, 128]); clone_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_427, [0, 2, 1, 3]); view_427 = None + permute_137 = torch.ops.aten.permute.default(view_428, [0, 2, 1, 3]); view_428 = None + _scaled_dot_product_cudnn_attention_backward_19 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_981, permute_135, permute_136, permute_137, getitem_108, getitem_109, getitem_114, getitem_115, None, None, None, 8192, 8192, 0.0, True); permute_981 = permute_135 = permute_136 = permute_137 = getitem_108 = getitem_109 = getitem_114 = getitem_115 = None + getitem_345 = _scaled_dot_product_cudnn_attention_backward_19[0] + getitem_346 = _scaled_dot_product_cudnn_attention_backward_19[1] + getitem_347 = _scaled_dot_product_cudnn_attention_backward_19[2]; _scaled_dot_product_cudnn_attention_backward_19 = None + permute_982 = torch.ops.aten.permute.default(getitem_347, [0, 2, 1, 3]); getitem_347 = None + permute_983 = torch.ops.aten.permute.default(getitem_346, [0, 2, 1, 3]); getitem_346 = None + permute_984 = torch.ops.aten.permute.default(getitem_345, [0, 2, 1, 3]); getitem_345 = None + view_1560 = torch.ops.aten.view.default(permute_982, [2, 8192, 8, 4, 128]); permute_982 = None + sum_119 = torch.ops.aten.sum.dim_IntList(view_1560, [3], True); view_1560 = None + squeeze_38 = torch.ops.aten.squeeze.dim(sum_119, 3); sum_119 = None + view_1561 = torch.ops.aten.view.default(permute_983, [2, 8192, 8, 4, 128]); permute_983 = None + sum_120 = torch.ops.aten.sum.dim_IntList(view_1561, [3], True); view_1561 = None + squeeze_39 = torch.ops.aten.squeeze.dim(sum_120, 3); sum_120 = None + convert_element_type_2129 = torch.ops.prims.convert_element_type.default(squeeze_39, torch.float32); squeeze_39 = None + convert_element_type_2130 = torch.ops.prims.convert_element_type.default(permute_984, torch.float32); permute_984 = None + view_1562 = torch.ops.aten.view.default(convert_element_type_2129, [2, 8192, 8, 64, 2]); convert_element_type_2129 = None + view_as_complex_102 = torch.ops.aten.view_as_complex.default(view_1562); view_1562 = None + mul_656 = torch.ops.aten.mul.Tensor(view_as_complex_102, _conj); view_as_complex_102 = None + view_1563 = torch.ops.aten.view.default(convert_element_type_2130, [2, 8192, 32, 64, 2]); convert_element_type_2130 = None + view_as_complex_103 = torch.ops.aten.view_as_complex.default(view_1563); view_1563 = None + mul_657 = torch.ops.aten.mul.Tensor(view_as_complex_103, _conj); view_as_complex_103 = None + view_as_real_102 = torch.ops.aten.view_as_real.default(mul_656); mul_656 = None + view_1564 = torch.ops.aten.view.default(view_as_real_102, [2, 8192, 8, 128]); view_as_real_102 = None + convert_element_type_2131 = torch.ops.prims.convert_element_type.default(view_1564, torch.bfloat16); view_1564 = None + view_as_real_103 = torch.ops.aten.view_as_real.default(mul_657); mul_657 = None + view_1565 = torch.ops.aten.view.default(view_as_real_103, [2, 8192, 32, 128]); view_as_real_103 = None + convert_element_type_2132 = torch.ops.prims.convert_element_type.default(view_1565, torch.bfloat16); view_1565 = None + view_1566 = torch.ops.aten.view.default(squeeze_38, [2, 8192, 1024]); squeeze_38 = None + view_1567 = torch.ops.aten.view.default(convert_element_type_2131, [2, 8192, 1024]); convert_element_type_2131 = None + view_1568 = torch.ops.aten.view.default(convert_element_type_2132, [2, 8192, 4096]); convert_element_type_2132 = None + view_1569 = torch.ops.aten.view.default(view_1566, [16384, 1024]); view_1566 = None + permute_985 = torch.ops.aten.permute.default(view_1569, [1, 0]) + mm_501 = torch.ops.aten.mm.default(permute_985, view_411); permute_985 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16); primals_115 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 256, '0'); convert_element_type_406 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_112, [1, 0]); wait_tensor_112 = None + permute_987 = torch.ops.aten.permute.default(permute_134, [1, 0]); permute_134 = None + mm_502 = torch.ops.aten.mm.default(view_1569, permute_987); view_1569 = permute_987 = None + view_1570 = torch.ops.aten.view.default(mm_502, [2, 8192, 4096]); mm_502 = None + convert_element_type_2137 = torch.ops.prims.convert_element_type.default(mm_501, torch.float32); mm_501 = None + reduce_scatter_tensor_178 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2137, 'avg', 256, '0'); convert_element_type_2137 = None + wait_tensor_469 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_178); reduce_scatter_tensor_178 = None + view_1571 = torch.ops.aten.view.default(view_1567, [16384, 1024]); view_1567 = None + permute_989 = torch.ops.aten.permute.default(view_1571, [1, 0]) + mm_503 = torch.ops.aten.mm.default(permute_989, view_411); permute_989 = None + permute_991 = torch.ops.aten.permute.default(permute_133, [1, 0]); permute_133 = None + mm_504 = torch.ops.aten.mm.default(view_1571, permute_991); view_1571 = permute_991 = None + view_1572 = torch.ops.aten.view.default(mm_504, [2, 8192, 4096]); mm_504 = None + add_266 = torch.ops.aten.add.Tensor(view_1570, view_1572); view_1570 = view_1572 = None + convert_element_type_2142 = torch.ops.prims.convert_element_type.default(mm_503, torch.float32); mm_503 = None + reduce_scatter_tensor_179 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2142, 'avg', 256, '0'); convert_element_type_2142 = None + wait_tensor_470 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_179); reduce_scatter_tensor_179 = None + view_1573 = torch.ops.aten.view.default(view_1568, [16384, 4096]); view_1568 = None + permute_993 = torch.ops.aten.permute.default(view_1573, [1, 0]) + mm_505 = torch.ops.aten.mm.default(permute_993, view_411); permute_993 = view_411 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16); primals_113 = None + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 256, '0'); convert_element_type_400 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + permute_995 = torch.ops.aten.permute.default(permute_132, [1, 0]); permute_132 = None + mm_506 = torch.ops.aten.mm.default(view_1573, permute_995); view_1573 = permute_995 = None + view_1574 = torch.ops.aten.view.default(mm_506, [2, 8192, 4096]); mm_506 = None + add_267 = torch.ops.aten.add.Tensor(add_266, view_1574); add_266 = view_1574 = None + convert_element_type_2147 = torch.ops.prims.convert_element_type.default(mm_505, torch.float32); mm_505 = None + reduce_scatter_tensor_180 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2147, 'avg', 256, '0'); convert_element_type_2147 = None + wait_tensor_471 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_180); reduce_scatter_tensor_180 = None + convert_element_type_2148 = torch.ops.prims.convert_element_type.default(add_267, torch.float32); add_267 = None + convert_element_type_2150 = torch.ops.prims.convert_element_type.default(wait_tensor_109, torch.float32); wait_tensor_109 = None + mul_658 = torch.ops.aten.mul.Tensor(convert_element_type_2148, convert_element_type_2150); convert_element_type_2150 = None + mul_660 = torch.ops.aten.mul.Tensor(mul_96, mul_658) + sum_121 = torch.ops.aten.sum.dim_IntList(mul_660, [2], True); mul_660 = None + div_40 = torch.ops.aten.div.Tensor(mul_96, 4096) + mul_661 = torch.ops.aten.mul.Tensor(div_40, sum_121); div_40 = sum_121 = None + sub_60 = torch.ops.aten.sub.Tensor(mul_658, mul_661); mul_658 = mul_661 = None + mul_662 = torch.ops.aten.mul.Tensor(sub_60, rsqrt_24); sub_60 = rsqrt_24 = None + mul_663 = torch.ops.aten.mul.Tensor(convert_element_type_2148, mul_96); convert_element_type_2148 = mul_96 = None + sum_122 = torch.ops.aten.sum.dim_IntList(mul_663, [0, 1]); mul_663 = None + convert_element_type_2151 = torch.ops.prims.convert_element_type.default(mul_662, torch.bfloat16); mul_662 = None + add_268 = torch.ops.aten.add.Tensor(add_265, convert_element_type_2151); add_265 = convert_element_type_2151 = None + convert_element_type_default_25 = torch.ops.prims.convert_element_type.default(sum_122, torch.float32); sum_122 = None + reduce_scatter_tensor_181 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_25, 'avg', 256, '0'); convert_element_type_default_25 = None + wait_tensor_472 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_181); reduce_scatter_tensor_181 = None + view_1575 = torch.ops.aten.view.default(add_268, [16384, 4096]) + permute_997 = torch.ops.aten.permute.default(view_1575, [1, 0]) + permute_127 = torch.ops.aten.permute.default(getitem_99, [0, 2, 1, 3]) + view_395 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16); primals_107 = None + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 256, '0'); convert_element_type_380 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + view_397 = torch.ops.aten.view.default(view_395, [16384, 4096]); view_395 = None + mm_80 = torch.ops.aten.mm.default(view_397, permute_128) + view_398 = torch.ops.aten.view.default(mm_80, [2, 8192, 4096]); mm_80 = None + add_45 = torch.ops.aten.add.Tensor(add_43, view_398); view_398 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16); primals_108 = None + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 256, '0'); convert_element_type_383 = None + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32); add_45 = None + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_105) + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + view_401 = torch.ops.aten.view.default(convert_element_type_385, [16384, 4096]); convert_element_type_385 = None + view_402 = torch.ops.aten.view.default(mm_81, [2, 8192, 14336]); mm_81 = None + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_402, torch.float32); view_402 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16); primals_110 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 256, '0'); convert_element_type_391 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_107, [1, 0]); wait_tensor_107 = None + mm_82 = torch.ops.aten.mm.default(view_401, permute_130) + view_405 = torch.ops.aten.view.default(mm_82, [2, 8192, 14336]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_405) + view_407 = torch.ops.aten.view.default(mul_95, [16384, 14336]); mul_95 = None + mm_507 = torch.ops.aten.mm.default(permute_997, view_407); permute_997 = view_407 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16); primals_111 = None + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 256, '0'); convert_element_type_394 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + permute_999 = torch.ops.aten.permute.default(permute_131, [1, 0]); permute_131 = None + mm_508 = torch.ops.aten.mm.default(view_1575, permute_999); view_1575 = permute_999 = None + view_1576 = torch.ops.aten.view.default(mm_508, [2, 8192, 14336]); mm_508 = None + convert_element_type_2158 = torch.ops.prims.convert_element_type.default(mm_507, torch.float32); mm_507 = None + reduce_scatter_tensor_182 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2158, 'avg', 256, '0'); convert_element_type_2158 = None + wait_tensor_473 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_182); reduce_scatter_tensor_182 = None + mul_664 = torch.ops.aten.mul.Tensor(view_1576, convert_element_type_390); convert_element_type_390 = None + mul_665 = torch.ops.aten.mul.Tensor(view_1576, view_405); view_1576 = view_405 = None + view_1577 = torch.ops.aten.view.default(mul_664, [16384, 14336]); mul_664 = None + permute_1001 = torch.ops.aten.permute.default(view_1577, [1, 0]) + mm_509 = torch.ops.aten.mm.default(permute_1001, view_401); permute_1001 = None + permute_1003 = torch.ops.aten.permute.default(permute_130, [1, 0]); permute_130 = None + mm_510 = torch.ops.aten.mm.default(view_1577, permute_1003); view_1577 = permute_1003 = None + view_1578 = torch.ops.aten.view.default(mm_510, [2, 8192, 4096]); mm_510 = None + convert_element_type_2163 = torch.ops.prims.convert_element_type.default(mm_509, torch.float32); mm_509 = None + reduce_scatter_tensor_183 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2163, 'avg', 256, '0'); convert_element_type_2163 = None + wait_tensor_474 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_183); reduce_scatter_tensor_183 = None + convert_element_type_2164 = torch.ops.prims.convert_element_type.default(mul_665, torch.float32); mul_665 = None + neg_20 = torch.ops.aten.neg.default(convert_element_type_389) + exp_20 = torch.ops.aten.exp.default(neg_20); neg_20 = None + add_269 = torch.ops.aten.add.Tensor(exp_20, 1); exp_20 = None + reciprocal_20 = torch.ops.aten.reciprocal.default(add_269); add_269 = None + mul_666 = torch.ops.aten.mul.Tensor(reciprocal_20, 1); reciprocal_20 = None + mul_667 = torch.ops.aten.mul.Tensor(convert_element_type_2164, mul_666); convert_element_type_2164 = None + sub_61 = torch.ops.aten.sub.Tensor(1, mul_666); mul_666 = None + mul_668 = torch.ops.aten.mul.Tensor(convert_element_type_389, sub_61); convert_element_type_389 = sub_61 = None + add_270 = torch.ops.aten.add.Tensor(mul_668, 1); mul_668 = None + mul_669 = torch.ops.aten.mul.Tensor(mul_667, add_270); mul_667 = add_270 = None + convert_element_type_2166 = torch.ops.prims.convert_element_type.default(mul_669, torch.bfloat16); mul_669 = None + view_1579 = torch.ops.aten.view.default(convert_element_type_2166, [16384, 14336]); convert_element_type_2166 = None + permute_1005 = torch.ops.aten.permute.default(view_1579, [1, 0]) + mm_511 = torch.ops.aten.mm.default(permute_1005, view_401); permute_1005 = view_401 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16); primals_109 = None + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 256, '0'); convert_element_type_386 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_106, [1, 0]); wait_tensor_106 = None + permute_1007 = torch.ops.aten.permute.default(permute_129, [1, 0]); permute_129 = None + mm_512 = torch.ops.aten.mm.default(view_1579, permute_1007); view_1579 = permute_1007 = None + view_1580 = torch.ops.aten.view.default(mm_512, [2, 8192, 4096]); mm_512 = None + add_271 = torch.ops.aten.add.Tensor(view_1578, view_1580); view_1578 = view_1580 = None + convert_element_type_2171 = torch.ops.prims.convert_element_type.default(mm_511, torch.float32); mm_511 = None + reduce_scatter_tensor_184 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2171, 'avg', 256, '0'); convert_element_type_2171 = None + wait_tensor_475 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_184); reduce_scatter_tensor_184 = None + convert_element_type_2172 = torch.ops.prims.convert_element_type.default(add_271, torch.float32); add_271 = None + convert_element_type_2174 = torch.ops.prims.convert_element_type.default(wait_tensor_105, torch.float32); wait_tensor_105 = None + mul_670 = torch.ops.aten.mul.Tensor(convert_element_type_2172, convert_element_type_2174); convert_element_type_2174 = None + mul_672 = torch.ops.aten.mul.Tensor(mul_92, mul_670) + sum_123 = torch.ops.aten.sum.dim_IntList(mul_672, [2], True); mul_672 = None + div_41 = torch.ops.aten.div.Tensor(mul_92, 4096) + mul_673 = torch.ops.aten.mul.Tensor(div_41, sum_123); div_41 = sum_123 = None + sub_62 = torch.ops.aten.sub.Tensor(mul_670, mul_673); mul_670 = mul_673 = None + mul_674 = torch.ops.aten.mul.Tensor(sub_62, rsqrt_23); sub_62 = rsqrt_23 = None + mul_675 = torch.ops.aten.mul.Tensor(convert_element_type_2172, mul_92); convert_element_type_2172 = mul_92 = None + sum_124 = torch.ops.aten.sum.dim_IntList(mul_675, [0, 1]); mul_675 = None + convert_element_type_2175 = torch.ops.prims.convert_element_type.default(mul_674, torch.bfloat16); mul_674 = None + add_272 = torch.ops.aten.add.Tensor(add_268, convert_element_type_2175); add_268 = convert_element_type_2175 = None + convert_element_type_default_24 = torch.ops.prims.convert_element_type.default(sum_124, torch.float32); sum_124 = None + reduce_scatter_tensor_185 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_24, 'avg', 256, '0'); convert_element_type_default_24 = None + wait_tensor_476 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_185); reduce_scatter_tensor_185 = None + view_1581 = torch.ops.aten.view.default(add_272, [16384, 4096]) + permute_1009 = torch.ops.aten.permute.default(view_1581, [1, 0]) + mm_513 = torch.ops.aten.mm.default(permute_1009, view_397); permute_1009 = view_397 = None + permute_1011 = torch.ops.aten.permute.default(permute_128, [1, 0]); permute_128 = None + mm_514 = torch.ops.aten.mm.default(view_1581, permute_1011); view_1581 = permute_1011 = None + view_1582 = torch.ops.aten.view.default(mm_514, [2, 8192, 4096]); mm_514 = None + convert_element_type_2182 = torch.ops.prims.convert_element_type.default(mm_513, torch.float32); mm_513 = None + reduce_scatter_tensor_186 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2182, 'avg', 256, '0'); convert_element_type_2182 = None + wait_tensor_477 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_186); reduce_scatter_tensor_186 = None + view_1583 = torch.ops.aten.view.default(view_1582, [2, 8192, 32, 128]); view_1582 = None + permute_1013 = torch.ops.aten.permute.default(view_1583, [0, 2, 1, 3]); view_1583 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16); primals_103 = None + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 256, '0'); convert_element_type_364 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32); add_43 = None + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_100) + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + view_377 = torch.ops.aten.view.default(convert_element_type_366, [16384, 4096]); convert_element_type_366 = None + view_378 = torch.ops.aten.view.default(mm_77, [2, 8192, 4096]); mm_77 = None + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16); primals_105 = None + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 256, '0'); convert_element_type_370 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + mm_78 = torch.ops.aten.mm.default(view_377, permute_122) + view_381 = torch.ops.aten.view.default(mm_78, [2, 8192, 1024]); mm_78 = None + view_384 = torch.ops.aten.view.default(mm_79, [2, 8192, 1024]); mm_79 = None + view_385 = torch.ops.aten.view.default(view_378, [2, 8192, -1, 128]); view_378 = None + view_386 = torch.ops.aten.view.default(view_381, [2, 8192, -1, 128]); view_381 = None + view_387 = torch.ops.aten.view.default(view_384, [2, 8192, -1, 128]); view_384 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_385, torch.float32); view_385 = None + view_388 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 32, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_388); view_388 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_386, torch.float32); view_386 = None + view_389 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 8, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_389); view_389 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_16); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_391 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 32, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_16); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_392 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 8, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_391, torch.bfloat16); view_391 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_392, torch.bfloat16); view_392 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 8, 4, 128]); unsqueeze_22 = None + clone_22 = torch.ops.aten.clone.default(expand_22, memory_format = torch.contiguous_format); expand_22 = None + view_393 = torch.ops.aten.view.default(clone_22, [2, 8192, 32, 128]); clone_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_387, 3); view_387 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 8, 4, 128]); unsqueeze_23 = None + clone_23 = torch.ops.aten.clone.default(expand_23, memory_format = torch.contiguous_format); expand_23 = None + view_394 = torch.ops.aten.view.default(clone_23, [2, 8192, 32, 128]); clone_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_393, [0, 2, 1, 3]); view_393 = None + permute_126 = torch.ops.aten.permute.default(view_394, [0, 2, 1, 3]); view_394 = None + _scaled_dot_product_cudnn_attention_backward_20 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1013, permute_124, permute_125, permute_126, getitem_99, getitem_100, getitem_105, getitem_106, None, None, None, 8192, 8192, 0.0, True); permute_1013 = permute_124 = permute_125 = permute_126 = getitem_99 = getitem_100 = getitem_105 = getitem_106 = None + getitem_348 = _scaled_dot_product_cudnn_attention_backward_20[0] + getitem_349 = _scaled_dot_product_cudnn_attention_backward_20[1] + getitem_350 = _scaled_dot_product_cudnn_attention_backward_20[2]; _scaled_dot_product_cudnn_attention_backward_20 = None + permute_1014 = torch.ops.aten.permute.default(getitem_350, [0, 2, 1, 3]); getitem_350 = None + permute_1015 = torch.ops.aten.permute.default(getitem_349, [0, 2, 1, 3]); getitem_349 = None + permute_1016 = torch.ops.aten.permute.default(getitem_348, [0, 2, 1, 3]); getitem_348 = None + view_1584 = torch.ops.aten.view.default(permute_1014, [2, 8192, 8, 4, 128]); permute_1014 = None + sum_125 = torch.ops.aten.sum.dim_IntList(view_1584, [3], True); view_1584 = None + squeeze_40 = torch.ops.aten.squeeze.dim(sum_125, 3); sum_125 = None + view_1585 = torch.ops.aten.view.default(permute_1015, [2, 8192, 8, 4, 128]); permute_1015 = None + sum_126 = torch.ops.aten.sum.dim_IntList(view_1585, [3], True); view_1585 = None + squeeze_41 = torch.ops.aten.squeeze.dim(sum_126, 3); sum_126 = None + convert_element_type_2183 = torch.ops.prims.convert_element_type.default(squeeze_41, torch.float32); squeeze_41 = None + convert_element_type_2184 = torch.ops.prims.convert_element_type.default(permute_1016, torch.float32); permute_1016 = None + view_1586 = torch.ops.aten.view.default(convert_element_type_2183, [2, 8192, 8, 64, 2]); convert_element_type_2183 = None + view_as_complex_104 = torch.ops.aten.view_as_complex.default(view_1586); view_1586 = None + mul_676 = torch.ops.aten.mul.Tensor(view_as_complex_104, _conj); view_as_complex_104 = None + view_1587 = torch.ops.aten.view.default(convert_element_type_2184, [2, 8192, 32, 64, 2]); convert_element_type_2184 = None + view_as_complex_105 = torch.ops.aten.view_as_complex.default(view_1587); view_1587 = None + mul_677 = torch.ops.aten.mul.Tensor(view_as_complex_105, _conj); view_as_complex_105 = None + view_as_real_104 = torch.ops.aten.view_as_real.default(mul_676); mul_676 = None + view_1588 = torch.ops.aten.view.default(view_as_real_104, [2, 8192, 8, 128]); view_as_real_104 = None + convert_element_type_2185 = torch.ops.prims.convert_element_type.default(view_1588, torch.bfloat16); view_1588 = None + view_as_real_105 = torch.ops.aten.view_as_real.default(mul_677); mul_677 = None + view_1589 = torch.ops.aten.view.default(view_as_real_105, [2, 8192, 32, 128]); view_as_real_105 = None + convert_element_type_2186 = torch.ops.prims.convert_element_type.default(view_1589, torch.bfloat16); view_1589 = None + view_1590 = torch.ops.aten.view.default(squeeze_40, [2, 8192, 1024]); squeeze_40 = None + view_1591 = torch.ops.aten.view.default(convert_element_type_2185, [2, 8192, 1024]); convert_element_type_2185 = None + view_1592 = torch.ops.aten.view.default(convert_element_type_2186, [2, 8192, 4096]); convert_element_type_2186 = None + view_1593 = torch.ops.aten.view.default(view_1590, [16384, 1024]); view_1590 = None + permute_1017 = torch.ops.aten.permute.default(view_1593, [1, 0]) + mm_515 = torch.ops.aten.mm.default(permute_1017, view_377); permute_1017 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16); primals_106 = None + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 256, '0'); convert_element_type_373 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + permute_1019 = torch.ops.aten.permute.default(permute_123, [1, 0]); permute_123 = None + mm_516 = torch.ops.aten.mm.default(view_1593, permute_1019); view_1593 = permute_1019 = None + view_1594 = torch.ops.aten.view.default(mm_516, [2, 8192, 4096]); mm_516 = None + convert_element_type_2191 = torch.ops.prims.convert_element_type.default(mm_515, torch.float32); mm_515 = None + reduce_scatter_tensor_187 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2191, 'avg', 256, '0'); convert_element_type_2191 = None + wait_tensor_478 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_187); reduce_scatter_tensor_187 = None + view_1595 = torch.ops.aten.view.default(view_1591, [16384, 1024]); view_1591 = None + permute_1021 = torch.ops.aten.permute.default(view_1595, [1, 0]) + mm_517 = torch.ops.aten.mm.default(permute_1021, view_377); permute_1021 = None + permute_1023 = torch.ops.aten.permute.default(permute_122, [1, 0]); permute_122 = None + mm_518 = torch.ops.aten.mm.default(view_1595, permute_1023); view_1595 = permute_1023 = None + view_1596 = torch.ops.aten.view.default(mm_518, [2, 8192, 4096]); mm_518 = None + add_273 = torch.ops.aten.add.Tensor(view_1594, view_1596); view_1594 = view_1596 = None + convert_element_type_2196 = torch.ops.prims.convert_element_type.default(mm_517, torch.float32); mm_517 = None + reduce_scatter_tensor_188 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2196, 'avg', 256, '0'); convert_element_type_2196 = None + wait_tensor_479 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_188); reduce_scatter_tensor_188 = None + view_1597 = torch.ops.aten.view.default(view_1592, [16384, 4096]); view_1592 = None + permute_1025 = torch.ops.aten.permute.default(view_1597, [1, 0]) + mm_519 = torch.ops.aten.mm.default(permute_1025, view_377); permute_1025 = view_377 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16); primals_104 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 256, '0'); convert_element_type_367 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_101, [1, 0]); wait_tensor_101 = None + permute_1027 = torch.ops.aten.permute.default(permute_121, [1, 0]); permute_121 = None + mm_520 = torch.ops.aten.mm.default(view_1597, permute_1027); view_1597 = permute_1027 = None + view_1598 = torch.ops.aten.view.default(mm_520, [2, 8192, 4096]); mm_520 = None + add_274 = torch.ops.aten.add.Tensor(add_273, view_1598); add_273 = view_1598 = None + convert_element_type_2201 = torch.ops.prims.convert_element_type.default(mm_519, torch.float32); mm_519 = None + reduce_scatter_tensor_189 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2201, 'avg', 256, '0'); convert_element_type_2201 = None + wait_tensor_480 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_189); reduce_scatter_tensor_189 = None + convert_element_type_2202 = torch.ops.prims.convert_element_type.default(add_274, torch.float32); add_274 = None + convert_element_type_2204 = torch.ops.prims.convert_element_type.default(wait_tensor_100, torch.float32); wait_tensor_100 = None + mul_678 = torch.ops.aten.mul.Tensor(convert_element_type_2202, convert_element_type_2204); convert_element_type_2204 = None + mul_680 = torch.ops.aten.mul.Tensor(mul_88, mul_678) + sum_127 = torch.ops.aten.sum.dim_IntList(mul_680, [2], True); mul_680 = None + div_42 = torch.ops.aten.div.Tensor(mul_88, 4096) + mul_681 = torch.ops.aten.mul.Tensor(div_42, sum_127); div_42 = sum_127 = None + sub_63 = torch.ops.aten.sub.Tensor(mul_678, mul_681); mul_678 = mul_681 = None + mul_682 = torch.ops.aten.mul.Tensor(sub_63, rsqrt_22); sub_63 = rsqrt_22 = None + mul_683 = torch.ops.aten.mul.Tensor(convert_element_type_2202, mul_88); convert_element_type_2202 = mul_88 = None + sum_128 = torch.ops.aten.sum.dim_IntList(mul_683, [0, 1]); mul_683 = None + convert_element_type_2205 = torch.ops.prims.convert_element_type.default(mul_682, torch.bfloat16); mul_682 = None + add_275 = torch.ops.aten.add.Tensor(add_272, convert_element_type_2205); add_272 = convert_element_type_2205 = None + convert_element_type_default_23 = torch.ops.prims.convert_element_type.default(sum_128, torch.float32); sum_128 = None + reduce_scatter_tensor_190 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_23, 'avg', 256, '0'); convert_element_type_default_23 = None + wait_tensor_481 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_190); reduce_scatter_tensor_190 = None + view_1599 = torch.ops.aten.view.default(add_275, [16384, 4096]) + permute_1029 = torch.ops.aten.permute.default(view_1599, [1, 0]) + permute_116 = torch.ops.aten.permute.default(getitem_90, [0, 2, 1, 3]) + view_361 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16); primals_98 = None + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 256, '0'); convert_element_type_347 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + view_363 = torch.ops.aten.view.default(view_361, [16384, 4096]); view_361 = None + mm_73 = torch.ops.aten.mm.default(view_363, permute_117) + view_364 = torch.ops.aten.view.default(mm_73, [2, 8192, 4096]); mm_73 = None + add_41 = torch.ops.aten.add.Tensor(add_39, view_364); view_364 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16); primals_99 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 256, '0'); convert_element_type_350 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32); add_41 = None + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_96) + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + view_367 = torch.ops.aten.view.default(convert_element_type_352, [16384, 4096]); convert_element_type_352 = None + view_368 = torch.ops.aten.view.default(mm_74, [2, 8192, 14336]); mm_74 = None + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_368, torch.float32); view_368 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16); primals_101 = None + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 256, '0'); convert_element_type_358 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + mm_75 = torch.ops.aten.mm.default(view_367, permute_119) + view_371 = torch.ops.aten.view.default(mm_75, [2, 8192, 14336]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_371) + view_373 = torch.ops.aten.view.default(mul_87, [16384, 14336]); mul_87 = None + mm_521 = torch.ops.aten.mm.default(permute_1029, view_373); permute_1029 = view_373 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16); primals_102 = None + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 256, '0'); convert_element_type_361 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_99, [1, 0]); wait_tensor_99 = None + permute_1031 = torch.ops.aten.permute.default(permute_120, [1, 0]); permute_120 = None + mm_522 = torch.ops.aten.mm.default(view_1599, permute_1031); view_1599 = permute_1031 = None + view_1600 = torch.ops.aten.view.default(mm_522, [2, 8192, 14336]); mm_522 = None + convert_element_type_2212 = torch.ops.prims.convert_element_type.default(mm_521, torch.float32); mm_521 = None + reduce_scatter_tensor_191 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2212, 'avg', 256, '0'); convert_element_type_2212 = None + wait_tensor_482 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_191); reduce_scatter_tensor_191 = None + mul_684 = torch.ops.aten.mul.Tensor(view_1600, convert_element_type_357); convert_element_type_357 = None + mul_685 = torch.ops.aten.mul.Tensor(view_1600, view_371); view_1600 = view_371 = None + view_1601 = torch.ops.aten.view.default(mul_684, [16384, 14336]); mul_684 = None + permute_1033 = torch.ops.aten.permute.default(view_1601, [1, 0]) + mm_523 = torch.ops.aten.mm.default(permute_1033, view_367); permute_1033 = None + permute_1035 = torch.ops.aten.permute.default(permute_119, [1, 0]); permute_119 = None + mm_524 = torch.ops.aten.mm.default(view_1601, permute_1035); view_1601 = permute_1035 = None + view_1602 = torch.ops.aten.view.default(mm_524, [2, 8192, 4096]); mm_524 = None + convert_element_type_2217 = torch.ops.prims.convert_element_type.default(mm_523, torch.float32); mm_523 = None + reduce_scatter_tensor_192 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2217, 'avg', 256, '0'); convert_element_type_2217 = None + wait_tensor_483 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_192); reduce_scatter_tensor_192 = None + convert_element_type_2218 = torch.ops.prims.convert_element_type.default(mul_685, torch.float32); mul_685 = None + neg_21 = torch.ops.aten.neg.default(convert_element_type_356) + exp_21 = torch.ops.aten.exp.default(neg_21); neg_21 = None + add_276 = torch.ops.aten.add.Tensor(exp_21, 1); exp_21 = None + reciprocal_21 = torch.ops.aten.reciprocal.default(add_276); add_276 = None + mul_686 = torch.ops.aten.mul.Tensor(reciprocal_21, 1); reciprocal_21 = None + mul_687 = torch.ops.aten.mul.Tensor(convert_element_type_2218, mul_686); convert_element_type_2218 = None + sub_64 = torch.ops.aten.sub.Tensor(1, mul_686); mul_686 = None + mul_688 = torch.ops.aten.mul.Tensor(convert_element_type_356, sub_64); convert_element_type_356 = sub_64 = None + add_277 = torch.ops.aten.add.Tensor(mul_688, 1); mul_688 = None + mul_689 = torch.ops.aten.mul.Tensor(mul_687, add_277); mul_687 = add_277 = None + convert_element_type_2220 = torch.ops.prims.convert_element_type.default(mul_689, torch.bfloat16); mul_689 = None + view_1603 = torch.ops.aten.view.default(convert_element_type_2220, [16384, 14336]); convert_element_type_2220 = None + permute_1037 = torch.ops.aten.permute.default(view_1603, [1, 0]) + mm_525 = torch.ops.aten.mm.default(permute_1037, view_367); permute_1037 = view_367 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16); primals_100 = None + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 256, '0'); convert_element_type_353 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + permute_1039 = torch.ops.aten.permute.default(permute_118, [1, 0]); permute_118 = None + mm_526 = torch.ops.aten.mm.default(view_1603, permute_1039); view_1603 = permute_1039 = None + view_1604 = torch.ops.aten.view.default(mm_526, [2, 8192, 4096]); mm_526 = None + add_278 = torch.ops.aten.add.Tensor(view_1602, view_1604); view_1602 = view_1604 = None + convert_element_type_2225 = torch.ops.prims.convert_element_type.default(mm_525, torch.float32); mm_525 = None + reduce_scatter_tensor_193 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2225, 'avg', 256, '0'); convert_element_type_2225 = None + wait_tensor_484 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_193); reduce_scatter_tensor_193 = None + convert_element_type_2226 = torch.ops.prims.convert_element_type.default(add_278, torch.float32); add_278 = None + convert_element_type_2228 = torch.ops.prims.convert_element_type.default(wait_tensor_96, torch.float32); wait_tensor_96 = None + mul_690 = torch.ops.aten.mul.Tensor(convert_element_type_2226, convert_element_type_2228); convert_element_type_2228 = None + mul_692 = torch.ops.aten.mul.Tensor(mul_84, mul_690) + sum_129 = torch.ops.aten.sum.dim_IntList(mul_692, [2], True); mul_692 = None + div_43 = torch.ops.aten.div.Tensor(mul_84, 4096) + mul_693 = torch.ops.aten.mul.Tensor(div_43, sum_129); div_43 = sum_129 = None + sub_65 = torch.ops.aten.sub.Tensor(mul_690, mul_693); mul_690 = mul_693 = None + mul_694 = torch.ops.aten.mul.Tensor(sub_65, rsqrt_21); sub_65 = rsqrt_21 = None + mul_695 = torch.ops.aten.mul.Tensor(convert_element_type_2226, mul_84); convert_element_type_2226 = mul_84 = None + sum_130 = torch.ops.aten.sum.dim_IntList(mul_695, [0, 1]); mul_695 = None + convert_element_type_2229 = torch.ops.prims.convert_element_type.default(mul_694, torch.bfloat16); mul_694 = None + add_279 = torch.ops.aten.add.Tensor(add_275, convert_element_type_2229); add_275 = convert_element_type_2229 = None + convert_element_type_default_22 = torch.ops.prims.convert_element_type.default(sum_130, torch.float32); sum_130 = None + reduce_scatter_tensor_194 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_22, 'avg', 256, '0'); convert_element_type_default_22 = None + wait_tensor_485 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_194); reduce_scatter_tensor_194 = None + view_1605 = torch.ops.aten.view.default(add_279, [16384, 4096]) + permute_1041 = torch.ops.aten.permute.default(view_1605, [1, 0]) + mm_527 = torch.ops.aten.mm.default(permute_1041, view_363); permute_1041 = view_363 = None + permute_1043 = torch.ops.aten.permute.default(permute_117, [1, 0]); permute_117 = None + mm_528 = torch.ops.aten.mm.default(view_1605, permute_1043); view_1605 = permute_1043 = None + view_1606 = torch.ops.aten.view.default(mm_528, [2, 8192, 4096]); mm_528 = None + convert_element_type_2236 = torch.ops.prims.convert_element_type.default(mm_527, torch.float32); mm_527 = None + reduce_scatter_tensor_195 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2236, 'avg', 256, '0'); convert_element_type_2236 = None + wait_tensor_486 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_195); reduce_scatter_tensor_195 = None + view_1607 = torch.ops.aten.view.default(view_1606, [2, 8192, 32, 128]); view_1606 = None + permute_1045 = torch.ops.aten.permute.default(view_1607, [0, 2, 1, 3]); view_1607 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16); primals_94 = None + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 256, '0'); convert_element_type_331 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32); add_39 = None + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_91) + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + view_343 = torch.ops.aten.view.default(convert_element_type_333, [16384, 4096]); convert_element_type_333 = None + view_344 = torch.ops.aten.view.default(mm_70, [2, 8192, 4096]); mm_70 = None + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16); primals_96 = None + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 256, '0'); convert_element_type_337 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_93, [1, 0]); wait_tensor_93 = None + mm_71 = torch.ops.aten.mm.default(view_343, permute_111) + view_347 = torch.ops.aten.view.default(mm_71, [2, 8192, 1024]); mm_71 = None + view_350 = torch.ops.aten.view.default(mm_72, [2, 8192, 1024]); mm_72 = None + view_351 = torch.ops.aten.view.default(view_344, [2, 8192, -1, 128]); view_344 = None + view_352 = torch.ops.aten.view.default(view_347, [2, 8192, -1, 128]); view_347 = None + view_353 = torch.ops.aten.view.default(view_350, [2, 8192, -1, 128]); view_350 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_351, torch.float32); view_351 = None + view_354 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 32, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_354); view_354 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_352, torch.float32); view_352 = None + view_355 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 8, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_355); view_355 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_16); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_357 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 32, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_16); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_358 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 8, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_357, torch.bfloat16); view_357 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_358, torch.bfloat16); view_358 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 8, 4, 128]); unsqueeze_20 = None + clone_20 = torch.ops.aten.clone.default(expand_20, memory_format = torch.contiguous_format); expand_20 = None + view_359 = torch.ops.aten.view.default(clone_20, [2, 8192, 32, 128]); clone_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_353, 3); view_353 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 8, 4, 128]); unsqueeze_21 = None + clone_21 = torch.ops.aten.clone.default(expand_21, memory_format = torch.contiguous_format); expand_21 = None + view_360 = torch.ops.aten.view.default(clone_21, [2, 8192, 32, 128]); clone_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_359, [0, 2, 1, 3]); view_359 = None + permute_115 = torch.ops.aten.permute.default(view_360, [0, 2, 1, 3]); view_360 = None + _scaled_dot_product_cudnn_attention_backward_21 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1045, permute_113, permute_114, permute_115, getitem_90, getitem_91, getitem_96, getitem_97, None, None, None, 8192, 8192, 0.0, True); permute_1045 = permute_113 = permute_114 = permute_115 = getitem_90 = getitem_91 = getitem_96 = getitem_97 = None + getitem_351 = _scaled_dot_product_cudnn_attention_backward_21[0] + getitem_352 = _scaled_dot_product_cudnn_attention_backward_21[1] + getitem_353 = _scaled_dot_product_cudnn_attention_backward_21[2]; _scaled_dot_product_cudnn_attention_backward_21 = None + permute_1046 = torch.ops.aten.permute.default(getitem_353, [0, 2, 1, 3]); getitem_353 = None + permute_1047 = torch.ops.aten.permute.default(getitem_352, [0, 2, 1, 3]); getitem_352 = None + permute_1048 = torch.ops.aten.permute.default(getitem_351, [0, 2, 1, 3]); getitem_351 = None + view_1608 = torch.ops.aten.view.default(permute_1046, [2, 8192, 8, 4, 128]); permute_1046 = None + sum_131 = torch.ops.aten.sum.dim_IntList(view_1608, [3], True); view_1608 = None + squeeze_42 = torch.ops.aten.squeeze.dim(sum_131, 3); sum_131 = None + view_1609 = torch.ops.aten.view.default(permute_1047, [2, 8192, 8, 4, 128]); permute_1047 = None + sum_132 = torch.ops.aten.sum.dim_IntList(view_1609, [3], True); view_1609 = None + squeeze_43 = torch.ops.aten.squeeze.dim(sum_132, 3); sum_132 = None + convert_element_type_2237 = torch.ops.prims.convert_element_type.default(squeeze_43, torch.float32); squeeze_43 = None + convert_element_type_2238 = torch.ops.prims.convert_element_type.default(permute_1048, torch.float32); permute_1048 = None + view_1610 = torch.ops.aten.view.default(convert_element_type_2237, [2, 8192, 8, 64, 2]); convert_element_type_2237 = None + view_as_complex_106 = torch.ops.aten.view_as_complex.default(view_1610); view_1610 = None + mul_696 = torch.ops.aten.mul.Tensor(view_as_complex_106, _conj); view_as_complex_106 = None + view_1611 = torch.ops.aten.view.default(convert_element_type_2238, [2, 8192, 32, 64, 2]); convert_element_type_2238 = None + view_as_complex_107 = torch.ops.aten.view_as_complex.default(view_1611); view_1611 = None + mul_697 = torch.ops.aten.mul.Tensor(view_as_complex_107, _conj); view_as_complex_107 = None + view_as_real_106 = torch.ops.aten.view_as_real.default(mul_696); mul_696 = None + view_1612 = torch.ops.aten.view.default(view_as_real_106, [2, 8192, 8, 128]); view_as_real_106 = None + convert_element_type_2239 = torch.ops.prims.convert_element_type.default(view_1612, torch.bfloat16); view_1612 = None + view_as_real_107 = torch.ops.aten.view_as_real.default(mul_697); mul_697 = None + view_1613 = torch.ops.aten.view.default(view_as_real_107, [2, 8192, 32, 128]); view_as_real_107 = None + convert_element_type_2240 = torch.ops.prims.convert_element_type.default(view_1613, torch.bfloat16); view_1613 = None + view_1614 = torch.ops.aten.view.default(squeeze_42, [2, 8192, 1024]); squeeze_42 = None + view_1615 = torch.ops.aten.view.default(convert_element_type_2239, [2, 8192, 1024]); convert_element_type_2239 = None + view_1616 = torch.ops.aten.view.default(convert_element_type_2240, [2, 8192, 4096]); convert_element_type_2240 = None + view_1617 = torch.ops.aten.view.default(view_1614, [16384, 1024]); view_1614 = None + permute_1049 = torch.ops.aten.permute.default(view_1617, [1, 0]) + mm_529 = torch.ops.aten.mm.default(permute_1049, view_343); permute_1049 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16); primals_97 = None + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 256, '0'); convert_element_type_340 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_94, [1, 0]); wait_tensor_94 = None + permute_1051 = torch.ops.aten.permute.default(permute_112, [1, 0]); permute_112 = None + mm_530 = torch.ops.aten.mm.default(view_1617, permute_1051); view_1617 = permute_1051 = None + view_1618 = torch.ops.aten.view.default(mm_530, [2, 8192, 4096]); mm_530 = None + convert_element_type_2245 = torch.ops.prims.convert_element_type.default(mm_529, torch.float32); mm_529 = None + reduce_scatter_tensor_196 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2245, 'avg', 256, '0'); convert_element_type_2245 = None + wait_tensor_487 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_196); reduce_scatter_tensor_196 = None + view_1619 = torch.ops.aten.view.default(view_1615, [16384, 1024]); view_1615 = None + permute_1053 = torch.ops.aten.permute.default(view_1619, [1, 0]) + mm_531 = torch.ops.aten.mm.default(permute_1053, view_343); permute_1053 = None + permute_1055 = torch.ops.aten.permute.default(permute_111, [1, 0]); permute_111 = None + mm_532 = torch.ops.aten.mm.default(view_1619, permute_1055); view_1619 = permute_1055 = None + view_1620 = torch.ops.aten.view.default(mm_532, [2, 8192, 4096]); mm_532 = None + add_280 = torch.ops.aten.add.Tensor(view_1618, view_1620); view_1618 = view_1620 = None + convert_element_type_2250 = torch.ops.prims.convert_element_type.default(mm_531, torch.float32); mm_531 = None + reduce_scatter_tensor_197 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2250, 'avg', 256, '0'); convert_element_type_2250 = None + wait_tensor_488 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_197); reduce_scatter_tensor_197 = None + view_1621 = torch.ops.aten.view.default(view_1616, [16384, 4096]); view_1616 = None + permute_1057 = torch.ops.aten.permute.default(view_1621, [1, 0]) + mm_533 = torch.ops.aten.mm.default(permute_1057, view_343); permute_1057 = view_343 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16); primals_95 = None + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 256, '0'); convert_element_type_334 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_92, [1, 0]); wait_tensor_92 = None + permute_1059 = torch.ops.aten.permute.default(permute_110, [1, 0]); permute_110 = None + mm_534 = torch.ops.aten.mm.default(view_1621, permute_1059); view_1621 = permute_1059 = None + view_1622 = torch.ops.aten.view.default(mm_534, [2, 8192, 4096]); mm_534 = None + add_281 = torch.ops.aten.add.Tensor(add_280, view_1622); add_280 = view_1622 = None + convert_element_type_2255 = torch.ops.prims.convert_element_type.default(mm_533, torch.float32); mm_533 = None + reduce_scatter_tensor_198 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2255, 'avg', 256, '0'); convert_element_type_2255 = None + wait_tensor_489 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_198); reduce_scatter_tensor_198 = None + convert_element_type_2256 = torch.ops.prims.convert_element_type.default(add_281, torch.float32); add_281 = None + convert_element_type_2258 = torch.ops.prims.convert_element_type.default(wait_tensor_91, torch.float32); wait_tensor_91 = None + mul_698 = torch.ops.aten.mul.Tensor(convert_element_type_2256, convert_element_type_2258); convert_element_type_2258 = None + mul_700 = torch.ops.aten.mul.Tensor(mul_80, mul_698) + sum_133 = torch.ops.aten.sum.dim_IntList(mul_700, [2], True); mul_700 = None + div_44 = torch.ops.aten.div.Tensor(mul_80, 4096) + mul_701 = torch.ops.aten.mul.Tensor(div_44, sum_133); div_44 = sum_133 = None + sub_66 = torch.ops.aten.sub.Tensor(mul_698, mul_701); mul_698 = mul_701 = None + mul_702 = torch.ops.aten.mul.Tensor(sub_66, rsqrt_20); sub_66 = rsqrt_20 = None + mul_703 = torch.ops.aten.mul.Tensor(convert_element_type_2256, mul_80); convert_element_type_2256 = mul_80 = None + sum_134 = torch.ops.aten.sum.dim_IntList(mul_703, [0, 1]); mul_703 = None + convert_element_type_2259 = torch.ops.prims.convert_element_type.default(mul_702, torch.bfloat16); mul_702 = None + add_282 = torch.ops.aten.add.Tensor(add_279, convert_element_type_2259); add_279 = convert_element_type_2259 = None + convert_element_type_default_21 = torch.ops.prims.convert_element_type.default(sum_134, torch.float32); sum_134 = None + reduce_scatter_tensor_199 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_21, 'avg', 256, '0'); convert_element_type_default_21 = None + wait_tensor_490 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_199); reduce_scatter_tensor_199 = None + view_1623 = torch.ops.aten.view.default(add_282, [16384, 4096]) + permute_1061 = torch.ops.aten.permute.default(view_1623, [1, 0]) + permute_105 = torch.ops.aten.permute.default(getitem_81, [0, 2, 1, 3]) + view_327 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16); primals_89 = None + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 256, '0'); convert_element_type_314 = None + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_86, [1, 0]); wait_tensor_86 = None + view_329 = torch.ops.aten.view.default(view_327, [16384, 4096]); view_327 = None + mm_66 = torch.ops.aten.mm.default(view_329, permute_106) + view_330 = torch.ops.aten.view.default(mm_66, [2, 8192, 4096]); mm_66 = None + add_37 = torch.ops.aten.add.Tensor(add_35, view_330); view_330 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16); primals_90 = None + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 256, '0'); convert_element_type_317 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32); add_37 = None + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_87) + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + view_333 = torch.ops.aten.view.default(convert_element_type_319, [16384, 4096]); convert_element_type_319 = None + view_334 = torch.ops.aten.view.default(mm_67, [2, 8192, 14336]); mm_67 = None + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_334, torch.float32); view_334 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16); primals_92 = None + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 256, '0'); convert_element_type_325 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + mm_68 = torch.ops.aten.mm.default(view_333, permute_108) + view_337 = torch.ops.aten.view.default(mm_68, [2, 8192, 14336]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_337) + view_339 = torch.ops.aten.view.default(mul_79, [16384, 14336]); mul_79 = None + mm_535 = torch.ops.aten.mm.default(permute_1061, view_339); permute_1061 = view_339 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16); primals_93 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 256, '0'); convert_element_type_328 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + permute_1063 = torch.ops.aten.permute.default(permute_109, [1, 0]); permute_109 = None + mm_536 = torch.ops.aten.mm.default(view_1623, permute_1063); view_1623 = permute_1063 = None + view_1624 = torch.ops.aten.view.default(mm_536, [2, 8192, 14336]); mm_536 = None + convert_element_type_2266 = torch.ops.prims.convert_element_type.default(mm_535, torch.float32); mm_535 = None + reduce_scatter_tensor_200 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2266, 'avg', 256, '0'); convert_element_type_2266 = None + wait_tensor_491 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_200); reduce_scatter_tensor_200 = None + mul_704 = torch.ops.aten.mul.Tensor(view_1624, convert_element_type_324); convert_element_type_324 = None + mul_705 = torch.ops.aten.mul.Tensor(view_1624, view_337); view_1624 = view_337 = None + view_1625 = torch.ops.aten.view.default(mul_704, [16384, 14336]); mul_704 = None + permute_1065 = torch.ops.aten.permute.default(view_1625, [1, 0]) + mm_537 = torch.ops.aten.mm.default(permute_1065, view_333); permute_1065 = None + permute_1067 = torch.ops.aten.permute.default(permute_108, [1, 0]); permute_108 = None + mm_538 = torch.ops.aten.mm.default(view_1625, permute_1067); view_1625 = permute_1067 = None + view_1626 = torch.ops.aten.view.default(mm_538, [2, 8192, 4096]); mm_538 = None + convert_element_type_2271 = torch.ops.prims.convert_element_type.default(mm_537, torch.float32); mm_537 = None + reduce_scatter_tensor_201 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2271, 'avg', 256, '0'); convert_element_type_2271 = None + wait_tensor_492 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_201); reduce_scatter_tensor_201 = None + convert_element_type_2272 = torch.ops.prims.convert_element_type.default(mul_705, torch.float32); mul_705 = None + neg_22 = torch.ops.aten.neg.default(convert_element_type_323) + exp_22 = torch.ops.aten.exp.default(neg_22); neg_22 = None + add_283 = torch.ops.aten.add.Tensor(exp_22, 1); exp_22 = None + reciprocal_22 = torch.ops.aten.reciprocal.default(add_283); add_283 = None + mul_706 = torch.ops.aten.mul.Tensor(reciprocal_22, 1); reciprocal_22 = None + mul_707 = torch.ops.aten.mul.Tensor(convert_element_type_2272, mul_706); convert_element_type_2272 = None + sub_67 = torch.ops.aten.sub.Tensor(1, mul_706); mul_706 = None + mul_708 = torch.ops.aten.mul.Tensor(convert_element_type_323, sub_67); convert_element_type_323 = sub_67 = None + add_284 = torch.ops.aten.add.Tensor(mul_708, 1); mul_708 = None + mul_709 = torch.ops.aten.mul.Tensor(mul_707, add_284); mul_707 = add_284 = None + convert_element_type_2274 = torch.ops.prims.convert_element_type.default(mul_709, torch.bfloat16); mul_709 = None + view_1627 = torch.ops.aten.view.default(convert_element_type_2274, [16384, 14336]); convert_element_type_2274 = None + permute_1069 = torch.ops.aten.permute.default(view_1627, [1, 0]) + mm_539 = torch.ops.aten.mm.default(permute_1069, view_333); permute_1069 = view_333 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16); primals_91 = None + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 256, '0'); convert_element_type_320 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_88, [1, 0]); wait_tensor_88 = None + permute_1071 = torch.ops.aten.permute.default(permute_107, [1, 0]); permute_107 = None + mm_540 = torch.ops.aten.mm.default(view_1627, permute_1071); view_1627 = permute_1071 = None + view_1628 = torch.ops.aten.view.default(mm_540, [2, 8192, 4096]); mm_540 = None + add_285 = torch.ops.aten.add.Tensor(view_1626, view_1628); view_1626 = view_1628 = None + convert_element_type_2279 = torch.ops.prims.convert_element_type.default(mm_539, torch.float32); mm_539 = None + reduce_scatter_tensor_202 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2279, 'avg', 256, '0'); convert_element_type_2279 = None + wait_tensor_493 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_202); reduce_scatter_tensor_202 = None + convert_element_type_2280 = torch.ops.prims.convert_element_type.default(add_285, torch.float32); add_285 = None + convert_element_type_2282 = torch.ops.prims.convert_element_type.default(wait_tensor_87, torch.float32); wait_tensor_87 = None + mul_710 = torch.ops.aten.mul.Tensor(convert_element_type_2280, convert_element_type_2282); convert_element_type_2282 = None + mul_712 = torch.ops.aten.mul.Tensor(mul_76, mul_710) + sum_135 = torch.ops.aten.sum.dim_IntList(mul_712, [2], True); mul_712 = None + div_45 = torch.ops.aten.div.Tensor(mul_76, 4096) + mul_713 = torch.ops.aten.mul.Tensor(div_45, sum_135); div_45 = sum_135 = None + sub_68 = torch.ops.aten.sub.Tensor(mul_710, mul_713); mul_710 = mul_713 = None + mul_714 = torch.ops.aten.mul.Tensor(sub_68, rsqrt_19); sub_68 = rsqrt_19 = None + mul_715 = torch.ops.aten.mul.Tensor(convert_element_type_2280, mul_76); convert_element_type_2280 = mul_76 = None + sum_136 = torch.ops.aten.sum.dim_IntList(mul_715, [0, 1]); mul_715 = None + convert_element_type_2283 = torch.ops.prims.convert_element_type.default(mul_714, torch.bfloat16); mul_714 = None + add_286 = torch.ops.aten.add.Tensor(add_282, convert_element_type_2283); add_282 = convert_element_type_2283 = None + convert_element_type_default_20 = torch.ops.prims.convert_element_type.default(sum_136, torch.float32); sum_136 = None + reduce_scatter_tensor_203 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_20, 'avg', 256, '0'); convert_element_type_default_20 = None + wait_tensor_494 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_203); reduce_scatter_tensor_203 = None + view_1629 = torch.ops.aten.view.default(add_286, [16384, 4096]) + permute_1073 = torch.ops.aten.permute.default(view_1629, [1, 0]) + mm_541 = torch.ops.aten.mm.default(permute_1073, view_329); permute_1073 = view_329 = None + permute_1075 = torch.ops.aten.permute.default(permute_106, [1, 0]); permute_106 = None + mm_542 = torch.ops.aten.mm.default(view_1629, permute_1075); view_1629 = permute_1075 = None + view_1630 = torch.ops.aten.view.default(mm_542, [2, 8192, 4096]); mm_542 = None + convert_element_type_2290 = torch.ops.prims.convert_element_type.default(mm_541, torch.float32); mm_541 = None + reduce_scatter_tensor_204 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2290, 'avg', 256, '0'); convert_element_type_2290 = None + wait_tensor_495 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_204); reduce_scatter_tensor_204 = None + view_1631 = torch.ops.aten.view.default(view_1630, [2, 8192, 32, 128]); view_1630 = None + permute_1077 = torch.ops.aten.permute.default(view_1631, [0, 2, 1, 3]); view_1631 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16); primals_85 = None + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 256, '0'); convert_element_type_298 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32); add_35 = None + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_82) + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + view_309 = torch.ops.aten.view.default(convert_element_type_300, [16384, 4096]); convert_element_type_300 = None + view_310 = torch.ops.aten.view.default(mm_63, [2, 8192, 4096]); mm_63 = None + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16); primals_87 = None + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 256, '0'); convert_element_type_304 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + mm_64 = torch.ops.aten.mm.default(view_309, permute_100) + view_313 = torch.ops.aten.view.default(mm_64, [2, 8192, 1024]); mm_64 = None + view_316 = torch.ops.aten.view.default(mm_65, [2, 8192, 1024]); mm_65 = None + view_317 = torch.ops.aten.view.default(view_310, [2, 8192, -1, 128]); view_310 = None + view_318 = torch.ops.aten.view.default(view_313, [2, 8192, -1, 128]); view_313 = None + view_319 = torch.ops.aten.view.default(view_316, [2, 8192, -1, 128]); view_316 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_317, torch.float32); view_317 = None + view_320 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 32, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_320); view_320 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_318, torch.float32); view_318 = None + view_321 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 8, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_321); view_321 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_16); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_323 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 32, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_16); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_324 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 8, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_323, torch.bfloat16); view_323 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_324, torch.bfloat16); view_324 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 8, 4, 128]); unsqueeze_18 = None + clone_18 = torch.ops.aten.clone.default(expand_18, memory_format = torch.contiguous_format); expand_18 = None + view_325 = torch.ops.aten.view.default(clone_18, [2, 8192, 32, 128]); clone_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_319, 3); view_319 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 8, 4, 128]); unsqueeze_19 = None + clone_19 = torch.ops.aten.clone.default(expand_19, memory_format = torch.contiguous_format); expand_19 = None + view_326 = torch.ops.aten.view.default(clone_19, [2, 8192, 32, 128]); clone_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_325, [0, 2, 1, 3]); view_325 = None + permute_104 = torch.ops.aten.permute.default(view_326, [0, 2, 1, 3]); view_326 = None + _scaled_dot_product_cudnn_attention_backward_22 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1077, permute_102, permute_103, permute_104, getitem_81, getitem_82, getitem_87, getitem_88, None, None, None, 8192, 8192, 0.0, True); permute_1077 = permute_102 = permute_103 = permute_104 = getitem_81 = getitem_82 = getitem_87 = getitem_88 = None + getitem_354 = _scaled_dot_product_cudnn_attention_backward_22[0] + getitem_355 = _scaled_dot_product_cudnn_attention_backward_22[1] + getitem_356 = _scaled_dot_product_cudnn_attention_backward_22[2]; _scaled_dot_product_cudnn_attention_backward_22 = None + permute_1078 = torch.ops.aten.permute.default(getitem_356, [0, 2, 1, 3]); getitem_356 = None + permute_1079 = torch.ops.aten.permute.default(getitem_355, [0, 2, 1, 3]); getitem_355 = None + permute_1080 = torch.ops.aten.permute.default(getitem_354, [0, 2, 1, 3]); getitem_354 = None + view_1632 = torch.ops.aten.view.default(permute_1078, [2, 8192, 8, 4, 128]); permute_1078 = None + sum_137 = torch.ops.aten.sum.dim_IntList(view_1632, [3], True); view_1632 = None + squeeze_44 = torch.ops.aten.squeeze.dim(sum_137, 3); sum_137 = None + view_1633 = torch.ops.aten.view.default(permute_1079, [2, 8192, 8, 4, 128]); permute_1079 = None + sum_138 = torch.ops.aten.sum.dim_IntList(view_1633, [3], True); view_1633 = None + squeeze_45 = torch.ops.aten.squeeze.dim(sum_138, 3); sum_138 = None + convert_element_type_2291 = torch.ops.prims.convert_element_type.default(squeeze_45, torch.float32); squeeze_45 = None + convert_element_type_2292 = torch.ops.prims.convert_element_type.default(permute_1080, torch.float32); permute_1080 = None + view_1634 = torch.ops.aten.view.default(convert_element_type_2291, [2, 8192, 8, 64, 2]); convert_element_type_2291 = None + view_as_complex_108 = torch.ops.aten.view_as_complex.default(view_1634); view_1634 = None + mul_716 = torch.ops.aten.mul.Tensor(view_as_complex_108, _conj); view_as_complex_108 = None + view_1635 = torch.ops.aten.view.default(convert_element_type_2292, [2, 8192, 32, 64, 2]); convert_element_type_2292 = None + view_as_complex_109 = torch.ops.aten.view_as_complex.default(view_1635); view_1635 = None + mul_717 = torch.ops.aten.mul.Tensor(view_as_complex_109, _conj); view_as_complex_109 = None + view_as_real_108 = torch.ops.aten.view_as_real.default(mul_716); mul_716 = None + view_1636 = torch.ops.aten.view.default(view_as_real_108, [2, 8192, 8, 128]); view_as_real_108 = None + convert_element_type_2293 = torch.ops.prims.convert_element_type.default(view_1636, torch.bfloat16); view_1636 = None + view_as_real_109 = torch.ops.aten.view_as_real.default(mul_717); mul_717 = None + view_1637 = torch.ops.aten.view.default(view_as_real_109, [2, 8192, 32, 128]); view_as_real_109 = None + convert_element_type_2294 = torch.ops.prims.convert_element_type.default(view_1637, torch.bfloat16); view_1637 = None + view_1638 = torch.ops.aten.view.default(squeeze_44, [2, 8192, 1024]); squeeze_44 = None + view_1639 = torch.ops.aten.view.default(convert_element_type_2293, [2, 8192, 1024]); convert_element_type_2293 = None + view_1640 = torch.ops.aten.view.default(convert_element_type_2294, [2, 8192, 4096]); convert_element_type_2294 = None + view_1641 = torch.ops.aten.view.default(view_1638, [16384, 1024]); view_1638 = None + permute_1081 = torch.ops.aten.permute.default(view_1641, [1, 0]) + mm_543 = torch.ops.aten.mm.default(permute_1081, view_309); permute_1081 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16); primals_88 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 256, '0'); convert_element_type_307 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + permute_1083 = torch.ops.aten.permute.default(permute_101, [1, 0]); permute_101 = None + mm_544 = torch.ops.aten.mm.default(view_1641, permute_1083); view_1641 = permute_1083 = None + view_1642 = torch.ops.aten.view.default(mm_544, [2, 8192, 4096]); mm_544 = None + convert_element_type_2299 = torch.ops.prims.convert_element_type.default(mm_543, torch.float32); mm_543 = None + reduce_scatter_tensor_205 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2299, 'avg', 256, '0'); convert_element_type_2299 = None + wait_tensor_496 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_205); reduce_scatter_tensor_205 = None + view_1643 = torch.ops.aten.view.default(view_1639, [16384, 1024]); view_1639 = None + permute_1085 = torch.ops.aten.permute.default(view_1643, [1, 0]) + mm_545 = torch.ops.aten.mm.default(permute_1085, view_309); permute_1085 = None + permute_1087 = torch.ops.aten.permute.default(permute_100, [1, 0]); permute_100 = None + mm_546 = torch.ops.aten.mm.default(view_1643, permute_1087); view_1643 = permute_1087 = None + view_1644 = torch.ops.aten.view.default(mm_546, [2, 8192, 4096]); mm_546 = None + add_287 = torch.ops.aten.add.Tensor(view_1642, view_1644); view_1642 = view_1644 = None + convert_element_type_2304 = torch.ops.prims.convert_element_type.default(mm_545, torch.float32); mm_545 = None + reduce_scatter_tensor_206 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2304, 'avg', 256, '0'); convert_element_type_2304 = None + wait_tensor_497 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_206); reduce_scatter_tensor_206 = None + view_1645 = torch.ops.aten.view.default(view_1640, [16384, 4096]); view_1640 = None + permute_1089 = torch.ops.aten.permute.default(view_1645, [1, 0]) + mm_547 = torch.ops.aten.mm.default(permute_1089, view_309); permute_1089 = view_309 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16); primals_86 = None + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 256, '0'); convert_element_type_301 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + permute_1091 = torch.ops.aten.permute.default(permute_99, [1, 0]); permute_99 = None + mm_548 = torch.ops.aten.mm.default(view_1645, permute_1091); view_1645 = permute_1091 = None + view_1646 = torch.ops.aten.view.default(mm_548, [2, 8192, 4096]); mm_548 = None + add_288 = torch.ops.aten.add.Tensor(add_287, view_1646); add_287 = view_1646 = None + convert_element_type_2309 = torch.ops.prims.convert_element_type.default(mm_547, torch.float32); mm_547 = None + reduce_scatter_tensor_207 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2309, 'avg', 256, '0'); convert_element_type_2309 = None + wait_tensor_498 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_207); reduce_scatter_tensor_207 = None + convert_element_type_2310 = torch.ops.prims.convert_element_type.default(add_288, torch.float32); add_288 = None + convert_element_type_2312 = torch.ops.prims.convert_element_type.default(wait_tensor_82, torch.float32); wait_tensor_82 = None + mul_718 = torch.ops.aten.mul.Tensor(convert_element_type_2310, convert_element_type_2312); convert_element_type_2312 = None + mul_720 = torch.ops.aten.mul.Tensor(mul_72, mul_718) + sum_139 = torch.ops.aten.sum.dim_IntList(mul_720, [2], True); mul_720 = None + div_46 = torch.ops.aten.div.Tensor(mul_72, 4096) + mul_721 = torch.ops.aten.mul.Tensor(div_46, sum_139); div_46 = sum_139 = None + sub_69 = torch.ops.aten.sub.Tensor(mul_718, mul_721); mul_718 = mul_721 = None + mul_722 = torch.ops.aten.mul.Tensor(sub_69, rsqrt_18); sub_69 = rsqrt_18 = None + mul_723 = torch.ops.aten.mul.Tensor(convert_element_type_2310, mul_72); convert_element_type_2310 = mul_72 = None + sum_140 = torch.ops.aten.sum.dim_IntList(mul_723, [0, 1]); mul_723 = None + convert_element_type_2313 = torch.ops.prims.convert_element_type.default(mul_722, torch.bfloat16); mul_722 = None + add_289 = torch.ops.aten.add.Tensor(add_286, convert_element_type_2313); add_286 = convert_element_type_2313 = None + convert_element_type_default_19 = torch.ops.prims.convert_element_type.default(sum_140, torch.float32); sum_140 = None + reduce_scatter_tensor_208 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_19, 'avg', 256, '0'); convert_element_type_default_19 = None + wait_tensor_499 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_208); reduce_scatter_tensor_208 = None + view_1647 = torch.ops.aten.view.default(add_289, [16384, 4096]) + permute_1093 = torch.ops.aten.permute.default(view_1647, [1, 0]) + permute_94 = torch.ops.aten.permute.default(getitem_72, [0, 2, 1, 3]) + view_293 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16); primals_80 = None + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 256, '0'); convert_element_type_281 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + view_295 = torch.ops.aten.view.default(view_293, [16384, 4096]); view_293 = None + mm_59 = torch.ops.aten.mm.default(view_295, permute_95) + view_296 = torch.ops.aten.view.default(mm_59, [2, 8192, 4096]); mm_59 = None + add_33 = torch.ops.aten.add.Tensor(add_31, view_296); view_296 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16); primals_81 = None + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 256, '0'); convert_element_type_284 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32); add_33 = None + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_78) + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + view_299 = torch.ops.aten.view.default(convert_element_type_286, [16384, 4096]); convert_element_type_286 = None + view_300 = torch.ops.aten.view.default(mm_60, [2, 8192, 14336]); mm_60 = None + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16); primals_83 = None + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 256, '0'); convert_element_type_292 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_80, [1, 0]); wait_tensor_80 = None + mm_61 = torch.ops.aten.mm.default(view_299, permute_97) + view_303 = torch.ops.aten.view.default(mm_61, [2, 8192, 14336]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_303) + view_305 = torch.ops.aten.view.default(mul_71, [16384, 14336]); mul_71 = None + mm_549 = torch.ops.aten.mm.default(permute_1093, view_305); permute_1093 = view_305 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16); primals_84 = None + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 256, '0'); convert_element_type_295 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_81, [1, 0]); wait_tensor_81 = None + permute_1095 = torch.ops.aten.permute.default(permute_98, [1, 0]); permute_98 = None + mm_550 = torch.ops.aten.mm.default(view_1647, permute_1095); view_1647 = permute_1095 = None + view_1648 = torch.ops.aten.view.default(mm_550, [2, 8192, 14336]); mm_550 = None + convert_element_type_2320 = torch.ops.prims.convert_element_type.default(mm_549, torch.float32); mm_549 = None + reduce_scatter_tensor_209 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2320, 'avg', 256, '0'); convert_element_type_2320 = None + wait_tensor_500 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_209); reduce_scatter_tensor_209 = None + mul_724 = torch.ops.aten.mul.Tensor(view_1648, convert_element_type_291); convert_element_type_291 = None + mul_725 = torch.ops.aten.mul.Tensor(view_1648, view_303); view_1648 = view_303 = None + view_1649 = torch.ops.aten.view.default(mul_724, [16384, 14336]); mul_724 = None + permute_1097 = torch.ops.aten.permute.default(view_1649, [1, 0]) + mm_551 = torch.ops.aten.mm.default(permute_1097, view_299); permute_1097 = None + permute_1099 = torch.ops.aten.permute.default(permute_97, [1, 0]); permute_97 = None + mm_552 = torch.ops.aten.mm.default(view_1649, permute_1099); view_1649 = permute_1099 = None + view_1650 = torch.ops.aten.view.default(mm_552, [2, 8192, 4096]); mm_552 = None + convert_element_type_2325 = torch.ops.prims.convert_element_type.default(mm_551, torch.float32); mm_551 = None + reduce_scatter_tensor_210 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2325, 'avg', 256, '0'); convert_element_type_2325 = None + wait_tensor_501 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_210); reduce_scatter_tensor_210 = None + convert_element_type_2326 = torch.ops.prims.convert_element_type.default(mul_725, torch.float32); mul_725 = None + neg_23 = torch.ops.aten.neg.default(convert_element_type_290) + exp_23 = torch.ops.aten.exp.default(neg_23); neg_23 = None + add_290 = torch.ops.aten.add.Tensor(exp_23, 1); exp_23 = None + reciprocal_23 = torch.ops.aten.reciprocal.default(add_290); add_290 = None + mul_726 = torch.ops.aten.mul.Tensor(reciprocal_23, 1); reciprocal_23 = None + mul_727 = torch.ops.aten.mul.Tensor(convert_element_type_2326, mul_726); convert_element_type_2326 = None + sub_70 = torch.ops.aten.sub.Tensor(1, mul_726); mul_726 = None + mul_728 = torch.ops.aten.mul.Tensor(convert_element_type_290, sub_70); convert_element_type_290 = sub_70 = None + add_291 = torch.ops.aten.add.Tensor(mul_728, 1); mul_728 = None + mul_729 = torch.ops.aten.mul.Tensor(mul_727, add_291); mul_727 = add_291 = None + convert_element_type_2328 = torch.ops.prims.convert_element_type.default(mul_729, torch.bfloat16); mul_729 = None + view_1651 = torch.ops.aten.view.default(convert_element_type_2328, [16384, 14336]); convert_element_type_2328 = None + permute_1101 = torch.ops.aten.permute.default(view_1651, [1, 0]) + mm_553 = torch.ops.aten.mm.default(permute_1101, view_299); permute_1101 = view_299 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16); primals_82 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 256, '0'); convert_element_type_287 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_79, [1, 0]); wait_tensor_79 = None + permute_1103 = torch.ops.aten.permute.default(permute_96, [1, 0]); permute_96 = None + mm_554 = torch.ops.aten.mm.default(view_1651, permute_1103); view_1651 = permute_1103 = None + view_1652 = torch.ops.aten.view.default(mm_554, [2, 8192, 4096]); mm_554 = None + add_292 = torch.ops.aten.add.Tensor(view_1650, view_1652); view_1650 = view_1652 = None + convert_element_type_2333 = torch.ops.prims.convert_element_type.default(mm_553, torch.float32); mm_553 = None + reduce_scatter_tensor_211 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2333, 'avg', 256, '0'); convert_element_type_2333 = None + wait_tensor_502 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_211); reduce_scatter_tensor_211 = None + convert_element_type_2334 = torch.ops.prims.convert_element_type.default(add_292, torch.float32); add_292 = None + convert_element_type_2336 = torch.ops.prims.convert_element_type.default(wait_tensor_78, torch.float32); wait_tensor_78 = None + mul_730 = torch.ops.aten.mul.Tensor(convert_element_type_2334, convert_element_type_2336); convert_element_type_2336 = None + mul_732 = torch.ops.aten.mul.Tensor(mul_68, mul_730) + sum_141 = torch.ops.aten.sum.dim_IntList(mul_732, [2], True); mul_732 = None + div_47 = torch.ops.aten.div.Tensor(mul_68, 4096) + mul_733 = torch.ops.aten.mul.Tensor(div_47, sum_141); div_47 = sum_141 = None + sub_71 = torch.ops.aten.sub.Tensor(mul_730, mul_733); mul_730 = mul_733 = None + mul_734 = torch.ops.aten.mul.Tensor(sub_71, rsqrt_17); sub_71 = rsqrt_17 = None + mul_735 = torch.ops.aten.mul.Tensor(convert_element_type_2334, mul_68); convert_element_type_2334 = mul_68 = None + sum_142 = torch.ops.aten.sum.dim_IntList(mul_735, [0, 1]); mul_735 = None + convert_element_type_2337 = torch.ops.prims.convert_element_type.default(mul_734, torch.bfloat16); mul_734 = None + add_293 = torch.ops.aten.add.Tensor(add_289, convert_element_type_2337); add_289 = convert_element_type_2337 = None + convert_element_type_default_18 = torch.ops.prims.convert_element_type.default(sum_142, torch.float32); sum_142 = None + reduce_scatter_tensor_212 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_18, 'avg', 256, '0'); convert_element_type_default_18 = None + wait_tensor_503 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_212); reduce_scatter_tensor_212 = None + view_1653 = torch.ops.aten.view.default(add_293, [16384, 4096]) + permute_1105 = torch.ops.aten.permute.default(view_1653, [1, 0]) + mm_555 = torch.ops.aten.mm.default(permute_1105, view_295); permute_1105 = view_295 = None + permute_1107 = torch.ops.aten.permute.default(permute_95, [1, 0]); permute_95 = None + mm_556 = torch.ops.aten.mm.default(view_1653, permute_1107); view_1653 = permute_1107 = None + view_1654 = torch.ops.aten.view.default(mm_556, [2, 8192, 4096]); mm_556 = None + convert_element_type_2344 = torch.ops.prims.convert_element_type.default(mm_555, torch.float32); mm_555 = None + reduce_scatter_tensor_213 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2344, 'avg', 256, '0'); convert_element_type_2344 = None + wait_tensor_504 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_213); reduce_scatter_tensor_213 = None + view_1655 = torch.ops.aten.view.default(view_1654, [2, 8192, 32, 128]); view_1654 = None + permute_1109 = torch.ops.aten.permute.default(view_1655, [0, 2, 1, 3]); view_1655 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16); primals_76 = None + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 256, '0'); convert_element_type_265 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32); add_31 = None + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_73) + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + view_275 = torch.ops.aten.view.default(convert_element_type_267, [16384, 4096]); convert_element_type_267 = None + view_276 = torch.ops.aten.view.default(mm_56, [2, 8192, 4096]); mm_56 = None + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16); primals_78 = None + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 256, '0'); convert_element_type_271 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_75, [1, 0]); wait_tensor_75 = None + mm_57 = torch.ops.aten.mm.default(view_275, permute_89) + view_279 = torch.ops.aten.view.default(mm_57, [2, 8192, 1024]); mm_57 = None + view_282 = torch.ops.aten.view.default(mm_58, [2, 8192, 1024]); mm_58 = None + view_283 = torch.ops.aten.view.default(view_276, [2, 8192, -1, 128]); view_276 = None + view_284 = torch.ops.aten.view.default(view_279, [2, 8192, -1, 128]); view_279 = None + view_285 = torch.ops.aten.view.default(view_282, [2, 8192, -1, 128]); view_282 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_283, torch.float32); view_283 = None + view_286 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 32, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_286); view_286 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_284, torch.float32); view_284 = None + view_287 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 8, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_287); view_287 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_16); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_289 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 32, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_16); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_290 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 8, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_289, torch.bfloat16); view_289 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_290, torch.bfloat16); view_290 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 8, 4, 128]); unsqueeze_16 = None + clone_16 = torch.ops.aten.clone.default(expand_16, memory_format = torch.contiguous_format); expand_16 = None + view_291 = torch.ops.aten.view.default(clone_16, [2, 8192, 32, 128]); clone_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_285, 3); view_285 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 8, 4, 128]); unsqueeze_17 = None + clone_17 = torch.ops.aten.clone.default(expand_17, memory_format = torch.contiguous_format); expand_17 = None + view_292 = torch.ops.aten.view.default(clone_17, [2, 8192, 32, 128]); clone_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_291, [0, 2, 1, 3]); view_291 = None + permute_93 = torch.ops.aten.permute.default(view_292, [0, 2, 1, 3]); view_292 = None + _scaled_dot_product_cudnn_attention_backward_23 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1109, permute_91, permute_92, permute_93, getitem_72, getitem_73, getitem_78, getitem_79, None, None, None, 8192, 8192, 0.0, True); permute_1109 = permute_91 = permute_92 = permute_93 = getitem_72 = getitem_73 = getitem_78 = getitem_79 = None + getitem_357 = _scaled_dot_product_cudnn_attention_backward_23[0] + getitem_358 = _scaled_dot_product_cudnn_attention_backward_23[1] + getitem_359 = _scaled_dot_product_cudnn_attention_backward_23[2]; _scaled_dot_product_cudnn_attention_backward_23 = None + permute_1110 = torch.ops.aten.permute.default(getitem_359, [0, 2, 1, 3]); getitem_359 = None + permute_1111 = torch.ops.aten.permute.default(getitem_358, [0, 2, 1, 3]); getitem_358 = None + permute_1112 = torch.ops.aten.permute.default(getitem_357, [0, 2, 1, 3]); getitem_357 = None + view_1656 = torch.ops.aten.view.default(permute_1110, [2, 8192, 8, 4, 128]); permute_1110 = None + sum_143 = torch.ops.aten.sum.dim_IntList(view_1656, [3], True); view_1656 = None + squeeze_46 = torch.ops.aten.squeeze.dim(sum_143, 3); sum_143 = None + view_1657 = torch.ops.aten.view.default(permute_1111, [2, 8192, 8, 4, 128]); permute_1111 = None + sum_144 = torch.ops.aten.sum.dim_IntList(view_1657, [3], True); view_1657 = None + squeeze_47 = torch.ops.aten.squeeze.dim(sum_144, 3); sum_144 = None + convert_element_type_2345 = torch.ops.prims.convert_element_type.default(squeeze_47, torch.float32); squeeze_47 = None + convert_element_type_2346 = torch.ops.prims.convert_element_type.default(permute_1112, torch.float32); permute_1112 = None + view_1658 = torch.ops.aten.view.default(convert_element_type_2345, [2, 8192, 8, 64, 2]); convert_element_type_2345 = None + view_as_complex_110 = torch.ops.aten.view_as_complex.default(view_1658); view_1658 = None + mul_736 = torch.ops.aten.mul.Tensor(view_as_complex_110, _conj); view_as_complex_110 = None + view_1659 = torch.ops.aten.view.default(convert_element_type_2346, [2, 8192, 32, 64, 2]); convert_element_type_2346 = None + view_as_complex_111 = torch.ops.aten.view_as_complex.default(view_1659); view_1659 = None + mul_737 = torch.ops.aten.mul.Tensor(view_as_complex_111, _conj); view_as_complex_111 = None + view_as_real_110 = torch.ops.aten.view_as_real.default(mul_736); mul_736 = None + view_1660 = torch.ops.aten.view.default(view_as_real_110, [2, 8192, 8, 128]); view_as_real_110 = None + convert_element_type_2347 = torch.ops.prims.convert_element_type.default(view_1660, torch.bfloat16); view_1660 = None + view_as_real_111 = torch.ops.aten.view_as_real.default(mul_737); mul_737 = None + view_1661 = torch.ops.aten.view.default(view_as_real_111, [2, 8192, 32, 128]); view_as_real_111 = None + convert_element_type_2348 = torch.ops.prims.convert_element_type.default(view_1661, torch.bfloat16); view_1661 = None + view_1662 = torch.ops.aten.view.default(squeeze_46, [2, 8192, 1024]); squeeze_46 = None + view_1663 = torch.ops.aten.view.default(convert_element_type_2347, [2, 8192, 1024]); convert_element_type_2347 = None + view_1664 = torch.ops.aten.view.default(convert_element_type_2348, [2, 8192, 4096]); convert_element_type_2348 = None + view_1665 = torch.ops.aten.view.default(view_1662, [16384, 1024]); view_1662 = None + permute_1113 = torch.ops.aten.permute.default(view_1665, [1, 0]) + mm_557 = torch.ops.aten.mm.default(permute_1113, view_275); permute_1113 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16); primals_79 = None + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 256, '0'); convert_element_type_274 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + permute_1115 = torch.ops.aten.permute.default(permute_90, [1, 0]); permute_90 = None + mm_558 = torch.ops.aten.mm.default(view_1665, permute_1115); view_1665 = permute_1115 = None + view_1666 = torch.ops.aten.view.default(mm_558, [2, 8192, 4096]); mm_558 = None + convert_element_type_2353 = torch.ops.prims.convert_element_type.default(mm_557, torch.float32); mm_557 = None + reduce_scatter_tensor_214 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2353, 'avg', 256, '0'); convert_element_type_2353 = None + wait_tensor_505 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_214); reduce_scatter_tensor_214 = None + view_1667 = torch.ops.aten.view.default(view_1663, [16384, 1024]); view_1663 = None + permute_1117 = torch.ops.aten.permute.default(view_1667, [1, 0]) + mm_559 = torch.ops.aten.mm.default(permute_1117, view_275); permute_1117 = None + permute_1119 = torch.ops.aten.permute.default(permute_89, [1, 0]); permute_89 = None + mm_560 = torch.ops.aten.mm.default(view_1667, permute_1119); view_1667 = permute_1119 = None + view_1668 = torch.ops.aten.view.default(mm_560, [2, 8192, 4096]); mm_560 = None + add_294 = torch.ops.aten.add.Tensor(view_1666, view_1668); view_1666 = view_1668 = None + convert_element_type_2358 = torch.ops.prims.convert_element_type.default(mm_559, torch.float32); mm_559 = None + reduce_scatter_tensor_215 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2358, 'avg', 256, '0'); convert_element_type_2358 = None + wait_tensor_506 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_215); reduce_scatter_tensor_215 = None + view_1669 = torch.ops.aten.view.default(view_1664, [16384, 4096]); view_1664 = None + permute_1121 = torch.ops.aten.permute.default(view_1669, [1, 0]) + mm_561 = torch.ops.aten.mm.default(permute_1121, view_275); permute_1121 = view_275 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16); primals_77 = None + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 256, '0'); convert_element_type_268 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_74, [1, 0]); wait_tensor_74 = None + permute_1123 = torch.ops.aten.permute.default(permute_88, [1, 0]); permute_88 = None + mm_562 = torch.ops.aten.mm.default(view_1669, permute_1123); view_1669 = permute_1123 = None + view_1670 = torch.ops.aten.view.default(mm_562, [2, 8192, 4096]); mm_562 = None + add_295 = torch.ops.aten.add.Tensor(add_294, view_1670); add_294 = view_1670 = None + convert_element_type_2363 = torch.ops.prims.convert_element_type.default(mm_561, torch.float32); mm_561 = None + reduce_scatter_tensor_216 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2363, 'avg', 256, '0'); convert_element_type_2363 = None + wait_tensor_507 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_216); reduce_scatter_tensor_216 = None + convert_element_type_2364 = torch.ops.prims.convert_element_type.default(add_295, torch.float32); add_295 = None + convert_element_type_2366 = torch.ops.prims.convert_element_type.default(wait_tensor_73, torch.float32); wait_tensor_73 = None + mul_738 = torch.ops.aten.mul.Tensor(convert_element_type_2364, convert_element_type_2366); convert_element_type_2366 = None + mul_740 = torch.ops.aten.mul.Tensor(mul_64, mul_738) + sum_145 = torch.ops.aten.sum.dim_IntList(mul_740, [2], True); mul_740 = None + div_48 = torch.ops.aten.div.Tensor(mul_64, 4096) + mul_741 = torch.ops.aten.mul.Tensor(div_48, sum_145); div_48 = sum_145 = None + sub_72 = torch.ops.aten.sub.Tensor(mul_738, mul_741); mul_738 = mul_741 = None + mul_742 = torch.ops.aten.mul.Tensor(sub_72, rsqrt_16); sub_72 = rsqrt_16 = None + mul_743 = torch.ops.aten.mul.Tensor(convert_element_type_2364, mul_64); convert_element_type_2364 = mul_64 = None + sum_146 = torch.ops.aten.sum.dim_IntList(mul_743, [0, 1]); mul_743 = None + convert_element_type_2367 = torch.ops.prims.convert_element_type.default(mul_742, torch.bfloat16); mul_742 = None + add_296 = torch.ops.aten.add.Tensor(add_293, convert_element_type_2367); add_293 = convert_element_type_2367 = None + convert_element_type_default_17 = torch.ops.prims.convert_element_type.default(sum_146, torch.float32); sum_146 = None + reduce_scatter_tensor_217 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_17, 'avg', 256, '0'); convert_element_type_default_17 = None + wait_tensor_508 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_217); reduce_scatter_tensor_217 = None + view_1671 = torch.ops.aten.view.default(add_296, [16384, 4096]) + permute_1125 = torch.ops.aten.permute.default(view_1671, [1, 0]) + permute_83 = torch.ops.aten.permute.default(getitem_63, [0, 2, 1, 3]) + view_259 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16); primals_71 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 256, '0'); convert_element_type_248 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_68, [1, 0]); wait_tensor_68 = None + view_261 = torch.ops.aten.view.default(view_259, [16384, 4096]); view_259 = None + mm_52 = torch.ops.aten.mm.default(view_261, permute_84) + view_262 = torch.ops.aten.view.default(mm_52, [2, 8192, 4096]); mm_52 = None + add_29 = torch.ops.aten.add.Tensor(add_27, view_262); view_262 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16); primals_72 = None + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 256, '0'); convert_element_type_251 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32); add_29 = None + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_69) + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + view_265 = torch.ops.aten.view.default(convert_element_type_253, [16384, 4096]); convert_element_type_253 = None + view_266 = torch.ops.aten.view.default(mm_53, [2, 8192, 14336]); mm_53 = None + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_266, torch.float32); view_266 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16); primals_74 = None + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 256, '0'); convert_element_type_259 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_54 = torch.ops.aten.mm.default(view_265, permute_86) + view_269 = torch.ops.aten.view.default(mm_54, [2, 8192, 14336]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_269) + view_271 = torch.ops.aten.view.default(mul_63, [16384, 14336]); mul_63 = None + mm_563 = torch.ops.aten.mm.default(permute_1125, view_271); permute_1125 = view_271 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16); primals_75 = None + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 256, '0'); convert_element_type_262 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + permute_1127 = torch.ops.aten.permute.default(permute_87, [1, 0]); permute_87 = None + mm_564 = torch.ops.aten.mm.default(view_1671, permute_1127); view_1671 = permute_1127 = None + view_1672 = torch.ops.aten.view.default(mm_564, [2, 8192, 14336]); mm_564 = None + convert_element_type_2374 = torch.ops.prims.convert_element_type.default(mm_563, torch.float32); mm_563 = None + reduce_scatter_tensor_218 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2374, 'avg', 256, '0'); convert_element_type_2374 = None + wait_tensor_509 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_218); reduce_scatter_tensor_218 = None + mul_744 = torch.ops.aten.mul.Tensor(view_1672, convert_element_type_258); convert_element_type_258 = None + mul_745 = torch.ops.aten.mul.Tensor(view_1672, view_269); view_1672 = view_269 = None + view_1673 = torch.ops.aten.view.default(mul_744, [16384, 14336]); mul_744 = None + permute_1129 = torch.ops.aten.permute.default(view_1673, [1, 0]) + mm_565 = torch.ops.aten.mm.default(permute_1129, view_265); permute_1129 = None + permute_1131 = torch.ops.aten.permute.default(permute_86, [1, 0]); permute_86 = None + mm_566 = torch.ops.aten.mm.default(view_1673, permute_1131); view_1673 = permute_1131 = None + view_1674 = torch.ops.aten.view.default(mm_566, [2, 8192, 4096]); mm_566 = None + convert_element_type_2379 = torch.ops.prims.convert_element_type.default(mm_565, torch.float32); mm_565 = None + reduce_scatter_tensor_219 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2379, 'avg', 256, '0'); convert_element_type_2379 = None + wait_tensor_510 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_219); reduce_scatter_tensor_219 = None + convert_element_type_2380 = torch.ops.prims.convert_element_type.default(mul_745, torch.float32); mul_745 = None + neg_24 = torch.ops.aten.neg.default(convert_element_type_257) + exp_24 = torch.ops.aten.exp.default(neg_24); neg_24 = None + add_297 = torch.ops.aten.add.Tensor(exp_24, 1); exp_24 = None + reciprocal_24 = torch.ops.aten.reciprocal.default(add_297); add_297 = None + mul_746 = torch.ops.aten.mul.Tensor(reciprocal_24, 1); reciprocal_24 = None + mul_747 = torch.ops.aten.mul.Tensor(convert_element_type_2380, mul_746); convert_element_type_2380 = None + sub_73 = torch.ops.aten.sub.Tensor(1, mul_746); mul_746 = None + mul_748 = torch.ops.aten.mul.Tensor(convert_element_type_257, sub_73); convert_element_type_257 = sub_73 = None + add_298 = torch.ops.aten.add.Tensor(mul_748, 1); mul_748 = None + mul_749 = torch.ops.aten.mul.Tensor(mul_747, add_298); mul_747 = add_298 = None + convert_element_type_2382 = torch.ops.prims.convert_element_type.default(mul_749, torch.bfloat16); mul_749 = None + view_1675 = torch.ops.aten.view.default(convert_element_type_2382, [16384, 14336]); convert_element_type_2382 = None + permute_1133 = torch.ops.aten.permute.default(view_1675, [1, 0]) + mm_567 = torch.ops.aten.mm.default(permute_1133, view_265); permute_1133 = view_265 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16); primals_73 = None + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 256, '0'); convert_element_type_254 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + permute_1135 = torch.ops.aten.permute.default(permute_85, [1, 0]); permute_85 = None + mm_568 = torch.ops.aten.mm.default(view_1675, permute_1135); view_1675 = permute_1135 = None + view_1676 = torch.ops.aten.view.default(mm_568, [2, 8192, 4096]); mm_568 = None + add_299 = torch.ops.aten.add.Tensor(view_1674, view_1676); view_1674 = view_1676 = None + convert_element_type_2387 = torch.ops.prims.convert_element_type.default(mm_567, torch.float32); mm_567 = None + reduce_scatter_tensor_220 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2387, 'avg', 256, '0'); convert_element_type_2387 = None + wait_tensor_511 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_220); reduce_scatter_tensor_220 = None + convert_element_type_2388 = torch.ops.prims.convert_element_type.default(add_299, torch.float32); add_299 = None + convert_element_type_2390 = torch.ops.prims.convert_element_type.default(wait_tensor_69, torch.float32); wait_tensor_69 = None + mul_750 = torch.ops.aten.mul.Tensor(convert_element_type_2388, convert_element_type_2390); convert_element_type_2390 = None + mul_752 = torch.ops.aten.mul.Tensor(mul_60, mul_750) + sum_147 = torch.ops.aten.sum.dim_IntList(mul_752, [2], True); mul_752 = None + div_49 = torch.ops.aten.div.Tensor(mul_60, 4096) + mul_753 = torch.ops.aten.mul.Tensor(div_49, sum_147); div_49 = sum_147 = None + sub_74 = torch.ops.aten.sub.Tensor(mul_750, mul_753); mul_750 = mul_753 = None + mul_754 = torch.ops.aten.mul.Tensor(sub_74, rsqrt_15); sub_74 = rsqrt_15 = None + mul_755 = torch.ops.aten.mul.Tensor(convert_element_type_2388, mul_60); convert_element_type_2388 = mul_60 = None + sum_148 = torch.ops.aten.sum.dim_IntList(mul_755, [0, 1]); mul_755 = None + convert_element_type_2391 = torch.ops.prims.convert_element_type.default(mul_754, torch.bfloat16); mul_754 = None + add_300 = torch.ops.aten.add.Tensor(add_296, convert_element_type_2391); add_296 = convert_element_type_2391 = None + convert_element_type_default_16 = torch.ops.prims.convert_element_type.default(sum_148, torch.float32); sum_148 = None + reduce_scatter_tensor_221 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_16, 'avg', 256, '0'); convert_element_type_default_16 = None + wait_tensor_512 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_221); reduce_scatter_tensor_221 = None + view_1677 = torch.ops.aten.view.default(add_300, [16384, 4096]) + permute_1137 = torch.ops.aten.permute.default(view_1677, [1, 0]) + mm_569 = torch.ops.aten.mm.default(permute_1137, view_261); permute_1137 = view_261 = None + permute_1139 = torch.ops.aten.permute.default(permute_84, [1, 0]); permute_84 = None + mm_570 = torch.ops.aten.mm.default(view_1677, permute_1139); view_1677 = permute_1139 = None + view_1678 = torch.ops.aten.view.default(mm_570, [2, 8192, 4096]); mm_570 = None + convert_element_type_2398 = torch.ops.prims.convert_element_type.default(mm_569, torch.float32); mm_569 = None + reduce_scatter_tensor_222 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2398, 'avg', 256, '0'); convert_element_type_2398 = None + wait_tensor_513 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_222); reduce_scatter_tensor_222 = None + view_1679 = torch.ops.aten.view.default(view_1678, [2, 8192, 32, 128]); view_1678 = None + permute_1141 = torch.ops.aten.permute.default(view_1679, [0, 2, 1, 3]); view_1679 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16); primals_67 = None + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 256, '0'); convert_element_type_232 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32); add_27 = None + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_64) + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + view_241 = torch.ops.aten.view.default(convert_element_type_234, [16384, 4096]); convert_element_type_234 = None + view_242 = torch.ops.aten.view.default(mm_49, [2, 8192, 4096]); mm_49 = None + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16); primals_69 = None + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 256, '0'); convert_element_type_238 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_66, [1, 0]); wait_tensor_66 = None + mm_50 = torch.ops.aten.mm.default(view_241, permute_78) + view_245 = torch.ops.aten.view.default(mm_50, [2, 8192, 1024]); mm_50 = None + view_248 = torch.ops.aten.view.default(mm_51, [2, 8192, 1024]); mm_51 = None + view_249 = torch.ops.aten.view.default(view_242, [2, 8192, -1, 128]); view_242 = None + view_250 = torch.ops.aten.view.default(view_245, [2, 8192, -1, 128]); view_245 = None + view_251 = torch.ops.aten.view.default(view_248, [2, 8192, -1, 128]); view_248 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 32, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_250, torch.float32); view_250 = None + view_253 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 8, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_253); view_253 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_16); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_255 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 32, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_16); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_256 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 8, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_256, torch.bfloat16); view_256 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 8, 4, 128]); unsqueeze_14 = None + clone_14 = torch.ops.aten.clone.default(expand_14, memory_format = torch.contiguous_format); expand_14 = None + view_257 = torch.ops.aten.view.default(clone_14, [2, 8192, 32, 128]); clone_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_251, 3); view_251 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 8, 4, 128]); unsqueeze_15 = None + clone_15 = torch.ops.aten.clone.default(expand_15, memory_format = torch.contiguous_format); expand_15 = None + view_258 = torch.ops.aten.view.default(clone_15, [2, 8192, 32, 128]); clone_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + permute_82 = torch.ops.aten.permute.default(view_258, [0, 2, 1, 3]); view_258 = None + _scaled_dot_product_cudnn_attention_backward_24 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1141, permute_80, permute_81, permute_82, getitem_63, getitem_64, getitem_69, getitem_70, None, None, None, 8192, 8192, 0.0, True); permute_1141 = permute_80 = permute_81 = permute_82 = getitem_63 = getitem_64 = getitem_69 = getitem_70 = None + getitem_360 = _scaled_dot_product_cudnn_attention_backward_24[0] + getitem_361 = _scaled_dot_product_cudnn_attention_backward_24[1] + getitem_362 = _scaled_dot_product_cudnn_attention_backward_24[2]; _scaled_dot_product_cudnn_attention_backward_24 = None + permute_1142 = torch.ops.aten.permute.default(getitem_362, [0, 2, 1, 3]); getitem_362 = None + permute_1143 = torch.ops.aten.permute.default(getitem_361, [0, 2, 1, 3]); getitem_361 = None + permute_1144 = torch.ops.aten.permute.default(getitem_360, [0, 2, 1, 3]); getitem_360 = None + view_1680 = torch.ops.aten.view.default(permute_1142, [2, 8192, 8, 4, 128]); permute_1142 = None + sum_149 = torch.ops.aten.sum.dim_IntList(view_1680, [3], True); view_1680 = None + squeeze_48 = torch.ops.aten.squeeze.dim(sum_149, 3); sum_149 = None + view_1681 = torch.ops.aten.view.default(permute_1143, [2, 8192, 8, 4, 128]); permute_1143 = None + sum_150 = torch.ops.aten.sum.dim_IntList(view_1681, [3], True); view_1681 = None + squeeze_49 = torch.ops.aten.squeeze.dim(sum_150, 3); sum_150 = None + convert_element_type_2399 = torch.ops.prims.convert_element_type.default(squeeze_49, torch.float32); squeeze_49 = None + convert_element_type_2400 = torch.ops.prims.convert_element_type.default(permute_1144, torch.float32); permute_1144 = None + view_1682 = torch.ops.aten.view.default(convert_element_type_2399, [2, 8192, 8, 64, 2]); convert_element_type_2399 = None + view_as_complex_112 = torch.ops.aten.view_as_complex.default(view_1682); view_1682 = None + mul_756 = torch.ops.aten.mul.Tensor(view_as_complex_112, _conj); view_as_complex_112 = None + view_1683 = torch.ops.aten.view.default(convert_element_type_2400, [2, 8192, 32, 64, 2]); convert_element_type_2400 = None + view_as_complex_113 = torch.ops.aten.view_as_complex.default(view_1683); view_1683 = None + mul_757 = torch.ops.aten.mul.Tensor(view_as_complex_113, _conj); view_as_complex_113 = None + view_as_real_112 = torch.ops.aten.view_as_real.default(mul_756); mul_756 = None + view_1684 = torch.ops.aten.view.default(view_as_real_112, [2, 8192, 8, 128]); view_as_real_112 = None + convert_element_type_2401 = torch.ops.prims.convert_element_type.default(view_1684, torch.bfloat16); view_1684 = None + view_as_real_113 = torch.ops.aten.view_as_real.default(mul_757); mul_757 = None + view_1685 = torch.ops.aten.view.default(view_as_real_113, [2, 8192, 32, 128]); view_as_real_113 = None + convert_element_type_2402 = torch.ops.prims.convert_element_type.default(view_1685, torch.bfloat16); view_1685 = None + view_1686 = torch.ops.aten.view.default(squeeze_48, [2, 8192, 1024]); squeeze_48 = None + view_1687 = torch.ops.aten.view.default(convert_element_type_2401, [2, 8192, 1024]); convert_element_type_2401 = None + view_1688 = torch.ops.aten.view.default(convert_element_type_2402, [2, 8192, 4096]); convert_element_type_2402 = None + view_1689 = torch.ops.aten.view.default(view_1686, [16384, 1024]); view_1686 = None + permute_1145 = torch.ops.aten.permute.default(view_1689, [1, 0]) + mm_571 = torch.ops.aten.mm.default(permute_1145, view_241); permute_1145 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16); primals_70 = None + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 256, '0'); convert_element_type_241 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_67, [1, 0]); wait_tensor_67 = None + permute_1147 = torch.ops.aten.permute.default(permute_79, [1, 0]); permute_79 = None + mm_572 = torch.ops.aten.mm.default(view_1689, permute_1147); view_1689 = permute_1147 = None + view_1690 = torch.ops.aten.view.default(mm_572, [2, 8192, 4096]); mm_572 = None + convert_element_type_2407 = torch.ops.prims.convert_element_type.default(mm_571, torch.float32); mm_571 = None + reduce_scatter_tensor_223 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2407, 'avg', 256, '0'); convert_element_type_2407 = None + wait_tensor_514 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_223); reduce_scatter_tensor_223 = None + view_1691 = torch.ops.aten.view.default(view_1687, [16384, 1024]); view_1687 = None + permute_1149 = torch.ops.aten.permute.default(view_1691, [1, 0]) + mm_573 = torch.ops.aten.mm.default(permute_1149, view_241); permute_1149 = None + permute_1151 = torch.ops.aten.permute.default(permute_78, [1, 0]); permute_78 = None + mm_574 = torch.ops.aten.mm.default(view_1691, permute_1151); view_1691 = permute_1151 = None + view_1692 = torch.ops.aten.view.default(mm_574, [2, 8192, 4096]); mm_574 = None + add_301 = torch.ops.aten.add.Tensor(view_1690, view_1692); view_1690 = view_1692 = None + convert_element_type_2412 = torch.ops.prims.convert_element_type.default(mm_573, torch.float32); mm_573 = None + reduce_scatter_tensor_224 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2412, 'avg', 256, '0'); convert_element_type_2412 = None + wait_tensor_515 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_224); reduce_scatter_tensor_224 = None + view_1693 = torch.ops.aten.view.default(view_1688, [16384, 4096]); view_1688 = None + permute_1153 = torch.ops.aten.permute.default(view_1693, [1, 0]) + mm_575 = torch.ops.aten.mm.default(permute_1153, view_241); permute_1153 = view_241 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16); primals_68 = None + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 256, '0'); convert_element_type_235 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + permute_1155 = torch.ops.aten.permute.default(permute_77, [1, 0]); permute_77 = None + mm_576 = torch.ops.aten.mm.default(view_1693, permute_1155); view_1693 = permute_1155 = None + view_1694 = torch.ops.aten.view.default(mm_576, [2, 8192, 4096]); mm_576 = None + add_302 = torch.ops.aten.add.Tensor(add_301, view_1694); add_301 = view_1694 = None + convert_element_type_2417 = torch.ops.prims.convert_element_type.default(mm_575, torch.float32); mm_575 = None + reduce_scatter_tensor_225 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2417, 'avg', 256, '0'); convert_element_type_2417 = None + wait_tensor_516 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_225); reduce_scatter_tensor_225 = None + convert_element_type_2418 = torch.ops.prims.convert_element_type.default(add_302, torch.float32); add_302 = None + convert_element_type_2420 = torch.ops.prims.convert_element_type.default(wait_tensor_64, torch.float32); wait_tensor_64 = None + mul_758 = torch.ops.aten.mul.Tensor(convert_element_type_2418, convert_element_type_2420); convert_element_type_2420 = None + mul_760 = torch.ops.aten.mul.Tensor(mul_56, mul_758) + sum_151 = torch.ops.aten.sum.dim_IntList(mul_760, [2], True); mul_760 = None + div_50 = torch.ops.aten.div.Tensor(mul_56, 4096) + mul_761 = torch.ops.aten.mul.Tensor(div_50, sum_151); div_50 = sum_151 = None + sub_75 = torch.ops.aten.sub.Tensor(mul_758, mul_761); mul_758 = mul_761 = None + mul_762 = torch.ops.aten.mul.Tensor(sub_75, rsqrt_14); sub_75 = rsqrt_14 = None + mul_763 = torch.ops.aten.mul.Tensor(convert_element_type_2418, mul_56); convert_element_type_2418 = mul_56 = None + sum_152 = torch.ops.aten.sum.dim_IntList(mul_763, [0, 1]); mul_763 = None + convert_element_type_2421 = torch.ops.prims.convert_element_type.default(mul_762, torch.bfloat16); mul_762 = None + add_303 = torch.ops.aten.add.Tensor(add_300, convert_element_type_2421); add_300 = convert_element_type_2421 = None + convert_element_type_default_15 = torch.ops.prims.convert_element_type.default(sum_152, torch.float32); sum_152 = None + reduce_scatter_tensor_226 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_15, 'avg', 256, '0'); convert_element_type_default_15 = None + wait_tensor_517 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_226); reduce_scatter_tensor_226 = None + view_1695 = torch.ops.aten.view.default(add_303, [16384, 4096]) + permute_1157 = torch.ops.aten.permute.default(view_1695, [1, 0]) + permute_72 = torch.ops.aten.permute.default(getitem_54, [0, 2, 1, 3]) + view_225 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16); primals_62 = None + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 256, '0'); convert_element_type_215 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + view_227 = torch.ops.aten.view.default(view_225, [16384, 4096]); view_225 = None + mm_45 = torch.ops.aten.mm.default(view_227, permute_73) + view_228 = torch.ops.aten.view.default(mm_45, [2, 8192, 4096]); mm_45 = None + add_25 = torch.ops.aten.add.Tensor(add_23, view_228); view_228 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16); primals_63 = None + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 256, '0'); convert_element_type_218 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32); add_25 = None + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_60) + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + view_231 = torch.ops.aten.view.default(convert_element_type_220, [16384, 4096]); convert_element_type_220 = None + view_232 = torch.ops.aten.view.default(mm_46, [2, 8192, 14336]); mm_46 = None + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_232, torch.float32); view_232 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16); primals_65 = None + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 256, '0'); convert_element_type_226 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_62, [1, 0]); wait_tensor_62 = None + mm_47 = torch.ops.aten.mm.default(view_231, permute_75) + view_235 = torch.ops.aten.view.default(mm_47, [2, 8192, 14336]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_235) + view_237 = torch.ops.aten.view.default(mul_55, [16384, 14336]); mul_55 = None + mm_577 = torch.ops.aten.mm.default(permute_1157, view_237); permute_1157 = view_237 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16); primals_66 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 256, '0'); convert_element_type_229 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + permute_1159 = torch.ops.aten.permute.default(permute_76, [1, 0]); permute_76 = None + mm_578 = torch.ops.aten.mm.default(view_1695, permute_1159); view_1695 = permute_1159 = None + view_1696 = torch.ops.aten.view.default(mm_578, [2, 8192, 14336]); mm_578 = None + convert_element_type_2428 = torch.ops.prims.convert_element_type.default(mm_577, torch.float32); mm_577 = None + reduce_scatter_tensor_227 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2428, 'avg', 256, '0'); convert_element_type_2428 = None + wait_tensor_518 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_227); reduce_scatter_tensor_227 = None + mul_764 = torch.ops.aten.mul.Tensor(view_1696, convert_element_type_225); convert_element_type_225 = None + mul_765 = torch.ops.aten.mul.Tensor(view_1696, view_235); view_1696 = view_235 = None + view_1697 = torch.ops.aten.view.default(mul_764, [16384, 14336]); mul_764 = None + permute_1161 = torch.ops.aten.permute.default(view_1697, [1, 0]) + mm_579 = torch.ops.aten.mm.default(permute_1161, view_231); permute_1161 = None + permute_1163 = torch.ops.aten.permute.default(permute_75, [1, 0]); permute_75 = None + mm_580 = torch.ops.aten.mm.default(view_1697, permute_1163); view_1697 = permute_1163 = None + view_1698 = torch.ops.aten.view.default(mm_580, [2, 8192, 4096]); mm_580 = None + convert_element_type_2433 = torch.ops.prims.convert_element_type.default(mm_579, torch.float32); mm_579 = None + reduce_scatter_tensor_228 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2433, 'avg', 256, '0'); convert_element_type_2433 = None + wait_tensor_519 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_228); reduce_scatter_tensor_228 = None + convert_element_type_2434 = torch.ops.prims.convert_element_type.default(mul_765, torch.float32); mul_765 = None + neg_25 = torch.ops.aten.neg.default(convert_element_type_224) + exp_25 = torch.ops.aten.exp.default(neg_25); neg_25 = None + add_304 = torch.ops.aten.add.Tensor(exp_25, 1); exp_25 = None + reciprocal_25 = torch.ops.aten.reciprocal.default(add_304); add_304 = None + mul_766 = torch.ops.aten.mul.Tensor(reciprocal_25, 1); reciprocal_25 = None + mul_767 = torch.ops.aten.mul.Tensor(convert_element_type_2434, mul_766); convert_element_type_2434 = None + sub_76 = torch.ops.aten.sub.Tensor(1, mul_766); mul_766 = None + mul_768 = torch.ops.aten.mul.Tensor(convert_element_type_224, sub_76); convert_element_type_224 = sub_76 = None + add_305 = torch.ops.aten.add.Tensor(mul_768, 1); mul_768 = None + mul_769 = torch.ops.aten.mul.Tensor(mul_767, add_305); mul_767 = add_305 = None + convert_element_type_2436 = torch.ops.prims.convert_element_type.default(mul_769, torch.bfloat16); mul_769 = None + view_1699 = torch.ops.aten.view.default(convert_element_type_2436, [16384, 14336]); convert_element_type_2436 = None + permute_1165 = torch.ops.aten.permute.default(view_1699, [1, 0]) + mm_581 = torch.ops.aten.mm.default(permute_1165, view_231); permute_1165 = view_231 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16); primals_64 = None + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 256, '0'); convert_element_type_221 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_61, [1, 0]); wait_tensor_61 = None + permute_1167 = torch.ops.aten.permute.default(permute_74, [1, 0]); permute_74 = None + mm_582 = torch.ops.aten.mm.default(view_1699, permute_1167); view_1699 = permute_1167 = None + view_1700 = torch.ops.aten.view.default(mm_582, [2, 8192, 4096]); mm_582 = None + add_306 = torch.ops.aten.add.Tensor(view_1698, view_1700); view_1698 = view_1700 = None + convert_element_type_2441 = torch.ops.prims.convert_element_type.default(mm_581, torch.float32); mm_581 = None + reduce_scatter_tensor_229 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2441, 'avg', 256, '0'); convert_element_type_2441 = None + wait_tensor_520 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_229); reduce_scatter_tensor_229 = None + convert_element_type_2442 = torch.ops.prims.convert_element_type.default(add_306, torch.float32); add_306 = None + convert_element_type_2444 = torch.ops.prims.convert_element_type.default(wait_tensor_60, torch.float32); wait_tensor_60 = None + mul_770 = torch.ops.aten.mul.Tensor(convert_element_type_2442, convert_element_type_2444); convert_element_type_2444 = None + mul_772 = torch.ops.aten.mul.Tensor(mul_52, mul_770) + sum_153 = torch.ops.aten.sum.dim_IntList(mul_772, [2], True); mul_772 = None + div_51 = torch.ops.aten.div.Tensor(mul_52, 4096) + mul_773 = torch.ops.aten.mul.Tensor(div_51, sum_153); div_51 = sum_153 = None + sub_77 = torch.ops.aten.sub.Tensor(mul_770, mul_773); mul_770 = mul_773 = None + mul_774 = torch.ops.aten.mul.Tensor(sub_77, rsqrt_13); sub_77 = rsqrt_13 = None + mul_775 = torch.ops.aten.mul.Tensor(convert_element_type_2442, mul_52); convert_element_type_2442 = mul_52 = None + sum_154 = torch.ops.aten.sum.dim_IntList(mul_775, [0, 1]); mul_775 = None + convert_element_type_2445 = torch.ops.prims.convert_element_type.default(mul_774, torch.bfloat16); mul_774 = None + add_307 = torch.ops.aten.add.Tensor(add_303, convert_element_type_2445); add_303 = convert_element_type_2445 = None + convert_element_type_default_14 = torch.ops.prims.convert_element_type.default(sum_154, torch.float32); sum_154 = None + reduce_scatter_tensor_230 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_14, 'avg', 256, '0'); convert_element_type_default_14 = None + wait_tensor_521 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_230); reduce_scatter_tensor_230 = None + view_1701 = torch.ops.aten.view.default(add_307, [16384, 4096]) + permute_1169 = torch.ops.aten.permute.default(view_1701, [1, 0]) + mm_583 = torch.ops.aten.mm.default(permute_1169, view_227); permute_1169 = view_227 = None + permute_1171 = torch.ops.aten.permute.default(permute_73, [1, 0]); permute_73 = None + mm_584 = torch.ops.aten.mm.default(view_1701, permute_1171); view_1701 = permute_1171 = None + view_1702 = torch.ops.aten.view.default(mm_584, [2, 8192, 4096]); mm_584 = None + convert_element_type_2452 = torch.ops.prims.convert_element_type.default(mm_583, torch.float32); mm_583 = None + reduce_scatter_tensor_231 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2452, 'avg', 256, '0'); convert_element_type_2452 = None + wait_tensor_522 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_231); reduce_scatter_tensor_231 = None + view_1703 = torch.ops.aten.view.default(view_1702, [2, 8192, 32, 128]); view_1702 = None + permute_1173 = torch.ops.aten.permute.default(view_1703, [0, 2, 1, 3]); view_1703 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16); primals_58 = None + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 256, '0'); convert_element_type_199 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32); add_23 = None + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_55) + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + view_207 = torch.ops.aten.view.default(convert_element_type_201, [16384, 4096]); convert_element_type_201 = None + view_208 = torch.ops.aten.view.default(mm_42, [2, 8192, 4096]); mm_42 = None + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16); primals_60 = None + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 256, '0'); convert_element_type_205 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_43 = torch.ops.aten.mm.default(view_207, permute_67) + view_211 = torch.ops.aten.view.default(mm_43, [2, 8192, 1024]); mm_43 = None + view_214 = torch.ops.aten.view.default(mm_44, [2, 8192, 1024]); mm_44 = None + view_215 = torch.ops.aten.view.default(view_208, [2, 8192, -1, 128]); view_208 = None + view_216 = torch.ops.aten.view.default(view_211, [2, 8192, -1, 128]); view_211 = None + view_217 = torch.ops.aten.view.default(view_214, [2, 8192, -1, 128]); view_214 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_215, torch.float32); view_215 = None + view_218 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 32, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_218); view_218 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_216, torch.float32); view_216 = None + view_219 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 8, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_219); view_219 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_16); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_221 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 32, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_16); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_222 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 8, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_221, torch.bfloat16); view_221 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_222, torch.bfloat16); view_222 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 8, 4, 128]); unsqueeze_12 = None + clone_12 = torch.ops.aten.clone.default(expand_12, memory_format = torch.contiguous_format); expand_12 = None + view_223 = torch.ops.aten.view.default(clone_12, [2, 8192, 32, 128]); clone_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_217, 3); view_217 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 8, 4, 128]); unsqueeze_13 = None + clone_13 = torch.ops.aten.clone.default(expand_13, memory_format = torch.contiguous_format); expand_13 = None + view_224 = torch.ops.aten.view.default(clone_13, [2, 8192, 32, 128]); clone_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_223, [0, 2, 1, 3]); view_223 = None + permute_71 = torch.ops.aten.permute.default(view_224, [0, 2, 1, 3]); view_224 = None + _scaled_dot_product_cudnn_attention_backward_25 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1173, permute_69, permute_70, permute_71, getitem_54, getitem_55, getitem_60, getitem_61, None, None, None, 8192, 8192, 0.0, True); permute_1173 = permute_69 = permute_70 = permute_71 = getitem_54 = getitem_55 = getitem_60 = getitem_61 = None + getitem_363 = _scaled_dot_product_cudnn_attention_backward_25[0] + getitem_364 = _scaled_dot_product_cudnn_attention_backward_25[1] + getitem_365 = _scaled_dot_product_cudnn_attention_backward_25[2]; _scaled_dot_product_cudnn_attention_backward_25 = None + permute_1174 = torch.ops.aten.permute.default(getitem_365, [0, 2, 1, 3]); getitem_365 = None + permute_1175 = torch.ops.aten.permute.default(getitem_364, [0, 2, 1, 3]); getitem_364 = None + permute_1176 = torch.ops.aten.permute.default(getitem_363, [0, 2, 1, 3]); getitem_363 = None + view_1704 = torch.ops.aten.view.default(permute_1174, [2, 8192, 8, 4, 128]); permute_1174 = None + sum_155 = torch.ops.aten.sum.dim_IntList(view_1704, [3], True); view_1704 = None + squeeze_50 = torch.ops.aten.squeeze.dim(sum_155, 3); sum_155 = None + view_1705 = torch.ops.aten.view.default(permute_1175, [2, 8192, 8, 4, 128]); permute_1175 = None + sum_156 = torch.ops.aten.sum.dim_IntList(view_1705, [3], True); view_1705 = None + squeeze_51 = torch.ops.aten.squeeze.dim(sum_156, 3); sum_156 = None + convert_element_type_2453 = torch.ops.prims.convert_element_type.default(squeeze_51, torch.float32); squeeze_51 = None + convert_element_type_2454 = torch.ops.prims.convert_element_type.default(permute_1176, torch.float32); permute_1176 = None + view_1706 = torch.ops.aten.view.default(convert_element_type_2453, [2, 8192, 8, 64, 2]); convert_element_type_2453 = None + view_as_complex_114 = torch.ops.aten.view_as_complex.default(view_1706); view_1706 = None + mul_776 = torch.ops.aten.mul.Tensor(view_as_complex_114, _conj); view_as_complex_114 = None + view_1707 = torch.ops.aten.view.default(convert_element_type_2454, [2, 8192, 32, 64, 2]); convert_element_type_2454 = None + view_as_complex_115 = torch.ops.aten.view_as_complex.default(view_1707); view_1707 = None + mul_777 = torch.ops.aten.mul.Tensor(view_as_complex_115, _conj); view_as_complex_115 = None + view_as_real_114 = torch.ops.aten.view_as_real.default(mul_776); mul_776 = None + view_1708 = torch.ops.aten.view.default(view_as_real_114, [2, 8192, 8, 128]); view_as_real_114 = None + convert_element_type_2455 = torch.ops.prims.convert_element_type.default(view_1708, torch.bfloat16); view_1708 = None + view_as_real_115 = torch.ops.aten.view_as_real.default(mul_777); mul_777 = None + view_1709 = torch.ops.aten.view.default(view_as_real_115, [2, 8192, 32, 128]); view_as_real_115 = None + convert_element_type_2456 = torch.ops.prims.convert_element_type.default(view_1709, torch.bfloat16); view_1709 = None + view_1710 = torch.ops.aten.view.default(squeeze_50, [2, 8192, 1024]); squeeze_50 = None + view_1711 = torch.ops.aten.view.default(convert_element_type_2455, [2, 8192, 1024]); convert_element_type_2455 = None + view_1712 = torch.ops.aten.view.default(convert_element_type_2456, [2, 8192, 4096]); convert_element_type_2456 = None + view_1713 = torch.ops.aten.view.default(view_1710, [16384, 1024]); view_1710 = None + permute_1177 = torch.ops.aten.permute.default(view_1713, [1, 0]) + mm_585 = torch.ops.aten.mm.default(permute_1177, view_207); permute_1177 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16); primals_61 = None + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 256, '0'); convert_element_type_208 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + permute_1179 = torch.ops.aten.permute.default(permute_68, [1, 0]); permute_68 = None + mm_586 = torch.ops.aten.mm.default(view_1713, permute_1179); view_1713 = permute_1179 = None + view_1714 = torch.ops.aten.view.default(mm_586, [2, 8192, 4096]); mm_586 = None + convert_element_type_2461 = torch.ops.prims.convert_element_type.default(mm_585, torch.float32); mm_585 = None + reduce_scatter_tensor_232 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2461, 'avg', 256, '0'); convert_element_type_2461 = None + wait_tensor_523 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_232); reduce_scatter_tensor_232 = None + view_1715 = torch.ops.aten.view.default(view_1711, [16384, 1024]); view_1711 = None + permute_1181 = torch.ops.aten.permute.default(view_1715, [1, 0]) + mm_587 = torch.ops.aten.mm.default(permute_1181, view_207); permute_1181 = None + permute_1183 = torch.ops.aten.permute.default(permute_67, [1, 0]); permute_67 = None + mm_588 = torch.ops.aten.mm.default(view_1715, permute_1183); view_1715 = permute_1183 = None + view_1716 = torch.ops.aten.view.default(mm_588, [2, 8192, 4096]); mm_588 = None + add_308 = torch.ops.aten.add.Tensor(view_1714, view_1716); view_1714 = view_1716 = None + convert_element_type_2466 = torch.ops.prims.convert_element_type.default(mm_587, torch.float32); mm_587 = None + reduce_scatter_tensor_233 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2466, 'avg', 256, '0'); convert_element_type_2466 = None + wait_tensor_524 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_233); reduce_scatter_tensor_233 = None + view_1717 = torch.ops.aten.view.default(view_1712, [16384, 4096]); view_1712 = None + permute_1185 = torch.ops.aten.permute.default(view_1717, [1, 0]) + mm_589 = torch.ops.aten.mm.default(permute_1185, view_207); permute_1185 = view_207 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16); primals_59 = None + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 256, '0'); convert_element_type_202 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + permute_1187 = torch.ops.aten.permute.default(permute_66, [1, 0]); permute_66 = None + mm_590 = torch.ops.aten.mm.default(view_1717, permute_1187); view_1717 = permute_1187 = None + view_1718 = torch.ops.aten.view.default(mm_590, [2, 8192, 4096]); mm_590 = None + add_309 = torch.ops.aten.add.Tensor(add_308, view_1718); add_308 = view_1718 = None + convert_element_type_2471 = torch.ops.prims.convert_element_type.default(mm_589, torch.float32); mm_589 = None + reduce_scatter_tensor_234 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2471, 'avg', 256, '0'); convert_element_type_2471 = None + wait_tensor_525 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_234); reduce_scatter_tensor_234 = None + convert_element_type_2472 = torch.ops.prims.convert_element_type.default(add_309, torch.float32); add_309 = None + convert_element_type_2474 = torch.ops.prims.convert_element_type.default(wait_tensor_55, torch.float32); wait_tensor_55 = None + mul_778 = torch.ops.aten.mul.Tensor(convert_element_type_2472, convert_element_type_2474); convert_element_type_2474 = None + mul_780 = torch.ops.aten.mul.Tensor(mul_48, mul_778) + sum_157 = torch.ops.aten.sum.dim_IntList(mul_780, [2], True); mul_780 = None + div_52 = torch.ops.aten.div.Tensor(mul_48, 4096) + mul_781 = torch.ops.aten.mul.Tensor(div_52, sum_157); div_52 = sum_157 = None + sub_78 = torch.ops.aten.sub.Tensor(mul_778, mul_781); mul_778 = mul_781 = None + mul_782 = torch.ops.aten.mul.Tensor(sub_78, rsqrt_12); sub_78 = rsqrt_12 = None + mul_783 = torch.ops.aten.mul.Tensor(convert_element_type_2472, mul_48); convert_element_type_2472 = mul_48 = None + sum_158 = torch.ops.aten.sum.dim_IntList(mul_783, [0, 1]); mul_783 = None + convert_element_type_2475 = torch.ops.prims.convert_element_type.default(mul_782, torch.bfloat16); mul_782 = None + add_310 = torch.ops.aten.add.Tensor(add_307, convert_element_type_2475); add_307 = convert_element_type_2475 = None + convert_element_type_default_13 = torch.ops.prims.convert_element_type.default(sum_158, torch.float32); sum_158 = None + reduce_scatter_tensor_235 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_13, 'avg', 256, '0'); convert_element_type_default_13 = None + wait_tensor_526 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_235); reduce_scatter_tensor_235 = None + view_1719 = torch.ops.aten.view.default(add_310, [16384, 4096]) + permute_1189 = torch.ops.aten.permute.default(view_1719, [1, 0]) + permute_61 = torch.ops.aten.permute.default(getitem_45, [0, 2, 1, 3]) + view_191 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16); primals_53 = None + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 256, '0'); convert_element_type_182 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + view_193 = torch.ops.aten.view.default(view_191, [16384, 4096]); view_191 = None + mm_38 = torch.ops.aten.mm.default(view_193, permute_62) + view_194 = torch.ops.aten.view.default(mm_38, [2, 8192, 4096]); mm_38 = None + add_21 = torch.ops.aten.add.Tensor(add_19, view_194); view_194 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16); primals_54 = None + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 256, '0'); convert_element_type_185 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32); add_21 = None + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_51) + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + view_197 = torch.ops.aten.view.default(convert_element_type_187, [16384, 4096]); convert_element_type_187 = None + view_198 = torch.ops.aten.view.default(mm_39, [2, 8192, 14336]); mm_39 = None + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_198, torch.float32); view_198 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16); primals_56 = None + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 256, '0'); convert_element_type_193 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_53, [1, 0]); wait_tensor_53 = None + mm_40 = torch.ops.aten.mm.default(view_197, permute_64) + view_201 = torch.ops.aten.view.default(mm_40, [2, 8192, 14336]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_201) + view_203 = torch.ops.aten.view.default(mul_47, [16384, 14336]); mul_47 = None + mm_591 = torch.ops.aten.mm.default(permute_1189, view_203); permute_1189 = view_203 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16); primals_57 = None + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 256, '0'); convert_element_type_196 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_54, [1, 0]); wait_tensor_54 = None + permute_1191 = torch.ops.aten.permute.default(permute_65, [1, 0]); permute_65 = None + mm_592 = torch.ops.aten.mm.default(view_1719, permute_1191); view_1719 = permute_1191 = None + view_1720 = torch.ops.aten.view.default(mm_592, [2, 8192, 14336]); mm_592 = None + convert_element_type_2482 = torch.ops.prims.convert_element_type.default(mm_591, torch.float32); mm_591 = None + reduce_scatter_tensor_236 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2482, 'avg', 256, '0'); convert_element_type_2482 = None + wait_tensor_527 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_236); reduce_scatter_tensor_236 = None + mul_784 = torch.ops.aten.mul.Tensor(view_1720, convert_element_type_192); convert_element_type_192 = None + mul_785 = torch.ops.aten.mul.Tensor(view_1720, view_201); view_1720 = view_201 = None + view_1721 = torch.ops.aten.view.default(mul_784, [16384, 14336]); mul_784 = None + permute_1193 = torch.ops.aten.permute.default(view_1721, [1, 0]) + mm_593 = torch.ops.aten.mm.default(permute_1193, view_197); permute_1193 = None + permute_1195 = torch.ops.aten.permute.default(permute_64, [1, 0]); permute_64 = None + mm_594 = torch.ops.aten.mm.default(view_1721, permute_1195); view_1721 = permute_1195 = None + view_1722 = torch.ops.aten.view.default(mm_594, [2, 8192, 4096]); mm_594 = None + convert_element_type_2487 = torch.ops.prims.convert_element_type.default(mm_593, torch.float32); mm_593 = None + reduce_scatter_tensor_237 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2487, 'avg', 256, '0'); convert_element_type_2487 = None + wait_tensor_528 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_237); reduce_scatter_tensor_237 = None + convert_element_type_2488 = torch.ops.prims.convert_element_type.default(mul_785, torch.float32); mul_785 = None + neg_26 = torch.ops.aten.neg.default(convert_element_type_191) + exp_26 = torch.ops.aten.exp.default(neg_26); neg_26 = None + add_311 = torch.ops.aten.add.Tensor(exp_26, 1); exp_26 = None + reciprocal_26 = torch.ops.aten.reciprocal.default(add_311); add_311 = None + mul_786 = torch.ops.aten.mul.Tensor(reciprocal_26, 1); reciprocal_26 = None + mul_787 = torch.ops.aten.mul.Tensor(convert_element_type_2488, mul_786); convert_element_type_2488 = None + sub_79 = torch.ops.aten.sub.Tensor(1, mul_786); mul_786 = None + mul_788 = torch.ops.aten.mul.Tensor(convert_element_type_191, sub_79); convert_element_type_191 = sub_79 = None + add_312 = torch.ops.aten.add.Tensor(mul_788, 1); mul_788 = None + mul_789 = torch.ops.aten.mul.Tensor(mul_787, add_312); mul_787 = add_312 = None + convert_element_type_2490 = torch.ops.prims.convert_element_type.default(mul_789, torch.bfloat16); mul_789 = None + view_1723 = torch.ops.aten.view.default(convert_element_type_2490, [16384, 14336]); convert_element_type_2490 = None + permute_1197 = torch.ops.aten.permute.default(view_1723, [1, 0]) + mm_595 = torch.ops.aten.mm.default(permute_1197, view_197); permute_1197 = view_197 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16); primals_55 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 256, '0'); convert_element_type_188 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + permute_1199 = torch.ops.aten.permute.default(permute_63, [1, 0]); permute_63 = None + mm_596 = torch.ops.aten.mm.default(view_1723, permute_1199); view_1723 = permute_1199 = None + view_1724 = torch.ops.aten.view.default(mm_596, [2, 8192, 4096]); mm_596 = None + add_313 = torch.ops.aten.add.Tensor(view_1722, view_1724); view_1722 = view_1724 = None + convert_element_type_2495 = torch.ops.prims.convert_element_type.default(mm_595, torch.float32); mm_595 = None + reduce_scatter_tensor_238 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2495, 'avg', 256, '0'); convert_element_type_2495 = None + wait_tensor_529 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_238); reduce_scatter_tensor_238 = None + convert_element_type_2496 = torch.ops.prims.convert_element_type.default(add_313, torch.float32); add_313 = None + convert_element_type_2498 = torch.ops.prims.convert_element_type.default(wait_tensor_51, torch.float32); wait_tensor_51 = None + mul_790 = torch.ops.aten.mul.Tensor(convert_element_type_2496, convert_element_type_2498); convert_element_type_2498 = None + mul_792 = torch.ops.aten.mul.Tensor(mul_44, mul_790) + sum_159 = torch.ops.aten.sum.dim_IntList(mul_792, [2], True); mul_792 = None + div_53 = torch.ops.aten.div.Tensor(mul_44, 4096) + mul_793 = torch.ops.aten.mul.Tensor(div_53, sum_159); div_53 = sum_159 = None + sub_80 = torch.ops.aten.sub.Tensor(mul_790, mul_793); mul_790 = mul_793 = None + mul_794 = torch.ops.aten.mul.Tensor(sub_80, rsqrt_11); sub_80 = rsqrt_11 = None + mul_795 = torch.ops.aten.mul.Tensor(convert_element_type_2496, mul_44); convert_element_type_2496 = mul_44 = None + sum_160 = torch.ops.aten.sum.dim_IntList(mul_795, [0, 1]); mul_795 = None + convert_element_type_2499 = torch.ops.prims.convert_element_type.default(mul_794, torch.bfloat16); mul_794 = None + add_314 = torch.ops.aten.add.Tensor(add_310, convert_element_type_2499); add_310 = convert_element_type_2499 = None + convert_element_type_default_12 = torch.ops.prims.convert_element_type.default(sum_160, torch.float32); sum_160 = None + reduce_scatter_tensor_239 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_12, 'avg', 256, '0'); convert_element_type_default_12 = None + wait_tensor_530 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_239); reduce_scatter_tensor_239 = None + view_1725 = torch.ops.aten.view.default(add_314, [16384, 4096]) + permute_1201 = torch.ops.aten.permute.default(view_1725, [1, 0]) + mm_597 = torch.ops.aten.mm.default(permute_1201, view_193); permute_1201 = view_193 = None + permute_1203 = torch.ops.aten.permute.default(permute_62, [1, 0]); permute_62 = None + mm_598 = torch.ops.aten.mm.default(view_1725, permute_1203); view_1725 = permute_1203 = None + view_1726 = torch.ops.aten.view.default(mm_598, [2, 8192, 4096]); mm_598 = None + convert_element_type_2506 = torch.ops.prims.convert_element_type.default(mm_597, torch.float32); mm_597 = None + reduce_scatter_tensor_240 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2506, 'avg', 256, '0'); convert_element_type_2506 = None + wait_tensor_531 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_240); reduce_scatter_tensor_240 = None + view_1727 = torch.ops.aten.view.default(view_1726, [2, 8192, 32, 128]); view_1726 = None + permute_1205 = torch.ops.aten.permute.default(view_1727, [0, 2, 1, 3]); view_1727 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16); primals_49 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 256, '0'); convert_element_type_166 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32); add_19 = None + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_46) + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + view_173 = torch.ops.aten.view.default(convert_element_type_168, [16384, 4096]); convert_element_type_168 = None + view_174 = torch.ops.aten.view.default(mm_35, [2, 8192, 4096]); mm_35 = None + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16); primals_51 = None + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 256, '0'); convert_element_type_172 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_48, [1, 0]); wait_tensor_48 = None + mm_36 = torch.ops.aten.mm.default(view_173, permute_56) + view_177 = torch.ops.aten.view.default(mm_36, [2, 8192, 1024]); mm_36 = None + view_180 = torch.ops.aten.view.default(mm_37, [2, 8192, 1024]); mm_37 = None + view_181 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + view_182 = torch.ops.aten.view.default(view_177, [2, 8192, -1, 128]); view_177 = None + view_183 = torch.ops.aten.view.default(view_180, [2, 8192, -1, 128]); view_180 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_181, torch.float32); view_181 = None + view_184 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 32, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_184); view_184 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_182, torch.float32); view_182 = None + view_185 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 8, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_185); view_185 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_16); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_187 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 32, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_16); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_188 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 8, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_187, torch.bfloat16); view_187 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_188, torch.bfloat16); view_188 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 8, 4, 128]); unsqueeze_10 = None + clone_10 = torch.ops.aten.clone.default(expand_10, memory_format = torch.contiguous_format); expand_10 = None + view_189 = torch.ops.aten.view.default(clone_10, [2, 8192, 32, 128]); clone_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_183, 3); view_183 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 8, 4, 128]); unsqueeze_11 = None + clone_11 = torch.ops.aten.clone.default(expand_11, memory_format = torch.contiguous_format); expand_11 = None + view_190 = torch.ops.aten.view.default(clone_11, [2, 8192, 32, 128]); clone_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_189, [0, 2, 1, 3]); view_189 = None + permute_60 = torch.ops.aten.permute.default(view_190, [0, 2, 1, 3]); view_190 = None + _scaled_dot_product_cudnn_attention_backward_26 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1205, permute_58, permute_59, permute_60, getitem_45, getitem_46, getitem_51, getitem_52, None, None, None, 8192, 8192, 0.0, True); permute_1205 = permute_58 = permute_59 = permute_60 = getitem_45 = getitem_46 = getitem_51 = getitem_52 = None + getitem_366 = _scaled_dot_product_cudnn_attention_backward_26[0] + getitem_367 = _scaled_dot_product_cudnn_attention_backward_26[1] + getitem_368 = _scaled_dot_product_cudnn_attention_backward_26[2]; _scaled_dot_product_cudnn_attention_backward_26 = None + permute_1206 = torch.ops.aten.permute.default(getitem_368, [0, 2, 1, 3]); getitem_368 = None + permute_1207 = torch.ops.aten.permute.default(getitem_367, [0, 2, 1, 3]); getitem_367 = None + permute_1208 = torch.ops.aten.permute.default(getitem_366, [0, 2, 1, 3]); getitem_366 = None + view_1728 = torch.ops.aten.view.default(permute_1206, [2, 8192, 8, 4, 128]); permute_1206 = None + sum_161 = torch.ops.aten.sum.dim_IntList(view_1728, [3], True); view_1728 = None + squeeze_52 = torch.ops.aten.squeeze.dim(sum_161, 3); sum_161 = None + view_1729 = torch.ops.aten.view.default(permute_1207, [2, 8192, 8, 4, 128]); permute_1207 = None + sum_162 = torch.ops.aten.sum.dim_IntList(view_1729, [3], True); view_1729 = None + squeeze_53 = torch.ops.aten.squeeze.dim(sum_162, 3); sum_162 = None + convert_element_type_2507 = torch.ops.prims.convert_element_type.default(squeeze_53, torch.float32); squeeze_53 = None + convert_element_type_2508 = torch.ops.prims.convert_element_type.default(permute_1208, torch.float32); permute_1208 = None + view_1730 = torch.ops.aten.view.default(convert_element_type_2507, [2, 8192, 8, 64, 2]); convert_element_type_2507 = None + view_as_complex_116 = torch.ops.aten.view_as_complex.default(view_1730); view_1730 = None + mul_796 = torch.ops.aten.mul.Tensor(view_as_complex_116, _conj); view_as_complex_116 = None + view_1731 = torch.ops.aten.view.default(convert_element_type_2508, [2, 8192, 32, 64, 2]); convert_element_type_2508 = None + view_as_complex_117 = torch.ops.aten.view_as_complex.default(view_1731); view_1731 = None + mul_797 = torch.ops.aten.mul.Tensor(view_as_complex_117, _conj); view_as_complex_117 = None + view_as_real_116 = torch.ops.aten.view_as_real.default(mul_796); mul_796 = None + view_1732 = torch.ops.aten.view.default(view_as_real_116, [2, 8192, 8, 128]); view_as_real_116 = None + convert_element_type_2509 = torch.ops.prims.convert_element_type.default(view_1732, torch.bfloat16); view_1732 = None + view_as_real_117 = torch.ops.aten.view_as_real.default(mul_797); mul_797 = None + view_1733 = torch.ops.aten.view.default(view_as_real_117, [2, 8192, 32, 128]); view_as_real_117 = None + convert_element_type_2510 = torch.ops.prims.convert_element_type.default(view_1733, torch.bfloat16); view_1733 = None + view_1734 = torch.ops.aten.view.default(squeeze_52, [2, 8192, 1024]); squeeze_52 = None + view_1735 = torch.ops.aten.view.default(convert_element_type_2509, [2, 8192, 1024]); convert_element_type_2509 = None + view_1736 = torch.ops.aten.view.default(convert_element_type_2510, [2, 8192, 4096]); convert_element_type_2510 = None + view_1737 = torch.ops.aten.view.default(view_1734, [16384, 1024]); view_1734 = None + permute_1209 = torch.ops.aten.permute.default(view_1737, [1, 0]) + mm_599 = torch.ops.aten.mm.default(permute_1209, view_173); permute_1209 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16); primals_52 = None + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 256, '0'); convert_element_type_175 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_49, [1, 0]); wait_tensor_49 = None + permute_1211 = torch.ops.aten.permute.default(permute_57, [1, 0]); permute_57 = None + mm_600 = torch.ops.aten.mm.default(view_1737, permute_1211); view_1737 = permute_1211 = None + view_1738 = torch.ops.aten.view.default(mm_600, [2, 8192, 4096]); mm_600 = None + convert_element_type_2515 = torch.ops.prims.convert_element_type.default(mm_599, torch.float32); mm_599 = None + reduce_scatter_tensor_241 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2515, 'avg', 256, '0'); convert_element_type_2515 = None + wait_tensor_532 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_241); reduce_scatter_tensor_241 = None + view_1739 = torch.ops.aten.view.default(view_1735, [16384, 1024]); view_1735 = None + permute_1213 = torch.ops.aten.permute.default(view_1739, [1, 0]) + mm_601 = torch.ops.aten.mm.default(permute_1213, view_173); permute_1213 = None + permute_1215 = torch.ops.aten.permute.default(permute_56, [1, 0]); permute_56 = None + mm_602 = torch.ops.aten.mm.default(view_1739, permute_1215); view_1739 = permute_1215 = None + view_1740 = torch.ops.aten.view.default(mm_602, [2, 8192, 4096]); mm_602 = None + add_315 = torch.ops.aten.add.Tensor(view_1738, view_1740); view_1738 = view_1740 = None + convert_element_type_2520 = torch.ops.prims.convert_element_type.default(mm_601, torch.float32); mm_601 = None + reduce_scatter_tensor_242 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2520, 'avg', 256, '0'); convert_element_type_2520 = None + wait_tensor_533 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_242); reduce_scatter_tensor_242 = None + view_1741 = torch.ops.aten.view.default(view_1736, [16384, 4096]); view_1736 = None + permute_1217 = torch.ops.aten.permute.default(view_1741, [1, 0]) + mm_603 = torch.ops.aten.mm.default(permute_1217, view_173); permute_1217 = view_173 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16); primals_50 = None + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 256, '0'); convert_element_type_169 = None + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_47, [1, 0]); wait_tensor_47 = None + permute_1219 = torch.ops.aten.permute.default(permute_55, [1, 0]); permute_55 = None + mm_604 = torch.ops.aten.mm.default(view_1741, permute_1219); view_1741 = permute_1219 = None + view_1742 = torch.ops.aten.view.default(mm_604, [2, 8192, 4096]); mm_604 = None + add_316 = torch.ops.aten.add.Tensor(add_315, view_1742); add_315 = view_1742 = None + convert_element_type_2525 = torch.ops.prims.convert_element_type.default(mm_603, torch.float32); mm_603 = None + reduce_scatter_tensor_243 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2525, 'avg', 256, '0'); convert_element_type_2525 = None + wait_tensor_534 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_243); reduce_scatter_tensor_243 = None + convert_element_type_2526 = torch.ops.prims.convert_element_type.default(add_316, torch.float32); add_316 = None + convert_element_type_2528 = torch.ops.prims.convert_element_type.default(wait_tensor_46, torch.float32); wait_tensor_46 = None + mul_798 = torch.ops.aten.mul.Tensor(convert_element_type_2526, convert_element_type_2528); convert_element_type_2528 = None + mul_800 = torch.ops.aten.mul.Tensor(mul_40, mul_798) + sum_163 = torch.ops.aten.sum.dim_IntList(mul_800, [2], True); mul_800 = None + div_54 = torch.ops.aten.div.Tensor(mul_40, 4096) + mul_801 = torch.ops.aten.mul.Tensor(div_54, sum_163); div_54 = sum_163 = None + sub_81 = torch.ops.aten.sub.Tensor(mul_798, mul_801); mul_798 = mul_801 = None + mul_802 = torch.ops.aten.mul.Tensor(sub_81, rsqrt_10); sub_81 = rsqrt_10 = None + mul_803 = torch.ops.aten.mul.Tensor(convert_element_type_2526, mul_40); convert_element_type_2526 = mul_40 = None + sum_164 = torch.ops.aten.sum.dim_IntList(mul_803, [0, 1]); mul_803 = None + convert_element_type_2529 = torch.ops.prims.convert_element_type.default(mul_802, torch.bfloat16); mul_802 = None + add_317 = torch.ops.aten.add.Tensor(add_314, convert_element_type_2529); add_314 = convert_element_type_2529 = None + convert_element_type_default_11 = torch.ops.prims.convert_element_type.default(sum_164, torch.float32); sum_164 = None + reduce_scatter_tensor_244 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_11, 'avg', 256, '0'); convert_element_type_default_11 = None + wait_tensor_535 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_244); reduce_scatter_tensor_244 = None + view_1743 = torch.ops.aten.view.default(add_317, [16384, 4096]) + permute_1221 = torch.ops.aten.permute.default(view_1743, [1, 0]) + permute_50 = torch.ops.aten.permute.default(getitem_36, [0, 2, 1, 3]) + view_157 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16); primals_44 = None + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 256, '0'); convert_element_type_149 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_41, [1, 0]); wait_tensor_41 = None + view_159 = torch.ops.aten.view.default(view_157, [16384, 4096]); view_157 = None + mm_31 = torch.ops.aten.mm.default(view_159, permute_51) + view_160 = torch.ops.aten.view.default(mm_31, [2, 8192, 4096]); mm_31 = None + add_17 = torch.ops.aten.add.Tensor(add_15, view_160); view_160 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16); primals_45 = None + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 256, '0'); convert_element_type_152 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32); add_17 = None + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_42) + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + view_163 = torch.ops.aten.view.default(convert_element_type_154, [16384, 4096]); convert_element_type_154 = None + view_164 = torch.ops.aten.view.default(mm_32, [2, 8192, 14336]); mm_32 = None + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_164, torch.float32); view_164 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16); primals_47 = None + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 256, '0'); convert_element_type_160 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_33 = torch.ops.aten.mm.default(view_163, permute_53) + view_167 = torch.ops.aten.view.default(mm_33, [2, 8192, 14336]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_167) + view_169 = torch.ops.aten.view.default(mul_39, [16384, 14336]); mul_39 = None + mm_605 = torch.ops.aten.mm.default(permute_1221, view_169); permute_1221 = view_169 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16); primals_48 = None + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 256, '0'); convert_element_type_163 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + permute_1223 = torch.ops.aten.permute.default(permute_54, [1, 0]); permute_54 = None + mm_606 = torch.ops.aten.mm.default(view_1743, permute_1223); view_1743 = permute_1223 = None + view_1744 = torch.ops.aten.view.default(mm_606, [2, 8192, 14336]); mm_606 = None + convert_element_type_2536 = torch.ops.prims.convert_element_type.default(mm_605, torch.float32); mm_605 = None + reduce_scatter_tensor_245 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2536, 'avg', 256, '0'); convert_element_type_2536 = None + wait_tensor_536 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_245); reduce_scatter_tensor_245 = None + mul_804 = torch.ops.aten.mul.Tensor(view_1744, convert_element_type_159); convert_element_type_159 = None + mul_805 = torch.ops.aten.mul.Tensor(view_1744, view_167); view_1744 = view_167 = None + view_1745 = torch.ops.aten.view.default(mul_804, [16384, 14336]); mul_804 = None + permute_1225 = torch.ops.aten.permute.default(view_1745, [1, 0]) + mm_607 = torch.ops.aten.mm.default(permute_1225, view_163); permute_1225 = None + permute_1227 = torch.ops.aten.permute.default(permute_53, [1, 0]); permute_53 = None + mm_608 = torch.ops.aten.mm.default(view_1745, permute_1227); view_1745 = permute_1227 = None + view_1746 = torch.ops.aten.view.default(mm_608, [2, 8192, 4096]); mm_608 = None + convert_element_type_2541 = torch.ops.prims.convert_element_type.default(mm_607, torch.float32); mm_607 = None + reduce_scatter_tensor_246 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2541, 'avg', 256, '0'); convert_element_type_2541 = None + wait_tensor_537 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_246); reduce_scatter_tensor_246 = None + convert_element_type_2542 = torch.ops.prims.convert_element_type.default(mul_805, torch.float32); mul_805 = None + neg_27 = torch.ops.aten.neg.default(convert_element_type_158) + exp_27 = torch.ops.aten.exp.default(neg_27); neg_27 = None + add_318 = torch.ops.aten.add.Tensor(exp_27, 1); exp_27 = None + reciprocal_27 = torch.ops.aten.reciprocal.default(add_318); add_318 = None + mul_806 = torch.ops.aten.mul.Tensor(reciprocal_27, 1); reciprocal_27 = None + mul_807 = torch.ops.aten.mul.Tensor(convert_element_type_2542, mul_806); convert_element_type_2542 = None + sub_82 = torch.ops.aten.sub.Tensor(1, mul_806); mul_806 = None + mul_808 = torch.ops.aten.mul.Tensor(convert_element_type_158, sub_82); convert_element_type_158 = sub_82 = None + add_319 = torch.ops.aten.add.Tensor(mul_808, 1); mul_808 = None + mul_809 = torch.ops.aten.mul.Tensor(mul_807, add_319); mul_807 = add_319 = None + convert_element_type_2544 = torch.ops.prims.convert_element_type.default(mul_809, torch.bfloat16); mul_809 = None + view_1747 = torch.ops.aten.view.default(convert_element_type_2544, [16384, 14336]); convert_element_type_2544 = None + permute_1229 = torch.ops.aten.permute.default(view_1747, [1, 0]) + mm_609 = torch.ops.aten.mm.default(permute_1229, view_163); permute_1229 = view_163 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16); primals_46 = None + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 256, '0'); convert_element_type_155 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + permute_1231 = torch.ops.aten.permute.default(permute_52, [1, 0]); permute_52 = None + mm_610 = torch.ops.aten.mm.default(view_1747, permute_1231); view_1747 = permute_1231 = None + view_1748 = torch.ops.aten.view.default(mm_610, [2, 8192, 4096]); mm_610 = None + add_320 = torch.ops.aten.add.Tensor(view_1746, view_1748); view_1746 = view_1748 = None + convert_element_type_2549 = torch.ops.prims.convert_element_type.default(mm_609, torch.float32); mm_609 = None + reduce_scatter_tensor_247 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2549, 'avg', 256, '0'); convert_element_type_2549 = None + wait_tensor_538 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_247); reduce_scatter_tensor_247 = None + convert_element_type_2550 = torch.ops.prims.convert_element_type.default(add_320, torch.float32); add_320 = None + convert_element_type_2552 = torch.ops.prims.convert_element_type.default(wait_tensor_42, torch.float32); wait_tensor_42 = None + mul_810 = torch.ops.aten.mul.Tensor(convert_element_type_2550, convert_element_type_2552); convert_element_type_2552 = None + mul_812 = torch.ops.aten.mul.Tensor(mul_36, mul_810) + sum_165 = torch.ops.aten.sum.dim_IntList(mul_812, [2], True); mul_812 = None + div_55 = torch.ops.aten.div.Tensor(mul_36, 4096) + mul_813 = torch.ops.aten.mul.Tensor(div_55, sum_165); div_55 = sum_165 = None + sub_83 = torch.ops.aten.sub.Tensor(mul_810, mul_813); mul_810 = mul_813 = None + mul_814 = torch.ops.aten.mul.Tensor(sub_83, rsqrt_9); sub_83 = rsqrt_9 = None + mul_815 = torch.ops.aten.mul.Tensor(convert_element_type_2550, mul_36); convert_element_type_2550 = mul_36 = None + sum_166 = torch.ops.aten.sum.dim_IntList(mul_815, [0, 1]); mul_815 = None + convert_element_type_2553 = torch.ops.prims.convert_element_type.default(mul_814, torch.bfloat16); mul_814 = None + add_321 = torch.ops.aten.add.Tensor(add_317, convert_element_type_2553); add_317 = convert_element_type_2553 = None + convert_element_type_default_10 = torch.ops.prims.convert_element_type.default(sum_166, torch.float32); sum_166 = None + reduce_scatter_tensor_248 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_10, 'avg', 256, '0'); convert_element_type_default_10 = None + wait_tensor_539 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_248); reduce_scatter_tensor_248 = None + view_1749 = torch.ops.aten.view.default(add_321, [16384, 4096]) + permute_1233 = torch.ops.aten.permute.default(view_1749, [1, 0]) + mm_611 = torch.ops.aten.mm.default(permute_1233, view_159); permute_1233 = view_159 = None + permute_1235 = torch.ops.aten.permute.default(permute_51, [1, 0]); permute_51 = None + mm_612 = torch.ops.aten.mm.default(view_1749, permute_1235); view_1749 = permute_1235 = None + view_1750 = torch.ops.aten.view.default(mm_612, [2, 8192, 4096]); mm_612 = None + convert_element_type_2560 = torch.ops.prims.convert_element_type.default(mm_611, torch.float32); mm_611 = None + reduce_scatter_tensor_249 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2560, 'avg', 256, '0'); convert_element_type_2560 = None + wait_tensor_540 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_249); reduce_scatter_tensor_249 = None + view_1751 = torch.ops.aten.view.default(view_1750, [2, 8192, 32, 128]); view_1750 = None + permute_1237 = torch.ops.aten.permute.default(view_1751, [0, 2, 1, 3]); view_1751 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16); primals_40 = None + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 256, '0'); convert_element_type_133 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32); add_15 = None + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_37) + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + view_139 = torch.ops.aten.view.default(convert_element_type_135, [16384, 4096]); convert_element_type_135 = None + view_140 = torch.ops.aten.view.default(mm_28, [2, 8192, 4096]); mm_28 = None + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16); primals_42 = None + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 256, '0'); convert_element_type_139 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + mm_29 = torch.ops.aten.mm.default(view_139, permute_45) + view_143 = torch.ops.aten.view.default(mm_29, [2, 8192, 1024]); mm_29 = None + view_146 = torch.ops.aten.view.default(mm_30, [2, 8192, 1024]); mm_30 = None + view_147 = torch.ops.aten.view.default(view_140, [2, 8192, -1, 128]); view_140 = None + view_148 = torch.ops.aten.view.default(view_143, [2, 8192, -1, 128]); view_143 = None + view_149 = torch.ops.aten.view.default(view_146, [2, 8192, -1, 128]); view_146 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_147, torch.float32); view_147 = None + view_150 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 32, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_150); view_150 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_148, torch.float32); view_148 = None + view_151 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 8, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_151); view_151 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_16); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_153 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 32, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_16); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_154 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 8, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_153, torch.bfloat16); view_153 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_154, torch.bfloat16); view_154 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 8, 4, 128]); unsqueeze_8 = None + clone_8 = torch.ops.aten.clone.default(expand_8, memory_format = torch.contiguous_format); expand_8 = None + view_155 = torch.ops.aten.view.default(clone_8, [2, 8192, 32, 128]); clone_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_149, 3); view_149 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 8, 4, 128]); unsqueeze_9 = None + clone_9 = torch.ops.aten.clone.default(expand_9, memory_format = torch.contiguous_format); expand_9 = None + view_156 = torch.ops.aten.view.default(clone_9, [2, 8192, 32, 128]); clone_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_155, [0, 2, 1, 3]); view_155 = None + permute_49 = torch.ops.aten.permute.default(view_156, [0, 2, 1, 3]); view_156 = None + _scaled_dot_product_cudnn_attention_backward_27 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1237, permute_47, permute_48, permute_49, getitem_36, getitem_37, getitem_42, getitem_43, None, None, None, 8192, 8192, 0.0, True); permute_1237 = permute_47 = permute_48 = permute_49 = getitem_36 = getitem_37 = getitem_42 = getitem_43 = None + getitem_369 = _scaled_dot_product_cudnn_attention_backward_27[0] + getitem_370 = _scaled_dot_product_cudnn_attention_backward_27[1] + getitem_371 = _scaled_dot_product_cudnn_attention_backward_27[2]; _scaled_dot_product_cudnn_attention_backward_27 = None + permute_1238 = torch.ops.aten.permute.default(getitem_371, [0, 2, 1, 3]); getitem_371 = None + permute_1239 = torch.ops.aten.permute.default(getitem_370, [0, 2, 1, 3]); getitem_370 = None + permute_1240 = torch.ops.aten.permute.default(getitem_369, [0, 2, 1, 3]); getitem_369 = None + view_1752 = torch.ops.aten.view.default(permute_1238, [2, 8192, 8, 4, 128]); permute_1238 = None + sum_167 = torch.ops.aten.sum.dim_IntList(view_1752, [3], True); view_1752 = None + squeeze_54 = torch.ops.aten.squeeze.dim(sum_167, 3); sum_167 = None + view_1753 = torch.ops.aten.view.default(permute_1239, [2, 8192, 8, 4, 128]); permute_1239 = None + sum_168 = torch.ops.aten.sum.dim_IntList(view_1753, [3], True); view_1753 = None + squeeze_55 = torch.ops.aten.squeeze.dim(sum_168, 3); sum_168 = None + convert_element_type_2561 = torch.ops.prims.convert_element_type.default(squeeze_55, torch.float32); squeeze_55 = None + convert_element_type_2562 = torch.ops.prims.convert_element_type.default(permute_1240, torch.float32); permute_1240 = None + view_1754 = torch.ops.aten.view.default(convert_element_type_2561, [2, 8192, 8, 64, 2]); convert_element_type_2561 = None + view_as_complex_118 = torch.ops.aten.view_as_complex.default(view_1754); view_1754 = None + mul_816 = torch.ops.aten.mul.Tensor(view_as_complex_118, _conj); view_as_complex_118 = None + view_1755 = torch.ops.aten.view.default(convert_element_type_2562, [2, 8192, 32, 64, 2]); convert_element_type_2562 = None + view_as_complex_119 = torch.ops.aten.view_as_complex.default(view_1755); view_1755 = None + mul_817 = torch.ops.aten.mul.Tensor(view_as_complex_119, _conj); view_as_complex_119 = None + view_as_real_118 = torch.ops.aten.view_as_real.default(mul_816); mul_816 = None + view_1756 = torch.ops.aten.view.default(view_as_real_118, [2, 8192, 8, 128]); view_as_real_118 = None + convert_element_type_2563 = torch.ops.prims.convert_element_type.default(view_1756, torch.bfloat16); view_1756 = None + view_as_real_119 = torch.ops.aten.view_as_real.default(mul_817); mul_817 = None + view_1757 = torch.ops.aten.view.default(view_as_real_119, [2, 8192, 32, 128]); view_as_real_119 = None + convert_element_type_2564 = torch.ops.prims.convert_element_type.default(view_1757, torch.bfloat16); view_1757 = None + view_1758 = torch.ops.aten.view.default(squeeze_54, [2, 8192, 1024]); squeeze_54 = None + view_1759 = torch.ops.aten.view.default(convert_element_type_2563, [2, 8192, 1024]); convert_element_type_2563 = None + view_1760 = torch.ops.aten.view.default(convert_element_type_2564, [2, 8192, 4096]); convert_element_type_2564 = None + view_1761 = torch.ops.aten.view.default(view_1758, [16384, 1024]); view_1758 = None + permute_1241 = torch.ops.aten.permute.default(view_1761, [1, 0]) + mm_613 = torch.ops.aten.mm.default(permute_1241, view_139); permute_1241 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16); primals_43 = None + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 256, '0'); convert_element_type_142 = None + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_40, [1, 0]); wait_tensor_40 = None + permute_1243 = torch.ops.aten.permute.default(permute_46, [1, 0]); permute_46 = None + mm_614 = torch.ops.aten.mm.default(view_1761, permute_1243); view_1761 = permute_1243 = None + view_1762 = torch.ops.aten.view.default(mm_614, [2, 8192, 4096]); mm_614 = None + convert_element_type_2569 = torch.ops.prims.convert_element_type.default(mm_613, torch.float32); mm_613 = None + reduce_scatter_tensor_250 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2569, 'avg', 256, '0'); convert_element_type_2569 = None + wait_tensor_541 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_250); reduce_scatter_tensor_250 = None + view_1763 = torch.ops.aten.view.default(view_1759, [16384, 1024]); view_1759 = None + permute_1245 = torch.ops.aten.permute.default(view_1763, [1, 0]) + mm_615 = torch.ops.aten.mm.default(permute_1245, view_139); permute_1245 = None + permute_1247 = torch.ops.aten.permute.default(permute_45, [1, 0]); permute_45 = None + mm_616 = torch.ops.aten.mm.default(view_1763, permute_1247); view_1763 = permute_1247 = None + view_1764 = torch.ops.aten.view.default(mm_616, [2, 8192, 4096]); mm_616 = None + add_322 = torch.ops.aten.add.Tensor(view_1762, view_1764); view_1762 = view_1764 = None + convert_element_type_2574 = torch.ops.prims.convert_element_type.default(mm_615, torch.float32); mm_615 = None + reduce_scatter_tensor_251 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2574, 'avg', 256, '0'); convert_element_type_2574 = None + wait_tensor_542 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_251); reduce_scatter_tensor_251 = None + view_1765 = torch.ops.aten.view.default(view_1760, [16384, 4096]); view_1760 = None + permute_1249 = torch.ops.aten.permute.default(view_1765, [1, 0]) + mm_617 = torch.ops.aten.mm.default(permute_1249, view_139); permute_1249 = view_139 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16); primals_41 = None + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 256, '0'); convert_element_type_136 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + permute_1251 = torch.ops.aten.permute.default(permute_44, [1, 0]); permute_44 = None + mm_618 = torch.ops.aten.mm.default(view_1765, permute_1251); view_1765 = permute_1251 = None + view_1766 = torch.ops.aten.view.default(mm_618, [2, 8192, 4096]); mm_618 = None + add_323 = torch.ops.aten.add.Tensor(add_322, view_1766); add_322 = view_1766 = None + convert_element_type_2579 = torch.ops.prims.convert_element_type.default(mm_617, torch.float32); mm_617 = None + reduce_scatter_tensor_252 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2579, 'avg', 256, '0'); convert_element_type_2579 = None + wait_tensor_543 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_252); reduce_scatter_tensor_252 = None + convert_element_type_2580 = torch.ops.prims.convert_element_type.default(add_323, torch.float32); add_323 = None + convert_element_type_2582 = torch.ops.prims.convert_element_type.default(wait_tensor_37, torch.float32); wait_tensor_37 = None + mul_818 = torch.ops.aten.mul.Tensor(convert_element_type_2580, convert_element_type_2582); convert_element_type_2582 = None + mul_820 = torch.ops.aten.mul.Tensor(mul_32, mul_818) + sum_169 = torch.ops.aten.sum.dim_IntList(mul_820, [2], True); mul_820 = None + div_56 = torch.ops.aten.div.Tensor(mul_32, 4096) + mul_821 = torch.ops.aten.mul.Tensor(div_56, sum_169); div_56 = sum_169 = None + sub_84 = torch.ops.aten.sub.Tensor(mul_818, mul_821); mul_818 = mul_821 = None + mul_822 = torch.ops.aten.mul.Tensor(sub_84, rsqrt_8); sub_84 = rsqrt_8 = None + mul_823 = torch.ops.aten.mul.Tensor(convert_element_type_2580, mul_32); convert_element_type_2580 = mul_32 = None + sum_170 = torch.ops.aten.sum.dim_IntList(mul_823, [0, 1]); mul_823 = None + convert_element_type_2583 = torch.ops.prims.convert_element_type.default(mul_822, torch.bfloat16); mul_822 = None + add_324 = torch.ops.aten.add.Tensor(add_321, convert_element_type_2583); add_321 = convert_element_type_2583 = None + convert_element_type_default_9 = torch.ops.prims.convert_element_type.default(sum_170, torch.float32); sum_170 = None + reduce_scatter_tensor_253 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_9, 'avg', 256, '0'); convert_element_type_default_9 = None + wait_tensor_544 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_253); reduce_scatter_tensor_253 = None + view_1767 = torch.ops.aten.view.default(add_324, [16384, 4096]) + permute_1253 = torch.ops.aten.permute.default(view_1767, [1, 0]) + permute_39 = torch.ops.aten.permute.default(getitem_27, [0, 2, 1, 3]) + view_123 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16); primals_35 = None + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 256, '0'); convert_element_type_116 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + view_125 = torch.ops.aten.view.default(view_123, [16384, 4096]); view_123 = None + mm_24 = torch.ops.aten.mm.default(view_125, permute_40) + view_126 = torch.ops.aten.view.default(mm_24, [2, 8192, 4096]); mm_24 = None + add_13 = torch.ops.aten.add.Tensor(add_11, view_126); view_126 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16); primals_36 = None + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 256, '0'); convert_element_type_119 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32); add_13 = None + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_33) + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + view_129 = torch.ops.aten.view.default(convert_element_type_121, [16384, 4096]); convert_element_type_121 = None + view_130 = torch.ops.aten.view.default(mm_25, [2, 8192, 14336]); mm_25 = None + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_130, torch.float32); view_130 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16); primals_38 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 256, '0'); convert_element_type_127 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_35, [1, 0]); wait_tensor_35 = None + mm_26 = torch.ops.aten.mm.default(view_129, permute_42) + view_133 = torch.ops.aten.view.default(mm_26, [2, 8192, 14336]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_133) + view_135 = torch.ops.aten.view.default(mul_31, [16384, 14336]); mul_31 = None + mm_619 = torch.ops.aten.mm.default(permute_1253, view_135); permute_1253 = view_135 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16); primals_39 = None + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 256, '0'); convert_element_type_130 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_36, [1, 0]); wait_tensor_36 = None + permute_1255 = torch.ops.aten.permute.default(permute_43, [1, 0]); permute_43 = None + mm_620 = torch.ops.aten.mm.default(view_1767, permute_1255); view_1767 = permute_1255 = None + view_1768 = torch.ops.aten.view.default(mm_620, [2, 8192, 14336]); mm_620 = None + convert_element_type_2590 = torch.ops.prims.convert_element_type.default(mm_619, torch.float32); mm_619 = None + reduce_scatter_tensor_254 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2590, 'avg', 256, '0'); convert_element_type_2590 = None + wait_tensor_545 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_254); reduce_scatter_tensor_254 = None + mul_824 = torch.ops.aten.mul.Tensor(view_1768, convert_element_type_126); convert_element_type_126 = None + mul_825 = torch.ops.aten.mul.Tensor(view_1768, view_133); view_1768 = view_133 = None + view_1769 = torch.ops.aten.view.default(mul_824, [16384, 14336]); mul_824 = None + permute_1257 = torch.ops.aten.permute.default(view_1769, [1, 0]) + mm_621 = torch.ops.aten.mm.default(permute_1257, view_129); permute_1257 = None + permute_1259 = torch.ops.aten.permute.default(permute_42, [1, 0]); permute_42 = None + mm_622 = torch.ops.aten.mm.default(view_1769, permute_1259); view_1769 = permute_1259 = None + view_1770 = torch.ops.aten.view.default(mm_622, [2, 8192, 4096]); mm_622 = None + convert_element_type_2595 = torch.ops.prims.convert_element_type.default(mm_621, torch.float32); mm_621 = None + reduce_scatter_tensor_255 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2595, 'avg', 256, '0'); convert_element_type_2595 = None + wait_tensor_546 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_255); reduce_scatter_tensor_255 = None + convert_element_type_2596 = torch.ops.prims.convert_element_type.default(mul_825, torch.float32); mul_825 = None + neg_28 = torch.ops.aten.neg.default(convert_element_type_125) + exp_28 = torch.ops.aten.exp.default(neg_28); neg_28 = None + add_325 = torch.ops.aten.add.Tensor(exp_28, 1); exp_28 = None + reciprocal_28 = torch.ops.aten.reciprocal.default(add_325); add_325 = None + mul_826 = torch.ops.aten.mul.Tensor(reciprocal_28, 1); reciprocal_28 = None + mul_827 = torch.ops.aten.mul.Tensor(convert_element_type_2596, mul_826); convert_element_type_2596 = None + sub_85 = torch.ops.aten.sub.Tensor(1, mul_826); mul_826 = None + mul_828 = torch.ops.aten.mul.Tensor(convert_element_type_125, sub_85); convert_element_type_125 = sub_85 = None + add_326 = torch.ops.aten.add.Tensor(mul_828, 1); mul_828 = None + mul_829 = torch.ops.aten.mul.Tensor(mul_827, add_326); mul_827 = add_326 = None + convert_element_type_2598 = torch.ops.prims.convert_element_type.default(mul_829, torch.bfloat16); mul_829 = None + view_1771 = torch.ops.aten.view.default(convert_element_type_2598, [16384, 14336]); convert_element_type_2598 = None + permute_1261 = torch.ops.aten.permute.default(view_1771, [1, 0]) + mm_623 = torch.ops.aten.mm.default(permute_1261, view_129); permute_1261 = view_129 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16); primals_37 = None + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 256, '0'); convert_element_type_122 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_34, [1, 0]); wait_tensor_34 = None + permute_1263 = torch.ops.aten.permute.default(permute_41, [1, 0]); permute_41 = None + mm_624 = torch.ops.aten.mm.default(view_1771, permute_1263); view_1771 = permute_1263 = None + view_1772 = torch.ops.aten.view.default(mm_624, [2, 8192, 4096]); mm_624 = None + add_327 = torch.ops.aten.add.Tensor(view_1770, view_1772); view_1770 = view_1772 = None + convert_element_type_2603 = torch.ops.prims.convert_element_type.default(mm_623, torch.float32); mm_623 = None + reduce_scatter_tensor_256 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2603, 'avg', 256, '0'); convert_element_type_2603 = None + wait_tensor_547 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_256); reduce_scatter_tensor_256 = None + convert_element_type_2604 = torch.ops.prims.convert_element_type.default(add_327, torch.float32); add_327 = None + convert_element_type_2606 = torch.ops.prims.convert_element_type.default(wait_tensor_33, torch.float32); wait_tensor_33 = None + mul_830 = torch.ops.aten.mul.Tensor(convert_element_type_2604, convert_element_type_2606); convert_element_type_2606 = None + mul_832 = torch.ops.aten.mul.Tensor(mul_28, mul_830) + sum_171 = torch.ops.aten.sum.dim_IntList(mul_832, [2], True); mul_832 = None + div_57 = torch.ops.aten.div.Tensor(mul_28, 4096) + mul_833 = torch.ops.aten.mul.Tensor(div_57, sum_171); div_57 = sum_171 = None + sub_86 = torch.ops.aten.sub.Tensor(mul_830, mul_833); mul_830 = mul_833 = None + mul_834 = torch.ops.aten.mul.Tensor(sub_86, rsqrt_7); sub_86 = rsqrt_7 = None + mul_835 = torch.ops.aten.mul.Tensor(convert_element_type_2604, mul_28); convert_element_type_2604 = mul_28 = None + sum_172 = torch.ops.aten.sum.dim_IntList(mul_835, [0, 1]); mul_835 = None + convert_element_type_2607 = torch.ops.prims.convert_element_type.default(mul_834, torch.bfloat16); mul_834 = None + add_328 = torch.ops.aten.add.Tensor(add_324, convert_element_type_2607); add_324 = convert_element_type_2607 = None + convert_element_type_default_8 = torch.ops.prims.convert_element_type.default(sum_172, torch.float32); sum_172 = None + reduce_scatter_tensor_257 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_8, 'avg', 256, '0'); convert_element_type_default_8 = None + wait_tensor_548 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_257); reduce_scatter_tensor_257 = None + view_1773 = torch.ops.aten.view.default(add_328, [16384, 4096]) + permute_1265 = torch.ops.aten.permute.default(view_1773, [1, 0]) + mm_625 = torch.ops.aten.mm.default(permute_1265, view_125); permute_1265 = view_125 = None + permute_1267 = torch.ops.aten.permute.default(permute_40, [1, 0]); permute_40 = None + mm_626 = torch.ops.aten.mm.default(view_1773, permute_1267); view_1773 = permute_1267 = None + view_1774 = torch.ops.aten.view.default(mm_626, [2, 8192, 4096]); mm_626 = None + convert_element_type_2614 = torch.ops.prims.convert_element_type.default(mm_625, torch.float32); mm_625 = None + reduce_scatter_tensor_258 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2614, 'avg', 256, '0'); convert_element_type_2614 = None + wait_tensor_549 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_258); reduce_scatter_tensor_258 = None + view_1775 = torch.ops.aten.view.default(view_1774, [2, 8192, 32, 128]); view_1774 = None + permute_1269 = torch.ops.aten.permute.default(view_1775, [0, 2, 1, 3]); view_1775 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16); primals_31 = None + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 256, '0'); convert_element_type_100 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32); add_11 = None + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_28) + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + view_105 = torch.ops.aten.view.default(convert_element_type_102, [16384, 4096]); convert_element_type_102 = None + view_106 = torch.ops.aten.view.default(mm_21, [2, 8192, 4096]); mm_21 = None + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16); primals_33 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 256, '0'); convert_element_type_106 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + mm_22 = torch.ops.aten.mm.default(view_105, permute_34) + view_109 = torch.ops.aten.view.default(mm_22, [2, 8192, 1024]); mm_22 = None + view_112 = torch.ops.aten.view.default(mm_23, [2, 8192, 1024]); mm_23 = None + view_113 = torch.ops.aten.view.default(view_106, [2, 8192, -1, 128]); view_106 = None + view_114 = torch.ops.aten.view.default(view_109, [2, 8192, -1, 128]); view_109 = None + view_115 = torch.ops.aten.view.default(view_112, [2, 8192, -1, 128]); view_112 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_113, torch.float32); view_113 = None + view_116 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 32, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_116); view_116 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_114, torch.float32); view_114 = None + view_117 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 8, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_117); view_117 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_16); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_119 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 32, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_16); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_120 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 8, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_119, torch.bfloat16); view_119 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_120, torch.bfloat16); view_120 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 8, 4, 128]); unsqueeze_6 = None + clone_6 = torch.ops.aten.clone.default(expand_6, memory_format = torch.contiguous_format); expand_6 = None + view_121 = torch.ops.aten.view.default(clone_6, [2, 8192, 32, 128]); clone_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_115, 3); view_115 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 8, 4, 128]); unsqueeze_7 = None + clone_7 = torch.ops.aten.clone.default(expand_7, memory_format = torch.contiguous_format); expand_7 = None + view_122 = torch.ops.aten.view.default(clone_7, [2, 8192, 32, 128]); clone_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_121, [0, 2, 1, 3]); view_121 = None + permute_38 = torch.ops.aten.permute.default(view_122, [0, 2, 1, 3]); view_122 = None + _scaled_dot_product_cudnn_attention_backward_28 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1269, permute_36, permute_37, permute_38, getitem_27, getitem_28, getitem_33, getitem_34, None, None, None, 8192, 8192, 0.0, True); permute_1269 = permute_36 = permute_37 = permute_38 = getitem_27 = getitem_28 = getitem_33 = getitem_34 = None + getitem_372 = _scaled_dot_product_cudnn_attention_backward_28[0] + getitem_373 = _scaled_dot_product_cudnn_attention_backward_28[1] + getitem_374 = _scaled_dot_product_cudnn_attention_backward_28[2]; _scaled_dot_product_cudnn_attention_backward_28 = None + permute_1270 = torch.ops.aten.permute.default(getitem_374, [0, 2, 1, 3]); getitem_374 = None + permute_1271 = torch.ops.aten.permute.default(getitem_373, [0, 2, 1, 3]); getitem_373 = None + permute_1272 = torch.ops.aten.permute.default(getitem_372, [0, 2, 1, 3]); getitem_372 = None + view_1776 = torch.ops.aten.view.default(permute_1270, [2, 8192, 8, 4, 128]); permute_1270 = None + sum_173 = torch.ops.aten.sum.dim_IntList(view_1776, [3], True); view_1776 = None + squeeze_56 = torch.ops.aten.squeeze.dim(sum_173, 3); sum_173 = None + view_1777 = torch.ops.aten.view.default(permute_1271, [2, 8192, 8, 4, 128]); permute_1271 = None + sum_174 = torch.ops.aten.sum.dim_IntList(view_1777, [3], True); view_1777 = None + squeeze_57 = torch.ops.aten.squeeze.dim(sum_174, 3); sum_174 = None + convert_element_type_2615 = torch.ops.prims.convert_element_type.default(squeeze_57, torch.float32); squeeze_57 = None + convert_element_type_2616 = torch.ops.prims.convert_element_type.default(permute_1272, torch.float32); permute_1272 = None + view_1778 = torch.ops.aten.view.default(convert_element_type_2615, [2, 8192, 8, 64, 2]); convert_element_type_2615 = None + view_as_complex_120 = torch.ops.aten.view_as_complex.default(view_1778); view_1778 = None + mul_836 = torch.ops.aten.mul.Tensor(view_as_complex_120, _conj); view_as_complex_120 = None + view_1779 = torch.ops.aten.view.default(convert_element_type_2616, [2, 8192, 32, 64, 2]); convert_element_type_2616 = None + view_as_complex_121 = torch.ops.aten.view_as_complex.default(view_1779); view_1779 = None + mul_837 = torch.ops.aten.mul.Tensor(view_as_complex_121, _conj); view_as_complex_121 = None + view_as_real_120 = torch.ops.aten.view_as_real.default(mul_836); mul_836 = None + view_1780 = torch.ops.aten.view.default(view_as_real_120, [2, 8192, 8, 128]); view_as_real_120 = None + convert_element_type_2617 = torch.ops.prims.convert_element_type.default(view_1780, torch.bfloat16); view_1780 = None + view_as_real_121 = torch.ops.aten.view_as_real.default(mul_837); mul_837 = None + view_1781 = torch.ops.aten.view.default(view_as_real_121, [2, 8192, 32, 128]); view_as_real_121 = None + convert_element_type_2618 = torch.ops.prims.convert_element_type.default(view_1781, torch.bfloat16); view_1781 = None + view_1782 = torch.ops.aten.view.default(squeeze_56, [2, 8192, 1024]); squeeze_56 = None + view_1783 = torch.ops.aten.view.default(convert_element_type_2617, [2, 8192, 1024]); convert_element_type_2617 = None + view_1784 = torch.ops.aten.view.default(convert_element_type_2618, [2, 8192, 4096]); convert_element_type_2618 = None + view_1785 = torch.ops.aten.view.default(view_1782, [16384, 1024]); view_1782 = None + permute_1273 = torch.ops.aten.permute.default(view_1785, [1, 0]) + mm_627 = torch.ops.aten.mm.default(permute_1273, view_105); permute_1273 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16); primals_34 = None + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 256, '0'); convert_element_type_109 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + permute_1275 = torch.ops.aten.permute.default(permute_35, [1, 0]); permute_35 = None + mm_628 = torch.ops.aten.mm.default(view_1785, permute_1275); view_1785 = permute_1275 = None + view_1786 = torch.ops.aten.view.default(mm_628, [2, 8192, 4096]); mm_628 = None + convert_element_type_2623 = torch.ops.prims.convert_element_type.default(mm_627, torch.float32); mm_627 = None + reduce_scatter_tensor_259 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2623, 'avg', 256, '0'); convert_element_type_2623 = None + wait_tensor_550 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_259); reduce_scatter_tensor_259 = None + view_1787 = torch.ops.aten.view.default(view_1783, [16384, 1024]); view_1783 = None + permute_1277 = torch.ops.aten.permute.default(view_1787, [1, 0]) + mm_629 = torch.ops.aten.mm.default(permute_1277, view_105); permute_1277 = None + permute_1279 = torch.ops.aten.permute.default(permute_34, [1, 0]); permute_34 = None + mm_630 = torch.ops.aten.mm.default(view_1787, permute_1279); view_1787 = permute_1279 = None + view_1788 = torch.ops.aten.view.default(mm_630, [2, 8192, 4096]); mm_630 = None + add_329 = torch.ops.aten.add.Tensor(view_1786, view_1788); view_1786 = view_1788 = None + convert_element_type_2628 = torch.ops.prims.convert_element_type.default(mm_629, torch.float32); mm_629 = None + reduce_scatter_tensor_260 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2628, 'avg', 256, '0'); convert_element_type_2628 = None + wait_tensor_551 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_260); reduce_scatter_tensor_260 = None + view_1789 = torch.ops.aten.view.default(view_1784, [16384, 4096]); view_1784 = None + permute_1281 = torch.ops.aten.permute.default(view_1789, [1, 0]) + mm_631 = torch.ops.aten.mm.default(permute_1281, view_105); permute_1281 = view_105 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16); primals_32 = None + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 256, '0'); convert_element_type_103 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_29, [1, 0]); wait_tensor_29 = None + permute_1283 = torch.ops.aten.permute.default(permute_33, [1, 0]); permute_33 = None + mm_632 = torch.ops.aten.mm.default(view_1789, permute_1283); view_1789 = permute_1283 = None + view_1790 = torch.ops.aten.view.default(mm_632, [2, 8192, 4096]); mm_632 = None + add_330 = torch.ops.aten.add.Tensor(add_329, view_1790); add_329 = view_1790 = None + convert_element_type_2633 = torch.ops.prims.convert_element_type.default(mm_631, torch.float32); mm_631 = None + reduce_scatter_tensor_261 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2633, 'avg', 256, '0'); convert_element_type_2633 = None + wait_tensor_552 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_261); reduce_scatter_tensor_261 = None + convert_element_type_2634 = torch.ops.prims.convert_element_type.default(add_330, torch.float32); add_330 = None + convert_element_type_2636 = torch.ops.prims.convert_element_type.default(wait_tensor_28, torch.float32); wait_tensor_28 = None + mul_838 = torch.ops.aten.mul.Tensor(convert_element_type_2634, convert_element_type_2636); convert_element_type_2636 = None + mul_840 = torch.ops.aten.mul.Tensor(mul_24, mul_838) + sum_175 = torch.ops.aten.sum.dim_IntList(mul_840, [2], True); mul_840 = None + div_58 = torch.ops.aten.div.Tensor(mul_24, 4096) + mul_841 = torch.ops.aten.mul.Tensor(div_58, sum_175); div_58 = sum_175 = None + sub_87 = torch.ops.aten.sub.Tensor(mul_838, mul_841); mul_838 = mul_841 = None + mul_842 = torch.ops.aten.mul.Tensor(sub_87, rsqrt_6); sub_87 = rsqrt_6 = None + mul_843 = torch.ops.aten.mul.Tensor(convert_element_type_2634, mul_24); convert_element_type_2634 = mul_24 = None + sum_176 = torch.ops.aten.sum.dim_IntList(mul_843, [0, 1]); mul_843 = None + convert_element_type_2637 = torch.ops.prims.convert_element_type.default(mul_842, torch.bfloat16); mul_842 = None + add_331 = torch.ops.aten.add.Tensor(add_328, convert_element_type_2637); add_328 = convert_element_type_2637 = None + convert_element_type_default_7 = torch.ops.prims.convert_element_type.default(sum_176, torch.float32); sum_176 = None + reduce_scatter_tensor_262 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_7, 'avg', 256, '0'); convert_element_type_default_7 = None + wait_tensor_553 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_262); reduce_scatter_tensor_262 = None + view_1791 = torch.ops.aten.view.default(add_331, [16384, 4096]) + permute_1285 = torch.ops.aten.permute.default(view_1791, [1, 0]) + permute_28 = torch.ops.aten.permute.default(getitem_18, [0, 2, 1, 3]) + view_89 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16); primals_26 = None + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 256, '0'); convert_element_type_83 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_23, [1, 0]); wait_tensor_23 = None + view_91 = torch.ops.aten.view.default(view_89, [16384, 4096]); view_89 = None + mm_17 = torch.ops.aten.mm.default(view_91, permute_29) + view_92 = torch.ops.aten.view.default(mm_17, [2, 8192, 4096]); mm_17 = None + add_9 = torch.ops.aten.add.Tensor(add_7, view_92); view_92 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16); primals_27 = None + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 256, '0'); convert_element_type_86 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32); add_9 = None + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_24) + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + view_95 = torch.ops.aten.view.default(convert_element_type_88, [16384, 4096]); convert_element_type_88 = None + view_96 = torch.ops.aten.view.default(mm_18, [2, 8192, 14336]); mm_18 = None + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_96, torch.float32); view_96 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16); primals_29 = None + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 256, '0'); convert_element_type_94 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + mm_19 = torch.ops.aten.mm.default(view_95, permute_31) + view_99 = torch.ops.aten.view.default(mm_19, [2, 8192, 14336]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_99) + view_101 = torch.ops.aten.view.default(mul_23, [16384, 14336]); mul_23 = None + mm_633 = torch.ops.aten.mm.default(permute_1285, view_101); permute_1285 = view_101 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16); primals_30 = None + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 256, '0'); convert_element_type_97 = None + wait_tensor_27 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_27, [1, 0]); wait_tensor_27 = None + permute_1287 = torch.ops.aten.permute.default(permute_32, [1, 0]); permute_32 = None + mm_634 = torch.ops.aten.mm.default(view_1791, permute_1287); view_1791 = permute_1287 = None + view_1792 = torch.ops.aten.view.default(mm_634, [2, 8192, 14336]); mm_634 = None + convert_element_type_2644 = torch.ops.prims.convert_element_type.default(mm_633, torch.float32); mm_633 = None + reduce_scatter_tensor_263 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2644, 'avg', 256, '0'); convert_element_type_2644 = None + wait_tensor_554 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_263); reduce_scatter_tensor_263 = None + mul_844 = torch.ops.aten.mul.Tensor(view_1792, convert_element_type_93); convert_element_type_93 = None + mul_845 = torch.ops.aten.mul.Tensor(view_1792, view_99); view_1792 = view_99 = None + view_1793 = torch.ops.aten.view.default(mul_844, [16384, 14336]); mul_844 = None + permute_1289 = torch.ops.aten.permute.default(view_1793, [1, 0]) + mm_635 = torch.ops.aten.mm.default(permute_1289, view_95); permute_1289 = None + permute_1291 = torch.ops.aten.permute.default(permute_31, [1, 0]); permute_31 = None + mm_636 = torch.ops.aten.mm.default(view_1793, permute_1291); view_1793 = permute_1291 = None + view_1794 = torch.ops.aten.view.default(mm_636, [2, 8192, 4096]); mm_636 = None + convert_element_type_2649 = torch.ops.prims.convert_element_type.default(mm_635, torch.float32); mm_635 = None + reduce_scatter_tensor_264 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2649, 'avg', 256, '0'); convert_element_type_2649 = None + wait_tensor_555 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_264); reduce_scatter_tensor_264 = None + convert_element_type_2650 = torch.ops.prims.convert_element_type.default(mul_845, torch.float32); mul_845 = None + neg_29 = torch.ops.aten.neg.default(convert_element_type_92) + exp_29 = torch.ops.aten.exp.default(neg_29); neg_29 = None + add_332 = torch.ops.aten.add.Tensor(exp_29, 1); exp_29 = None + reciprocal_29 = torch.ops.aten.reciprocal.default(add_332); add_332 = None + mul_846 = torch.ops.aten.mul.Tensor(reciprocal_29, 1); reciprocal_29 = None + mul_847 = torch.ops.aten.mul.Tensor(convert_element_type_2650, mul_846); convert_element_type_2650 = None + sub_88 = torch.ops.aten.sub.Tensor(1, mul_846); mul_846 = None + mul_848 = torch.ops.aten.mul.Tensor(convert_element_type_92, sub_88); convert_element_type_92 = sub_88 = None + add_333 = torch.ops.aten.add.Tensor(mul_848, 1); mul_848 = None + mul_849 = torch.ops.aten.mul.Tensor(mul_847, add_333); mul_847 = add_333 = None + convert_element_type_2652 = torch.ops.prims.convert_element_type.default(mul_849, torch.bfloat16); mul_849 = None + view_1795 = torch.ops.aten.view.default(convert_element_type_2652, [16384, 14336]); convert_element_type_2652 = None + permute_1293 = torch.ops.aten.permute.default(view_1795, [1, 0]) + mm_637 = torch.ops.aten.mm.default(permute_1293, view_95); permute_1293 = view_95 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16); primals_28 = None + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 256, '0'); convert_element_type_89 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + permute_1295 = torch.ops.aten.permute.default(permute_30, [1, 0]); permute_30 = None + mm_638 = torch.ops.aten.mm.default(view_1795, permute_1295); view_1795 = permute_1295 = None + view_1796 = torch.ops.aten.view.default(mm_638, [2, 8192, 4096]); mm_638 = None + add_334 = torch.ops.aten.add.Tensor(view_1794, view_1796); view_1794 = view_1796 = None + convert_element_type_2657 = torch.ops.prims.convert_element_type.default(mm_637, torch.float32); mm_637 = None + reduce_scatter_tensor_265 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2657, 'avg', 256, '0'); convert_element_type_2657 = None + wait_tensor_556 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_265); reduce_scatter_tensor_265 = None + convert_element_type_2658 = torch.ops.prims.convert_element_type.default(add_334, torch.float32); add_334 = None + convert_element_type_2660 = torch.ops.prims.convert_element_type.default(wait_tensor_24, torch.float32); wait_tensor_24 = None + mul_850 = torch.ops.aten.mul.Tensor(convert_element_type_2658, convert_element_type_2660); convert_element_type_2660 = None + mul_852 = torch.ops.aten.mul.Tensor(mul_20, mul_850) + sum_177 = torch.ops.aten.sum.dim_IntList(mul_852, [2], True); mul_852 = None + div_59 = torch.ops.aten.div.Tensor(mul_20, 4096) + mul_853 = torch.ops.aten.mul.Tensor(div_59, sum_177); div_59 = sum_177 = None + sub_89 = torch.ops.aten.sub.Tensor(mul_850, mul_853); mul_850 = mul_853 = None + mul_854 = torch.ops.aten.mul.Tensor(sub_89, rsqrt_5); sub_89 = rsqrt_5 = None + mul_855 = torch.ops.aten.mul.Tensor(convert_element_type_2658, mul_20); convert_element_type_2658 = mul_20 = None + sum_178 = torch.ops.aten.sum.dim_IntList(mul_855, [0, 1]); mul_855 = None + convert_element_type_2661 = torch.ops.prims.convert_element_type.default(mul_854, torch.bfloat16); mul_854 = None + add_335 = torch.ops.aten.add.Tensor(add_331, convert_element_type_2661); add_331 = convert_element_type_2661 = None + convert_element_type_default_6 = torch.ops.prims.convert_element_type.default(sum_178, torch.float32); sum_178 = None + reduce_scatter_tensor_266 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_6, 'avg', 256, '0'); convert_element_type_default_6 = None + wait_tensor_557 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_266); reduce_scatter_tensor_266 = None + view_1797 = torch.ops.aten.view.default(add_335, [16384, 4096]) + permute_1297 = torch.ops.aten.permute.default(view_1797, [1, 0]) + mm_639 = torch.ops.aten.mm.default(permute_1297, view_91); permute_1297 = view_91 = None + permute_1299 = torch.ops.aten.permute.default(permute_29, [1, 0]); permute_29 = None + mm_640 = torch.ops.aten.mm.default(view_1797, permute_1299); view_1797 = permute_1299 = None + view_1798 = torch.ops.aten.view.default(mm_640, [2, 8192, 4096]); mm_640 = None + convert_element_type_2668 = torch.ops.prims.convert_element_type.default(mm_639, torch.float32); mm_639 = None + reduce_scatter_tensor_267 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2668, 'avg', 256, '0'); convert_element_type_2668 = None + wait_tensor_558 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_267); reduce_scatter_tensor_267 = None + view_1799 = torch.ops.aten.view.default(view_1798, [2, 8192, 32, 128]); view_1798 = None + permute_1301 = torch.ops.aten.permute.default(view_1799, [0, 2, 1, 3]); view_1799 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16); primals_22 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 256, '0'); convert_element_type_67 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32); add_7 = None + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_19) + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + view_71 = torch.ops.aten.view.default(convert_element_type_69, [16384, 4096]); convert_element_type_69 = None + view_72 = torch.ops.aten.view.default(mm_14, [2, 8192, 4096]); mm_14 = None + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16); primals_24 = None + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 256, '0'); convert_element_type_73 = None + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_21, [1, 0]); wait_tensor_21 = None + mm_15 = torch.ops.aten.mm.default(view_71, permute_23) + view_75 = torch.ops.aten.view.default(mm_15, [2, 8192, 1024]); mm_15 = None + view_78 = torch.ops.aten.view.default(mm_16, [2, 8192, 1024]); mm_16 = None + view_79 = torch.ops.aten.view.default(view_72, [2, 8192, -1, 128]); view_72 = None + view_80 = torch.ops.aten.view.default(view_75, [2, 8192, -1, 128]); view_75 = None + view_81 = torch.ops.aten.view.default(view_78, [2, 8192, -1, 128]); view_78 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_79, torch.float32); view_79 = None + view_82 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 32, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_82); view_82 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_80, torch.float32); view_80 = None + view_83 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 8, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_83); view_83 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_16); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_85 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 32, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_16); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_86 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 8, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_85, torch.bfloat16); view_85 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_86, torch.bfloat16); view_86 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 8, 4, 128]); unsqueeze_4 = None + clone_4 = torch.ops.aten.clone.default(expand_4, memory_format = torch.contiguous_format); expand_4 = None + view_87 = torch.ops.aten.view.default(clone_4, [2, 8192, 32, 128]); clone_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_81, 3); view_81 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 8, 4, 128]); unsqueeze_5 = None + clone_5 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format); expand_5 = None + view_88 = torch.ops.aten.view.default(clone_5, [2, 8192, 32, 128]); clone_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_87, [0, 2, 1, 3]); view_87 = None + permute_27 = torch.ops.aten.permute.default(view_88, [0, 2, 1, 3]); view_88 = None + _scaled_dot_product_cudnn_attention_backward_29 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1301, permute_25, permute_26, permute_27, getitem_18, getitem_19, getitem_24, getitem_25, None, None, None, 8192, 8192, 0.0, True); permute_1301 = permute_25 = permute_26 = permute_27 = getitem_18 = getitem_19 = getitem_24 = getitem_25 = None + getitem_375 = _scaled_dot_product_cudnn_attention_backward_29[0] + getitem_376 = _scaled_dot_product_cudnn_attention_backward_29[1] + getitem_377 = _scaled_dot_product_cudnn_attention_backward_29[2]; _scaled_dot_product_cudnn_attention_backward_29 = None + permute_1302 = torch.ops.aten.permute.default(getitem_377, [0, 2, 1, 3]); getitem_377 = None + permute_1303 = torch.ops.aten.permute.default(getitem_376, [0, 2, 1, 3]); getitem_376 = None + permute_1304 = torch.ops.aten.permute.default(getitem_375, [0, 2, 1, 3]); getitem_375 = None + view_1800 = torch.ops.aten.view.default(permute_1302, [2, 8192, 8, 4, 128]); permute_1302 = None + sum_179 = torch.ops.aten.sum.dim_IntList(view_1800, [3], True); view_1800 = None + squeeze_58 = torch.ops.aten.squeeze.dim(sum_179, 3); sum_179 = None + view_1801 = torch.ops.aten.view.default(permute_1303, [2, 8192, 8, 4, 128]); permute_1303 = None + sum_180 = torch.ops.aten.sum.dim_IntList(view_1801, [3], True); view_1801 = None + squeeze_59 = torch.ops.aten.squeeze.dim(sum_180, 3); sum_180 = None + convert_element_type_2669 = torch.ops.prims.convert_element_type.default(squeeze_59, torch.float32); squeeze_59 = None + convert_element_type_2670 = torch.ops.prims.convert_element_type.default(permute_1304, torch.float32); permute_1304 = None + view_1802 = torch.ops.aten.view.default(convert_element_type_2669, [2, 8192, 8, 64, 2]); convert_element_type_2669 = None + view_as_complex_122 = torch.ops.aten.view_as_complex.default(view_1802); view_1802 = None + mul_856 = torch.ops.aten.mul.Tensor(view_as_complex_122, _conj); view_as_complex_122 = None + view_1803 = torch.ops.aten.view.default(convert_element_type_2670, [2, 8192, 32, 64, 2]); convert_element_type_2670 = None + view_as_complex_123 = torch.ops.aten.view_as_complex.default(view_1803); view_1803 = None + mul_857 = torch.ops.aten.mul.Tensor(view_as_complex_123, _conj); view_as_complex_123 = None + view_as_real_122 = torch.ops.aten.view_as_real.default(mul_856); mul_856 = None + view_1804 = torch.ops.aten.view.default(view_as_real_122, [2, 8192, 8, 128]); view_as_real_122 = None + convert_element_type_2671 = torch.ops.prims.convert_element_type.default(view_1804, torch.bfloat16); view_1804 = None + view_as_real_123 = torch.ops.aten.view_as_real.default(mul_857); mul_857 = None + view_1805 = torch.ops.aten.view.default(view_as_real_123, [2, 8192, 32, 128]); view_as_real_123 = None + convert_element_type_2672 = torch.ops.prims.convert_element_type.default(view_1805, torch.bfloat16); view_1805 = None + view_1806 = torch.ops.aten.view.default(squeeze_58, [2, 8192, 1024]); squeeze_58 = None + view_1807 = torch.ops.aten.view.default(convert_element_type_2671, [2, 8192, 1024]); convert_element_type_2671 = None + view_1808 = torch.ops.aten.view.default(convert_element_type_2672, [2, 8192, 4096]); convert_element_type_2672 = None + view_1809 = torch.ops.aten.view.default(view_1806, [16384, 1024]); view_1806 = None + permute_1305 = torch.ops.aten.permute.default(view_1809, [1, 0]) + mm_641 = torch.ops.aten.mm.default(permute_1305, view_71); permute_1305 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16); primals_25 = None + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 256, '0'); convert_element_type_76 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_22, [1, 0]); wait_tensor_22 = None + permute_1307 = torch.ops.aten.permute.default(permute_24, [1, 0]); permute_24 = None + mm_642 = torch.ops.aten.mm.default(view_1809, permute_1307); view_1809 = permute_1307 = None + view_1810 = torch.ops.aten.view.default(mm_642, [2, 8192, 4096]); mm_642 = None + convert_element_type_2677 = torch.ops.prims.convert_element_type.default(mm_641, torch.float32); mm_641 = None + reduce_scatter_tensor_268 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2677, 'avg', 256, '0'); convert_element_type_2677 = None + wait_tensor_559 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_268); reduce_scatter_tensor_268 = None + view_1811 = torch.ops.aten.view.default(view_1807, [16384, 1024]); view_1807 = None + permute_1309 = torch.ops.aten.permute.default(view_1811, [1, 0]) + mm_643 = torch.ops.aten.mm.default(permute_1309, view_71); permute_1309 = None + permute_1311 = torch.ops.aten.permute.default(permute_23, [1, 0]); permute_23 = None + mm_644 = torch.ops.aten.mm.default(view_1811, permute_1311); view_1811 = permute_1311 = None + view_1812 = torch.ops.aten.view.default(mm_644, [2, 8192, 4096]); mm_644 = None + add_336 = torch.ops.aten.add.Tensor(view_1810, view_1812); view_1810 = view_1812 = None + convert_element_type_2682 = torch.ops.prims.convert_element_type.default(mm_643, torch.float32); mm_643 = None + reduce_scatter_tensor_269 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2682, 'avg', 256, '0'); convert_element_type_2682 = None + wait_tensor_560 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_269); reduce_scatter_tensor_269 = None + view_1813 = torch.ops.aten.view.default(view_1808, [16384, 4096]); view_1808 = None + permute_1313 = torch.ops.aten.permute.default(view_1813, [1, 0]) + mm_645 = torch.ops.aten.mm.default(permute_1313, view_71); permute_1313 = view_71 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16); primals_23 = None + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 256, '0'); convert_element_type_70 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + permute_1315 = torch.ops.aten.permute.default(permute_22, [1, 0]); permute_22 = None + mm_646 = torch.ops.aten.mm.default(view_1813, permute_1315); view_1813 = permute_1315 = None + view_1814 = torch.ops.aten.view.default(mm_646, [2, 8192, 4096]); mm_646 = None + add_337 = torch.ops.aten.add.Tensor(add_336, view_1814); add_336 = view_1814 = None + convert_element_type_2687 = torch.ops.prims.convert_element_type.default(mm_645, torch.float32); mm_645 = None + reduce_scatter_tensor_270 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2687, 'avg', 256, '0'); convert_element_type_2687 = None + wait_tensor_561 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_270); reduce_scatter_tensor_270 = None + convert_element_type_2688 = torch.ops.prims.convert_element_type.default(add_337, torch.float32); add_337 = None + convert_element_type_2690 = torch.ops.prims.convert_element_type.default(wait_tensor_19, torch.float32); wait_tensor_19 = None + mul_858 = torch.ops.aten.mul.Tensor(convert_element_type_2688, convert_element_type_2690); convert_element_type_2690 = None + mul_860 = torch.ops.aten.mul.Tensor(mul_16, mul_858) + sum_181 = torch.ops.aten.sum.dim_IntList(mul_860, [2], True); mul_860 = None + div_60 = torch.ops.aten.div.Tensor(mul_16, 4096) + mul_861 = torch.ops.aten.mul.Tensor(div_60, sum_181); div_60 = sum_181 = None + sub_90 = torch.ops.aten.sub.Tensor(mul_858, mul_861); mul_858 = mul_861 = None + mul_862 = torch.ops.aten.mul.Tensor(sub_90, rsqrt_4); sub_90 = rsqrt_4 = None + mul_863 = torch.ops.aten.mul.Tensor(convert_element_type_2688, mul_16); convert_element_type_2688 = mul_16 = None + sum_182 = torch.ops.aten.sum.dim_IntList(mul_863, [0, 1]); mul_863 = None + convert_element_type_2691 = torch.ops.prims.convert_element_type.default(mul_862, torch.bfloat16); mul_862 = None + add_338 = torch.ops.aten.add.Tensor(add_335, convert_element_type_2691); add_335 = convert_element_type_2691 = None + convert_element_type_default_5 = torch.ops.prims.convert_element_type.default(sum_182, torch.float32); sum_182 = None + reduce_scatter_tensor_271 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_5, 'avg', 256, '0'); convert_element_type_default_5 = None + wait_tensor_562 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_271); reduce_scatter_tensor_271 = None + view_1815 = torch.ops.aten.view.default(add_338, [16384, 4096]) + permute_1317 = torch.ops.aten.permute.default(view_1815, [1, 0]) + permute_17 = torch.ops.aten.permute.default(getitem_9, [0, 2, 1, 3]) + view_55 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16); primals_17 = None + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 256, '0'); convert_element_type_50 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_14, [1, 0]); wait_tensor_14 = None + view_57 = torch.ops.aten.view.default(view_55, [16384, 4096]); view_55 = None + mm_10 = torch.ops.aten.mm.default(view_57, permute_18) + view_58 = torch.ops.aten.view.default(mm_10, [2, 8192, 4096]); mm_10 = None + add_5 = torch.ops.aten.add.Tensor(add_3, view_58); view_58 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16); primals_18 = None + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 256, '0'); convert_element_type_53 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32); add_5 = None + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_15) + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + view_61 = torch.ops.aten.view.default(convert_element_type_55, [16384, 4096]); convert_element_type_55 = None + view_62 = torch.ops.aten.view.default(mm_11, [2, 8192, 14336]); mm_11 = None + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_62, torch.float32); view_62 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16); primals_20 = None + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 256, '0'); convert_element_type_61 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + mm_12 = torch.ops.aten.mm.default(view_61, permute_20) + view_65 = torch.ops.aten.view.default(mm_12, [2, 8192, 14336]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_65) + view_67 = torch.ops.aten.view.default(mul_15, [16384, 14336]); mul_15 = None + mm_647 = torch.ops.aten.mm.default(permute_1317, view_67); permute_1317 = view_67 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16); primals_21 = None + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 256, '0'); convert_element_type_64 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + permute_1319 = torch.ops.aten.permute.default(permute_21, [1, 0]); permute_21 = None + mm_648 = torch.ops.aten.mm.default(view_1815, permute_1319); view_1815 = permute_1319 = None + view_1816 = torch.ops.aten.view.default(mm_648, [2, 8192, 14336]); mm_648 = None + convert_element_type_2698 = torch.ops.prims.convert_element_type.default(mm_647, torch.float32); mm_647 = None + reduce_scatter_tensor_272 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2698, 'avg', 256, '0'); convert_element_type_2698 = None + wait_tensor_563 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_272); reduce_scatter_tensor_272 = None + mul_864 = torch.ops.aten.mul.Tensor(view_1816, convert_element_type_60); convert_element_type_60 = None + mul_865 = torch.ops.aten.mul.Tensor(view_1816, view_65); view_1816 = view_65 = None + view_1817 = torch.ops.aten.view.default(mul_864, [16384, 14336]); mul_864 = None + permute_1321 = torch.ops.aten.permute.default(view_1817, [1, 0]) + mm_649 = torch.ops.aten.mm.default(permute_1321, view_61); permute_1321 = None + permute_1323 = torch.ops.aten.permute.default(permute_20, [1, 0]); permute_20 = None + mm_650 = torch.ops.aten.mm.default(view_1817, permute_1323); view_1817 = permute_1323 = None + view_1818 = torch.ops.aten.view.default(mm_650, [2, 8192, 4096]); mm_650 = None + convert_element_type_2703 = torch.ops.prims.convert_element_type.default(mm_649, torch.float32); mm_649 = None + reduce_scatter_tensor_273 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2703, 'avg', 256, '0'); convert_element_type_2703 = None + wait_tensor_564 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_273); reduce_scatter_tensor_273 = None + convert_element_type_2704 = torch.ops.prims.convert_element_type.default(mul_865, torch.float32); mul_865 = None + neg_30 = torch.ops.aten.neg.default(convert_element_type_59) + exp_30 = torch.ops.aten.exp.default(neg_30); neg_30 = None + add_339 = torch.ops.aten.add.Tensor(exp_30, 1); exp_30 = None + reciprocal_30 = torch.ops.aten.reciprocal.default(add_339); add_339 = None + mul_866 = torch.ops.aten.mul.Tensor(reciprocal_30, 1); reciprocal_30 = None + mul_867 = torch.ops.aten.mul.Tensor(convert_element_type_2704, mul_866); convert_element_type_2704 = None + sub_91 = torch.ops.aten.sub.Tensor(1, mul_866); mul_866 = None + mul_868 = torch.ops.aten.mul.Tensor(convert_element_type_59, sub_91); convert_element_type_59 = sub_91 = None + add_340 = torch.ops.aten.add.Tensor(mul_868, 1); mul_868 = None + mul_869 = torch.ops.aten.mul.Tensor(mul_867, add_340); mul_867 = add_340 = None + convert_element_type_2706 = torch.ops.prims.convert_element_type.default(mul_869, torch.bfloat16); mul_869 = None + view_1819 = torch.ops.aten.view.default(convert_element_type_2706, [16384, 14336]); convert_element_type_2706 = None + permute_1325 = torch.ops.aten.permute.default(view_1819, [1, 0]) + mm_651 = torch.ops.aten.mm.default(permute_1325, view_61); permute_1325 = view_61 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16); primals_19 = None + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 256, '0'); convert_element_type_56 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_16, [1, 0]); wait_tensor_16 = None + permute_1327 = torch.ops.aten.permute.default(permute_19, [1, 0]); permute_19 = None + mm_652 = torch.ops.aten.mm.default(view_1819, permute_1327); view_1819 = permute_1327 = None + view_1820 = torch.ops.aten.view.default(mm_652, [2, 8192, 4096]); mm_652 = None + add_341 = torch.ops.aten.add.Tensor(view_1818, view_1820); view_1818 = view_1820 = None + convert_element_type_2711 = torch.ops.prims.convert_element_type.default(mm_651, torch.float32); mm_651 = None + reduce_scatter_tensor_274 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2711, 'avg', 256, '0'); convert_element_type_2711 = None + wait_tensor_565 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_274); reduce_scatter_tensor_274 = None + convert_element_type_2712 = torch.ops.prims.convert_element_type.default(add_341, torch.float32); add_341 = None + convert_element_type_2714 = torch.ops.prims.convert_element_type.default(wait_tensor_15, torch.float32); wait_tensor_15 = None + mul_870 = torch.ops.aten.mul.Tensor(convert_element_type_2712, convert_element_type_2714); convert_element_type_2714 = None + mul_872 = torch.ops.aten.mul.Tensor(mul_12, mul_870) + sum_183 = torch.ops.aten.sum.dim_IntList(mul_872, [2], True); mul_872 = None + div_61 = torch.ops.aten.div.Tensor(mul_12, 4096) + mul_873 = torch.ops.aten.mul.Tensor(div_61, sum_183); div_61 = sum_183 = None + sub_92 = torch.ops.aten.sub.Tensor(mul_870, mul_873); mul_870 = mul_873 = None + mul_874 = torch.ops.aten.mul.Tensor(sub_92, rsqrt_3); sub_92 = rsqrt_3 = None + mul_875 = torch.ops.aten.mul.Tensor(convert_element_type_2712, mul_12); convert_element_type_2712 = mul_12 = None + sum_184 = torch.ops.aten.sum.dim_IntList(mul_875, [0, 1]); mul_875 = None + convert_element_type_2715 = torch.ops.prims.convert_element_type.default(mul_874, torch.bfloat16); mul_874 = None + add_342 = torch.ops.aten.add.Tensor(add_338, convert_element_type_2715); add_338 = convert_element_type_2715 = None + convert_element_type_default_4 = torch.ops.prims.convert_element_type.default(sum_184, torch.float32); sum_184 = None + reduce_scatter_tensor_275 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_4, 'avg', 256, '0'); convert_element_type_default_4 = None + wait_tensor_566 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_275); reduce_scatter_tensor_275 = None + view_1821 = torch.ops.aten.view.default(add_342, [16384, 4096]) + permute_1329 = torch.ops.aten.permute.default(view_1821, [1, 0]) + mm_653 = torch.ops.aten.mm.default(permute_1329, view_57); permute_1329 = view_57 = None + permute_1331 = torch.ops.aten.permute.default(permute_18, [1, 0]); permute_18 = None + mm_654 = torch.ops.aten.mm.default(view_1821, permute_1331); view_1821 = permute_1331 = None + view_1822 = torch.ops.aten.view.default(mm_654, [2, 8192, 4096]); mm_654 = None + convert_element_type_2722 = torch.ops.prims.convert_element_type.default(mm_653, torch.float32); mm_653 = None + reduce_scatter_tensor_276 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2722, 'avg', 256, '0'); convert_element_type_2722 = None + wait_tensor_567 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_276); reduce_scatter_tensor_276 = None + view_1823 = torch.ops.aten.view.default(view_1822, [2, 8192, 32, 128]); view_1822 = None + permute_1333 = torch.ops.aten.permute.default(view_1823, [0, 2, 1, 3]); view_1823 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16); primals_13 = None + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 256, '0'); convert_element_type_34 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32); add_3 = None + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_10) + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + view_37 = torch.ops.aten.view.default(convert_element_type_36, [16384, 4096]); convert_element_type_36 = None + view_38 = torch.ops.aten.view.default(mm_7, [2, 8192, 4096]); mm_7 = None + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16); primals_15 = None + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 256, '0'); convert_element_type_40 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_8 = torch.ops.aten.mm.default(view_37, permute_12) + view_41 = torch.ops.aten.view.default(mm_8, [2, 8192, 1024]); mm_8 = None + view_44 = torch.ops.aten.view.default(mm_9, [2, 8192, 1024]); mm_9 = None + view_45 = torch.ops.aten.view.default(view_38, [2, 8192, -1, 128]); view_38 = None + view_46 = torch.ops.aten.view.default(view_41, [2, 8192, -1, 128]); view_41 = None + view_47 = torch.ops.aten.view.default(view_44, [2, 8192, -1, 128]); view_44 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_45, torch.float32); view_45 = None + view_48 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 32, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_48); view_48 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_46, torch.float32); view_46 = None + view_49 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 8, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_49); view_49 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_16); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_51 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 32, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_16); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_52 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 8, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_51, torch.bfloat16); view_51 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_52, torch.bfloat16); view_52 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 8, 4, 128]); unsqueeze_2 = None + clone_2 = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format); expand_2 = None + view_53 = torch.ops.aten.view.default(clone_2, [2, 8192, 32, 128]); clone_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_47, 3); view_47 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 8, 4, 128]); unsqueeze_3 = None + clone_3 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format); expand_3 = None + view_54 = torch.ops.aten.view.default(clone_3, [2, 8192, 32, 128]); clone_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_53, [0, 2, 1, 3]); view_53 = None + permute_16 = torch.ops.aten.permute.default(view_54, [0, 2, 1, 3]); view_54 = None + _scaled_dot_product_cudnn_attention_backward_30 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1333, permute_14, permute_15, permute_16, getitem_9, getitem_10, getitem_15, getitem_16, None, None, None, 8192, 8192, 0.0, True); permute_1333 = permute_14 = permute_15 = permute_16 = getitem_9 = getitem_10 = getitem_15 = getitem_16 = None + getitem_378 = _scaled_dot_product_cudnn_attention_backward_30[0] + getitem_379 = _scaled_dot_product_cudnn_attention_backward_30[1] + getitem_380 = _scaled_dot_product_cudnn_attention_backward_30[2]; _scaled_dot_product_cudnn_attention_backward_30 = None + permute_1334 = torch.ops.aten.permute.default(getitem_380, [0, 2, 1, 3]); getitem_380 = None + permute_1335 = torch.ops.aten.permute.default(getitem_379, [0, 2, 1, 3]); getitem_379 = None + permute_1336 = torch.ops.aten.permute.default(getitem_378, [0, 2, 1, 3]); getitem_378 = None + view_1824 = torch.ops.aten.view.default(permute_1334, [2, 8192, 8, 4, 128]); permute_1334 = None + sum_185 = torch.ops.aten.sum.dim_IntList(view_1824, [3], True); view_1824 = None + squeeze_60 = torch.ops.aten.squeeze.dim(sum_185, 3); sum_185 = None + view_1825 = torch.ops.aten.view.default(permute_1335, [2, 8192, 8, 4, 128]); permute_1335 = None + sum_186 = torch.ops.aten.sum.dim_IntList(view_1825, [3], True); view_1825 = None + squeeze_61 = torch.ops.aten.squeeze.dim(sum_186, 3); sum_186 = None + convert_element_type_2723 = torch.ops.prims.convert_element_type.default(squeeze_61, torch.float32); squeeze_61 = None + convert_element_type_2724 = torch.ops.prims.convert_element_type.default(permute_1336, torch.float32); permute_1336 = None + view_1826 = torch.ops.aten.view.default(convert_element_type_2723, [2, 8192, 8, 64, 2]); convert_element_type_2723 = None + view_as_complex_124 = torch.ops.aten.view_as_complex.default(view_1826); view_1826 = None + mul_876 = torch.ops.aten.mul.Tensor(view_as_complex_124, _conj); view_as_complex_124 = None + view_1827 = torch.ops.aten.view.default(convert_element_type_2724, [2, 8192, 32, 64, 2]); convert_element_type_2724 = None + view_as_complex_125 = torch.ops.aten.view_as_complex.default(view_1827); view_1827 = None + mul_877 = torch.ops.aten.mul.Tensor(view_as_complex_125, _conj); view_as_complex_125 = None + view_as_real_124 = torch.ops.aten.view_as_real.default(mul_876); mul_876 = None + view_1828 = torch.ops.aten.view.default(view_as_real_124, [2, 8192, 8, 128]); view_as_real_124 = None + convert_element_type_2725 = torch.ops.prims.convert_element_type.default(view_1828, torch.bfloat16); view_1828 = None + view_as_real_125 = torch.ops.aten.view_as_real.default(mul_877); mul_877 = None + view_1829 = torch.ops.aten.view.default(view_as_real_125, [2, 8192, 32, 128]); view_as_real_125 = None + convert_element_type_2726 = torch.ops.prims.convert_element_type.default(view_1829, torch.bfloat16); view_1829 = None + view_1830 = torch.ops.aten.view.default(squeeze_60, [2, 8192, 1024]); squeeze_60 = None + view_1831 = torch.ops.aten.view.default(convert_element_type_2725, [2, 8192, 1024]); convert_element_type_2725 = None + view_1832 = torch.ops.aten.view.default(convert_element_type_2726, [2, 8192, 4096]); convert_element_type_2726 = None + view_1833 = torch.ops.aten.view.default(view_1830, [16384, 1024]); view_1830 = None + permute_1337 = torch.ops.aten.permute.default(view_1833, [1, 0]) + mm_655 = torch.ops.aten.mm.default(permute_1337, view_37); permute_1337 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16); primals_16 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 256, '0'); convert_element_type_43 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + permute_1339 = torch.ops.aten.permute.default(permute_13, [1, 0]); permute_13 = None + mm_656 = torch.ops.aten.mm.default(view_1833, permute_1339); view_1833 = permute_1339 = None + view_1834 = torch.ops.aten.view.default(mm_656, [2, 8192, 4096]); mm_656 = None + convert_element_type_2731 = torch.ops.prims.convert_element_type.default(mm_655, torch.float32); mm_655 = None + reduce_scatter_tensor_277 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2731, 'avg', 256, '0'); convert_element_type_2731 = None + wait_tensor_568 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_277); reduce_scatter_tensor_277 = None + view_1835 = torch.ops.aten.view.default(view_1831, [16384, 1024]); view_1831 = None + permute_1341 = torch.ops.aten.permute.default(view_1835, [1, 0]) + mm_657 = torch.ops.aten.mm.default(permute_1341, view_37); permute_1341 = None + permute_1343 = torch.ops.aten.permute.default(permute_12, [1, 0]); permute_12 = None + mm_658 = torch.ops.aten.mm.default(view_1835, permute_1343); view_1835 = permute_1343 = None + view_1836 = torch.ops.aten.view.default(mm_658, [2, 8192, 4096]); mm_658 = None + add_343 = torch.ops.aten.add.Tensor(view_1834, view_1836); view_1834 = view_1836 = None + convert_element_type_2736 = torch.ops.prims.convert_element_type.default(mm_657, torch.float32); mm_657 = None + reduce_scatter_tensor_278 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2736, 'avg', 256, '0'); convert_element_type_2736 = None + wait_tensor_569 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_278); reduce_scatter_tensor_278 = None + view_1837 = torch.ops.aten.view.default(view_1832, [16384, 4096]); view_1832 = None + permute_1345 = torch.ops.aten.permute.default(view_1837, [1, 0]) + mm_659 = torch.ops.aten.mm.default(permute_1345, view_37); permute_1345 = view_37 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16); primals_14 = None + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 256, '0'); convert_element_type_37 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + permute_1347 = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None + mm_660 = torch.ops.aten.mm.default(view_1837, permute_1347); view_1837 = permute_1347 = None + view_1838 = torch.ops.aten.view.default(mm_660, [2, 8192, 4096]); mm_660 = None + add_344 = torch.ops.aten.add.Tensor(add_343, view_1838); add_343 = view_1838 = None + convert_element_type_2741 = torch.ops.prims.convert_element_type.default(mm_659, torch.float32); mm_659 = None + reduce_scatter_tensor_279 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2741, 'avg', 256, '0'); convert_element_type_2741 = None + wait_tensor_570 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_279); reduce_scatter_tensor_279 = None + convert_element_type_2742 = torch.ops.prims.convert_element_type.default(add_344, torch.float32); add_344 = None + convert_element_type_2744 = torch.ops.prims.convert_element_type.default(wait_tensor_10, torch.float32); wait_tensor_10 = None + mul_878 = torch.ops.aten.mul.Tensor(convert_element_type_2742, convert_element_type_2744); convert_element_type_2744 = None + mul_880 = torch.ops.aten.mul.Tensor(mul_8, mul_878) + sum_187 = torch.ops.aten.sum.dim_IntList(mul_880, [2], True); mul_880 = None + div_62 = torch.ops.aten.div.Tensor(mul_8, 4096) + mul_881 = torch.ops.aten.mul.Tensor(div_62, sum_187); div_62 = sum_187 = None + sub_93 = torch.ops.aten.sub.Tensor(mul_878, mul_881); mul_878 = mul_881 = None + mul_882 = torch.ops.aten.mul.Tensor(sub_93, rsqrt_2); sub_93 = rsqrt_2 = None + mul_883 = torch.ops.aten.mul.Tensor(convert_element_type_2742, mul_8); convert_element_type_2742 = mul_8 = None + sum_188 = torch.ops.aten.sum.dim_IntList(mul_883, [0, 1]); mul_883 = None + convert_element_type_2745 = torch.ops.prims.convert_element_type.default(mul_882, torch.bfloat16); mul_882 = None + add_345 = torch.ops.aten.add.Tensor(add_342, convert_element_type_2745); add_342 = convert_element_type_2745 = None + convert_element_type_default_3 = torch.ops.prims.convert_element_type.default(sum_188, torch.float32); sum_188 = None + reduce_scatter_tensor_280 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_3, 'avg', 256, '0'); convert_element_type_default_3 = None + wait_tensor_571 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_280); reduce_scatter_tensor_280 = None + view_1839 = torch.ops.aten.view.default(add_345, [16384, 4096]) + permute_1349 = torch.ops.aten.permute.default(view_1839, [1, 0]) + permute_6 = torch.ops.aten.permute.default(getitem, [0, 2, 1, 3]) + view_21 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16); primals_8 = None + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 256, '0'); convert_element_type_17 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + view_23 = torch.ops.aten.view.default(view_21, [16384, 4096]); view_21 = None + mm_3 = torch.ops.aten.mm.default(view_23, permute_7) + view_24 = torch.ops.aten.view.default(mm_3, [2, 8192, 4096]); mm_3 = None + add_1 = torch.ops.aten.add.Tensor(embedding, view_24); view_24 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16); primals_9 = None + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 256, '0'); convert_element_type_20 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32); add_1 = None + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_6) + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + view_27 = torch.ops.aten.view.default(convert_element_type_22, [16384, 4096]); convert_element_type_22 = None + view_28 = torch.ops.aten.view.default(mm_4, [2, 8192, 14336]); mm_4 = None + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_28, torch.float32); view_28 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16); primals_11 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 256, '0'); convert_element_type_28 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_8, [1, 0]); wait_tensor_8 = None + mm_5 = torch.ops.aten.mm.default(view_27, permute_9) + view_31 = torch.ops.aten.view.default(mm_5, [2, 8192, 14336]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_31) + view_33 = torch.ops.aten.view.default(mul_7, [16384, 14336]); mul_7 = None + mm_661 = torch.ops.aten.mm.default(permute_1349, view_33); permute_1349 = view_33 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16); primals_12 = None + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 256, '0'); convert_element_type_31 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_9, [1, 0]); wait_tensor_9 = None + permute_1351 = torch.ops.aten.permute.default(permute_10, [1, 0]); permute_10 = None + mm_662 = torch.ops.aten.mm.default(view_1839, permute_1351); view_1839 = permute_1351 = None + view_1840 = torch.ops.aten.view.default(mm_662, [2, 8192, 14336]); mm_662 = None + convert_element_type_2752 = torch.ops.prims.convert_element_type.default(mm_661, torch.float32); mm_661 = None + reduce_scatter_tensor_281 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2752, 'avg', 256, '0'); convert_element_type_2752 = None + wait_tensor_572 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_281); reduce_scatter_tensor_281 = None + mul_884 = torch.ops.aten.mul.Tensor(view_1840, convert_element_type_27); convert_element_type_27 = None + mul_885 = torch.ops.aten.mul.Tensor(view_1840, view_31); view_1840 = view_31 = None + view_1841 = torch.ops.aten.view.default(mul_884, [16384, 14336]); mul_884 = None + permute_1353 = torch.ops.aten.permute.default(view_1841, [1, 0]) + mm_663 = torch.ops.aten.mm.default(permute_1353, view_27); permute_1353 = None + permute_1355 = torch.ops.aten.permute.default(permute_9, [1, 0]); permute_9 = None + mm_664 = torch.ops.aten.mm.default(view_1841, permute_1355); view_1841 = permute_1355 = None + view_1842 = torch.ops.aten.view.default(mm_664, [2, 8192, 4096]); mm_664 = None + convert_element_type_2757 = torch.ops.prims.convert_element_type.default(mm_663, torch.float32); mm_663 = None + reduce_scatter_tensor_282 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2757, 'avg', 256, '0'); convert_element_type_2757 = None + wait_tensor_573 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_282); reduce_scatter_tensor_282 = None + convert_element_type_2758 = torch.ops.prims.convert_element_type.default(mul_885, torch.float32); mul_885 = None + neg_31 = torch.ops.aten.neg.default(convert_element_type_26) + exp_31 = torch.ops.aten.exp.default(neg_31); neg_31 = None + add_346 = torch.ops.aten.add.Tensor(exp_31, 1); exp_31 = None + reciprocal_31 = torch.ops.aten.reciprocal.default(add_346); add_346 = None + mul_886 = torch.ops.aten.mul.Tensor(reciprocal_31, 1); reciprocal_31 = None + mul_887 = torch.ops.aten.mul.Tensor(convert_element_type_2758, mul_886); convert_element_type_2758 = None + sub_94 = torch.ops.aten.sub.Tensor(1, mul_886); mul_886 = None + mul_888 = torch.ops.aten.mul.Tensor(convert_element_type_26, sub_94); convert_element_type_26 = sub_94 = None + add_347 = torch.ops.aten.add.Tensor(mul_888, 1); mul_888 = None + mul_889 = torch.ops.aten.mul.Tensor(mul_887, add_347); mul_887 = add_347 = None + convert_element_type_2760 = torch.ops.prims.convert_element_type.default(mul_889, torch.bfloat16); mul_889 = None + view_1843 = torch.ops.aten.view.default(convert_element_type_2760, [16384, 14336]); convert_element_type_2760 = None + permute_1357 = torch.ops.aten.permute.default(view_1843, [1, 0]) + mm_665 = torch.ops.aten.mm.default(permute_1357, view_27); permute_1357 = view_27 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16); primals_10 = None + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 256, '0'); convert_element_type_23 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + permute_1359 = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None + mm_666 = torch.ops.aten.mm.default(view_1843, permute_1359); view_1843 = permute_1359 = None + view_1844 = torch.ops.aten.view.default(mm_666, [2, 8192, 4096]); mm_666 = None + add_348 = torch.ops.aten.add.Tensor(view_1842, view_1844); view_1842 = view_1844 = None + convert_element_type_2765 = torch.ops.prims.convert_element_type.default(mm_665, torch.float32); mm_665 = None + reduce_scatter_tensor_283 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2765, 'avg', 256, '0'); convert_element_type_2765 = None + wait_tensor_574 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_283); reduce_scatter_tensor_283 = None + convert_element_type_2766 = torch.ops.prims.convert_element_type.default(add_348, torch.float32); add_348 = None + convert_element_type_2768 = torch.ops.prims.convert_element_type.default(wait_tensor_6, torch.float32); wait_tensor_6 = None + mul_890 = torch.ops.aten.mul.Tensor(convert_element_type_2766, convert_element_type_2768); convert_element_type_2768 = None + mul_892 = torch.ops.aten.mul.Tensor(mul_4, mul_890) + sum_189 = torch.ops.aten.sum.dim_IntList(mul_892, [2], True); mul_892 = None + div_63 = torch.ops.aten.div.Tensor(mul_4, 4096) + mul_893 = torch.ops.aten.mul.Tensor(div_63, sum_189); div_63 = sum_189 = None + sub_95 = torch.ops.aten.sub.Tensor(mul_890, mul_893); mul_890 = mul_893 = None + mul_894 = torch.ops.aten.mul.Tensor(sub_95, rsqrt_1); sub_95 = rsqrt_1 = None + mul_895 = torch.ops.aten.mul.Tensor(convert_element_type_2766, mul_4); convert_element_type_2766 = mul_4 = None + sum_190 = torch.ops.aten.sum.dim_IntList(mul_895, [0, 1]); mul_895 = None + convert_element_type_2769 = torch.ops.prims.convert_element_type.default(mul_894, torch.bfloat16); mul_894 = None + add_349 = torch.ops.aten.add.Tensor(add_345, convert_element_type_2769); add_345 = convert_element_type_2769 = None + convert_element_type_default_2 = torch.ops.prims.convert_element_type.default(sum_190, torch.float32); sum_190 = None + reduce_scatter_tensor_284 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_2, 'avg', 256, '0'); convert_element_type_default_2 = None + wait_tensor_575 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_284); reduce_scatter_tensor_284 = None + view_1845 = torch.ops.aten.view.default(add_349, [16384, 4096]) + permute_1361 = torch.ops.aten.permute.default(view_1845, [1, 0]) + mm_667 = torch.ops.aten.mm.default(permute_1361, view_23); permute_1361 = view_23 = None + permute_1363 = torch.ops.aten.permute.default(permute_7, [1, 0]); permute_7 = None + mm_668 = torch.ops.aten.mm.default(view_1845, permute_1363); view_1845 = permute_1363 = None + view_1846 = torch.ops.aten.view.default(mm_668, [2, 8192, 4096]); mm_668 = None + convert_element_type_2776 = torch.ops.prims.convert_element_type.default(mm_667, torch.float32); mm_667 = None + reduce_scatter_tensor_285 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2776, 'avg', 256, '0'); convert_element_type_2776 = None + wait_tensor_576 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_285); reduce_scatter_tensor_285 = None + view_1847 = torch.ops.aten.view.default(view_1846, [2, 8192, 32, 128]); view_1846 = None + permute_1365 = torch.ops.aten.permute.default(view_1847, [0, 2, 1, 3]); view_1847 = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16); primals_4 = None + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 256, '0'); convert_element_type_1 = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(embedding, torch.float32); embedding = None + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_1) + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + view_3 = torch.ops.aten.view.default(convert_element_type_3, [16384, 4096]); convert_element_type_3 = None + view_4 = torch.ops.aten.view.default(mm, [2, 8192, 4096]); mm = None + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16); primals_6 = None + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 256, '0'); convert_element_type_7 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_3, [1, 0]); wait_tensor_3 = None + mm_1 = torch.ops.aten.mm.default(view_3, permute_1) + view_7 = torch.ops.aten.view.default(mm_1, [2, 8192, 1024]); mm_1 = None + view_10 = torch.ops.aten.view.default(mm_2, [2, 8192, 1024]); mm_2 = None + view_11 = torch.ops.aten.view.default(view_4, [2, 8192, -1, 128]); view_4 = None + view_12 = torch.ops.aten.view.default(view_7, [2, 8192, -1, 128]); view_7 = None + view_13 = torch.ops.aten.view.default(view_10, [2, 8192, -1, 128]); view_10 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_11, torch.float32); view_11 = None + view_14 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 32, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_14); view_14 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_12, torch.float32); view_12 = None + view_15 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 8, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_15); view_15 = None + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_16); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_17 = torch.ops.aten.view.default(view_as_real, [2, 8192, 32, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_16); view_as_complex_1 = view_16 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_18 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 8, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_17, torch.bfloat16); view_17 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_18, torch.bfloat16); view_18 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 8, 4, 128]); unsqueeze = None + clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None + view_19 = torch.ops.aten.view.default(clone, [2, 8192, 32, 128]); clone = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_13, 3); view_13 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 8, 4, 128]); unsqueeze_1 = None + clone_1 = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None + view_20 = torch.ops.aten.view.default(clone_1, [2, 8192, 32, 128]); clone_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3]); view_19 = None + permute_5 = torch.ops.aten.permute.default(view_20, [0, 2, 1, 3]); view_20 = None + _scaled_dot_product_cudnn_attention_backward_31 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1365, permute_3, permute_4, permute_5, getitem, getitem_1, getitem_6, getitem_7, None, None, None, 8192, 8192, 0.0, True); permute_1365 = permute_3 = permute_4 = permute_5 = getitem = getitem_1 = getitem_6 = getitem_7 = None + getitem_381 = _scaled_dot_product_cudnn_attention_backward_31[0] + getitem_382 = _scaled_dot_product_cudnn_attention_backward_31[1] + getitem_383 = _scaled_dot_product_cudnn_attention_backward_31[2]; _scaled_dot_product_cudnn_attention_backward_31 = None + permute_1366 = torch.ops.aten.permute.default(getitem_383, [0, 2, 1, 3]); getitem_383 = None + permute_1367 = torch.ops.aten.permute.default(getitem_382, [0, 2, 1, 3]); getitem_382 = None + permute_1368 = torch.ops.aten.permute.default(getitem_381, [0, 2, 1, 3]); getitem_381 = None + view_1848 = torch.ops.aten.view.default(permute_1366, [2, 8192, 8, 4, 128]); permute_1366 = None + sum_191 = torch.ops.aten.sum.dim_IntList(view_1848, [3], True); view_1848 = None + squeeze_62 = torch.ops.aten.squeeze.dim(sum_191, 3); sum_191 = None + view_1849 = torch.ops.aten.view.default(permute_1367, [2, 8192, 8, 4, 128]); permute_1367 = None + sum_192 = torch.ops.aten.sum.dim_IntList(view_1849, [3], True); view_1849 = None + squeeze_63 = torch.ops.aten.squeeze.dim(sum_192, 3); sum_192 = None + convert_element_type_2777 = torch.ops.prims.convert_element_type.default(squeeze_63, torch.float32); squeeze_63 = None + convert_element_type_2778 = torch.ops.prims.convert_element_type.default(permute_1368, torch.float32); permute_1368 = None + view_1850 = torch.ops.aten.view.default(convert_element_type_2777, [2, 8192, 8, 64, 2]); convert_element_type_2777 = None + view_as_complex_126 = torch.ops.aten.view_as_complex.default(view_1850); view_1850 = None + mul_896 = torch.ops.aten.mul.Tensor(view_as_complex_126, _conj); view_as_complex_126 = None + view_1851 = torch.ops.aten.view.default(convert_element_type_2778, [2, 8192, 32, 64, 2]); convert_element_type_2778 = None + view_as_complex_127 = torch.ops.aten.view_as_complex.default(view_1851); view_1851 = None + mul_897 = torch.ops.aten.mul.Tensor(view_as_complex_127, _conj); view_as_complex_127 = _conj = None + view_as_real_126 = torch.ops.aten.view_as_real.default(mul_896); mul_896 = None + view_1852 = torch.ops.aten.view.default(view_as_real_126, [2, 8192, 8, 128]); view_as_real_126 = None + convert_element_type_2779 = torch.ops.prims.convert_element_type.default(view_1852, torch.bfloat16); view_1852 = None + view_as_real_127 = torch.ops.aten.view_as_real.default(mul_897); mul_897 = None + view_1853 = torch.ops.aten.view.default(view_as_real_127, [2, 8192, 32, 128]); view_as_real_127 = None + convert_element_type_2780 = torch.ops.prims.convert_element_type.default(view_1853, torch.bfloat16); view_1853 = None + view_1854 = torch.ops.aten.view.default(squeeze_62, [2, 8192, 1024]); squeeze_62 = None + view_1855 = torch.ops.aten.view.default(convert_element_type_2779, [2, 8192, 1024]); convert_element_type_2779 = None + view_1856 = torch.ops.aten.view.default(convert_element_type_2780, [2, 8192, 4096]); convert_element_type_2780 = None + view_1857 = torch.ops.aten.view.default(view_1854, [16384, 1024]); view_1854 = None + permute_1369 = torch.ops.aten.permute.default(view_1857, [1, 0]) + mm_669 = torch.ops.aten.mm.default(permute_1369, view_3); permute_1369 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16); primals_7 = None + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 256, '0'); convert_element_type_10 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + permute_1371 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + mm_670 = torch.ops.aten.mm.default(view_1857, permute_1371); view_1857 = permute_1371 = None + view_1858 = torch.ops.aten.view.default(mm_670, [2, 8192, 4096]); mm_670 = None + convert_element_type_2785 = torch.ops.prims.convert_element_type.default(mm_669, torch.float32); mm_669 = None + reduce_scatter_tensor_286 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2785, 'avg', 256, '0'); convert_element_type_2785 = None + wait_tensor_577 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_286); reduce_scatter_tensor_286 = None + view_1859 = torch.ops.aten.view.default(view_1855, [16384, 1024]); view_1855 = None + permute_1373 = torch.ops.aten.permute.default(view_1859, [1, 0]) + mm_671 = torch.ops.aten.mm.default(permute_1373, view_3); permute_1373 = None + permute_1375 = torch.ops.aten.permute.default(permute_1, [1, 0]); permute_1 = None + mm_672 = torch.ops.aten.mm.default(view_1859, permute_1375); view_1859 = permute_1375 = None + view_1860 = torch.ops.aten.view.default(mm_672, [2, 8192, 4096]); mm_672 = None + add_350 = torch.ops.aten.add.Tensor(view_1858, view_1860); view_1858 = view_1860 = None + convert_element_type_2790 = torch.ops.prims.convert_element_type.default(mm_671, torch.float32); mm_671 = None + reduce_scatter_tensor_287 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2790, 'avg', 256, '0'); convert_element_type_2790 = None + wait_tensor_578 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_287); reduce_scatter_tensor_287 = None + view_1861 = torch.ops.aten.view.default(view_1856, [16384, 4096]); view_1856 = None + permute_1377 = torch.ops.aten.permute.default(view_1861, [1, 0]) + mm_673 = torch.ops.aten.mm.default(permute_1377, view_3); permute_1377 = view_3 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16); primals_5 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 256, '0'); convert_element_type_4 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + permute = torch.ops.aten.permute.default(wait_tensor_2, [1, 0]); wait_tensor_2 = None + permute_1379 = torch.ops.aten.permute.default(permute, [1, 0]); permute = None + mm_674 = torch.ops.aten.mm.default(view_1861, permute_1379); view_1861 = permute_1379 = None + view_1862 = torch.ops.aten.view.default(mm_674, [2, 8192, 4096]); mm_674 = None + add_351 = torch.ops.aten.add.Tensor(add_350, view_1862); add_350 = view_1862 = None + convert_element_type_2795 = torch.ops.prims.convert_element_type.default(mm_673, torch.float32); mm_673 = None + reduce_scatter_tensor_288 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2795, 'avg', 256, '0'); convert_element_type_2795 = None + wait_tensor_579 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_288); reduce_scatter_tensor_288 = None + convert_element_type_2796 = torch.ops.prims.convert_element_type.default(add_351, torch.float32); add_351 = None + convert_element_type_2798 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32); wait_tensor_1 = None + mul_898 = torch.ops.aten.mul.Tensor(convert_element_type_2796, convert_element_type_2798); convert_element_type_2798 = None + mul_900 = torch.ops.aten.mul.Tensor(mul, mul_898) + sum_193 = torch.ops.aten.sum.dim_IntList(mul_900, [2], True); mul_900 = None + div_64 = torch.ops.aten.div.Tensor(mul, 4096) + mul_901 = torch.ops.aten.mul.Tensor(div_64, sum_193); div_64 = sum_193 = None + sub_96 = torch.ops.aten.sub.Tensor(mul_898, mul_901); mul_898 = mul_901 = None + mul_902 = torch.ops.aten.mul.Tensor(sub_96, rsqrt); sub_96 = rsqrt = None + mul_903 = torch.ops.aten.mul.Tensor(convert_element_type_2796, mul); convert_element_type_2796 = mul = None + sum_194 = torch.ops.aten.sum.dim_IntList(mul_903, [0, 1]); mul_903 = None + convert_element_type_2799 = torch.ops.prims.convert_element_type.default(mul_902, torch.bfloat16); mul_902 = None + add_352 = torch.ops.aten.add.Tensor(add_349, convert_element_type_2799); add_349 = convert_element_type_2799 = None + convert_element_type_default_1 = torch.ops.prims.convert_element_type.default(sum_194, torch.float32); sum_194 = None + reduce_scatter_tensor_289 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_1, 'avg', 256, '0'); convert_element_type_default_1 = None + wait_tensor_580 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_289); reduce_scatter_tensor_289 = None + convert_element_type_2802 = torch.ops.prims.convert_element_type.default(add_352, torch.float32); add_352 = None + eq = torch.ops.aten.eq.Scalar(primals_2, -1) + unsqueeze_64 = torch.ops.aten.unsqueeze.default(eq, -1); eq = None + full_default = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + where = torch.ops.aten.where.self(unsqueeze_64, full_default, convert_element_type_2802); unsqueeze_64 = full_default = convert_element_type_2802 = None + full_default_1 = torch.ops.aten.full.default([128256, 4096], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put = torch.ops.aten.index_put.default(full_default_1, [primals_2], where, True); full_default_1 = primals_2 = where = None + convert_element_type_default = torch.ops.prims.convert_element_type.default(index_put, torch.float32); index_put = None + reduce_scatter_tensor_290 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default, 'avg', 256, '0'); convert_element_type_default = None + wait_tensor_581 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_290); reduce_scatter_tensor_290 = None + return (wait_tensor_581, None, None, wait_tensor_580, wait_tensor_579, wait_tensor_578, wait_tensor_577, wait_tensor_576, wait_tensor_575, wait_tensor_574, wait_tensor_573, wait_tensor_572, wait_tensor_571, wait_tensor_570, wait_tensor_569, wait_tensor_568, wait_tensor_567, wait_tensor_566, wait_tensor_565, wait_tensor_564, wait_tensor_563, wait_tensor_562, wait_tensor_561, wait_tensor_560, wait_tensor_559, wait_tensor_558, wait_tensor_557, wait_tensor_556, wait_tensor_555, wait_tensor_554, wait_tensor_553, wait_tensor_552, wait_tensor_551, wait_tensor_550, wait_tensor_549, wait_tensor_548, wait_tensor_547, wait_tensor_546, wait_tensor_545, wait_tensor_544, wait_tensor_543, wait_tensor_542, wait_tensor_541, wait_tensor_540, wait_tensor_539, wait_tensor_538, wait_tensor_537, wait_tensor_536, wait_tensor_535, wait_tensor_534, wait_tensor_533, wait_tensor_532, wait_tensor_531, wait_tensor_530, wait_tensor_529, wait_tensor_528, wait_tensor_527, wait_tensor_526, wait_tensor_525, wait_tensor_524, wait_tensor_523, wait_tensor_522, wait_tensor_521, wait_tensor_520, wait_tensor_519, wait_tensor_518, wait_tensor_517, wait_tensor_516, wait_tensor_515, wait_tensor_514, wait_tensor_513, wait_tensor_512, wait_tensor_511, wait_tensor_510, wait_tensor_509, wait_tensor_508, wait_tensor_507, wait_tensor_506, wait_tensor_505, wait_tensor_504, wait_tensor_503, wait_tensor_502, wait_tensor_501, wait_tensor_500, wait_tensor_499, wait_tensor_498, wait_tensor_497, wait_tensor_496, wait_tensor_495, wait_tensor_494, wait_tensor_493, wait_tensor_492, wait_tensor_491, wait_tensor_490, wait_tensor_489, wait_tensor_488, wait_tensor_487, wait_tensor_486, wait_tensor_485, wait_tensor_484, wait_tensor_483, wait_tensor_482, wait_tensor_481, wait_tensor_480, wait_tensor_479, wait_tensor_478, wait_tensor_477, wait_tensor_476, wait_tensor_475, wait_tensor_474, wait_tensor_473, wait_tensor_472, wait_tensor_471, wait_tensor_470, wait_tensor_469, wait_tensor_468, wait_tensor_467, wait_tensor_466, wait_tensor_465, wait_tensor_464, wait_tensor_463, wait_tensor_462, wait_tensor_461, wait_tensor_460, wait_tensor_459, wait_tensor_458, wait_tensor_457, wait_tensor_456, wait_tensor_455, wait_tensor_454, wait_tensor_453, wait_tensor_452, wait_tensor_451, wait_tensor_450, wait_tensor_449, wait_tensor_448, wait_tensor_447, wait_tensor_446, wait_tensor_445, wait_tensor_444, wait_tensor_443, wait_tensor_442, wait_tensor_441, wait_tensor_440, wait_tensor_439, wait_tensor_438, wait_tensor_437, wait_tensor_436, wait_tensor_435, wait_tensor_434, wait_tensor_433, wait_tensor_432, wait_tensor_431, wait_tensor_430, wait_tensor_429, wait_tensor_428, wait_tensor_427, wait_tensor_426, wait_tensor_425, wait_tensor_424, wait_tensor_423, wait_tensor_422, wait_tensor_421, wait_tensor_420, wait_tensor_419, wait_tensor_418, wait_tensor_417, wait_tensor_416, wait_tensor_415, wait_tensor_414, wait_tensor_413, wait_tensor_412, wait_tensor_411, wait_tensor_410, wait_tensor_409, wait_tensor_408, wait_tensor_407, wait_tensor_406, wait_tensor_405, wait_tensor_404, wait_tensor_403, wait_tensor_402, wait_tensor_401, wait_tensor_400, wait_tensor_399, wait_tensor_398, wait_tensor_397, wait_tensor_396, wait_tensor_395, wait_tensor_394, wait_tensor_393, wait_tensor_392, wait_tensor_391, wait_tensor_390, wait_tensor_389, wait_tensor_388, wait_tensor_387, wait_tensor_386, wait_tensor_385, wait_tensor_384, wait_tensor_383, wait_tensor_382, wait_tensor_381, wait_tensor_380, wait_tensor_379, wait_tensor_378, wait_tensor_377, wait_tensor_376, wait_tensor_375, wait_tensor_374, wait_tensor_373, wait_tensor_372, wait_tensor_371, wait_tensor_370, wait_tensor_369, wait_tensor_368, wait_tensor_367, wait_tensor_366, wait_tensor_365, wait_tensor_364, wait_tensor_363, wait_tensor_362, wait_tensor_361, wait_tensor_360, wait_tensor_359, wait_tensor_358, wait_tensor_357, wait_tensor_356, wait_tensor_355, wait_tensor_354, wait_tensor_353, wait_tensor_352, wait_tensor_351, wait_tensor_350, wait_tensor_349, wait_tensor_348, wait_tensor_347, wait_tensor_346, wait_tensor_345, wait_tensor_344, wait_tensor_343, wait_tensor_342, wait_tensor_341, wait_tensor_340, wait_tensor_339, wait_tensor_338, wait_tensor_337, wait_tensor_336, wait_tensor_335, wait_tensor_334, wait_tensor_333, wait_tensor_332, wait_tensor_331, wait_tensor_330, wait_tensor_329, wait_tensor_328, wait_tensor_327, wait_tensor_326, wait_tensor_325, wait_tensor_324, wait_tensor_323, wait_tensor_322, wait_tensor_321, wait_tensor_320, wait_tensor_319, wait_tensor_318, wait_tensor_317, wait_tensor_316, wait_tensor_315, wait_tensor_314, wait_tensor_313, wait_tensor_312, wait_tensor_311, wait_tensor_310, wait_tensor_309, wait_tensor_308, wait_tensor_307, wait_tensor_306, wait_tensor_305, wait_tensor_304, wait_tensor_303, wait_tensor_302, wait_tensor_301, wait_tensor_300, wait_tensor_299, wait_tensor_298, wait_tensor_297, wait_tensor_296, wait_tensor_295, wait_tensor_294, wait_tensor_293, wait_tensor_292, wait_tensor_291) + +def load_args(reader): + buf0 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf0, (501, 4096), is_leaf=True) # primals_1 + buf1 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf3, (16,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf4, (16, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf5, (4, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf6, (4, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf7, (16, 4096), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf8, (16,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf9, (56, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf10, (56, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf11, (16, 14336), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf12, (16,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf13, (16, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf14, (4, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf15, (4, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf16, (16, 4096), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf17, (16,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf18, (56, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf19, (56, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf20, (16, 14336), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf21, (16,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf22, (16, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf23, (4, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf24, (4, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf25, (16, 4096), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf26, (16,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf27, (56, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf28, (56, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf29, (16, 14336), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf30, (16,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf31, (16, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf32, (4, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf33, (4, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf34, (16, 4096), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf35, (16,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf36, (56, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf37, (56, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf38, (16, 14336), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf39, (16,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf40, (16, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf41, (4, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf42, (4, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf43, (16, 4096), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf44, (16,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf45, (56, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf46, (56, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf47, (16, 14336), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf48, (16,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf49, (16, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf50, (4, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf51, (4, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf52, (16, 4096), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf53, (16,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf54, (56, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf55, (56, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf56, (16, 14336), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf57, (16,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf58, (16, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf59, (4, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf60, (4, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf61, (16, 4096), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf62, (16,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf63, (56, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf64, (56, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf65, (16, 14336), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf66, (16,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf67, (16, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf68, (4, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf69, (4, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf70, (16, 4096), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf71, (16,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf72, (56, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf73, (56, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf74, (16, 14336), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf75, (16,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf76, (16, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf77, (4, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf78, (4, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf79, (16, 4096), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf80, (16,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf81, (56, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf82, (56, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf83, (16, 14336), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf84, (16,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf85, (16, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf86, (4, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf87, (4, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf88, (16, 4096), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf89, (16,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf90, (56, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf91, (56, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf92, (16, 14336), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf93, (16,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf94, (16, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf95, (4, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf96, (4, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf97, (16, 4096), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf98, (16,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf99, (56, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf100, (56, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf101, (16, 14336), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf102, (16,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf103, (16, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf104, (4, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf105, (4, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf106, (16, 4096), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf107, (16,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf108, (56, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf109, (56, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf110, (16, 14336), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf111, (16,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf112, (16, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf113, (4, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf114, (4, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf115, (16, 4096), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf116, (16,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf117, (56, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf118, (56, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf119, (16, 14336), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf120, (16,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf121, (16, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf122, (4, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf123, (4, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf124, (16, 4096), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf125, (16,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf126, (56, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf127, (56, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf128, (16, 14336), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf129, (16,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf130, (16, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf131, (4, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf132, (4, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf133, (16, 4096), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf134, (16,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf135, (56, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf136, (56, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf137, (16, 14336), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf138, (16,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf139, (16, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf140, (4, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf141, (4, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf142, (16, 4096), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf143, (16,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf144, (56, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf145, (56, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf146, (16, 14336), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf147, (16,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf148, (16, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf149, (4, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf150, (4, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf151, (16, 4096), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf152, (16,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf153, (56, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf154, (56, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf155, (16, 14336), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf156, (16,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf157, (16, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf158, (4, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf159, (4, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf160, (16, 4096), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf161, (16,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf162, (56, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf163, (56, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf164, (16, 14336), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf165, (16,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf166, (16, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf167, (4, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf168, (4, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf169, (16, 4096), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf170, (16,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf171, (56, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf172, (56, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf173, (16, 14336), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf174, (16,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf175, (16, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf176, (4, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf177, (4, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf178, (16, 4096), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf179, (16,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf180, (56, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf181, (56, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf182, (16, 14336), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf183, (16,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf184, (16, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf185, (4, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf186, (4, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf187, (16, 4096), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf188, (16,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf189, (56, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf190, (56, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf191, (16, 14336), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf192, (16,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf193, (16, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf194, (4, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf195, (4, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf196, (16, 4096), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf197, (16,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf198, (56, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf199, (56, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf200, (16, 14336), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf201, (16,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf202, (16, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf203, (4, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf204, (4, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf205, (16, 4096), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf206, (16,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf207, (56, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf208, (56, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf209, (16, 14336), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf210, (16,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf211, (16, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf212, (4, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf213, (4, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf214, (16, 4096), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf215, (16,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf216, (56, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf217, (56, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf218, (16, 14336), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf219, (16,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf220, (16, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf221, (4, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf222, (4, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf223, (16, 4096), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf224, (16,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf225, (56, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf226, (56, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf227, (16, 14336), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf228, (16,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf229, (16, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf230, (4, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf231, (4, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf232, (16, 4096), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf233, (16,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf234, (56, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf235, (56, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf236, (16, 14336), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf237, (16,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf238, (16, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf239, (4, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf240, (4, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf241, (16, 4096), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf242, (16,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf243, (56, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf244, (56, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf245, (16, 14336), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf246, (16,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf247, (16, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf248, (4, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf249, (4, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf250, (16, 4096), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf251, (16,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf252, (56, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf253, (56, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf254, (16, 14336), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf255, (16,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf256, (16, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf257, (4, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf258, (4, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf259, (16, 4096), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf260, (16,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf261, (56, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf262, (56, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf263, (16, 14336), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf264, (16,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf265, (16, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf266, (4, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf267, (4, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf268, (16, 4096), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf269, (16,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf270, (56, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf271, (56, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf272, (16, 14336), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf273, (16,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf274, (16, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf275, (4, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf276, (4, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf277, (16, 4096), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf278, (16,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf279, (56, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf280, (56, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf281, (16, 14336), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf282, (16,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf283, (16, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf284, (4, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf285, (4, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf286, (16, 4096), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf287, (16,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf288, (56, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf289, (56, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf290, (16, 14336), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf291, (16,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf292, (501, 4096), is_leaf=True) # primals_293 + buf293 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf293, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # embedding + buf294 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf294, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm + buf295 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf295, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_2 + buf296 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf296, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem + buf297 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf297, (2, 32, 8192, 1), is_leaf=True) # getitem_1 + buf298 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf298, (), dtype=torch.int64, is_leaf=True) # getitem_6 + buf299 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf299, (), dtype=torch.int64, is_leaf=True) # getitem_7 + buf300 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf300, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_4 + buf301 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf301, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_3 + buf302 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf302, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_7 + buf303 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf303, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_9 + buf304 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf304, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_9 + buf305 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf305, (2, 32, 8192, 1), is_leaf=True) # getitem_10 + buf306 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf306, (), dtype=torch.int64, is_leaf=True) # getitem_15 + buf307 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf307, (), dtype=torch.int64, is_leaf=True) # getitem_16 + buf308 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf308, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_11 + buf309 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf309, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_7 + buf310 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf310, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_14 + buf311 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf311, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_16 + buf312 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf312, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_18 + buf313 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf313, (2, 32, 8192, 1), is_leaf=True) # getitem_19 + buf314 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf314, (), dtype=torch.int64, is_leaf=True) # getitem_24 + buf315 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf315, (), dtype=torch.int64, is_leaf=True) # getitem_25 + buf316 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf316, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_18 + buf317 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf317, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_11 + buf318 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf318, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_21 + buf319 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf319, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_23 + buf320 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf320, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_27 + buf321 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf321, (2, 32, 8192, 1), is_leaf=True) # getitem_28 + buf322 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf322, (), dtype=torch.int64, is_leaf=True) # getitem_33 + buf323 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf323, (), dtype=torch.int64, is_leaf=True) # getitem_34 + buf324 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf324, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_25 + buf325 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf325, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_15 + buf326 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf326, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_28 + buf327 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf327, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_30 + buf328 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf328, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_36 + buf329 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf329, (2, 32, 8192, 1), is_leaf=True) # getitem_37 + buf330 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf330, (), dtype=torch.int64, is_leaf=True) # getitem_42 + buf331 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf331, (), dtype=torch.int64, is_leaf=True) # getitem_43 + buf332 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf332, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_32 + buf333 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf333, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_19 + buf334 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf334, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_35 + buf335 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf335, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_37 + buf336 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf336, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_45 + buf337 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf337, (2, 32, 8192, 1), is_leaf=True) # getitem_46 + buf338 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf338, (), dtype=torch.int64, is_leaf=True) # getitem_51 + buf339 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf339, (), dtype=torch.int64, is_leaf=True) # getitem_52 + buf340 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf340, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_39 + buf341 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf341, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_23 + buf342 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf342, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_42 + buf343 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf343, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_44 + buf344 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf344, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_54 + buf345 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf345, (2, 32, 8192, 1), is_leaf=True) # getitem_55 + buf346 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf346, (), dtype=torch.int64, is_leaf=True) # getitem_60 + buf347 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf347, (), dtype=torch.int64, is_leaf=True) # getitem_61 + buf348 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf348, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_46 + buf349 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf349, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_27 + buf350 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf350, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_49 + buf351 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf351, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_51 + buf352 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf352, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_63 + buf353 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf353, (2, 32, 8192, 1), is_leaf=True) # getitem_64 + buf354 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf354, (), dtype=torch.int64, is_leaf=True) # getitem_69 + buf355 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf355, (), dtype=torch.int64, is_leaf=True) # getitem_70 + buf356 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf356, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_53 + buf357 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf357, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_31 + buf358 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf358, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_56 + buf359 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf359, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_58 + buf360 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf360, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_72 + buf361 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf361, (2, 32, 8192, 1), is_leaf=True) # getitem_73 + buf362 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf362, (), dtype=torch.int64, is_leaf=True) # getitem_78 + buf363 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf363, (), dtype=torch.int64, is_leaf=True) # getitem_79 + buf364 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf364, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_60 + buf365 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf365, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_35 + buf366 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf366, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_63 + buf367 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf367, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_65 + buf368 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf368, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_81 + buf369 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf369, (2, 32, 8192, 1), is_leaf=True) # getitem_82 + buf370 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf370, (), dtype=torch.int64, is_leaf=True) # getitem_87 + buf371 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf371, (), dtype=torch.int64, is_leaf=True) # getitem_88 + buf372 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf372, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_67 + buf373 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf373, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_39 + buf374 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf374, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_70 + buf375 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf375, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_72 + buf376 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf376, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_90 + buf377 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf377, (2, 32, 8192, 1), is_leaf=True) # getitem_91 + buf378 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf378, (), dtype=torch.int64, is_leaf=True) # getitem_96 + buf379 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf379, (), dtype=torch.int64, is_leaf=True) # getitem_97 + buf380 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf380, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_74 + buf381 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf381, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_43 + buf382 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf382, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_77 + buf383 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf383, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_79 + buf384 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf384, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_99 + buf385 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf385, (2, 32, 8192, 1), is_leaf=True) # getitem_100 + buf386 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf386, (), dtype=torch.int64, is_leaf=True) # getitem_105 + buf387 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf387, (), dtype=torch.int64, is_leaf=True) # getitem_106 + buf388 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf388, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_81 + buf389 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf389, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_47 + buf390 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf390, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_84 + buf391 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf391, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_86 + buf392 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf392, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_108 + buf393 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf393, (2, 32, 8192, 1), is_leaf=True) # getitem_109 + buf394 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf394, (), dtype=torch.int64, is_leaf=True) # getitem_114 + buf395 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf395, (), dtype=torch.int64, is_leaf=True) # getitem_115 + buf396 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf396, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_88 + buf397 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf397, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_51 + buf398 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf398, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_91 + buf399 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf399, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_93 + buf400 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf400, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_117 + buf401 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf401, (2, 32, 8192, 1), is_leaf=True) # getitem_118 + buf402 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf402, (), dtype=torch.int64, is_leaf=True) # getitem_123 + buf403 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf403, (), dtype=torch.int64, is_leaf=True) # getitem_124 + buf404 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf404, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_95 + buf405 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf405, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_55 + buf406 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf406, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_98 + buf407 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf407, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_100 + buf408 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf408, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_126 + buf409 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf409, (2, 32, 8192, 1), is_leaf=True) # getitem_127 + buf410 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf410, (), dtype=torch.int64, is_leaf=True) # getitem_132 + buf411 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf411, (), dtype=torch.int64, is_leaf=True) # getitem_133 + buf412 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf412, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_102 + buf413 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf413, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_59 + buf414 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf414, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_105 + buf415 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf415, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_107 + buf416 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf416, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_135 + buf417 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf417, (2, 32, 8192, 1), is_leaf=True) # getitem_136 + buf418 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf418, (), dtype=torch.int64, is_leaf=True) # getitem_141 + buf419 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf419, (), dtype=torch.int64, is_leaf=True) # getitem_142 + buf420 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf420, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_109 + buf421 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf421, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_63 + buf422 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf422, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_112 + buf423 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf423, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_114 + buf424 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf424, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_144 + buf425 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf425, (2, 32, 8192, 1), is_leaf=True) # getitem_145 + buf426 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf426, (), dtype=torch.int64, is_leaf=True) # getitem_150 + buf427 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf427, (), dtype=torch.int64, is_leaf=True) # getitem_151 + buf428 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf428, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_116 + buf429 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf429, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_67 + buf430 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf430, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_119 + buf431 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf431, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_121 + buf432 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf432, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_153 + buf433 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf433, (2, 32, 8192, 1), is_leaf=True) # getitem_154 + buf434 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf434, (), dtype=torch.int64, is_leaf=True) # getitem_159 + buf435 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf435, (), dtype=torch.int64, is_leaf=True) # getitem_160 + buf436 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf436, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_123 + buf437 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf437, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_71 + buf438 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf438, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_126 + buf439 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf439, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_128 + buf440 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf440, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_162 + buf441 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf441, (2, 32, 8192, 1), is_leaf=True) # getitem_163 + buf442 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf442, (), dtype=torch.int64, is_leaf=True) # getitem_168 + buf443 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf443, (), dtype=torch.int64, is_leaf=True) # getitem_169 + buf444 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf444, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_130 + buf445 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf445, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_75 + buf446 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf446, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_133 + buf447 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf447, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_135 + buf448 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf448, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_171 + buf449 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf449, (2, 32, 8192, 1), is_leaf=True) # getitem_172 + buf450 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf450, (), dtype=torch.int64, is_leaf=True) # getitem_177 + buf451 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf451, (), dtype=torch.int64, is_leaf=True) # getitem_178 + buf452 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf452, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_137 + buf453 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf453, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_79 + buf454 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf454, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_140 + buf455 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf455, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_142 + buf456 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf456, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_180 + buf457 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf457, (2, 32, 8192, 1), is_leaf=True) # getitem_181 + buf458 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf458, (), dtype=torch.int64, is_leaf=True) # getitem_186 + buf459 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf459, (), dtype=torch.int64, is_leaf=True) # getitem_187 + buf460 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf460, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_144 + buf461 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf461, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_83 + buf462 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf462, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_147 + buf463 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf463, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_149 + buf464 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf464, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_189 + buf465 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf465, (2, 32, 8192, 1), is_leaf=True) # getitem_190 + buf466 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf466, (), dtype=torch.int64, is_leaf=True) # getitem_195 + buf467 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf467, (), dtype=torch.int64, is_leaf=True) # getitem_196 + buf468 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf468, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_151 + buf469 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf469, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_87 + buf470 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf470, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_154 + buf471 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf471, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_156 + buf472 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf472, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_198 + buf473 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf473, (2, 32, 8192, 1), is_leaf=True) # getitem_199 + buf474 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf474, (), dtype=torch.int64, is_leaf=True) # getitem_204 + buf475 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf475, (), dtype=torch.int64, is_leaf=True) # getitem_205 + buf476 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf476, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_158 + buf477 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf477, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_91 + buf478 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf478, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_161 + buf479 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf479, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_163 + buf480 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf480, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_207 + buf481 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf481, (2, 32, 8192, 1), is_leaf=True) # getitem_208 + buf482 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf482, (), dtype=torch.int64, is_leaf=True) # getitem_213 + buf483 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf483, (), dtype=torch.int64, is_leaf=True) # getitem_214 + buf484 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf484, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_165 + buf485 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf485, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_95 + buf486 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf486, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_168 + buf487 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf487, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_170 + buf488 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf488, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_216 + buf489 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf489, (2, 32, 8192, 1), is_leaf=True) # getitem_217 + buf490 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf490, (), dtype=torch.int64, is_leaf=True) # getitem_222 + buf491 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf491, (), dtype=torch.int64, is_leaf=True) # getitem_223 + buf492 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf492, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_172 + buf493 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf493, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_99 + buf494 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf494, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_175 + buf495 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf495, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_177 + buf496 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf496, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_225 + buf497 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf497, (2, 32, 8192, 1), is_leaf=True) # getitem_226 + buf498 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf498, (), dtype=torch.int64, is_leaf=True) # getitem_231 + buf499 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf499, (), dtype=torch.int64, is_leaf=True) # getitem_232 + buf500 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf500, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_179 + buf501 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf501, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_103 + buf502 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf502, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_182 + buf503 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf503, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_184 + buf504 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf504, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_234 + buf505 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf505, (2, 32, 8192, 1), is_leaf=True) # getitem_235 + buf506 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf506, (), dtype=torch.int64, is_leaf=True) # getitem_240 + buf507 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf507, (), dtype=torch.int64, is_leaf=True) # getitem_241 + buf508 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf508, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_186 + buf509 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf509, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_107 + buf510 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf510, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_189 + buf511 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf511, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_191 + buf512 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf512, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_243 + buf513 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf513, (2, 32, 8192, 1), is_leaf=True) # getitem_244 + buf514 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf514, (), dtype=torch.int64, is_leaf=True) # getitem_249 + buf515 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf515, (), dtype=torch.int64, is_leaf=True) # getitem_250 + buf516 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf516, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_193 + buf517 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf517, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_111 + buf518 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf518, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_196 + buf519 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf519, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_198 + buf520 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf520, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_252 + buf521 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf521, (2, 32, 8192, 1), is_leaf=True) # getitem_253 + buf522 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf522, (), dtype=torch.int64, is_leaf=True) # getitem_258 + buf523 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf523, (), dtype=torch.int64, is_leaf=True) # getitem_259 + buf524 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf524, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_200 + buf525 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf525, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_115 + buf526 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf526, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_203 + buf527 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf527, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_205 + buf528 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf528, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_261 + buf529 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf529, (2, 32, 8192, 1), is_leaf=True) # getitem_262 + buf530 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf530, (), dtype=torch.int64, is_leaf=True) # getitem_267 + buf531 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf531, (), dtype=torch.int64, is_leaf=True) # getitem_268 + buf532 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf532, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_207 + buf533 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf533, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_119 + buf534 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf534, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_210 + buf535 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf535, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_212 + buf536 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf536, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_270 + buf537 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf537, (2, 32, 8192, 1), is_leaf=True) # getitem_271 + buf538 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf538, (), dtype=torch.int64, is_leaf=True) # getitem_276 + buf539 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf539, (), dtype=torch.int64, is_leaf=True) # getitem_277 + buf540 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf540, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_214 + buf541 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf541, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_123 + buf542 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf542, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_217 + buf543 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf543, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_219 + buf544 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf544, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_279 + buf545 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf545, (2, 32, 8192, 1), is_leaf=True) # getitem_280 + buf546 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf546, (), dtype=torch.int64, is_leaf=True) # getitem_285 + buf547 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf547, (), dtype=torch.int64, is_leaf=True) # getitem_286 + buf548 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf548, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_221 + buf549 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf549, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_223 + buf550 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf550, (2, 8192, 1), is_leaf=True) # rsqrt_64 + buf551 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf551, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # view_1091 + buf552 = reader.storage(None, 4202692608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf552, (2, 8192, 128256), dtype=torch.bfloat16, is_leaf=True) # tangents_1 + +load_args._version = 0 + +def get_mesh_sizes(): + return 256, + +def get_colls_estimations_file(): + return "colls32_8.table" + +def get_pg_names(): + return "0", + diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_2d_32layers.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_2d_32layers.py new file mode 100644 index 00000000..da109fc1 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_2d_32layers.py @@ -0,0 +1,11446 @@ +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, wait_tensor_1, mm, mm_2, getitem_80, getitem_81, getitem_86, getitem_87, reduce_scatter_tensor_1, mm_4, add_3, mm_7, mm_9, getitem_121, getitem_122, getitem_127, getitem_128, reduce_scatter_tensor_3, mm_11, add_7, mm_14, mm_16, getitem_162, getitem_163, getitem_168, getitem_169, reduce_scatter_tensor_5, mm_18, add_11, mm_21, mm_23, getitem_203, getitem_204, getitem_209, getitem_210, reduce_scatter_tensor_7, mm_25, add_15, mm_28, mm_30, getitem_244, getitem_245, getitem_250, getitem_251, reduce_scatter_tensor_9, mm_32, add_19, mm_35, mm_37, getitem_285, getitem_286, getitem_291, getitem_292, reduce_scatter_tensor_11, mm_39, add_23, mm_42, mm_44, getitem_326, getitem_327, getitem_332, getitem_333, reduce_scatter_tensor_13, mm_46, add_27, mm_49, mm_51, getitem_367, getitem_368, getitem_373, getitem_374, reduce_scatter_tensor_15, mm_53, add_31, mm_56, mm_58, getitem_408, getitem_409, getitem_414, getitem_415, reduce_scatter_tensor_17, mm_60, add_35, mm_63, mm_65, getitem_449, getitem_450, getitem_455, getitem_456, reduce_scatter_tensor_19, mm_67, add_39, mm_70, mm_72, getitem_490, getitem_491, getitem_496, getitem_497, reduce_scatter_tensor_21, mm_74, add_43, mm_77, mm_79, getitem_531, getitem_532, getitem_537, getitem_538, reduce_scatter_tensor_23, mm_81, add_47, mm_84, mm_86, getitem_572, getitem_573, getitem_578, getitem_579, reduce_scatter_tensor_25, mm_88, add_51, mm_91, mm_93, getitem_613, getitem_614, getitem_619, getitem_620, reduce_scatter_tensor_27, mm_95, add_55, mm_98, mm_100, getitem_654, getitem_655, getitem_660, getitem_661, reduce_scatter_tensor_29, mm_102, add_59, mm_105, mm_107, getitem_695, getitem_696, getitem_701, getitem_702, reduce_scatter_tensor_31, mm_109, add_63, mm_112, mm_114, getitem_736, getitem_737, getitem_742, getitem_743, reduce_scatter_tensor_33, mm_116, add_67, mm_119, mm_121, getitem_777, getitem_778, getitem_783, getitem_784, reduce_scatter_tensor_35, mm_123, add_71, mm_126, mm_128, getitem_818, getitem_819, getitem_824, getitem_825, reduce_scatter_tensor_37, mm_130, add_75, mm_133, mm_135, getitem_859, getitem_860, getitem_865, getitem_866, reduce_scatter_tensor_39, mm_137, add_79, mm_140, mm_142, getitem_900, getitem_901, getitem_906, getitem_907, reduce_scatter_tensor_41, mm_144, add_83, mm_147, mm_149, getitem_941, getitem_942, getitem_947, getitem_948, reduce_scatter_tensor_43, mm_151, add_87, mm_154, mm_156, getitem_982, getitem_983, getitem_988, getitem_989, reduce_scatter_tensor_45, mm_158, add_91, mm_161, mm_163, getitem_1023, getitem_1024, getitem_1029, getitem_1030, reduce_scatter_tensor_47, mm_165, add_95, mm_168, mm_170, getitem_1064, getitem_1065, getitem_1070, getitem_1071, reduce_scatter_tensor_49, mm_172, add_99, mm_175, mm_177, getitem_1105, getitem_1106, getitem_1111, getitem_1112, reduce_scatter_tensor_51, mm_179, add_103, mm_182, mm_184, getitem_1146, getitem_1147, getitem_1152, getitem_1153, reduce_scatter_tensor_53, mm_186, add_107, mm_189, mm_191, getitem_1187, getitem_1188, getitem_1193, getitem_1194, reduce_scatter_tensor_55, mm_193, add_111, mm_196, mm_198, getitem_1228, getitem_1229, getitem_1234, getitem_1235, reduce_scatter_tensor_57, mm_200, add_115, mm_203, mm_205, getitem_1269, getitem_1270, getitem_1275, getitem_1276, reduce_scatter_tensor_59, mm_207, add_119, mm_210, mm_212, getitem_1310, getitem_1311, getitem_1316, getitem_1317, reduce_scatter_tensor_61, mm_214, add_123, mm_217, mm_219, getitem_1351, getitem_1352, getitem_1357, getitem_1358, reduce_scatter_tensor_63, mm_221, reduce_scatter_tensor_64, rsqrt_64, view_2319, tangents_1): + view_2321 = torch.ops.aten.view.default(tangents_1, [16384, 16032]); tangents_1 = None + permute_353 = torch.ops.aten.permute.default(view_2321, [1, 0]) + mm_225 = torch.ops.aten.mm.default(permute_353, view_2319); permute_353 = view_2319 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16); primals_293 = None + all_gather_into_tensor_355 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 32, '0'); convert_element_type_1060 = None + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_355); all_gather_into_tensor_355 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_420, [1, 0]); wait_tensor_420 = None + permute_355 = torch.ops.aten.permute.default(permute_352, [1, 0]); permute_352 = None + mm_226 = torch.ops.aten.mm.default(view_2321, permute_355); view_2321 = permute_355 = None + view_2322 = torch.ops.aten.view.default(mm_226, [2, 8192, 4096]); mm_226 = None + convert_element_type_1067 = torch.ops.prims.convert_element_type.default(mm_225, torch.float32); mm_225 = None + reduce_scatter_tensor_65 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1067, 'avg', 32, '0'); convert_element_type_1067 = None + wait_tensor_421 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_65); reduce_scatter_tensor_65 = None + split_138 = torch.ops.aten.split.Tensor(view_2322, 1024, 1); view_2322 = None + getitem_1392 = split_138[0] + getitem_1393 = split_138[1] + getitem_1394 = split_138[2] + getitem_1395 = split_138[3] + getitem_1396 = split_138[4] + getitem_1397 = split_138[5] + getitem_1398 = split_138[6] + getitem_1399 = split_138[7]; split_138 = None + cat_130 = torch.ops.aten.cat.default([getitem_1392, getitem_1393, getitem_1394, getitem_1395, getitem_1396, getitem_1397, getitem_1398, getitem_1399]); getitem_1392 = getitem_1393 = getitem_1394 = getitem_1395 = getitem_1396 = getitem_1397 = getitem_1398 = getitem_1399 = None + reduce_scatter_tensor_66 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_130, 'sum', 8, '1'); cat_130 = None + wait_tensor_422 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_66); reduce_scatter_tensor_66 = None + convert_element_type_1068 = torch.ops.prims.convert_element_type.default(wait_tensor_422, torch.float32); wait_tensor_422 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16); primals_292 = None + all_gather_into_tensor_353 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 32, '0'); convert_element_type_1057 = None + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_353); all_gather_into_tensor_353 = None + convert_element_type_1070 = torch.ops.prims.convert_element_type.default(wait_tensor_418, torch.float32); wait_tensor_418 = None + mul_258 = torch.ops.aten.mul.Tensor(convert_element_type_1068, convert_element_type_1070); convert_element_type_1070 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63); reduce_scatter_tensor_63 = None + add_125 = torch.ops.aten.add.Tensor(add_123, wait_tensor_411); wait_tensor_411 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64); reduce_scatter_tensor_64 = None + add_127 = torch.ops.aten.add.Tensor(add_125, wait_tensor_417); wait_tensor_417 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_260 = torch.ops.aten.mul.Tensor(mul_256, mul_258) + sum_1 = torch.ops.aten.sum.dim_IntList(mul_260, [2], True); mul_260 = None + div = torch.ops.aten.div.Tensor(mul_256, 4096) + mul_261 = torch.ops.aten.mul.Tensor(div, sum_1); div = sum_1 = None + sub_1 = torch.ops.aten.sub.Tensor(mul_258, mul_261); mul_258 = mul_261 = None + mul_262 = torch.ops.aten.mul.Tensor(sub_1, rsqrt_64); sub_1 = rsqrt_64 = None + mul_263 = torch.ops.aten.mul.Tensor(convert_element_type_1068, mul_256); convert_element_type_1068 = mul_256 = None + sum_2 = torch.ops.aten.sum.dim_IntList(mul_263, [0, 1]); mul_263 = None + convert_element_type_1071 = torch.ops.prims.convert_element_type.default(mul_262, torch.bfloat16); mul_262 = None + convert_element_type_1072 = torch.ops.prims.convert_element_type.default(sum_2, torch.bfloat16); sum_2 = None + all_reduce = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1072, 'sum', '1'); convert_element_type_1072 = None + wait_tensor_423 = torch.ops._c10d_functional.wait_tensor.default(all_reduce); all_reduce = None + convert_element_type_1073 = torch.ops.prims.convert_element_type.default(wait_tensor_423, torch.float32); wait_tensor_423 = None + reduce_scatter_tensor_67 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1073, 'avg', 32, '0'); convert_element_type_1073 = None + wait_tensor_424 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_67); reduce_scatter_tensor_67 = None + all_gather_into_tensor_356 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1071, 8, '1') + wait_tensor_425 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_356); all_gather_into_tensor_356 = None + split_139 = torch.ops.aten.split.Tensor(wait_tensor_425, 2); wait_tensor_425 = None + getitem_1400 = split_139[0] + getitem_1401 = split_139[1] + getitem_1402 = split_139[2] + getitem_1403 = split_139[3] + getitem_1404 = split_139[4] + getitem_1405 = split_139[5] + getitem_1406 = split_139[6] + getitem_1407 = split_139[7]; split_139 = None + cat_131 = torch.ops.aten.cat.default([getitem_1400, getitem_1401, getitem_1402, getitem_1403, getitem_1404, getitem_1405, getitem_1406, getitem_1407], 1); getitem_1400 = getitem_1401 = getitem_1402 = getitem_1403 = getitem_1404 = getitem_1405 = getitem_1406 = getitem_1407 = None + view_2323 = torch.ops.aten.view.default(cat_131, [16384, 4096]); cat_131 = None + permute_357 = torch.ops.aten.permute.default(view_2323, [1, 0]) + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16); primals_288 = None + all_gather_into_tensor_348 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 32, '0'); convert_element_type_1043 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_348); all_gather_into_tensor_348 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32); add_125 = None + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_412) + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + all_gather_into_tensor_349 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1045, 8, '1'); convert_element_type_1045 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_349); all_gather_into_tensor_349 = None + split_135 = torch.ops.aten.split.Tensor(wait_tensor_413, 2); wait_tensor_413 = None + getitem_1368 = split_135[0] + getitem_1369 = split_135[1] + getitem_1370 = split_135[2] + getitem_1371 = split_135[3] + getitem_1372 = split_135[4] + getitem_1373 = split_135[5] + getitem_1374 = split_135[6] + getitem_1375 = split_135[7]; split_135 = None + cat_127 = torch.ops.aten.cat.default([getitem_1368, getitem_1369, getitem_1370, getitem_1371, getitem_1372, getitem_1373, getitem_1374, getitem_1375], 1); getitem_1368 = getitem_1369 = getitem_1370 = getitem_1371 = getitem_1372 = getitem_1373 = getitem_1374 = getitem_1375 = None + view_2292 = torch.ops.aten.view.default(cat_127, [16384, 4096]); cat_127 = None + view_2293 = torch.ops.aten.view.default(mm_221, [2, 8192, 1792]); mm_221 = None + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_2293, torch.float32); view_2293 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16); primals_290 = None + all_gather_into_tensor_351 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 32, '0'); convert_element_type_1051 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_351); all_gather_into_tensor_351 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_415, [1, 0]); wait_tensor_415 = None + mm_222 = torch.ops.aten.mm.default(view_2292, permute_350) + view_2300 = torch.ops.aten.view.default(mm_222, [2, 8192, 1792]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_2300) + view_2307 = torch.ops.aten.view.default(mul_255, [16384, 1792]); mul_255 = None + mm_227 = torch.ops.aten.mm.default(permute_357, view_2307); permute_357 = view_2307 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16); primals_291 = None + all_gather_into_tensor_352 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 32, '0'); convert_element_type_1054 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_352); all_gather_into_tensor_352 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_416, [1, 0]); wait_tensor_416 = None + permute_359 = torch.ops.aten.permute.default(permute_351, [1, 0]); permute_351 = None + mm_228 = torch.ops.aten.mm.default(view_2323, permute_359); view_2323 = permute_359 = None + view_2324 = torch.ops.aten.view.default(mm_228, [2, 8192, 1792]); mm_228 = None + convert_element_type_1078 = torch.ops.prims.convert_element_type.default(mm_227, torch.float32); mm_227 = None + reduce_scatter_tensor_68 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1078, 'avg', 32, '0'); convert_element_type_1078 = None + wait_tensor_426 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_68); reduce_scatter_tensor_68 = None + mul_264 = torch.ops.aten.mul.Tensor(view_2324, convert_element_type_1050); convert_element_type_1050 = None + mul_265 = torch.ops.aten.mul.Tensor(view_2324, view_2300); view_2324 = view_2300 = None + view_2325 = torch.ops.aten.view.default(mul_264, [16384, 1792]); mul_264 = None + permute_361 = torch.ops.aten.permute.default(view_2325, [1, 0]) + mm_229 = torch.ops.aten.mm.default(permute_361, view_2292); permute_361 = None + permute_363 = torch.ops.aten.permute.default(permute_350, [1, 0]); permute_350 = None + mm_230 = torch.ops.aten.mm.default(view_2325, permute_363); view_2325 = permute_363 = None + view_2326 = torch.ops.aten.view.default(mm_230, [2, 8192, 4096]); mm_230 = None + convert_element_type_1083 = torch.ops.prims.convert_element_type.default(mm_229, torch.float32); mm_229 = None + reduce_scatter_tensor_69 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1083, 'avg', 32, '0'); convert_element_type_1083 = None + wait_tensor_427 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_69); reduce_scatter_tensor_69 = None + convert_element_type_1084 = torch.ops.prims.convert_element_type.default(mul_265, torch.float32); mul_265 = None + neg = torch.ops.aten.neg.default(convert_element_type_1049) + exp = torch.ops.aten.exp.default(neg); neg = None + add_129 = torch.ops.aten.add.Tensor(exp, 1); exp = None + reciprocal = torch.ops.aten.reciprocal.default(add_129); add_129 = None + mul_266 = torch.ops.aten.mul.Tensor(reciprocal, 1); reciprocal = None + mul_267 = torch.ops.aten.mul.Tensor(convert_element_type_1084, mul_266); convert_element_type_1084 = None + sub_2 = torch.ops.aten.sub.Tensor(1, mul_266); mul_266 = None + mul_268 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sub_2); convert_element_type_1049 = sub_2 = None + add_130 = torch.ops.aten.add.Tensor(mul_268, 1); mul_268 = None + mul_269 = torch.ops.aten.mul.Tensor(mul_267, add_130); mul_267 = add_130 = None + convert_element_type_1086 = torch.ops.prims.convert_element_type.default(mul_269, torch.bfloat16); mul_269 = None + view_2327 = torch.ops.aten.view.default(convert_element_type_1086, [16384, 1792]); convert_element_type_1086 = None + permute_365 = torch.ops.aten.permute.default(view_2327, [1, 0]) + mm_231 = torch.ops.aten.mm.default(permute_365, view_2292); permute_365 = view_2292 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16); primals_289 = None + all_gather_into_tensor_350 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 32, '0'); convert_element_type_1046 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_350); all_gather_into_tensor_350 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_414, [1, 0]); wait_tensor_414 = None + permute_367 = torch.ops.aten.permute.default(permute_349, [1, 0]); permute_349 = None + mm_232 = torch.ops.aten.mm.default(view_2327, permute_367); view_2327 = permute_367 = None + view_2328 = torch.ops.aten.view.default(mm_232, [2, 8192, 4096]); mm_232 = None + add_131 = torch.ops.aten.add.Tensor(view_2326, view_2328); view_2326 = view_2328 = None + convert_element_type_1091 = torch.ops.prims.convert_element_type.default(mm_231, torch.float32); mm_231 = None + reduce_scatter_tensor_70 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1091, 'avg', 32, '0'); convert_element_type_1091 = None + wait_tensor_428 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_70); reduce_scatter_tensor_70 = None + split_140 = torch.ops.aten.split.Tensor(add_131, 1024, 1); add_131 = None + getitem_1408 = split_140[0] + getitem_1409 = split_140[1] + getitem_1410 = split_140[2] + getitem_1411 = split_140[3] + getitem_1412 = split_140[4] + getitem_1413 = split_140[5] + getitem_1414 = split_140[6] + getitem_1415 = split_140[7]; split_140 = None + cat_132 = torch.ops.aten.cat.default([getitem_1408, getitem_1409, getitem_1410, getitem_1411, getitem_1412, getitem_1413, getitem_1414, getitem_1415]); getitem_1408 = getitem_1409 = getitem_1410 = getitem_1411 = getitem_1412 = getitem_1413 = getitem_1414 = getitem_1415 = None + reduce_scatter_tensor_71 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_132, 'sum', 8, '1'); cat_132 = None + wait_tensor_429 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_71); reduce_scatter_tensor_71 = None + convert_element_type_1092 = torch.ops.prims.convert_element_type.default(wait_tensor_429, torch.float32); wait_tensor_429 = None + convert_element_type_1094 = torch.ops.prims.convert_element_type.default(wait_tensor_412, torch.float32); wait_tensor_412 = None + mul_270 = torch.ops.aten.mul.Tensor(convert_element_type_1092, convert_element_type_1094); convert_element_type_1094 = None + mul_272 = torch.ops.aten.mul.Tensor(mul_252, mul_270) + sum_3 = torch.ops.aten.sum.dim_IntList(mul_272, [2], True); mul_272 = None + div_1 = torch.ops.aten.div.Tensor(mul_252, 4096) + mul_273 = torch.ops.aten.mul.Tensor(div_1, sum_3); div_1 = sum_3 = None + sub_3 = torch.ops.aten.sub.Tensor(mul_270, mul_273); mul_270 = mul_273 = None + mul_274 = torch.ops.aten.mul.Tensor(sub_3, rsqrt_63); sub_3 = rsqrt_63 = None + mul_275 = torch.ops.aten.mul.Tensor(convert_element_type_1092, mul_252); convert_element_type_1092 = mul_252 = None + sum_4 = torch.ops.aten.sum.dim_IntList(mul_275, [0, 1]); mul_275 = None + convert_element_type_1095 = torch.ops.prims.convert_element_type.default(mul_274, torch.bfloat16); mul_274 = None + convert_element_type_1096 = torch.ops.prims.convert_element_type.default(sum_4, torch.bfloat16); sum_4 = None + all_reduce_1 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1096, 'sum', '1'); convert_element_type_1096 = None + wait_tensor_430 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_1); all_reduce_1 = None + convert_element_type_1097 = torch.ops.prims.convert_element_type.default(wait_tensor_430, torch.float32); wait_tensor_430 = None + reduce_scatter_tensor_72 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1097, 'avg', 32, '0'); convert_element_type_1097 = None + wait_tensor_431 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_72); reduce_scatter_tensor_72 = None + add_132 = torch.ops.aten.add.Tensor(convert_element_type_1071, convert_element_type_1095); convert_element_type_1071 = convert_element_type_1095 = None + all_gather_into_tensor_357 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_132, 8, '1') + wait_tensor_432 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_357); all_gather_into_tensor_357 = None + split_141 = torch.ops.aten.split.Tensor(wait_tensor_432, 2); wait_tensor_432 = None + getitem_1416 = split_141[0] + getitem_1417 = split_141[1] + getitem_1418 = split_141[2] + getitem_1419 = split_141[3] + getitem_1420 = split_141[4] + getitem_1421 = split_141[5] + getitem_1422 = split_141[6] + getitem_1423 = split_141[7]; split_141 = None + cat_133 = torch.ops.aten.cat.default([getitem_1416, getitem_1417, getitem_1418, getitem_1419, getitem_1420, getitem_1421, getitem_1422, getitem_1423], 1); getitem_1416 = getitem_1417 = getitem_1418 = getitem_1419 = getitem_1420 = getitem_1421 = getitem_1422 = getitem_1423 = None + view_2329 = torch.ops.aten.view.default(cat_133, [16384, 4096]); cat_133 = None + permute_369 = torch.ops.aten.permute.default(view_2329, [1, 0]) + permute_347 = torch.ops.aten.permute.default(getitem_1351, [0, 2, 1, 3]) + view_2274 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + view_2280 = torch.ops.aten.view.default(view_2274, [16384, 512]); view_2274 = None + mm_233 = torch.ops.aten.mm.default(permute_369, view_2280); permute_369 = view_2280 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16); primals_287 = None + all_gather_into_tensor_347 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 32, '0'); convert_element_type_1040 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_347); all_gather_into_tensor_347 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_410, [1, 0]); wait_tensor_410 = None + permute_371 = torch.ops.aten.permute.default(permute_348, [1, 0]); permute_348 = None + mm_234 = torch.ops.aten.mm.default(view_2329, permute_371); view_2329 = permute_371 = None + view_2330 = torch.ops.aten.view.default(mm_234, [2, 8192, 512]); mm_234 = None + convert_element_type_1102 = torch.ops.prims.convert_element_type.default(mm_233, torch.float32); mm_233 = None + reduce_scatter_tensor_73 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1102, 'avg', 32, '0'); convert_element_type_1102 = None + wait_tensor_433 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_73); reduce_scatter_tensor_73 = None + view_2331 = torch.ops.aten.view.default(view_2330, [2, 8192, 4, 128]); view_2330 = None + permute_373 = torch.ops.aten.permute.default(view_2331, [0, 2, 1, 3]); view_2331 = None + view_37 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]); primals_3 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16); primals_283 = None + all_gather_into_tensor_342 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 32, '0'); convert_element_type_1024 = None + wait_tensor_405 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_342); all_gather_into_tensor_342 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32); add_123 = None + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_405) + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + all_gather_into_tensor_343 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1026, 8, '1'); convert_element_type_1026 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_343); all_gather_into_tensor_343 = None + split_133 = torch.ops.aten.split.Tensor(wait_tensor_406, 2); wait_tensor_406 = None + getitem_1343 = split_133[0] + getitem_1344 = split_133[1] + getitem_1345 = split_133[2] + getitem_1346 = split_133[3] + getitem_1347 = split_133[4] + getitem_1348 = split_133[5] + getitem_1349 = split_133[6] + getitem_1350 = split_133[7]; split_133 = None + cat_125 = torch.ops.aten.cat.default([getitem_1343, getitem_1344, getitem_1345, getitem_1346, getitem_1347, getitem_1348, getitem_1349, getitem_1350], 1); getitem_1343 = getitem_1344 = getitem_1345 = getitem_1346 = getitem_1347 = getitem_1348 = getitem_1349 = getitem_1350 = None + view_2247 = torch.ops.aten.view.default(cat_125, [16384, 4096]); cat_125 = None + view_2248 = torch.ops.aten.view.default(mm_217, [2, 8192, 512]); mm_217 = None + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16); primals_285 = None + all_gather_into_tensor_345 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 32, '0'); convert_element_type_1030 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_345); all_gather_into_tensor_345 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_408, [1, 0]); wait_tensor_408 = None + mm_218 = torch.ops.aten.mm.default(view_2247, permute_342) + view_2255 = torch.ops.aten.view.default(mm_218, [2, 8192, 128]); mm_218 = None + view_2262 = torch.ops.aten.view.default(mm_219, [2, 8192, 128]); mm_219 = None + view_2264 = torch.ops.aten.view.default(view_2248, [2, 8192, -1, 128]); view_2248 = None + view_2265 = torch.ops.aten.view.default(view_2255, [2, 8192, -1, 128]); view_2255 = None + view_2266 = torch.ops.aten.view.default(view_2262, [2, 8192, -1, 128]); view_2262 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_2264, torch.float32); view_2264 = None + view_2267 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 4, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_2267); view_2267 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_2265, torch.float32); view_2265 = None + view_2268 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 1, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_2268); view_2268 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_37); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_2270 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 4, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_37); view_as_complex_63 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_2271 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 1, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_2270, torch.bfloat16); view_2270 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_2271, torch.bfloat16); view_2271 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 1, 4, 128]); unsqueeze_62 = None + view_2272 = torch.ops.aten.view.default(expand_62, [2, 8192, 4, 128]); expand_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_2266, 3); view_2266 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 1, 4, 128]); unsqueeze_63 = None + view_2273 = torch.ops.aten.view.default(expand_63, [2, 8192, 4, 128]); expand_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_2272, [0, 2, 1, 3]); view_2272 = None + permute_346 = torch.ops.aten.permute.default(view_2273, [0, 2, 1, 3]); view_2273 = None + _scaled_dot_product_cudnn_attention_backward = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_373, permute_344, permute_345, permute_346, getitem_1351, getitem_1352, getitem_1357, getitem_1358, None, None, None, 8192, 8192, 0.0, True); permute_373 = permute_344 = permute_345 = permute_346 = getitem_1351 = getitem_1352 = getitem_1357 = getitem_1358 = None + getitem_1424 = _scaled_dot_product_cudnn_attention_backward[0] + getitem_1425 = _scaled_dot_product_cudnn_attention_backward[1] + getitem_1426 = _scaled_dot_product_cudnn_attention_backward[2]; _scaled_dot_product_cudnn_attention_backward = None + permute_374 = torch.ops.aten.permute.default(getitem_1426, [0, 2, 1, 3]); getitem_1426 = None + permute_375 = torch.ops.aten.permute.default(getitem_1425, [0, 2, 1, 3]); getitem_1425 = None + permute_376 = torch.ops.aten.permute.default(getitem_1424, [0, 2, 1, 3]); getitem_1424 = None + view_2332 = torch.ops.aten.view.default(permute_374, [2, 8192, 1, 4, 128]); permute_374 = None + sum_5 = torch.ops.aten.sum.dim_IntList(view_2332, [3], True); view_2332 = None + squeeze = torch.ops.aten.squeeze.dim(sum_5, 3); sum_5 = None + view_2333 = torch.ops.aten.view.default(permute_375, [2, 8192, 1, 4, 128]); permute_375 = None + sum_6 = torch.ops.aten.sum.dim_IntList(view_2333, [3], True); view_2333 = None + squeeze_1 = torch.ops.aten.squeeze.dim(sum_6, 3); sum_6 = None + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(squeeze_1, torch.float32); squeeze_1 = None + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(permute_376, torch.float32); permute_376 = None + view_2334 = torch.ops.aten.view.default(convert_element_type_1103, [2, 8192, 1, 64, 2]); convert_element_type_1103 = None + view_as_complex_64 = torch.ops.aten.view_as_complex.default(view_2334); view_2334 = None + _conj = torch.ops.aten._conj.default(view_37) + mul_276 = torch.ops.aten.mul.Tensor(view_as_complex_64, _conj); view_as_complex_64 = None + view_2335 = torch.ops.aten.view.default(convert_element_type_1104, [2, 8192, 4, 64, 2]); convert_element_type_1104 = None + view_as_complex_65 = torch.ops.aten.view_as_complex.default(view_2335); view_2335 = None + mul_277 = torch.ops.aten.mul.Tensor(view_as_complex_65, _conj); view_as_complex_65 = None + view_as_real_64 = torch.ops.aten.view_as_real.default(mul_276); mul_276 = None + view_2336 = torch.ops.aten.view.default(view_as_real_64, [2, 8192, 1, 128]); view_as_real_64 = None + convert_element_type_1105 = torch.ops.prims.convert_element_type.default(view_2336, torch.bfloat16); view_2336 = None + view_as_real_65 = torch.ops.aten.view_as_real.default(mul_277); mul_277 = None + view_2337 = torch.ops.aten.view.default(view_as_real_65, [2, 8192, 4, 128]); view_as_real_65 = None + convert_element_type_1106 = torch.ops.prims.convert_element_type.default(view_2337, torch.bfloat16); view_2337 = None + view_2338 = torch.ops.aten.view.default(squeeze, [2, 8192, 128]); squeeze = None + view_2339 = torch.ops.aten.view.default(convert_element_type_1105, [2, 8192, 128]); convert_element_type_1105 = None + view_2340 = torch.ops.aten.view.default(convert_element_type_1106, [2, 8192, 512]); convert_element_type_1106 = None + view_2341 = torch.ops.aten.view.default(view_2338, [16384, 128]); view_2338 = None + permute_377 = torch.ops.aten.permute.default(view_2341, [1, 0]) + mm_235 = torch.ops.aten.mm.default(permute_377, view_2247); permute_377 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16); primals_286 = None + all_gather_into_tensor_346 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 32, '0'); convert_element_type_1033 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_346); all_gather_into_tensor_346 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_409, [1, 0]); wait_tensor_409 = None + permute_379 = torch.ops.aten.permute.default(permute_343, [1, 0]); permute_343 = None + mm_236 = torch.ops.aten.mm.default(view_2341, permute_379); view_2341 = permute_379 = None + view_2342 = torch.ops.aten.view.default(mm_236, [2, 8192, 4096]); mm_236 = None + convert_element_type_1111 = torch.ops.prims.convert_element_type.default(mm_235, torch.float32); mm_235 = None + reduce_scatter_tensor_74 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1111, 'avg', 32, '0'); convert_element_type_1111 = None + wait_tensor_434 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_74); reduce_scatter_tensor_74 = None + view_2343 = torch.ops.aten.view.default(view_2339, [16384, 128]); view_2339 = None + permute_381 = torch.ops.aten.permute.default(view_2343, [1, 0]) + mm_237 = torch.ops.aten.mm.default(permute_381, view_2247); permute_381 = None + permute_383 = torch.ops.aten.permute.default(permute_342, [1, 0]); permute_342 = None + mm_238 = torch.ops.aten.mm.default(view_2343, permute_383); view_2343 = permute_383 = None + view_2344 = torch.ops.aten.view.default(mm_238, [2, 8192, 4096]); mm_238 = None + add_133 = torch.ops.aten.add.Tensor(view_2342, view_2344); view_2342 = view_2344 = None + convert_element_type_1116 = torch.ops.prims.convert_element_type.default(mm_237, torch.float32); mm_237 = None + reduce_scatter_tensor_75 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1116, 'avg', 32, '0'); convert_element_type_1116 = None + wait_tensor_435 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_75); reduce_scatter_tensor_75 = None + view_2345 = torch.ops.aten.view.default(view_2340, [16384, 512]); view_2340 = None + permute_385 = torch.ops.aten.permute.default(view_2345, [1, 0]) + mm_239 = torch.ops.aten.mm.default(permute_385, view_2247); permute_385 = view_2247 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16); primals_284 = None + all_gather_into_tensor_344 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 32, '0'); convert_element_type_1027 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_344); all_gather_into_tensor_344 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_407, [1, 0]); wait_tensor_407 = None + permute_387 = torch.ops.aten.permute.default(permute_341, [1, 0]); permute_341 = None + mm_240 = torch.ops.aten.mm.default(view_2345, permute_387); view_2345 = permute_387 = None + view_2346 = torch.ops.aten.view.default(mm_240, [2, 8192, 4096]); mm_240 = None + add_134 = torch.ops.aten.add.Tensor(add_133, view_2346); add_133 = view_2346 = None + convert_element_type_1121 = torch.ops.prims.convert_element_type.default(mm_239, torch.float32); mm_239 = None + reduce_scatter_tensor_76 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1121, 'avg', 32, '0'); convert_element_type_1121 = None + wait_tensor_436 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_76); reduce_scatter_tensor_76 = None + split_142 = torch.ops.aten.split.Tensor(add_134, 1024, 1); add_134 = None + getitem_1427 = split_142[0] + getitem_1428 = split_142[1] + getitem_1429 = split_142[2] + getitem_1430 = split_142[3] + getitem_1431 = split_142[4] + getitem_1432 = split_142[5] + getitem_1433 = split_142[6] + getitem_1434 = split_142[7]; split_142 = None + cat_134 = torch.ops.aten.cat.default([getitem_1427, getitem_1428, getitem_1429, getitem_1430, getitem_1431, getitem_1432, getitem_1433, getitem_1434]); getitem_1427 = getitem_1428 = getitem_1429 = getitem_1430 = getitem_1431 = getitem_1432 = getitem_1433 = getitem_1434 = None + reduce_scatter_tensor_77 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_134, 'sum', 8, '1'); cat_134 = None + wait_tensor_437 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_77); reduce_scatter_tensor_77 = None + convert_element_type_1122 = torch.ops.prims.convert_element_type.default(wait_tensor_437, torch.float32); wait_tensor_437 = None + convert_element_type_1124 = torch.ops.prims.convert_element_type.default(wait_tensor_405, torch.float32); wait_tensor_405 = None + mul_278 = torch.ops.aten.mul.Tensor(convert_element_type_1122, convert_element_type_1124); convert_element_type_1124 = None + mul_280 = torch.ops.aten.mul.Tensor(mul_248, mul_278) + sum_7 = torch.ops.aten.sum.dim_IntList(mul_280, [2], True); mul_280 = None + div_2 = torch.ops.aten.div.Tensor(mul_248, 4096) + mul_281 = torch.ops.aten.mul.Tensor(div_2, sum_7); div_2 = sum_7 = None + sub_4 = torch.ops.aten.sub.Tensor(mul_278, mul_281); mul_278 = mul_281 = None + mul_282 = torch.ops.aten.mul.Tensor(sub_4, rsqrt_62); sub_4 = rsqrt_62 = None + mul_283 = torch.ops.aten.mul.Tensor(convert_element_type_1122, mul_248); convert_element_type_1122 = mul_248 = None + sum_8 = torch.ops.aten.sum.dim_IntList(mul_283, [0, 1]); mul_283 = None + convert_element_type_1125 = torch.ops.prims.convert_element_type.default(mul_282, torch.bfloat16); mul_282 = None + convert_element_type_1126 = torch.ops.prims.convert_element_type.default(sum_8, torch.bfloat16); sum_8 = None + all_reduce_2 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1126, 'sum', '1'); convert_element_type_1126 = None + wait_tensor_438 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_2); all_reduce_2 = None + convert_element_type_1127 = torch.ops.prims.convert_element_type.default(wait_tensor_438, torch.float32); wait_tensor_438 = None + reduce_scatter_tensor_78 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1127, 'avg', 32, '0'); convert_element_type_1127 = None + wait_tensor_439 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_78); reduce_scatter_tensor_78 = None + add_135 = torch.ops.aten.add.Tensor(add_132, convert_element_type_1125); add_132 = convert_element_type_1125 = None + all_gather_into_tensor_358 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_135, 8, '1') + wait_tensor_440 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_358); all_gather_into_tensor_358 = None + split_143 = torch.ops.aten.split.Tensor(wait_tensor_440, 2); wait_tensor_440 = None + getitem_1435 = split_143[0] + getitem_1436 = split_143[1] + getitem_1437 = split_143[2] + getitem_1438 = split_143[3] + getitem_1439 = split_143[4] + getitem_1440 = split_143[5] + getitem_1441 = split_143[6] + getitem_1442 = split_143[7]; split_143 = None + cat_135 = torch.ops.aten.cat.default([getitem_1435, getitem_1436, getitem_1437, getitem_1438, getitem_1439, getitem_1440, getitem_1441, getitem_1442], 1); getitem_1435 = getitem_1436 = getitem_1437 = getitem_1438 = getitem_1439 = getitem_1440 = getitem_1441 = getitem_1442 = None + view_2347 = torch.ops.aten.view.default(cat_135, [16384, 4096]); cat_135 = None + permute_389 = torch.ops.aten.permute.default(view_2347, [1, 0]) + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61); reduce_scatter_tensor_61 = None + add_121 = torch.ops.aten.add.Tensor(add_119, wait_tensor_398); wait_tensor_398 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16); primals_279 = None + all_gather_into_tensor_337 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 32, '0'); convert_element_type_1010 = None + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_337); all_gather_into_tensor_337 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32); add_121 = None + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_399) + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + all_gather_into_tensor_338 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1012, 8, '1'); convert_element_type_1012 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_338); all_gather_into_tensor_338 = None + split_131 = torch.ops.aten.split.Tensor(wait_tensor_400, 2); wait_tensor_400 = None + getitem_1327 = split_131[0] + getitem_1328 = split_131[1] + getitem_1329 = split_131[2] + getitem_1330 = split_131[3] + getitem_1331 = split_131[4] + getitem_1332 = split_131[5] + getitem_1333 = split_131[6] + getitem_1334 = split_131[7]; split_131 = None + cat_123 = torch.ops.aten.cat.default([getitem_1327, getitem_1328, getitem_1329, getitem_1330, getitem_1331, getitem_1332, getitem_1333, getitem_1334], 1); getitem_1327 = getitem_1328 = getitem_1329 = getitem_1330 = getitem_1331 = getitem_1332 = getitem_1333 = getitem_1334 = None + view_2220 = torch.ops.aten.view.default(cat_123, [16384, 4096]); cat_123 = None + view_2221 = torch.ops.aten.view.default(mm_214, [2, 8192, 1792]); mm_214 = None + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_2221, torch.float32); view_2221 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16); primals_281 = None + all_gather_into_tensor_340 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 32, '0'); convert_element_type_1018 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_340); all_gather_into_tensor_340 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_402, [1, 0]); wait_tensor_402 = None + mm_215 = torch.ops.aten.mm.default(view_2220, permute_339) + view_2228 = torch.ops.aten.view.default(mm_215, [2, 8192, 1792]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_2228) + view_2235 = torch.ops.aten.view.default(mul_247, [16384, 1792]); mul_247 = None + mm_241 = torch.ops.aten.mm.default(permute_389, view_2235); permute_389 = view_2235 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16); primals_282 = None + all_gather_into_tensor_341 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 32, '0'); convert_element_type_1021 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_341); all_gather_into_tensor_341 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_403, [1, 0]); wait_tensor_403 = None + permute_391 = torch.ops.aten.permute.default(permute_340, [1, 0]); permute_340 = None + mm_242 = torch.ops.aten.mm.default(view_2347, permute_391); view_2347 = permute_391 = None + view_2348 = torch.ops.aten.view.default(mm_242, [2, 8192, 1792]); mm_242 = None + convert_element_type_1132 = torch.ops.prims.convert_element_type.default(mm_241, torch.float32); mm_241 = None + reduce_scatter_tensor_79 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1132, 'avg', 32, '0'); convert_element_type_1132 = None + wait_tensor_441 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_79); reduce_scatter_tensor_79 = None + mul_284 = torch.ops.aten.mul.Tensor(view_2348, convert_element_type_1017); convert_element_type_1017 = None + mul_285 = torch.ops.aten.mul.Tensor(view_2348, view_2228); view_2348 = view_2228 = None + view_2349 = torch.ops.aten.view.default(mul_284, [16384, 1792]); mul_284 = None + permute_393 = torch.ops.aten.permute.default(view_2349, [1, 0]) + mm_243 = torch.ops.aten.mm.default(permute_393, view_2220); permute_393 = None + permute_395 = torch.ops.aten.permute.default(permute_339, [1, 0]); permute_339 = None + mm_244 = torch.ops.aten.mm.default(view_2349, permute_395); view_2349 = permute_395 = None + view_2350 = torch.ops.aten.view.default(mm_244, [2, 8192, 4096]); mm_244 = None + convert_element_type_1137 = torch.ops.prims.convert_element_type.default(mm_243, torch.float32); mm_243 = None + reduce_scatter_tensor_80 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1137, 'avg', 32, '0'); convert_element_type_1137 = None + wait_tensor_442 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_80); reduce_scatter_tensor_80 = None + convert_element_type_1138 = torch.ops.prims.convert_element_type.default(mul_285, torch.float32); mul_285 = None + neg_1 = torch.ops.aten.neg.default(convert_element_type_1016) + exp_1 = torch.ops.aten.exp.default(neg_1); neg_1 = None + add_136 = torch.ops.aten.add.Tensor(exp_1, 1); exp_1 = None + reciprocal_1 = torch.ops.aten.reciprocal.default(add_136); add_136 = None + mul_286 = torch.ops.aten.mul.Tensor(reciprocal_1, 1); reciprocal_1 = None + mul_287 = torch.ops.aten.mul.Tensor(convert_element_type_1138, mul_286); convert_element_type_1138 = None + sub_5 = torch.ops.aten.sub.Tensor(1, mul_286); mul_286 = None + mul_288 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sub_5); convert_element_type_1016 = sub_5 = None + add_137 = torch.ops.aten.add.Tensor(mul_288, 1); mul_288 = None + mul_289 = torch.ops.aten.mul.Tensor(mul_287, add_137); mul_287 = add_137 = None + convert_element_type_1140 = torch.ops.prims.convert_element_type.default(mul_289, torch.bfloat16); mul_289 = None + view_2351 = torch.ops.aten.view.default(convert_element_type_1140, [16384, 1792]); convert_element_type_1140 = None + permute_397 = torch.ops.aten.permute.default(view_2351, [1, 0]) + mm_245 = torch.ops.aten.mm.default(permute_397, view_2220); permute_397 = view_2220 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16); primals_280 = None + all_gather_into_tensor_339 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 32, '0'); convert_element_type_1013 = None + wait_tensor_401 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_339); all_gather_into_tensor_339 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_401, [1, 0]); wait_tensor_401 = None + permute_399 = torch.ops.aten.permute.default(permute_338, [1, 0]); permute_338 = None + mm_246 = torch.ops.aten.mm.default(view_2351, permute_399); view_2351 = permute_399 = None + view_2352 = torch.ops.aten.view.default(mm_246, [2, 8192, 4096]); mm_246 = None + add_138 = torch.ops.aten.add.Tensor(view_2350, view_2352); view_2350 = view_2352 = None + convert_element_type_1145 = torch.ops.prims.convert_element_type.default(mm_245, torch.float32); mm_245 = None + reduce_scatter_tensor_81 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1145, 'avg', 32, '0'); convert_element_type_1145 = None + wait_tensor_443 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_81); reduce_scatter_tensor_81 = None + split_144 = torch.ops.aten.split.Tensor(add_138, 1024, 1); add_138 = None + getitem_1443 = split_144[0] + getitem_1444 = split_144[1] + getitem_1445 = split_144[2] + getitem_1446 = split_144[3] + getitem_1447 = split_144[4] + getitem_1448 = split_144[5] + getitem_1449 = split_144[6] + getitem_1450 = split_144[7]; split_144 = None + cat_136 = torch.ops.aten.cat.default([getitem_1443, getitem_1444, getitem_1445, getitem_1446, getitem_1447, getitem_1448, getitem_1449, getitem_1450]); getitem_1443 = getitem_1444 = getitem_1445 = getitem_1446 = getitem_1447 = getitem_1448 = getitem_1449 = getitem_1450 = None + reduce_scatter_tensor_82 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_136, 'sum', 8, '1'); cat_136 = None + wait_tensor_444 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_82); reduce_scatter_tensor_82 = None + convert_element_type_1146 = torch.ops.prims.convert_element_type.default(wait_tensor_444, torch.float32); wait_tensor_444 = None + convert_element_type_1148 = torch.ops.prims.convert_element_type.default(wait_tensor_399, torch.float32); wait_tensor_399 = None + mul_290 = torch.ops.aten.mul.Tensor(convert_element_type_1146, convert_element_type_1148); convert_element_type_1148 = None + mul_292 = torch.ops.aten.mul.Tensor(mul_244, mul_290) + sum_9 = torch.ops.aten.sum.dim_IntList(mul_292, [2], True); mul_292 = None + div_3 = torch.ops.aten.div.Tensor(mul_244, 4096) + mul_293 = torch.ops.aten.mul.Tensor(div_3, sum_9); div_3 = sum_9 = None + sub_6 = torch.ops.aten.sub.Tensor(mul_290, mul_293); mul_290 = mul_293 = None + mul_294 = torch.ops.aten.mul.Tensor(sub_6, rsqrt_61); sub_6 = rsqrt_61 = None + mul_295 = torch.ops.aten.mul.Tensor(convert_element_type_1146, mul_244); convert_element_type_1146 = mul_244 = None + sum_10 = torch.ops.aten.sum.dim_IntList(mul_295, [0, 1]); mul_295 = None + convert_element_type_1149 = torch.ops.prims.convert_element_type.default(mul_294, torch.bfloat16); mul_294 = None + convert_element_type_1150 = torch.ops.prims.convert_element_type.default(sum_10, torch.bfloat16); sum_10 = None + all_reduce_3 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1150, 'sum', '1'); convert_element_type_1150 = None + wait_tensor_445 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_3); all_reduce_3 = None + convert_element_type_1151 = torch.ops.prims.convert_element_type.default(wait_tensor_445, torch.float32); wait_tensor_445 = None + reduce_scatter_tensor_83 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1151, 'avg', 32, '0'); convert_element_type_1151 = None + wait_tensor_446 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_83); reduce_scatter_tensor_83 = None + add_139 = torch.ops.aten.add.Tensor(add_135, convert_element_type_1149); add_135 = convert_element_type_1149 = None + all_gather_into_tensor_359 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_139, 8, '1') + wait_tensor_447 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_359); all_gather_into_tensor_359 = None + split_145 = torch.ops.aten.split.Tensor(wait_tensor_447, 2); wait_tensor_447 = None + getitem_1451 = split_145[0] + getitem_1452 = split_145[1] + getitem_1453 = split_145[2] + getitem_1454 = split_145[3] + getitem_1455 = split_145[4] + getitem_1456 = split_145[5] + getitem_1457 = split_145[6] + getitem_1458 = split_145[7]; split_145 = None + cat_137 = torch.ops.aten.cat.default([getitem_1451, getitem_1452, getitem_1453, getitem_1454, getitem_1455, getitem_1456, getitem_1457, getitem_1458], 1); getitem_1451 = getitem_1452 = getitem_1453 = getitem_1454 = getitem_1455 = getitem_1456 = getitem_1457 = getitem_1458 = None + view_2353 = torch.ops.aten.view.default(cat_137, [16384, 4096]); cat_137 = None + permute_401 = torch.ops.aten.permute.default(view_2353, [1, 0]) + permute_336 = torch.ops.aten.permute.default(getitem_1310, [0, 2, 1, 3]) + view_2202 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + view_2208 = torch.ops.aten.view.default(view_2202, [16384, 512]); view_2202 = None + mm_247 = torch.ops.aten.mm.default(permute_401, view_2208); permute_401 = view_2208 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16); primals_278 = None + all_gather_into_tensor_336 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 32, '0'); convert_element_type_1007 = None + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_336); all_gather_into_tensor_336 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_397, [1, 0]); wait_tensor_397 = None + permute_403 = torch.ops.aten.permute.default(permute_337, [1, 0]); permute_337 = None + mm_248 = torch.ops.aten.mm.default(view_2353, permute_403); view_2353 = permute_403 = None + view_2354 = torch.ops.aten.view.default(mm_248, [2, 8192, 512]); mm_248 = None + convert_element_type_1156 = torch.ops.prims.convert_element_type.default(mm_247, torch.float32); mm_247 = None + reduce_scatter_tensor_84 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1156, 'avg', 32, '0'); convert_element_type_1156 = None + wait_tensor_448 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_84); reduce_scatter_tensor_84 = None + view_2355 = torch.ops.aten.view.default(view_2354, [2, 8192, 4, 128]); view_2354 = None + permute_405 = torch.ops.aten.permute.default(view_2355, [0, 2, 1, 3]); view_2355 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16); primals_274 = None + all_gather_into_tensor_331 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 32, '0'); convert_element_type_991 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_331); all_gather_into_tensor_331 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32); add_119 = None + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_392) + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + all_gather_into_tensor_332 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_993, 8, '1'); convert_element_type_993 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_332); all_gather_into_tensor_332 = None + split_129 = torch.ops.aten.split.Tensor(wait_tensor_393, 2); wait_tensor_393 = None + getitem_1302 = split_129[0] + getitem_1303 = split_129[1] + getitem_1304 = split_129[2] + getitem_1305 = split_129[3] + getitem_1306 = split_129[4] + getitem_1307 = split_129[5] + getitem_1308 = split_129[6] + getitem_1309 = split_129[7]; split_129 = None + cat_121 = torch.ops.aten.cat.default([getitem_1302, getitem_1303, getitem_1304, getitem_1305, getitem_1306, getitem_1307, getitem_1308, getitem_1309], 1); getitem_1302 = getitem_1303 = getitem_1304 = getitem_1305 = getitem_1306 = getitem_1307 = getitem_1308 = getitem_1309 = None + view_2175 = torch.ops.aten.view.default(cat_121, [16384, 4096]); cat_121 = None + view_2176 = torch.ops.aten.view.default(mm_210, [2, 8192, 512]); mm_210 = None + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16); primals_276 = None + all_gather_into_tensor_334 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 32, '0'); convert_element_type_997 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_334); all_gather_into_tensor_334 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_395, [1, 0]); wait_tensor_395 = None + mm_211 = torch.ops.aten.mm.default(view_2175, permute_331) + view_2183 = torch.ops.aten.view.default(mm_211, [2, 8192, 128]); mm_211 = None + view_2190 = torch.ops.aten.view.default(mm_212, [2, 8192, 128]); mm_212 = None + view_2192 = torch.ops.aten.view.default(view_2176, [2, 8192, -1, 128]); view_2176 = None + view_2193 = torch.ops.aten.view.default(view_2183, [2, 8192, -1, 128]); view_2183 = None + view_2194 = torch.ops.aten.view.default(view_2190, [2, 8192, -1, 128]); view_2190 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_2192, torch.float32); view_2192 = None + view_2195 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 4, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_2195); view_2195 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_2193, torch.float32); view_2193 = None + view_2196 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 1, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_2196); view_2196 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_37); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_2198 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 4, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_37); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_2199 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 1, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_2198, torch.bfloat16); view_2198 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_2199, torch.bfloat16); view_2199 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 1, 4, 128]); unsqueeze_60 = None + view_2200 = torch.ops.aten.view.default(expand_60, [2, 8192, 4, 128]); expand_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_2194, 3); view_2194 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 1, 4, 128]); unsqueeze_61 = None + view_2201 = torch.ops.aten.view.default(expand_61, [2, 8192, 4, 128]); expand_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_2200, [0, 2, 1, 3]); view_2200 = None + permute_335 = torch.ops.aten.permute.default(view_2201, [0, 2, 1, 3]); view_2201 = None + _scaled_dot_product_cudnn_attention_backward_1 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_405, permute_333, permute_334, permute_335, getitem_1310, getitem_1311, getitem_1316, getitem_1317, None, None, None, 8192, 8192, 0.0, True); permute_405 = permute_333 = permute_334 = permute_335 = getitem_1310 = getitem_1311 = getitem_1316 = getitem_1317 = None + getitem_1459 = _scaled_dot_product_cudnn_attention_backward_1[0] + getitem_1460 = _scaled_dot_product_cudnn_attention_backward_1[1] + getitem_1461 = _scaled_dot_product_cudnn_attention_backward_1[2]; _scaled_dot_product_cudnn_attention_backward_1 = None + permute_406 = torch.ops.aten.permute.default(getitem_1461, [0, 2, 1, 3]); getitem_1461 = None + permute_407 = torch.ops.aten.permute.default(getitem_1460, [0, 2, 1, 3]); getitem_1460 = None + permute_408 = torch.ops.aten.permute.default(getitem_1459, [0, 2, 1, 3]); getitem_1459 = None + view_2356 = torch.ops.aten.view.default(permute_406, [2, 8192, 1, 4, 128]); permute_406 = None + sum_11 = torch.ops.aten.sum.dim_IntList(view_2356, [3], True); view_2356 = None + squeeze_2 = torch.ops.aten.squeeze.dim(sum_11, 3); sum_11 = None + view_2357 = torch.ops.aten.view.default(permute_407, [2, 8192, 1, 4, 128]); permute_407 = None + sum_12 = torch.ops.aten.sum.dim_IntList(view_2357, [3], True); view_2357 = None + squeeze_3 = torch.ops.aten.squeeze.dim(sum_12, 3); sum_12 = None + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(squeeze_3, torch.float32); squeeze_3 = None + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(permute_408, torch.float32); permute_408 = None + view_2358 = torch.ops.aten.view.default(convert_element_type_1157, [2, 8192, 1, 64, 2]); convert_element_type_1157 = None + view_as_complex_66 = torch.ops.aten.view_as_complex.default(view_2358); view_2358 = None + mul_296 = torch.ops.aten.mul.Tensor(view_as_complex_66, _conj); view_as_complex_66 = None + view_2359 = torch.ops.aten.view.default(convert_element_type_1158, [2, 8192, 4, 64, 2]); convert_element_type_1158 = None + view_as_complex_67 = torch.ops.aten.view_as_complex.default(view_2359); view_2359 = None + mul_297 = torch.ops.aten.mul.Tensor(view_as_complex_67, _conj); view_as_complex_67 = None + view_as_real_66 = torch.ops.aten.view_as_real.default(mul_296); mul_296 = None + view_2360 = torch.ops.aten.view.default(view_as_real_66, [2, 8192, 1, 128]); view_as_real_66 = None + convert_element_type_1159 = torch.ops.prims.convert_element_type.default(view_2360, torch.bfloat16); view_2360 = None + view_as_real_67 = torch.ops.aten.view_as_real.default(mul_297); mul_297 = None + view_2361 = torch.ops.aten.view.default(view_as_real_67, [2, 8192, 4, 128]); view_as_real_67 = None + convert_element_type_1160 = torch.ops.prims.convert_element_type.default(view_2361, torch.bfloat16); view_2361 = None + view_2362 = torch.ops.aten.view.default(squeeze_2, [2, 8192, 128]); squeeze_2 = None + view_2363 = torch.ops.aten.view.default(convert_element_type_1159, [2, 8192, 128]); convert_element_type_1159 = None + view_2364 = torch.ops.aten.view.default(convert_element_type_1160, [2, 8192, 512]); convert_element_type_1160 = None + view_2365 = torch.ops.aten.view.default(view_2362, [16384, 128]); view_2362 = None + permute_409 = torch.ops.aten.permute.default(view_2365, [1, 0]) + mm_249 = torch.ops.aten.mm.default(permute_409, view_2175); permute_409 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16); primals_277 = None + all_gather_into_tensor_335 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 32, '0'); convert_element_type_1000 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_335); all_gather_into_tensor_335 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_396, [1, 0]); wait_tensor_396 = None + permute_411 = torch.ops.aten.permute.default(permute_332, [1, 0]); permute_332 = None + mm_250 = torch.ops.aten.mm.default(view_2365, permute_411); view_2365 = permute_411 = None + view_2366 = torch.ops.aten.view.default(mm_250, [2, 8192, 4096]); mm_250 = None + convert_element_type_1165 = torch.ops.prims.convert_element_type.default(mm_249, torch.float32); mm_249 = None + reduce_scatter_tensor_85 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1165, 'avg', 32, '0'); convert_element_type_1165 = None + wait_tensor_449 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_85); reduce_scatter_tensor_85 = None + view_2367 = torch.ops.aten.view.default(view_2363, [16384, 128]); view_2363 = None + permute_413 = torch.ops.aten.permute.default(view_2367, [1, 0]) + mm_251 = torch.ops.aten.mm.default(permute_413, view_2175); permute_413 = None + permute_415 = torch.ops.aten.permute.default(permute_331, [1, 0]); permute_331 = None + mm_252 = torch.ops.aten.mm.default(view_2367, permute_415); view_2367 = permute_415 = None + view_2368 = torch.ops.aten.view.default(mm_252, [2, 8192, 4096]); mm_252 = None + add_140 = torch.ops.aten.add.Tensor(view_2366, view_2368); view_2366 = view_2368 = None + convert_element_type_1170 = torch.ops.prims.convert_element_type.default(mm_251, torch.float32); mm_251 = None + reduce_scatter_tensor_86 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1170, 'avg', 32, '0'); convert_element_type_1170 = None + wait_tensor_450 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_86); reduce_scatter_tensor_86 = None + view_2369 = torch.ops.aten.view.default(view_2364, [16384, 512]); view_2364 = None + permute_417 = torch.ops.aten.permute.default(view_2369, [1, 0]) + mm_253 = torch.ops.aten.mm.default(permute_417, view_2175); permute_417 = view_2175 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16); primals_275 = None + all_gather_into_tensor_333 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 32, '0'); convert_element_type_994 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_333); all_gather_into_tensor_333 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_394, [1, 0]); wait_tensor_394 = None + permute_419 = torch.ops.aten.permute.default(permute_330, [1, 0]); permute_330 = None + mm_254 = torch.ops.aten.mm.default(view_2369, permute_419); view_2369 = permute_419 = None + view_2370 = torch.ops.aten.view.default(mm_254, [2, 8192, 4096]); mm_254 = None + add_141 = torch.ops.aten.add.Tensor(add_140, view_2370); add_140 = view_2370 = None + convert_element_type_1175 = torch.ops.prims.convert_element_type.default(mm_253, torch.float32); mm_253 = None + reduce_scatter_tensor_87 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1175, 'avg', 32, '0'); convert_element_type_1175 = None + wait_tensor_451 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_87); reduce_scatter_tensor_87 = None + split_146 = torch.ops.aten.split.Tensor(add_141, 1024, 1); add_141 = None + getitem_1462 = split_146[0] + getitem_1463 = split_146[1] + getitem_1464 = split_146[2] + getitem_1465 = split_146[3] + getitem_1466 = split_146[4] + getitem_1467 = split_146[5] + getitem_1468 = split_146[6] + getitem_1469 = split_146[7]; split_146 = None + cat_138 = torch.ops.aten.cat.default([getitem_1462, getitem_1463, getitem_1464, getitem_1465, getitem_1466, getitem_1467, getitem_1468, getitem_1469]); getitem_1462 = getitem_1463 = getitem_1464 = getitem_1465 = getitem_1466 = getitem_1467 = getitem_1468 = getitem_1469 = None + reduce_scatter_tensor_88 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_138, 'sum', 8, '1'); cat_138 = None + wait_tensor_452 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_88); reduce_scatter_tensor_88 = None + convert_element_type_1176 = torch.ops.prims.convert_element_type.default(wait_tensor_452, torch.float32); wait_tensor_452 = None + convert_element_type_1178 = torch.ops.prims.convert_element_type.default(wait_tensor_392, torch.float32); wait_tensor_392 = None + mul_298 = torch.ops.aten.mul.Tensor(convert_element_type_1176, convert_element_type_1178); convert_element_type_1178 = None + mul_300 = torch.ops.aten.mul.Tensor(mul_240, mul_298) + sum_13 = torch.ops.aten.sum.dim_IntList(mul_300, [2], True); mul_300 = None + div_4 = torch.ops.aten.div.Tensor(mul_240, 4096) + mul_301 = torch.ops.aten.mul.Tensor(div_4, sum_13); div_4 = sum_13 = None + sub_7 = torch.ops.aten.sub.Tensor(mul_298, mul_301); mul_298 = mul_301 = None + mul_302 = torch.ops.aten.mul.Tensor(sub_7, rsqrt_60); sub_7 = rsqrt_60 = None + mul_303 = torch.ops.aten.mul.Tensor(convert_element_type_1176, mul_240); convert_element_type_1176 = mul_240 = None + sum_14 = torch.ops.aten.sum.dim_IntList(mul_303, [0, 1]); mul_303 = None + convert_element_type_1179 = torch.ops.prims.convert_element_type.default(mul_302, torch.bfloat16); mul_302 = None + convert_element_type_1180 = torch.ops.prims.convert_element_type.default(sum_14, torch.bfloat16); sum_14 = None + all_reduce_4 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1180, 'sum', '1'); convert_element_type_1180 = None + wait_tensor_453 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_4); all_reduce_4 = None + convert_element_type_1181 = torch.ops.prims.convert_element_type.default(wait_tensor_453, torch.float32); wait_tensor_453 = None + reduce_scatter_tensor_89 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1181, 'avg', 32, '0'); convert_element_type_1181 = None + wait_tensor_454 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_89); reduce_scatter_tensor_89 = None + add_142 = torch.ops.aten.add.Tensor(add_139, convert_element_type_1179); add_139 = convert_element_type_1179 = None + all_gather_into_tensor_360 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_142, 8, '1') + wait_tensor_455 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_360); all_gather_into_tensor_360 = None + split_147 = torch.ops.aten.split.Tensor(wait_tensor_455, 2); wait_tensor_455 = None + getitem_1470 = split_147[0] + getitem_1471 = split_147[1] + getitem_1472 = split_147[2] + getitem_1473 = split_147[3] + getitem_1474 = split_147[4] + getitem_1475 = split_147[5] + getitem_1476 = split_147[6] + getitem_1477 = split_147[7]; split_147 = None + cat_139 = torch.ops.aten.cat.default([getitem_1470, getitem_1471, getitem_1472, getitem_1473, getitem_1474, getitem_1475, getitem_1476, getitem_1477], 1); getitem_1470 = getitem_1471 = getitem_1472 = getitem_1473 = getitem_1474 = getitem_1475 = getitem_1476 = getitem_1477 = None + view_2371 = torch.ops.aten.view.default(cat_139, [16384, 4096]); cat_139 = None + permute_421 = torch.ops.aten.permute.default(view_2371, [1, 0]) + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59); reduce_scatter_tensor_59 = None + add_117 = torch.ops.aten.add.Tensor(add_115, wait_tensor_385); wait_tensor_385 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16); primals_270 = None + all_gather_into_tensor_326 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 32, '0'); convert_element_type_977 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_326); all_gather_into_tensor_326 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32); add_117 = None + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_386) + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + all_gather_into_tensor_327 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_979, 8, '1'); convert_element_type_979 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_327); all_gather_into_tensor_327 = None + split_127 = torch.ops.aten.split.Tensor(wait_tensor_387, 2); wait_tensor_387 = None + getitem_1286 = split_127[0] + getitem_1287 = split_127[1] + getitem_1288 = split_127[2] + getitem_1289 = split_127[3] + getitem_1290 = split_127[4] + getitem_1291 = split_127[5] + getitem_1292 = split_127[6] + getitem_1293 = split_127[7]; split_127 = None + cat_119 = torch.ops.aten.cat.default([getitem_1286, getitem_1287, getitem_1288, getitem_1289, getitem_1290, getitem_1291, getitem_1292, getitem_1293], 1); getitem_1286 = getitem_1287 = getitem_1288 = getitem_1289 = getitem_1290 = getitem_1291 = getitem_1292 = getitem_1293 = None + view_2148 = torch.ops.aten.view.default(cat_119, [16384, 4096]); cat_119 = None + view_2149 = torch.ops.aten.view.default(mm_207, [2, 8192, 1792]); mm_207 = None + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_2149, torch.float32); view_2149 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16); primals_272 = None + all_gather_into_tensor_329 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 32, '0'); convert_element_type_985 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_329); all_gather_into_tensor_329 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_389, [1, 0]); wait_tensor_389 = None + mm_208 = torch.ops.aten.mm.default(view_2148, permute_328) + view_2156 = torch.ops.aten.view.default(mm_208, [2, 8192, 1792]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_2156) + view_2163 = torch.ops.aten.view.default(mul_239, [16384, 1792]); mul_239 = None + mm_255 = torch.ops.aten.mm.default(permute_421, view_2163); permute_421 = view_2163 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16); primals_273 = None + all_gather_into_tensor_330 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 32, '0'); convert_element_type_988 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_330); all_gather_into_tensor_330 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_390, [1, 0]); wait_tensor_390 = None + permute_423 = torch.ops.aten.permute.default(permute_329, [1, 0]); permute_329 = None + mm_256 = torch.ops.aten.mm.default(view_2371, permute_423); view_2371 = permute_423 = None + view_2372 = torch.ops.aten.view.default(mm_256, [2, 8192, 1792]); mm_256 = None + convert_element_type_1186 = torch.ops.prims.convert_element_type.default(mm_255, torch.float32); mm_255 = None + reduce_scatter_tensor_90 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1186, 'avg', 32, '0'); convert_element_type_1186 = None + wait_tensor_456 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_90); reduce_scatter_tensor_90 = None + mul_304 = torch.ops.aten.mul.Tensor(view_2372, convert_element_type_984); convert_element_type_984 = None + mul_305 = torch.ops.aten.mul.Tensor(view_2372, view_2156); view_2372 = view_2156 = None + view_2373 = torch.ops.aten.view.default(mul_304, [16384, 1792]); mul_304 = None + permute_425 = torch.ops.aten.permute.default(view_2373, [1, 0]) + mm_257 = torch.ops.aten.mm.default(permute_425, view_2148); permute_425 = None + permute_427 = torch.ops.aten.permute.default(permute_328, [1, 0]); permute_328 = None + mm_258 = torch.ops.aten.mm.default(view_2373, permute_427); view_2373 = permute_427 = None + view_2374 = torch.ops.aten.view.default(mm_258, [2, 8192, 4096]); mm_258 = None + convert_element_type_1191 = torch.ops.prims.convert_element_type.default(mm_257, torch.float32); mm_257 = None + reduce_scatter_tensor_91 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1191, 'avg', 32, '0'); convert_element_type_1191 = None + wait_tensor_457 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_91); reduce_scatter_tensor_91 = None + convert_element_type_1192 = torch.ops.prims.convert_element_type.default(mul_305, torch.float32); mul_305 = None + neg_2 = torch.ops.aten.neg.default(convert_element_type_983) + exp_2 = torch.ops.aten.exp.default(neg_2); neg_2 = None + add_143 = torch.ops.aten.add.Tensor(exp_2, 1); exp_2 = None + reciprocal_2 = torch.ops.aten.reciprocal.default(add_143); add_143 = None + mul_306 = torch.ops.aten.mul.Tensor(reciprocal_2, 1); reciprocal_2 = None + mul_307 = torch.ops.aten.mul.Tensor(convert_element_type_1192, mul_306); convert_element_type_1192 = None + sub_8 = torch.ops.aten.sub.Tensor(1, mul_306); mul_306 = None + mul_308 = torch.ops.aten.mul.Tensor(convert_element_type_983, sub_8); convert_element_type_983 = sub_8 = None + add_144 = torch.ops.aten.add.Tensor(mul_308, 1); mul_308 = None + mul_309 = torch.ops.aten.mul.Tensor(mul_307, add_144); mul_307 = add_144 = None + convert_element_type_1194 = torch.ops.prims.convert_element_type.default(mul_309, torch.bfloat16); mul_309 = None + view_2375 = torch.ops.aten.view.default(convert_element_type_1194, [16384, 1792]); convert_element_type_1194 = None + permute_429 = torch.ops.aten.permute.default(view_2375, [1, 0]) + mm_259 = torch.ops.aten.mm.default(permute_429, view_2148); permute_429 = view_2148 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16); primals_271 = None + all_gather_into_tensor_328 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 32, '0'); convert_element_type_980 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_328); all_gather_into_tensor_328 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_388, [1, 0]); wait_tensor_388 = None + permute_431 = torch.ops.aten.permute.default(permute_327, [1, 0]); permute_327 = None + mm_260 = torch.ops.aten.mm.default(view_2375, permute_431); view_2375 = permute_431 = None + view_2376 = torch.ops.aten.view.default(mm_260, [2, 8192, 4096]); mm_260 = None + add_145 = torch.ops.aten.add.Tensor(view_2374, view_2376); view_2374 = view_2376 = None + convert_element_type_1199 = torch.ops.prims.convert_element_type.default(mm_259, torch.float32); mm_259 = None + reduce_scatter_tensor_92 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1199, 'avg', 32, '0'); convert_element_type_1199 = None + wait_tensor_458 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_92); reduce_scatter_tensor_92 = None + split_148 = torch.ops.aten.split.Tensor(add_145, 1024, 1); add_145 = None + getitem_1478 = split_148[0] + getitem_1479 = split_148[1] + getitem_1480 = split_148[2] + getitem_1481 = split_148[3] + getitem_1482 = split_148[4] + getitem_1483 = split_148[5] + getitem_1484 = split_148[6] + getitem_1485 = split_148[7]; split_148 = None + cat_140 = torch.ops.aten.cat.default([getitem_1478, getitem_1479, getitem_1480, getitem_1481, getitem_1482, getitem_1483, getitem_1484, getitem_1485]); getitem_1478 = getitem_1479 = getitem_1480 = getitem_1481 = getitem_1482 = getitem_1483 = getitem_1484 = getitem_1485 = None + reduce_scatter_tensor_93 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_140, 'sum', 8, '1'); cat_140 = None + wait_tensor_459 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_93); reduce_scatter_tensor_93 = None + convert_element_type_1200 = torch.ops.prims.convert_element_type.default(wait_tensor_459, torch.float32); wait_tensor_459 = None + convert_element_type_1202 = torch.ops.prims.convert_element_type.default(wait_tensor_386, torch.float32); wait_tensor_386 = None + mul_310 = torch.ops.aten.mul.Tensor(convert_element_type_1200, convert_element_type_1202); convert_element_type_1202 = None + mul_312 = torch.ops.aten.mul.Tensor(mul_236, mul_310) + sum_15 = torch.ops.aten.sum.dim_IntList(mul_312, [2], True); mul_312 = None + div_5 = torch.ops.aten.div.Tensor(mul_236, 4096) + mul_313 = torch.ops.aten.mul.Tensor(div_5, sum_15); div_5 = sum_15 = None + sub_9 = torch.ops.aten.sub.Tensor(mul_310, mul_313); mul_310 = mul_313 = None + mul_314 = torch.ops.aten.mul.Tensor(sub_9, rsqrt_59); sub_9 = rsqrt_59 = None + mul_315 = torch.ops.aten.mul.Tensor(convert_element_type_1200, mul_236); convert_element_type_1200 = mul_236 = None + sum_16 = torch.ops.aten.sum.dim_IntList(mul_315, [0, 1]); mul_315 = None + convert_element_type_1203 = torch.ops.prims.convert_element_type.default(mul_314, torch.bfloat16); mul_314 = None + convert_element_type_1204 = torch.ops.prims.convert_element_type.default(sum_16, torch.bfloat16); sum_16 = None + all_reduce_5 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1204, 'sum', '1'); convert_element_type_1204 = None + wait_tensor_460 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_5); all_reduce_5 = None + convert_element_type_1205 = torch.ops.prims.convert_element_type.default(wait_tensor_460, torch.float32); wait_tensor_460 = None + reduce_scatter_tensor_94 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1205, 'avg', 32, '0'); convert_element_type_1205 = None + wait_tensor_461 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_94); reduce_scatter_tensor_94 = None + add_146 = torch.ops.aten.add.Tensor(add_142, convert_element_type_1203); add_142 = convert_element_type_1203 = None + all_gather_into_tensor_361 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_146, 8, '1') + wait_tensor_462 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_361); all_gather_into_tensor_361 = None + split_149 = torch.ops.aten.split.Tensor(wait_tensor_462, 2); wait_tensor_462 = None + getitem_1486 = split_149[0] + getitem_1487 = split_149[1] + getitem_1488 = split_149[2] + getitem_1489 = split_149[3] + getitem_1490 = split_149[4] + getitem_1491 = split_149[5] + getitem_1492 = split_149[6] + getitem_1493 = split_149[7]; split_149 = None + cat_141 = torch.ops.aten.cat.default([getitem_1486, getitem_1487, getitem_1488, getitem_1489, getitem_1490, getitem_1491, getitem_1492, getitem_1493], 1); getitem_1486 = getitem_1487 = getitem_1488 = getitem_1489 = getitem_1490 = getitem_1491 = getitem_1492 = getitem_1493 = None + view_2377 = torch.ops.aten.view.default(cat_141, [16384, 4096]); cat_141 = None + permute_433 = torch.ops.aten.permute.default(view_2377, [1, 0]) + permute_325 = torch.ops.aten.permute.default(getitem_1269, [0, 2, 1, 3]) + view_2130 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + view_2136 = torch.ops.aten.view.default(view_2130, [16384, 512]); view_2130 = None + mm_261 = torch.ops.aten.mm.default(permute_433, view_2136); permute_433 = view_2136 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16); primals_269 = None + all_gather_into_tensor_325 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 32, '0'); convert_element_type_974 = None + wait_tensor_384 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_325); all_gather_into_tensor_325 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_384, [1, 0]); wait_tensor_384 = None + permute_435 = torch.ops.aten.permute.default(permute_326, [1, 0]); permute_326 = None + mm_262 = torch.ops.aten.mm.default(view_2377, permute_435); view_2377 = permute_435 = None + view_2378 = torch.ops.aten.view.default(mm_262, [2, 8192, 512]); mm_262 = None + convert_element_type_1210 = torch.ops.prims.convert_element_type.default(mm_261, torch.float32); mm_261 = None + reduce_scatter_tensor_95 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1210, 'avg', 32, '0'); convert_element_type_1210 = None + wait_tensor_463 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_95); reduce_scatter_tensor_95 = None + view_2379 = torch.ops.aten.view.default(view_2378, [2, 8192, 4, 128]); view_2378 = None + permute_437 = torch.ops.aten.permute.default(view_2379, [0, 2, 1, 3]); view_2379 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16); primals_265 = None + all_gather_into_tensor_320 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 32, '0'); convert_element_type_958 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_320); all_gather_into_tensor_320 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32); add_115 = None + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_379) + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + all_gather_into_tensor_321 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_960, 8, '1'); convert_element_type_960 = None + wait_tensor_380 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_321); all_gather_into_tensor_321 = None + split_125 = torch.ops.aten.split.Tensor(wait_tensor_380, 2); wait_tensor_380 = None + getitem_1261 = split_125[0] + getitem_1262 = split_125[1] + getitem_1263 = split_125[2] + getitem_1264 = split_125[3] + getitem_1265 = split_125[4] + getitem_1266 = split_125[5] + getitem_1267 = split_125[6] + getitem_1268 = split_125[7]; split_125 = None + cat_117 = torch.ops.aten.cat.default([getitem_1261, getitem_1262, getitem_1263, getitem_1264, getitem_1265, getitem_1266, getitem_1267, getitem_1268], 1); getitem_1261 = getitem_1262 = getitem_1263 = getitem_1264 = getitem_1265 = getitem_1266 = getitem_1267 = getitem_1268 = None + view_2103 = torch.ops.aten.view.default(cat_117, [16384, 4096]); cat_117 = None + view_2104 = torch.ops.aten.view.default(mm_203, [2, 8192, 512]); mm_203 = None + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16); primals_267 = None + all_gather_into_tensor_323 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 32, '0'); convert_element_type_964 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_323); all_gather_into_tensor_323 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_382, [1, 0]); wait_tensor_382 = None + mm_204 = torch.ops.aten.mm.default(view_2103, permute_320) + view_2111 = torch.ops.aten.view.default(mm_204, [2, 8192, 128]); mm_204 = None + view_2118 = torch.ops.aten.view.default(mm_205, [2, 8192, 128]); mm_205 = None + view_2120 = torch.ops.aten.view.default(view_2104, [2, 8192, -1, 128]); view_2104 = None + view_2121 = torch.ops.aten.view.default(view_2111, [2, 8192, -1, 128]); view_2111 = None + view_2122 = torch.ops.aten.view.default(view_2118, [2, 8192, -1, 128]); view_2118 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_2120, torch.float32); view_2120 = None + view_2123 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 4, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_2123); view_2123 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_2121, torch.float32); view_2121 = None + view_2124 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 1, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_2124); view_2124 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_37); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_2126 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 4, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_37); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_2127 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 1, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_2126, torch.bfloat16); view_2126 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_2127, torch.bfloat16); view_2127 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 1, 4, 128]); unsqueeze_58 = None + view_2128 = torch.ops.aten.view.default(expand_58, [2, 8192, 4, 128]); expand_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_2122, 3); view_2122 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 1, 4, 128]); unsqueeze_59 = None + view_2129 = torch.ops.aten.view.default(expand_59, [2, 8192, 4, 128]); expand_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_2128, [0, 2, 1, 3]); view_2128 = None + permute_324 = torch.ops.aten.permute.default(view_2129, [0, 2, 1, 3]); view_2129 = None + _scaled_dot_product_cudnn_attention_backward_2 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_437, permute_322, permute_323, permute_324, getitem_1269, getitem_1270, getitem_1275, getitem_1276, None, None, None, 8192, 8192, 0.0, True); permute_437 = permute_322 = permute_323 = permute_324 = getitem_1269 = getitem_1270 = getitem_1275 = getitem_1276 = None + getitem_1494 = _scaled_dot_product_cudnn_attention_backward_2[0] + getitem_1495 = _scaled_dot_product_cudnn_attention_backward_2[1] + getitem_1496 = _scaled_dot_product_cudnn_attention_backward_2[2]; _scaled_dot_product_cudnn_attention_backward_2 = None + permute_438 = torch.ops.aten.permute.default(getitem_1496, [0, 2, 1, 3]); getitem_1496 = None + permute_439 = torch.ops.aten.permute.default(getitem_1495, [0, 2, 1, 3]); getitem_1495 = None + permute_440 = torch.ops.aten.permute.default(getitem_1494, [0, 2, 1, 3]); getitem_1494 = None + view_2380 = torch.ops.aten.view.default(permute_438, [2, 8192, 1, 4, 128]); permute_438 = None + sum_17 = torch.ops.aten.sum.dim_IntList(view_2380, [3], True); view_2380 = None + squeeze_4 = torch.ops.aten.squeeze.dim(sum_17, 3); sum_17 = None + view_2381 = torch.ops.aten.view.default(permute_439, [2, 8192, 1, 4, 128]); permute_439 = None + sum_18 = torch.ops.aten.sum.dim_IntList(view_2381, [3], True); view_2381 = None + squeeze_5 = torch.ops.aten.squeeze.dim(sum_18, 3); sum_18 = None + convert_element_type_1211 = torch.ops.prims.convert_element_type.default(squeeze_5, torch.float32); squeeze_5 = None + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(permute_440, torch.float32); permute_440 = None + view_2382 = torch.ops.aten.view.default(convert_element_type_1211, [2, 8192, 1, 64, 2]); convert_element_type_1211 = None + view_as_complex_68 = torch.ops.aten.view_as_complex.default(view_2382); view_2382 = None + mul_316 = torch.ops.aten.mul.Tensor(view_as_complex_68, _conj); view_as_complex_68 = None + view_2383 = torch.ops.aten.view.default(convert_element_type_1212, [2, 8192, 4, 64, 2]); convert_element_type_1212 = None + view_as_complex_69 = torch.ops.aten.view_as_complex.default(view_2383); view_2383 = None + mul_317 = torch.ops.aten.mul.Tensor(view_as_complex_69, _conj); view_as_complex_69 = None + view_as_real_68 = torch.ops.aten.view_as_real.default(mul_316); mul_316 = None + view_2384 = torch.ops.aten.view.default(view_as_real_68, [2, 8192, 1, 128]); view_as_real_68 = None + convert_element_type_1213 = torch.ops.prims.convert_element_type.default(view_2384, torch.bfloat16); view_2384 = None + view_as_real_69 = torch.ops.aten.view_as_real.default(mul_317); mul_317 = None + view_2385 = torch.ops.aten.view.default(view_as_real_69, [2, 8192, 4, 128]); view_as_real_69 = None + convert_element_type_1214 = torch.ops.prims.convert_element_type.default(view_2385, torch.bfloat16); view_2385 = None + view_2386 = torch.ops.aten.view.default(squeeze_4, [2, 8192, 128]); squeeze_4 = None + view_2387 = torch.ops.aten.view.default(convert_element_type_1213, [2, 8192, 128]); convert_element_type_1213 = None + view_2388 = torch.ops.aten.view.default(convert_element_type_1214, [2, 8192, 512]); convert_element_type_1214 = None + view_2389 = torch.ops.aten.view.default(view_2386, [16384, 128]); view_2386 = None + permute_441 = torch.ops.aten.permute.default(view_2389, [1, 0]) + mm_263 = torch.ops.aten.mm.default(permute_441, view_2103); permute_441 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16); primals_268 = None + all_gather_into_tensor_324 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 32, '0'); convert_element_type_967 = None + wait_tensor_383 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_324); all_gather_into_tensor_324 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_383, [1, 0]); wait_tensor_383 = None + permute_443 = torch.ops.aten.permute.default(permute_321, [1, 0]); permute_321 = None + mm_264 = torch.ops.aten.mm.default(view_2389, permute_443); view_2389 = permute_443 = None + view_2390 = torch.ops.aten.view.default(mm_264, [2, 8192, 4096]); mm_264 = None + convert_element_type_1219 = torch.ops.prims.convert_element_type.default(mm_263, torch.float32); mm_263 = None + reduce_scatter_tensor_96 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1219, 'avg', 32, '0'); convert_element_type_1219 = None + wait_tensor_464 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_96); reduce_scatter_tensor_96 = None + view_2391 = torch.ops.aten.view.default(view_2387, [16384, 128]); view_2387 = None + permute_445 = torch.ops.aten.permute.default(view_2391, [1, 0]) + mm_265 = torch.ops.aten.mm.default(permute_445, view_2103); permute_445 = None + permute_447 = torch.ops.aten.permute.default(permute_320, [1, 0]); permute_320 = None + mm_266 = torch.ops.aten.mm.default(view_2391, permute_447); view_2391 = permute_447 = None + view_2392 = torch.ops.aten.view.default(mm_266, [2, 8192, 4096]); mm_266 = None + add_147 = torch.ops.aten.add.Tensor(view_2390, view_2392); view_2390 = view_2392 = None + convert_element_type_1224 = torch.ops.prims.convert_element_type.default(mm_265, torch.float32); mm_265 = None + reduce_scatter_tensor_97 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1224, 'avg', 32, '0'); convert_element_type_1224 = None + wait_tensor_465 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_97); reduce_scatter_tensor_97 = None + view_2393 = torch.ops.aten.view.default(view_2388, [16384, 512]); view_2388 = None + permute_449 = torch.ops.aten.permute.default(view_2393, [1, 0]) + mm_267 = torch.ops.aten.mm.default(permute_449, view_2103); permute_449 = view_2103 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16); primals_266 = None + all_gather_into_tensor_322 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 32, '0'); convert_element_type_961 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_322); all_gather_into_tensor_322 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_381, [1, 0]); wait_tensor_381 = None + permute_451 = torch.ops.aten.permute.default(permute_319, [1, 0]); permute_319 = None + mm_268 = torch.ops.aten.mm.default(view_2393, permute_451); view_2393 = permute_451 = None + view_2394 = torch.ops.aten.view.default(mm_268, [2, 8192, 4096]); mm_268 = None + add_148 = torch.ops.aten.add.Tensor(add_147, view_2394); add_147 = view_2394 = None + convert_element_type_1229 = torch.ops.prims.convert_element_type.default(mm_267, torch.float32); mm_267 = None + reduce_scatter_tensor_98 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1229, 'avg', 32, '0'); convert_element_type_1229 = None + wait_tensor_466 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_98); reduce_scatter_tensor_98 = None + split_150 = torch.ops.aten.split.Tensor(add_148, 1024, 1); add_148 = None + getitem_1497 = split_150[0] + getitem_1498 = split_150[1] + getitem_1499 = split_150[2] + getitem_1500 = split_150[3] + getitem_1501 = split_150[4] + getitem_1502 = split_150[5] + getitem_1503 = split_150[6] + getitem_1504 = split_150[7]; split_150 = None + cat_142 = torch.ops.aten.cat.default([getitem_1497, getitem_1498, getitem_1499, getitem_1500, getitem_1501, getitem_1502, getitem_1503, getitem_1504]); getitem_1497 = getitem_1498 = getitem_1499 = getitem_1500 = getitem_1501 = getitem_1502 = getitem_1503 = getitem_1504 = None + reduce_scatter_tensor_99 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_142, 'sum', 8, '1'); cat_142 = None + wait_tensor_467 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_99); reduce_scatter_tensor_99 = None + convert_element_type_1230 = torch.ops.prims.convert_element_type.default(wait_tensor_467, torch.float32); wait_tensor_467 = None + convert_element_type_1232 = torch.ops.prims.convert_element_type.default(wait_tensor_379, torch.float32); wait_tensor_379 = None + mul_318 = torch.ops.aten.mul.Tensor(convert_element_type_1230, convert_element_type_1232); convert_element_type_1232 = None + mul_320 = torch.ops.aten.mul.Tensor(mul_232, mul_318) + sum_19 = torch.ops.aten.sum.dim_IntList(mul_320, [2], True); mul_320 = None + div_6 = torch.ops.aten.div.Tensor(mul_232, 4096) + mul_321 = torch.ops.aten.mul.Tensor(div_6, sum_19); div_6 = sum_19 = None + sub_10 = torch.ops.aten.sub.Tensor(mul_318, mul_321); mul_318 = mul_321 = None + mul_322 = torch.ops.aten.mul.Tensor(sub_10, rsqrt_58); sub_10 = rsqrt_58 = None + mul_323 = torch.ops.aten.mul.Tensor(convert_element_type_1230, mul_232); convert_element_type_1230 = mul_232 = None + sum_20 = torch.ops.aten.sum.dim_IntList(mul_323, [0, 1]); mul_323 = None + convert_element_type_1233 = torch.ops.prims.convert_element_type.default(mul_322, torch.bfloat16); mul_322 = None + convert_element_type_1234 = torch.ops.prims.convert_element_type.default(sum_20, torch.bfloat16); sum_20 = None + all_reduce_6 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1234, 'sum', '1'); convert_element_type_1234 = None + wait_tensor_468 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_6); all_reduce_6 = None + convert_element_type_1235 = torch.ops.prims.convert_element_type.default(wait_tensor_468, torch.float32); wait_tensor_468 = None + reduce_scatter_tensor_100 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1235, 'avg', 32, '0'); convert_element_type_1235 = None + wait_tensor_469 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_100); reduce_scatter_tensor_100 = None + add_149 = torch.ops.aten.add.Tensor(add_146, convert_element_type_1233); add_146 = convert_element_type_1233 = None + all_gather_into_tensor_362 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_149, 8, '1') + wait_tensor_470 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_362); all_gather_into_tensor_362 = None + split_151 = torch.ops.aten.split.Tensor(wait_tensor_470, 2); wait_tensor_470 = None + getitem_1505 = split_151[0] + getitem_1506 = split_151[1] + getitem_1507 = split_151[2] + getitem_1508 = split_151[3] + getitem_1509 = split_151[4] + getitem_1510 = split_151[5] + getitem_1511 = split_151[6] + getitem_1512 = split_151[7]; split_151 = None + cat_143 = torch.ops.aten.cat.default([getitem_1505, getitem_1506, getitem_1507, getitem_1508, getitem_1509, getitem_1510, getitem_1511, getitem_1512], 1); getitem_1505 = getitem_1506 = getitem_1507 = getitem_1508 = getitem_1509 = getitem_1510 = getitem_1511 = getitem_1512 = None + view_2395 = torch.ops.aten.view.default(cat_143, [16384, 4096]); cat_143 = None + permute_453 = torch.ops.aten.permute.default(view_2395, [1, 0]) + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57); reduce_scatter_tensor_57 = None + add_113 = torch.ops.aten.add.Tensor(add_111, wait_tensor_372); wait_tensor_372 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16); primals_261 = None + all_gather_into_tensor_315 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 32, '0'); convert_element_type_944 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_315); all_gather_into_tensor_315 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32); add_113 = None + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_373) + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + all_gather_into_tensor_316 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_946, 8, '1'); convert_element_type_946 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_316); all_gather_into_tensor_316 = None + split_123 = torch.ops.aten.split.Tensor(wait_tensor_374, 2); wait_tensor_374 = None + getitem_1245 = split_123[0] + getitem_1246 = split_123[1] + getitem_1247 = split_123[2] + getitem_1248 = split_123[3] + getitem_1249 = split_123[4] + getitem_1250 = split_123[5] + getitem_1251 = split_123[6] + getitem_1252 = split_123[7]; split_123 = None + cat_115 = torch.ops.aten.cat.default([getitem_1245, getitem_1246, getitem_1247, getitem_1248, getitem_1249, getitem_1250, getitem_1251, getitem_1252], 1); getitem_1245 = getitem_1246 = getitem_1247 = getitem_1248 = getitem_1249 = getitem_1250 = getitem_1251 = getitem_1252 = None + view_2076 = torch.ops.aten.view.default(cat_115, [16384, 4096]); cat_115 = None + view_2077 = torch.ops.aten.view.default(mm_200, [2, 8192, 1792]); mm_200 = None + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_2077, torch.float32); view_2077 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16); primals_263 = None + all_gather_into_tensor_318 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 32, '0'); convert_element_type_952 = None + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_318); all_gather_into_tensor_318 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_376, [1, 0]); wait_tensor_376 = None + mm_201 = torch.ops.aten.mm.default(view_2076, permute_317) + view_2084 = torch.ops.aten.view.default(mm_201, [2, 8192, 1792]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_2084) + view_2091 = torch.ops.aten.view.default(mul_231, [16384, 1792]); mul_231 = None + mm_269 = torch.ops.aten.mm.default(permute_453, view_2091); permute_453 = view_2091 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16); primals_264 = None + all_gather_into_tensor_319 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 32, '0'); convert_element_type_955 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_319); all_gather_into_tensor_319 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_377, [1, 0]); wait_tensor_377 = None + permute_455 = torch.ops.aten.permute.default(permute_318, [1, 0]); permute_318 = None + mm_270 = torch.ops.aten.mm.default(view_2395, permute_455); view_2395 = permute_455 = None + view_2396 = torch.ops.aten.view.default(mm_270, [2, 8192, 1792]); mm_270 = None + convert_element_type_1240 = torch.ops.prims.convert_element_type.default(mm_269, torch.float32); mm_269 = None + reduce_scatter_tensor_101 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1240, 'avg', 32, '0'); convert_element_type_1240 = None + wait_tensor_471 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_101); reduce_scatter_tensor_101 = None + mul_324 = torch.ops.aten.mul.Tensor(view_2396, convert_element_type_951); convert_element_type_951 = None + mul_325 = torch.ops.aten.mul.Tensor(view_2396, view_2084); view_2396 = view_2084 = None + view_2397 = torch.ops.aten.view.default(mul_324, [16384, 1792]); mul_324 = None + permute_457 = torch.ops.aten.permute.default(view_2397, [1, 0]) + mm_271 = torch.ops.aten.mm.default(permute_457, view_2076); permute_457 = None + permute_459 = torch.ops.aten.permute.default(permute_317, [1, 0]); permute_317 = None + mm_272 = torch.ops.aten.mm.default(view_2397, permute_459); view_2397 = permute_459 = None + view_2398 = torch.ops.aten.view.default(mm_272, [2, 8192, 4096]); mm_272 = None + convert_element_type_1245 = torch.ops.prims.convert_element_type.default(mm_271, torch.float32); mm_271 = None + reduce_scatter_tensor_102 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1245, 'avg', 32, '0'); convert_element_type_1245 = None + wait_tensor_472 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_102); reduce_scatter_tensor_102 = None + convert_element_type_1246 = torch.ops.prims.convert_element_type.default(mul_325, torch.float32); mul_325 = None + neg_3 = torch.ops.aten.neg.default(convert_element_type_950) + exp_3 = torch.ops.aten.exp.default(neg_3); neg_3 = None + add_150 = torch.ops.aten.add.Tensor(exp_3, 1); exp_3 = None + reciprocal_3 = torch.ops.aten.reciprocal.default(add_150); add_150 = None + mul_326 = torch.ops.aten.mul.Tensor(reciprocal_3, 1); reciprocal_3 = None + mul_327 = torch.ops.aten.mul.Tensor(convert_element_type_1246, mul_326); convert_element_type_1246 = None + sub_11 = torch.ops.aten.sub.Tensor(1, mul_326); mul_326 = None + mul_328 = torch.ops.aten.mul.Tensor(convert_element_type_950, sub_11); convert_element_type_950 = sub_11 = None + add_151 = torch.ops.aten.add.Tensor(mul_328, 1); mul_328 = None + mul_329 = torch.ops.aten.mul.Tensor(mul_327, add_151); mul_327 = add_151 = None + convert_element_type_1248 = torch.ops.prims.convert_element_type.default(mul_329, torch.bfloat16); mul_329 = None + view_2399 = torch.ops.aten.view.default(convert_element_type_1248, [16384, 1792]); convert_element_type_1248 = None + permute_461 = torch.ops.aten.permute.default(view_2399, [1, 0]) + mm_273 = torch.ops.aten.mm.default(permute_461, view_2076); permute_461 = view_2076 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16); primals_262 = None + all_gather_into_tensor_317 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 32, '0'); convert_element_type_947 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_317); all_gather_into_tensor_317 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_375, [1, 0]); wait_tensor_375 = None + permute_463 = torch.ops.aten.permute.default(permute_316, [1, 0]); permute_316 = None + mm_274 = torch.ops.aten.mm.default(view_2399, permute_463); view_2399 = permute_463 = None + view_2400 = torch.ops.aten.view.default(mm_274, [2, 8192, 4096]); mm_274 = None + add_152 = torch.ops.aten.add.Tensor(view_2398, view_2400); view_2398 = view_2400 = None + convert_element_type_1253 = torch.ops.prims.convert_element_type.default(mm_273, torch.float32); mm_273 = None + reduce_scatter_tensor_103 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1253, 'avg', 32, '0'); convert_element_type_1253 = None + wait_tensor_473 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_103); reduce_scatter_tensor_103 = None + split_152 = torch.ops.aten.split.Tensor(add_152, 1024, 1); add_152 = None + getitem_1513 = split_152[0] + getitem_1514 = split_152[1] + getitem_1515 = split_152[2] + getitem_1516 = split_152[3] + getitem_1517 = split_152[4] + getitem_1518 = split_152[5] + getitem_1519 = split_152[6] + getitem_1520 = split_152[7]; split_152 = None + cat_144 = torch.ops.aten.cat.default([getitem_1513, getitem_1514, getitem_1515, getitem_1516, getitem_1517, getitem_1518, getitem_1519, getitem_1520]); getitem_1513 = getitem_1514 = getitem_1515 = getitem_1516 = getitem_1517 = getitem_1518 = getitem_1519 = getitem_1520 = None + reduce_scatter_tensor_104 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_144, 'sum', 8, '1'); cat_144 = None + wait_tensor_474 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_104); reduce_scatter_tensor_104 = None + convert_element_type_1254 = torch.ops.prims.convert_element_type.default(wait_tensor_474, torch.float32); wait_tensor_474 = None + convert_element_type_1256 = torch.ops.prims.convert_element_type.default(wait_tensor_373, torch.float32); wait_tensor_373 = None + mul_330 = torch.ops.aten.mul.Tensor(convert_element_type_1254, convert_element_type_1256); convert_element_type_1256 = None + mul_332 = torch.ops.aten.mul.Tensor(mul_228, mul_330) + sum_21 = torch.ops.aten.sum.dim_IntList(mul_332, [2], True); mul_332 = None + div_7 = torch.ops.aten.div.Tensor(mul_228, 4096) + mul_333 = torch.ops.aten.mul.Tensor(div_7, sum_21); div_7 = sum_21 = None + sub_12 = torch.ops.aten.sub.Tensor(mul_330, mul_333); mul_330 = mul_333 = None + mul_334 = torch.ops.aten.mul.Tensor(sub_12, rsqrt_57); sub_12 = rsqrt_57 = None + mul_335 = torch.ops.aten.mul.Tensor(convert_element_type_1254, mul_228); convert_element_type_1254 = mul_228 = None + sum_22 = torch.ops.aten.sum.dim_IntList(mul_335, [0, 1]); mul_335 = None + convert_element_type_1257 = torch.ops.prims.convert_element_type.default(mul_334, torch.bfloat16); mul_334 = None + convert_element_type_1258 = torch.ops.prims.convert_element_type.default(sum_22, torch.bfloat16); sum_22 = None + all_reduce_7 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1258, 'sum', '1'); convert_element_type_1258 = None + wait_tensor_475 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_7); all_reduce_7 = None + convert_element_type_1259 = torch.ops.prims.convert_element_type.default(wait_tensor_475, torch.float32); wait_tensor_475 = None + reduce_scatter_tensor_105 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1259, 'avg', 32, '0'); convert_element_type_1259 = None + wait_tensor_476 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_105); reduce_scatter_tensor_105 = None + add_153 = torch.ops.aten.add.Tensor(add_149, convert_element_type_1257); add_149 = convert_element_type_1257 = None + all_gather_into_tensor_363 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_153, 8, '1') + wait_tensor_477 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_363); all_gather_into_tensor_363 = None + split_153 = torch.ops.aten.split.Tensor(wait_tensor_477, 2); wait_tensor_477 = None + getitem_1521 = split_153[0] + getitem_1522 = split_153[1] + getitem_1523 = split_153[2] + getitem_1524 = split_153[3] + getitem_1525 = split_153[4] + getitem_1526 = split_153[5] + getitem_1527 = split_153[6] + getitem_1528 = split_153[7]; split_153 = None + cat_145 = torch.ops.aten.cat.default([getitem_1521, getitem_1522, getitem_1523, getitem_1524, getitem_1525, getitem_1526, getitem_1527, getitem_1528], 1); getitem_1521 = getitem_1522 = getitem_1523 = getitem_1524 = getitem_1525 = getitem_1526 = getitem_1527 = getitem_1528 = None + view_2401 = torch.ops.aten.view.default(cat_145, [16384, 4096]); cat_145 = None + permute_465 = torch.ops.aten.permute.default(view_2401, [1, 0]) + permute_314 = torch.ops.aten.permute.default(getitem_1228, [0, 2, 1, 3]) + view_2058 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + view_2064 = torch.ops.aten.view.default(view_2058, [16384, 512]); view_2058 = None + mm_275 = torch.ops.aten.mm.default(permute_465, view_2064); permute_465 = view_2064 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16); primals_260 = None + all_gather_into_tensor_314 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 32, '0'); convert_element_type_941 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_314); all_gather_into_tensor_314 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_371, [1, 0]); wait_tensor_371 = None + permute_467 = torch.ops.aten.permute.default(permute_315, [1, 0]); permute_315 = None + mm_276 = torch.ops.aten.mm.default(view_2401, permute_467); view_2401 = permute_467 = None + view_2402 = torch.ops.aten.view.default(mm_276, [2, 8192, 512]); mm_276 = None + convert_element_type_1264 = torch.ops.prims.convert_element_type.default(mm_275, torch.float32); mm_275 = None + reduce_scatter_tensor_106 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1264, 'avg', 32, '0'); convert_element_type_1264 = None + wait_tensor_478 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_106); reduce_scatter_tensor_106 = None + view_2403 = torch.ops.aten.view.default(view_2402, [2, 8192, 4, 128]); view_2402 = None + permute_469 = torch.ops.aten.permute.default(view_2403, [0, 2, 1, 3]); view_2403 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16); primals_256 = None + all_gather_into_tensor_309 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 32, '0'); convert_element_type_925 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_309); all_gather_into_tensor_309 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32); add_111 = None + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_366) + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + all_gather_into_tensor_310 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_927, 8, '1'); convert_element_type_927 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_310); all_gather_into_tensor_310 = None + split_121 = torch.ops.aten.split.Tensor(wait_tensor_367, 2); wait_tensor_367 = None + getitem_1220 = split_121[0] + getitem_1221 = split_121[1] + getitem_1222 = split_121[2] + getitem_1223 = split_121[3] + getitem_1224 = split_121[4] + getitem_1225 = split_121[5] + getitem_1226 = split_121[6] + getitem_1227 = split_121[7]; split_121 = None + cat_113 = torch.ops.aten.cat.default([getitem_1220, getitem_1221, getitem_1222, getitem_1223, getitem_1224, getitem_1225, getitem_1226, getitem_1227], 1); getitem_1220 = getitem_1221 = getitem_1222 = getitem_1223 = getitem_1224 = getitem_1225 = getitem_1226 = getitem_1227 = None + view_2031 = torch.ops.aten.view.default(cat_113, [16384, 4096]); cat_113 = None + view_2032 = torch.ops.aten.view.default(mm_196, [2, 8192, 512]); mm_196 = None + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16); primals_258 = None + all_gather_into_tensor_312 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 32, '0'); convert_element_type_931 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_312); all_gather_into_tensor_312 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_369, [1, 0]); wait_tensor_369 = None + mm_197 = torch.ops.aten.mm.default(view_2031, permute_309) + view_2039 = torch.ops.aten.view.default(mm_197, [2, 8192, 128]); mm_197 = None + view_2046 = torch.ops.aten.view.default(mm_198, [2, 8192, 128]); mm_198 = None + view_2048 = torch.ops.aten.view.default(view_2032, [2, 8192, -1, 128]); view_2032 = None + view_2049 = torch.ops.aten.view.default(view_2039, [2, 8192, -1, 128]); view_2039 = None + view_2050 = torch.ops.aten.view.default(view_2046, [2, 8192, -1, 128]); view_2046 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_2048, torch.float32); view_2048 = None + view_2051 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 4, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_2051); view_2051 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_2049, torch.float32); view_2049 = None + view_2052 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 1, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_2052); view_2052 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_37); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_2054 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 4, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_37); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_2055 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 1, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_2054, torch.bfloat16); view_2054 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_2055, torch.bfloat16); view_2055 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 1, 4, 128]); unsqueeze_56 = None + view_2056 = torch.ops.aten.view.default(expand_56, [2, 8192, 4, 128]); expand_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_2050, 3); view_2050 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 1, 4, 128]); unsqueeze_57 = None + view_2057 = torch.ops.aten.view.default(expand_57, [2, 8192, 4, 128]); expand_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_2056, [0, 2, 1, 3]); view_2056 = None + permute_313 = torch.ops.aten.permute.default(view_2057, [0, 2, 1, 3]); view_2057 = None + _scaled_dot_product_cudnn_attention_backward_3 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_469, permute_311, permute_312, permute_313, getitem_1228, getitem_1229, getitem_1234, getitem_1235, None, None, None, 8192, 8192, 0.0, True); permute_469 = permute_311 = permute_312 = permute_313 = getitem_1228 = getitem_1229 = getitem_1234 = getitem_1235 = None + getitem_1529 = _scaled_dot_product_cudnn_attention_backward_3[0] + getitem_1530 = _scaled_dot_product_cudnn_attention_backward_3[1] + getitem_1531 = _scaled_dot_product_cudnn_attention_backward_3[2]; _scaled_dot_product_cudnn_attention_backward_3 = None + permute_470 = torch.ops.aten.permute.default(getitem_1531, [0, 2, 1, 3]); getitem_1531 = None + permute_471 = torch.ops.aten.permute.default(getitem_1530, [0, 2, 1, 3]); getitem_1530 = None + permute_472 = torch.ops.aten.permute.default(getitem_1529, [0, 2, 1, 3]); getitem_1529 = None + view_2404 = torch.ops.aten.view.default(permute_470, [2, 8192, 1, 4, 128]); permute_470 = None + sum_23 = torch.ops.aten.sum.dim_IntList(view_2404, [3], True); view_2404 = None + squeeze_6 = torch.ops.aten.squeeze.dim(sum_23, 3); sum_23 = None + view_2405 = torch.ops.aten.view.default(permute_471, [2, 8192, 1, 4, 128]); permute_471 = None + sum_24 = torch.ops.aten.sum.dim_IntList(view_2405, [3], True); view_2405 = None + squeeze_7 = torch.ops.aten.squeeze.dim(sum_24, 3); sum_24 = None + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(squeeze_7, torch.float32); squeeze_7 = None + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(permute_472, torch.float32); permute_472 = None + view_2406 = torch.ops.aten.view.default(convert_element_type_1265, [2, 8192, 1, 64, 2]); convert_element_type_1265 = None + view_as_complex_70 = torch.ops.aten.view_as_complex.default(view_2406); view_2406 = None + mul_336 = torch.ops.aten.mul.Tensor(view_as_complex_70, _conj); view_as_complex_70 = None + view_2407 = torch.ops.aten.view.default(convert_element_type_1266, [2, 8192, 4, 64, 2]); convert_element_type_1266 = None + view_as_complex_71 = torch.ops.aten.view_as_complex.default(view_2407); view_2407 = None + mul_337 = torch.ops.aten.mul.Tensor(view_as_complex_71, _conj); view_as_complex_71 = None + view_as_real_70 = torch.ops.aten.view_as_real.default(mul_336); mul_336 = None + view_2408 = torch.ops.aten.view.default(view_as_real_70, [2, 8192, 1, 128]); view_as_real_70 = None + convert_element_type_1267 = torch.ops.prims.convert_element_type.default(view_2408, torch.bfloat16); view_2408 = None + view_as_real_71 = torch.ops.aten.view_as_real.default(mul_337); mul_337 = None + view_2409 = torch.ops.aten.view.default(view_as_real_71, [2, 8192, 4, 128]); view_as_real_71 = None + convert_element_type_1268 = torch.ops.prims.convert_element_type.default(view_2409, torch.bfloat16); view_2409 = None + view_2410 = torch.ops.aten.view.default(squeeze_6, [2, 8192, 128]); squeeze_6 = None + view_2411 = torch.ops.aten.view.default(convert_element_type_1267, [2, 8192, 128]); convert_element_type_1267 = None + view_2412 = torch.ops.aten.view.default(convert_element_type_1268, [2, 8192, 512]); convert_element_type_1268 = None + view_2413 = torch.ops.aten.view.default(view_2410, [16384, 128]); view_2410 = None + permute_473 = torch.ops.aten.permute.default(view_2413, [1, 0]) + mm_277 = torch.ops.aten.mm.default(permute_473, view_2031); permute_473 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16); primals_259 = None + all_gather_into_tensor_313 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 32, '0'); convert_element_type_934 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_313); all_gather_into_tensor_313 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_370, [1, 0]); wait_tensor_370 = None + permute_475 = torch.ops.aten.permute.default(permute_310, [1, 0]); permute_310 = None + mm_278 = torch.ops.aten.mm.default(view_2413, permute_475); view_2413 = permute_475 = None + view_2414 = torch.ops.aten.view.default(mm_278, [2, 8192, 4096]); mm_278 = None + convert_element_type_1273 = torch.ops.prims.convert_element_type.default(mm_277, torch.float32); mm_277 = None + reduce_scatter_tensor_107 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1273, 'avg', 32, '0'); convert_element_type_1273 = None + wait_tensor_479 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_107); reduce_scatter_tensor_107 = None + view_2415 = torch.ops.aten.view.default(view_2411, [16384, 128]); view_2411 = None + permute_477 = torch.ops.aten.permute.default(view_2415, [1, 0]) + mm_279 = torch.ops.aten.mm.default(permute_477, view_2031); permute_477 = None + permute_479 = torch.ops.aten.permute.default(permute_309, [1, 0]); permute_309 = None + mm_280 = torch.ops.aten.mm.default(view_2415, permute_479); view_2415 = permute_479 = None + view_2416 = torch.ops.aten.view.default(mm_280, [2, 8192, 4096]); mm_280 = None + add_154 = torch.ops.aten.add.Tensor(view_2414, view_2416); view_2414 = view_2416 = None + convert_element_type_1278 = torch.ops.prims.convert_element_type.default(mm_279, torch.float32); mm_279 = None + reduce_scatter_tensor_108 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1278, 'avg', 32, '0'); convert_element_type_1278 = None + wait_tensor_480 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_108); reduce_scatter_tensor_108 = None + view_2417 = torch.ops.aten.view.default(view_2412, [16384, 512]); view_2412 = None + permute_481 = torch.ops.aten.permute.default(view_2417, [1, 0]) + mm_281 = torch.ops.aten.mm.default(permute_481, view_2031); permute_481 = view_2031 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16); primals_257 = None + all_gather_into_tensor_311 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 32, '0'); convert_element_type_928 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_311); all_gather_into_tensor_311 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_368, [1, 0]); wait_tensor_368 = None + permute_483 = torch.ops.aten.permute.default(permute_308, [1, 0]); permute_308 = None + mm_282 = torch.ops.aten.mm.default(view_2417, permute_483); view_2417 = permute_483 = None + view_2418 = torch.ops.aten.view.default(mm_282, [2, 8192, 4096]); mm_282 = None + add_155 = torch.ops.aten.add.Tensor(add_154, view_2418); add_154 = view_2418 = None + convert_element_type_1283 = torch.ops.prims.convert_element_type.default(mm_281, torch.float32); mm_281 = None + reduce_scatter_tensor_109 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1283, 'avg', 32, '0'); convert_element_type_1283 = None + wait_tensor_481 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_109); reduce_scatter_tensor_109 = None + split_154 = torch.ops.aten.split.Tensor(add_155, 1024, 1); add_155 = None + getitem_1532 = split_154[0] + getitem_1533 = split_154[1] + getitem_1534 = split_154[2] + getitem_1535 = split_154[3] + getitem_1536 = split_154[4] + getitem_1537 = split_154[5] + getitem_1538 = split_154[6] + getitem_1539 = split_154[7]; split_154 = None + cat_146 = torch.ops.aten.cat.default([getitem_1532, getitem_1533, getitem_1534, getitem_1535, getitem_1536, getitem_1537, getitem_1538, getitem_1539]); getitem_1532 = getitem_1533 = getitem_1534 = getitem_1535 = getitem_1536 = getitem_1537 = getitem_1538 = getitem_1539 = None + reduce_scatter_tensor_110 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_146, 'sum', 8, '1'); cat_146 = None + wait_tensor_482 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_110); reduce_scatter_tensor_110 = None + convert_element_type_1284 = torch.ops.prims.convert_element_type.default(wait_tensor_482, torch.float32); wait_tensor_482 = None + convert_element_type_1286 = torch.ops.prims.convert_element_type.default(wait_tensor_366, torch.float32); wait_tensor_366 = None + mul_338 = torch.ops.aten.mul.Tensor(convert_element_type_1284, convert_element_type_1286); convert_element_type_1286 = None + mul_340 = torch.ops.aten.mul.Tensor(mul_224, mul_338) + sum_25 = torch.ops.aten.sum.dim_IntList(mul_340, [2], True); mul_340 = None + div_8 = torch.ops.aten.div.Tensor(mul_224, 4096) + mul_341 = torch.ops.aten.mul.Tensor(div_8, sum_25); div_8 = sum_25 = None + sub_13 = torch.ops.aten.sub.Tensor(mul_338, mul_341); mul_338 = mul_341 = None + mul_342 = torch.ops.aten.mul.Tensor(sub_13, rsqrt_56); sub_13 = rsqrt_56 = None + mul_343 = torch.ops.aten.mul.Tensor(convert_element_type_1284, mul_224); convert_element_type_1284 = mul_224 = None + sum_26 = torch.ops.aten.sum.dim_IntList(mul_343, [0, 1]); mul_343 = None + convert_element_type_1287 = torch.ops.prims.convert_element_type.default(mul_342, torch.bfloat16); mul_342 = None + convert_element_type_1288 = torch.ops.prims.convert_element_type.default(sum_26, torch.bfloat16); sum_26 = None + all_reduce_8 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1288, 'sum', '1'); convert_element_type_1288 = None + wait_tensor_483 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_8); all_reduce_8 = None + convert_element_type_1289 = torch.ops.prims.convert_element_type.default(wait_tensor_483, torch.float32); wait_tensor_483 = None + reduce_scatter_tensor_111 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1289, 'avg', 32, '0'); convert_element_type_1289 = None + wait_tensor_484 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_111); reduce_scatter_tensor_111 = None + add_156 = torch.ops.aten.add.Tensor(add_153, convert_element_type_1287); add_153 = convert_element_type_1287 = None + all_gather_into_tensor_364 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_156, 8, '1') + wait_tensor_485 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_364); all_gather_into_tensor_364 = None + split_155 = torch.ops.aten.split.Tensor(wait_tensor_485, 2); wait_tensor_485 = None + getitem_1540 = split_155[0] + getitem_1541 = split_155[1] + getitem_1542 = split_155[2] + getitem_1543 = split_155[3] + getitem_1544 = split_155[4] + getitem_1545 = split_155[5] + getitem_1546 = split_155[6] + getitem_1547 = split_155[7]; split_155 = None + cat_147 = torch.ops.aten.cat.default([getitem_1540, getitem_1541, getitem_1542, getitem_1543, getitem_1544, getitem_1545, getitem_1546, getitem_1547], 1); getitem_1540 = getitem_1541 = getitem_1542 = getitem_1543 = getitem_1544 = getitem_1545 = getitem_1546 = getitem_1547 = None + view_2419 = torch.ops.aten.view.default(cat_147, [16384, 4096]); cat_147 = None + permute_485 = torch.ops.aten.permute.default(view_2419, [1, 0]) + wait_tensor_359 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55); reduce_scatter_tensor_55 = None + add_109 = torch.ops.aten.add.Tensor(add_107, wait_tensor_359); wait_tensor_359 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16); primals_252 = None + all_gather_into_tensor_304 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 32, '0'); convert_element_type_911 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_304); all_gather_into_tensor_304 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32); add_109 = None + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_360) + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + all_gather_into_tensor_305 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_913, 8, '1'); convert_element_type_913 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_305); all_gather_into_tensor_305 = None + split_119 = torch.ops.aten.split.Tensor(wait_tensor_361, 2); wait_tensor_361 = None + getitem_1204 = split_119[0] + getitem_1205 = split_119[1] + getitem_1206 = split_119[2] + getitem_1207 = split_119[3] + getitem_1208 = split_119[4] + getitem_1209 = split_119[5] + getitem_1210 = split_119[6] + getitem_1211 = split_119[7]; split_119 = None + cat_111 = torch.ops.aten.cat.default([getitem_1204, getitem_1205, getitem_1206, getitem_1207, getitem_1208, getitem_1209, getitem_1210, getitem_1211], 1); getitem_1204 = getitem_1205 = getitem_1206 = getitem_1207 = getitem_1208 = getitem_1209 = getitem_1210 = getitem_1211 = None + view_2004 = torch.ops.aten.view.default(cat_111, [16384, 4096]); cat_111 = None + view_2005 = torch.ops.aten.view.default(mm_193, [2, 8192, 1792]); mm_193 = None + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_2005, torch.float32); view_2005 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16); primals_254 = None + all_gather_into_tensor_307 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 32, '0'); convert_element_type_919 = None + wait_tensor_363 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_307); all_gather_into_tensor_307 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_363, [1, 0]); wait_tensor_363 = None + mm_194 = torch.ops.aten.mm.default(view_2004, permute_306) + view_2012 = torch.ops.aten.view.default(mm_194, [2, 8192, 1792]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_2012) + view_2019 = torch.ops.aten.view.default(mul_223, [16384, 1792]); mul_223 = None + mm_283 = torch.ops.aten.mm.default(permute_485, view_2019); permute_485 = view_2019 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16); primals_255 = None + all_gather_into_tensor_308 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 32, '0'); convert_element_type_922 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_308); all_gather_into_tensor_308 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_364, [1, 0]); wait_tensor_364 = None + permute_487 = torch.ops.aten.permute.default(permute_307, [1, 0]); permute_307 = None + mm_284 = torch.ops.aten.mm.default(view_2419, permute_487); view_2419 = permute_487 = None + view_2420 = torch.ops.aten.view.default(mm_284, [2, 8192, 1792]); mm_284 = None + convert_element_type_1294 = torch.ops.prims.convert_element_type.default(mm_283, torch.float32); mm_283 = None + reduce_scatter_tensor_112 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1294, 'avg', 32, '0'); convert_element_type_1294 = None + wait_tensor_486 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_112); reduce_scatter_tensor_112 = None + mul_344 = torch.ops.aten.mul.Tensor(view_2420, convert_element_type_918); convert_element_type_918 = None + mul_345 = torch.ops.aten.mul.Tensor(view_2420, view_2012); view_2420 = view_2012 = None + view_2421 = torch.ops.aten.view.default(mul_344, [16384, 1792]); mul_344 = None + permute_489 = torch.ops.aten.permute.default(view_2421, [1, 0]) + mm_285 = torch.ops.aten.mm.default(permute_489, view_2004); permute_489 = None + permute_491 = torch.ops.aten.permute.default(permute_306, [1, 0]); permute_306 = None + mm_286 = torch.ops.aten.mm.default(view_2421, permute_491); view_2421 = permute_491 = None + view_2422 = torch.ops.aten.view.default(mm_286, [2, 8192, 4096]); mm_286 = None + convert_element_type_1299 = torch.ops.prims.convert_element_type.default(mm_285, torch.float32); mm_285 = None + reduce_scatter_tensor_113 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1299, 'avg', 32, '0'); convert_element_type_1299 = None + wait_tensor_487 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_113); reduce_scatter_tensor_113 = None + convert_element_type_1300 = torch.ops.prims.convert_element_type.default(mul_345, torch.float32); mul_345 = None + neg_4 = torch.ops.aten.neg.default(convert_element_type_917) + exp_4 = torch.ops.aten.exp.default(neg_4); neg_4 = None + add_157 = torch.ops.aten.add.Tensor(exp_4, 1); exp_4 = None + reciprocal_4 = torch.ops.aten.reciprocal.default(add_157); add_157 = None + mul_346 = torch.ops.aten.mul.Tensor(reciprocal_4, 1); reciprocal_4 = None + mul_347 = torch.ops.aten.mul.Tensor(convert_element_type_1300, mul_346); convert_element_type_1300 = None + sub_14 = torch.ops.aten.sub.Tensor(1, mul_346); mul_346 = None + mul_348 = torch.ops.aten.mul.Tensor(convert_element_type_917, sub_14); convert_element_type_917 = sub_14 = None + add_158 = torch.ops.aten.add.Tensor(mul_348, 1); mul_348 = None + mul_349 = torch.ops.aten.mul.Tensor(mul_347, add_158); mul_347 = add_158 = None + convert_element_type_1302 = torch.ops.prims.convert_element_type.default(mul_349, torch.bfloat16); mul_349 = None + view_2423 = torch.ops.aten.view.default(convert_element_type_1302, [16384, 1792]); convert_element_type_1302 = None + permute_493 = torch.ops.aten.permute.default(view_2423, [1, 0]) + mm_287 = torch.ops.aten.mm.default(permute_493, view_2004); permute_493 = view_2004 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16); primals_253 = None + all_gather_into_tensor_306 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 32, '0'); convert_element_type_914 = None + wait_tensor_362 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_306); all_gather_into_tensor_306 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_362, [1, 0]); wait_tensor_362 = None + permute_495 = torch.ops.aten.permute.default(permute_305, [1, 0]); permute_305 = None + mm_288 = torch.ops.aten.mm.default(view_2423, permute_495); view_2423 = permute_495 = None + view_2424 = torch.ops.aten.view.default(mm_288, [2, 8192, 4096]); mm_288 = None + add_159 = torch.ops.aten.add.Tensor(view_2422, view_2424); view_2422 = view_2424 = None + convert_element_type_1307 = torch.ops.prims.convert_element_type.default(mm_287, torch.float32); mm_287 = None + reduce_scatter_tensor_114 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1307, 'avg', 32, '0'); convert_element_type_1307 = None + wait_tensor_488 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_114); reduce_scatter_tensor_114 = None + split_156 = torch.ops.aten.split.Tensor(add_159, 1024, 1); add_159 = None + getitem_1548 = split_156[0] + getitem_1549 = split_156[1] + getitem_1550 = split_156[2] + getitem_1551 = split_156[3] + getitem_1552 = split_156[4] + getitem_1553 = split_156[5] + getitem_1554 = split_156[6] + getitem_1555 = split_156[7]; split_156 = None + cat_148 = torch.ops.aten.cat.default([getitem_1548, getitem_1549, getitem_1550, getitem_1551, getitem_1552, getitem_1553, getitem_1554, getitem_1555]); getitem_1548 = getitem_1549 = getitem_1550 = getitem_1551 = getitem_1552 = getitem_1553 = getitem_1554 = getitem_1555 = None + reduce_scatter_tensor_115 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_148, 'sum', 8, '1'); cat_148 = None + wait_tensor_489 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_115); reduce_scatter_tensor_115 = None + convert_element_type_1308 = torch.ops.prims.convert_element_type.default(wait_tensor_489, torch.float32); wait_tensor_489 = None + convert_element_type_1310 = torch.ops.prims.convert_element_type.default(wait_tensor_360, torch.float32); wait_tensor_360 = None + mul_350 = torch.ops.aten.mul.Tensor(convert_element_type_1308, convert_element_type_1310); convert_element_type_1310 = None + mul_352 = torch.ops.aten.mul.Tensor(mul_220, mul_350) + sum_27 = torch.ops.aten.sum.dim_IntList(mul_352, [2], True); mul_352 = None + div_9 = torch.ops.aten.div.Tensor(mul_220, 4096) + mul_353 = torch.ops.aten.mul.Tensor(div_9, sum_27); div_9 = sum_27 = None + sub_15 = torch.ops.aten.sub.Tensor(mul_350, mul_353); mul_350 = mul_353 = None + mul_354 = torch.ops.aten.mul.Tensor(sub_15, rsqrt_55); sub_15 = rsqrt_55 = None + mul_355 = torch.ops.aten.mul.Tensor(convert_element_type_1308, mul_220); convert_element_type_1308 = mul_220 = None + sum_28 = torch.ops.aten.sum.dim_IntList(mul_355, [0, 1]); mul_355 = None + convert_element_type_1311 = torch.ops.prims.convert_element_type.default(mul_354, torch.bfloat16); mul_354 = None + convert_element_type_1312 = torch.ops.prims.convert_element_type.default(sum_28, torch.bfloat16); sum_28 = None + all_reduce_9 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1312, 'sum', '1'); convert_element_type_1312 = None + wait_tensor_490 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_9); all_reduce_9 = None + convert_element_type_1313 = torch.ops.prims.convert_element_type.default(wait_tensor_490, torch.float32); wait_tensor_490 = None + reduce_scatter_tensor_116 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1313, 'avg', 32, '0'); convert_element_type_1313 = None + wait_tensor_491 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_116); reduce_scatter_tensor_116 = None + add_160 = torch.ops.aten.add.Tensor(add_156, convert_element_type_1311); add_156 = convert_element_type_1311 = None + all_gather_into_tensor_365 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_160, 8, '1') + wait_tensor_492 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_365); all_gather_into_tensor_365 = None + split_157 = torch.ops.aten.split.Tensor(wait_tensor_492, 2); wait_tensor_492 = None + getitem_1556 = split_157[0] + getitem_1557 = split_157[1] + getitem_1558 = split_157[2] + getitem_1559 = split_157[3] + getitem_1560 = split_157[4] + getitem_1561 = split_157[5] + getitem_1562 = split_157[6] + getitem_1563 = split_157[7]; split_157 = None + cat_149 = torch.ops.aten.cat.default([getitem_1556, getitem_1557, getitem_1558, getitem_1559, getitem_1560, getitem_1561, getitem_1562, getitem_1563], 1); getitem_1556 = getitem_1557 = getitem_1558 = getitem_1559 = getitem_1560 = getitem_1561 = getitem_1562 = getitem_1563 = None + view_2425 = torch.ops.aten.view.default(cat_149, [16384, 4096]); cat_149 = None + permute_497 = torch.ops.aten.permute.default(view_2425, [1, 0]) + permute_303 = torch.ops.aten.permute.default(getitem_1187, [0, 2, 1, 3]) + view_1986 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + view_1992 = torch.ops.aten.view.default(view_1986, [16384, 512]); view_1986 = None + mm_289 = torch.ops.aten.mm.default(permute_497, view_1992); permute_497 = view_1992 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16); primals_251 = None + all_gather_into_tensor_303 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 32, '0'); convert_element_type_908 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_303); all_gather_into_tensor_303 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_358, [1, 0]); wait_tensor_358 = None + permute_499 = torch.ops.aten.permute.default(permute_304, [1, 0]); permute_304 = None + mm_290 = torch.ops.aten.mm.default(view_2425, permute_499); view_2425 = permute_499 = None + view_2426 = torch.ops.aten.view.default(mm_290, [2, 8192, 512]); mm_290 = None + convert_element_type_1318 = torch.ops.prims.convert_element_type.default(mm_289, torch.float32); mm_289 = None + reduce_scatter_tensor_117 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1318, 'avg', 32, '0'); convert_element_type_1318 = None + wait_tensor_493 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_117); reduce_scatter_tensor_117 = None + view_2427 = torch.ops.aten.view.default(view_2426, [2, 8192, 4, 128]); view_2426 = None + permute_501 = torch.ops.aten.permute.default(view_2427, [0, 2, 1, 3]); view_2427 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16); primals_247 = None + all_gather_into_tensor_298 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 32, '0'); convert_element_type_892 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_298); all_gather_into_tensor_298 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32); add_107 = None + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_353) + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + all_gather_into_tensor_299 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_894, 8, '1'); convert_element_type_894 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_299); all_gather_into_tensor_299 = None + split_117 = torch.ops.aten.split.Tensor(wait_tensor_354, 2); wait_tensor_354 = None + getitem_1179 = split_117[0] + getitem_1180 = split_117[1] + getitem_1181 = split_117[2] + getitem_1182 = split_117[3] + getitem_1183 = split_117[4] + getitem_1184 = split_117[5] + getitem_1185 = split_117[6] + getitem_1186 = split_117[7]; split_117 = None + cat_109 = torch.ops.aten.cat.default([getitem_1179, getitem_1180, getitem_1181, getitem_1182, getitem_1183, getitem_1184, getitem_1185, getitem_1186], 1); getitem_1179 = getitem_1180 = getitem_1181 = getitem_1182 = getitem_1183 = getitem_1184 = getitem_1185 = getitem_1186 = None + view_1959 = torch.ops.aten.view.default(cat_109, [16384, 4096]); cat_109 = None + view_1960 = torch.ops.aten.view.default(mm_189, [2, 8192, 512]); mm_189 = None + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16); primals_249 = None + all_gather_into_tensor_301 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 32, '0'); convert_element_type_898 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_301); all_gather_into_tensor_301 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_356, [1, 0]); wait_tensor_356 = None + mm_190 = torch.ops.aten.mm.default(view_1959, permute_298) + view_1967 = torch.ops.aten.view.default(mm_190, [2, 8192, 128]); mm_190 = None + view_1974 = torch.ops.aten.view.default(mm_191, [2, 8192, 128]); mm_191 = None + view_1976 = torch.ops.aten.view.default(view_1960, [2, 8192, -1, 128]); view_1960 = None + view_1977 = torch.ops.aten.view.default(view_1967, [2, 8192, -1, 128]); view_1967 = None + view_1978 = torch.ops.aten.view.default(view_1974, [2, 8192, -1, 128]); view_1974 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_1976, torch.float32); view_1976 = None + view_1979 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 4, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_1979); view_1979 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_1977, torch.float32); view_1977 = None + view_1980 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 1, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_1980); view_1980 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_37); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_1982 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 4, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_37); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_1983 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 1, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_1982, torch.bfloat16); view_1982 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_1983, torch.bfloat16); view_1983 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 1, 4, 128]); unsqueeze_54 = None + view_1984 = torch.ops.aten.view.default(expand_54, [2, 8192, 4, 128]); expand_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_1978, 3); view_1978 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 1, 4, 128]); unsqueeze_55 = None + view_1985 = torch.ops.aten.view.default(expand_55, [2, 8192, 4, 128]); expand_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_1984, [0, 2, 1, 3]); view_1984 = None + permute_302 = torch.ops.aten.permute.default(view_1985, [0, 2, 1, 3]); view_1985 = None + _scaled_dot_product_cudnn_attention_backward_4 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_501, permute_300, permute_301, permute_302, getitem_1187, getitem_1188, getitem_1193, getitem_1194, None, None, None, 8192, 8192, 0.0, True); permute_501 = permute_300 = permute_301 = permute_302 = getitem_1187 = getitem_1188 = getitem_1193 = getitem_1194 = None + getitem_1564 = _scaled_dot_product_cudnn_attention_backward_4[0] + getitem_1565 = _scaled_dot_product_cudnn_attention_backward_4[1] + getitem_1566 = _scaled_dot_product_cudnn_attention_backward_4[2]; _scaled_dot_product_cudnn_attention_backward_4 = None + permute_502 = torch.ops.aten.permute.default(getitem_1566, [0, 2, 1, 3]); getitem_1566 = None + permute_503 = torch.ops.aten.permute.default(getitem_1565, [0, 2, 1, 3]); getitem_1565 = None + permute_504 = torch.ops.aten.permute.default(getitem_1564, [0, 2, 1, 3]); getitem_1564 = None + view_2428 = torch.ops.aten.view.default(permute_502, [2, 8192, 1, 4, 128]); permute_502 = None + sum_29 = torch.ops.aten.sum.dim_IntList(view_2428, [3], True); view_2428 = None + squeeze_8 = torch.ops.aten.squeeze.dim(sum_29, 3); sum_29 = None + view_2429 = torch.ops.aten.view.default(permute_503, [2, 8192, 1, 4, 128]); permute_503 = None + sum_30 = torch.ops.aten.sum.dim_IntList(view_2429, [3], True); view_2429 = None + squeeze_9 = torch.ops.aten.squeeze.dim(sum_30, 3); sum_30 = None + convert_element_type_1319 = torch.ops.prims.convert_element_type.default(squeeze_9, torch.float32); squeeze_9 = None + convert_element_type_1320 = torch.ops.prims.convert_element_type.default(permute_504, torch.float32); permute_504 = None + view_2430 = torch.ops.aten.view.default(convert_element_type_1319, [2, 8192, 1, 64, 2]); convert_element_type_1319 = None + view_as_complex_72 = torch.ops.aten.view_as_complex.default(view_2430); view_2430 = None + mul_356 = torch.ops.aten.mul.Tensor(view_as_complex_72, _conj); view_as_complex_72 = None + view_2431 = torch.ops.aten.view.default(convert_element_type_1320, [2, 8192, 4, 64, 2]); convert_element_type_1320 = None + view_as_complex_73 = torch.ops.aten.view_as_complex.default(view_2431); view_2431 = None + mul_357 = torch.ops.aten.mul.Tensor(view_as_complex_73, _conj); view_as_complex_73 = None + view_as_real_72 = torch.ops.aten.view_as_real.default(mul_356); mul_356 = None + view_2432 = torch.ops.aten.view.default(view_as_real_72, [2, 8192, 1, 128]); view_as_real_72 = None + convert_element_type_1321 = torch.ops.prims.convert_element_type.default(view_2432, torch.bfloat16); view_2432 = None + view_as_real_73 = torch.ops.aten.view_as_real.default(mul_357); mul_357 = None + view_2433 = torch.ops.aten.view.default(view_as_real_73, [2, 8192, 4, 128]); view_as_real_73 = None + convert_element_type_1322 = torch.ops.prims.convert_element_type.default(view_2433, torch.bfloat16); view_2433 = None + view_2434 = torch.ops.aten.view.default(squeeze_8, [2, 8192, 128]); squeeze_8 = None + view_2435 = torch.ops.aten.view.default(convert_element_type_1321, [2, 8192, 128]); convert_element_type_1321 = None + view_2436 = torch.ops.aten.view.default(convert_element_type_1322, [2, 8192, 512]); convert_element_type_1322 = None + view_2437 = torch.ops.aten.view.default(view_2434, [16384, 128]); view_2434 = None + permute_505 = torch.ops.aten.permute.default(view_2437, [1, 0]) + mm_291 = torch.ops.aten.mm.default(permute_505, view_1959); permute_505 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16); primals_250 = None + all_gather_into_tensor_302 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 32, '0'); convert_element_type_901 = None + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_302); all_gather_into_tensor_302 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_357, [1, 0]); wait_tensor_357 = None + permute_507 = torch.ops.aten.permute.default(permute_299, [1, 0]); permute_299 = None + mm_292 = torch.ops.aten.mm.default(view_2437, permute_507); view_2437 = permute_507 = None + view_2438 = torch.ops.aten.view.default(mm_292, [2, 8192, 4096]); mm_292 = None + convert_element_type_1327 = torch.ops.prims.convert_element_type.default(mm_291, torch.float32); mm_291 = None + reduce_scatter_tensor_118 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1327, 'avg', 32, '0'); convert_element_type_1327 = None + wait_tensor_494 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_118); reduce_scatter_tensor_118 = None + view_2439 = torch.ops.aten.view.default(view_2435, [16384, 128]); view_2435 = None + permute_509 = torch.ops.aten.permute.default(view_2439, [1, 0]) + mm_293 = torch.ops.aten.mm.default(permute_509, view_1959); permute_509 = None + permute_511 = torch.ops.aten.permute.default(permute_298, [1, 0]); permute_298 = None + mm_294 = torch.ops.aten.mm.default(view_2439, permute_511); view_2439 = permute_511 = None + view_2440 = torch.ops.aten.view.default(mm_294, [2, 8192, 4096]); mm_294 = None + add_161 = torch.ops.aten.add.Tensor(view_2438, view_2440); view_2438 = view_2440 = None + convert_element_type_1332 = torch.ops.prims.convert_element_type.default(mm_293, torch.float32); mm_293 = None + reduce_scatter_tensor_119 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1332, 'avg', 32, '0'); convert_element_type_1332 = None + wait_tensor_495 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_119); reduce_scatter_tensor_119 = None + view_2441 = torch.ops.aten.view.default(view_2436, [16384, 512]); view_2436 = None + permute_513 = torch.ops.aten.permute.default(view_2441, [1, 0]) + mm_295 = torch.ops.aten.mm.default(permute_513, view_1959); permute_513 = view_1959 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16); primals_248 = None + all_gather_into_tensor_300 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 32, '0'); convert_element_type_895 = None + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_300); all_gather_into_tensor_300 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_355, [1, 0]); wait_tensor_355 = None + permute_515 = torch.ops.aten.permute.default(permute_297, [1, 0]); permute_297 = None + mm_296 = torch.ops.aten.mm.default(view_2441, permute_515); view_2441 = permute_515 = None + view_2442 = torch.ops.aten.view.default(mm_296, [2, 8192, 4096]); mm_296 = None + add_162 = torch.ops.aten.add.Tensor(add_161, view_2442); add_161 = view_2442 = None + convert_element_type_1337 = torch.ops.prims.convert_element_type.default(mm_295, torch.float32); mm_295 = None + reduce_scatter_tensor_120 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1337, 'avg', 32, '0'); convert_element_type_1337 = None + wait_tensor_496 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_120); reduce_scatter_tensor_120 = None + split_158 = torch.ops.aten.split.Tensor(add_162, 1024, 1); add_162 = None + getitem_1567 = split_158[0] + getitem_1568 = split_158[1] + getitem_1569 = split_158[2] + getitem_1570 = split_158[3] + getitem_1571 = split_158[4] + getitem_1572 = split_158[5] + getitem_1573 = split_158[6] + getitem_1574 = split_158[7]; split_158 = None + cat_150 = torch.ops.aten.cat.default([getitem_1567, getitem_1568, getitem_1569, getitem_1570, getitem_1571, getitem_1572, getitem_1573, getitem_1574]); getitem_1567 = getitem_1568 = getitem_1569 = getitem_1570 = getitem_1571 = getitem_1572 = getitem_1573 = getitem_1574 = None + reduce_scatter_tensor_121 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_150, 'sum', 8, '1'); cat_150 = None + wait_tensor_497 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_121); reduce_scatter_tensor_121 = None + convert_element_type_1338 = torch.ops.prims.convert_element_type.default(wait_tensor_497, torch.float32); wait_tensor_497 = None + convert_element_type_1340 = torch.ops.prims.convert_element_type.default(wait_tensor_353, torch.float32); wait_tensor_353 = None + mul_358 = torch.ops.aten.mul.Tensor(convert_element_type_1338, convert_element_type_1340); convert_element_type_1340 = None + mul_360 = torch.ops.aten.mul.Tensor(mul_216, mul_358) + sum_31 = torch.ops.aten.sum.dim_IntList(mul_360, [2], True); mul_360 = None + div_10 = torch.ops.aten.div.Tensor(mul_216, 4096) + mul_361 = torch.ops.aten.mul.Tensor(div_10, sum_31); div_10 = sum_31 = None + sub_16 = torch.ops.aten.sub.Tensor(mul_358, mul_361); mul_358 = mul_361 = None + mul_362 = torch.ops.aten.mul.Tensor(sub_16, rsqrt_54); sub_16 = rsqrt_54 = None + mul_363 = torch.ops.aten.mul.Tensor(convert_element_type_1338, mul_216); convert_element_type_1338 = mul_216 = None + sum_32 = torch.ops.aten.sum.dim_IntList(mul_363, [0, 1]); mul_363 = None + convert_element_type_1341 = torch.ops.prims.convert_element_type.default(mul_362, torch.bfloat16); mul_362 = None + convert_element_type_1342 = torch.ops.prims.convert_element_type.default(sum_32, torch.bfloat16); sum_32 = None + all_reduce_10 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1342, 'sum', '1'); convert_element_type_1342 = None + wait_tensor_498 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_10); all_reduce_10 = None + convert_element_type_1343 = torch.ops.prims.convert_element_type.default(wait_tensor_498, torch.float32); wait_tensor_498 = None + reduce_scatter_tensor_122 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1343, 'avg', 32, '0'); convert_element_type_1343 = None + wait_tensor_499 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_122); reduce_scatter_tensor_122 = None + add_163 = torch.ops.aten.add.Tensor(add_160, convert_element_type_1341); add_160 = convert_element_type_1341 = None + all_gather_into_tensor_366 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_163, 8, '1') + wait_tensor_500 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_366); all_gather_into_tensor_366 = None + split_159 = torch.ops.aten.split.Tensor(wait_tensor_500, 2); wait_tensor_500 = None + getitem_1575 = split_159[0] + getitem_1576 = split_159[1] + getitem_1577 = split_159[2] + getitem_1578 = split_159[3] + getitem_1579 = split_159[4] + getitem_1580 = split_159[5] + getitem_1581 = split_159[6] + getitem_1582 = split_159[7]; split_159 = None + cat_151 = torch.ops.aten.cat.default([getitem_1575, getitem_1576, getitem_1577, getitem_1578, getitem_1579, getitem_1580, getitem_1581, getitem_1582], 1); getitem_1575 = getitem_1576 = getitem_1577 = getitem_1578 = getitem_1579 = getitem_1580 = getitem_1581 = getitem_1582 = None + view_2443 = torch.ops.aten.view.default(cat_151, [16384, 4096]); cat_151 = None + permute_517 = torch.ops.aten.permute.default(view_2443, [1, 0]) + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53); reduce_scatter_tensor_53 = None + add_105 = torch.ops.aten.add.Tensor(add_103, wait_tensor_346); wait_tensor_346 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16); primals_243 = None + all_gather_into_tensor_293 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 32, '0'); convert_element_type_878 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_293); all_gather_into_tensor_293 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32); add_105 = None + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_347) + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + all_gather_into_tensor_294 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_880, 8, '1'); convert_element_type_880 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_294); all_gather_into_tensor_294 = None + split_115 = torch.ops.aten.split.Tensor(wait_tensor_348, 2); wait_tensor_348 = None + getitem_1163 = split_115[0] + getitem_1164 = split_115[1] + getitem_1165 = split_115[2] + getitem_1166 = split_115[3] + getitem_1167 = split_115[4] + getitem_1168 = split_115[5] + getitem_1169 = split_115[6] + getitem_1170 = split_115[7]; split_115 = None + cat_107 = torch.ops.aten.cat.default([getitem_1163, getitem_1164, getitem_1165, getitem_1166, getitem_1167, getitem_1168, getitem_1169, getitem_1170], 1); getitem_1163 = getitem_1164 = getitem_1165 = getitem_1166 = getitem_1167 = getitem_1168 = getitem_1169 = getitem_1170 = None + view_1932 = torch.ops.aten.view.default(cat_107, [16384, 4096]); cat_107 = None + view_1933 = torch.ops.aten.view.default(mm_186, [2, 8192, 1792]); mm_186 = None + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_1933, torch.float32); view_1933 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16); primals_245 = None + all_gather_into_tensor_296 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 32, '0'); convert_element_type_886 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_296); all_gather_into_tensor_296 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_350, [1, 0]); wait_tensor_350 = None + mm_187 = torch.ops.aten.mm.default(view_1932, permute_295) + view_1940 = torch.ops.aten.view.default(mm_187, [2, 8192, 1792]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_1940) + view_1947 = torch.ops.aten.view.default(mul_215, [16384, 1792]); mul_215 = None + mm_297 = torch.ops.aten.mm.default(permute_517, view_1947); permute_517 = view_1947 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16); primals_246 = None + all_gather_into_tensor_297 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 32, '0'); convert_element_type_889 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_297); all_gather_into_tensor_297 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_351, [1, 0]); wait_tensor_351 = None + permute_519 = torch.ops.aten.permute.default(permute_296, [1, 0]); permute_296 = None + mm_298 = torch.ops.aten.mm.default(view_2443, permute_519); view_2443 = permute_519 = None + view_2444 = torch.ops.aten.view.default(mm_298, [2, 8192, 1792]); mm_298 = None + convert_element_type_1348 = torch.ops.prims.convert_element_type.default(mm_297, torch.float32); mm_297 = None + reduce_scatter_tensor_123 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1348, 'avg', 32, '0'); convert_element_type_1348 = None + wait_tensor_501 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_123); reduce_scatter_tensor_123 = None + mul_364 = torch.ops.aten.mul.Tensor(view_2444, convert_element_type_885); convert_element_type_885 = None + mul_365 = torch.ops.aten.mul.Tensor(view_2444, view_1940); view_2444 = view_1940 = None + view_2445 = torch.ops.aten.view.default(mul_364, [16384, 1792]); mul_364 = None + permute_521 = torch.ops.aten.permute.default(view_2445, [1, 0]) + mm_299 = torch.ops.aten.mm.default(permute_521, view_1932); permute_521 = None + permute_523 = torch.ops.aten.permute.default(permute_295, [1, 0]); permute_295 = None + mm_300 = torch.ops.aten.mm.default(view_2445, permute_523); view_2445 = permute_523 = None + view_2446 = torch.ops.aten.view.default(mm_300, [2, 8192, 4096]); mm_300 = None + convert_element_type_1353 = torch.ops.prims.convert_element_type.default(mm_299, torch.float32); mm_299 = None + reduce_scatter_tensor_124 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1353, 'avg', 32, '0'); convert_element_type_1353 = None + wait_tensor_502 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_124); reduce_scatter_tensor_124 = None + convert_element_type_1354 = torch.ops.prims.convert_element_type.default(mul_365, torch.float32); mul_365 = None + neg_5 = torch.ops.aten.neg.default(convert_element_type_884) + exp_5 = torch.ops.aten.exp.default(neg_5); neg_5 = None + add_164 = torch.ops.aten.add.Tensor(exp_5, 1); exp_5 = None + reciprocal_5 = torch.ops.aten.reciprocal.default(add_164); add_164 = None + mul_366 = torch.ops.aten.mul.Tensor(reciprocal_5, 1); reciprocal_5 = None + mul_367 = torch.ops.aten.mul.Tensor(convert_element_type_1354, mul_366); convert_element_type_1354 = None + sub_17 = torch.ops.aten.sub.Tensor(1, mul_366); mul_366 = None + mul_368 = torch.ops.aten.mul.Tensor(convert_element_type_884, sub_17); convert_element_type_884 = sub_17 = None + add_165 = torch.ops.aten.add.Tensor(mul_368, 1); mul_368 = None + mul_369 = torch.ops.aten.mul.Tensor(mul_367, add_165); mul_367 = add_165 = None + convert_element_type_1356 = torch.ops.prims.convert_element_type.default(mul_369, torch.bfloat16); mul_369 = None + view_2447 = torch.ops.aten.view.default(convert_element_type_1356, [16384, 1792]); convert_element_type_1356 = None + permute_525 = torch.ops.aten.permute.default(view_2447, [1, 0]) + mm_301 = torch.ops.aten.mm.default(permute_525, view_1932); permute_525 = view_1932 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16); primals_244 = None + all_gather_into_tensor_295 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 32, '0'); convert_element_type_881 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_295); all_gather_into_tensor_295 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_349, [1, 0]); wait_tensor_349 = None + permute_527 = torch.ops.aten.permute.default(permute_294, [1, 0]); permute_294 = None + mm_302 = torch.ops.aten.mm.default(view_2447, permute_527); view_2447 = permute_527 = None + view_2448 = torch.ops.aten.view.default(mm_302, [2, 8192, 4096]); mm_302 = None + add_166 = torch.ops.aten.add.Tensor(view_2446, view_2448); view_2446 = view_2448 = None + convert_element_type_1361 = torch.ops.prims.convert_element_type.default(mm_301, torch.float32); mm_301 = None + reduce_scatter_tensor_125 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1361, 'avg', 32, '0'); convert_element_type_1361 = None + wait_tensor_503 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_125); reduce_scatter_tensor_125 = None + split_160 = torch.ops.aten.split.Tensor(add_166, 1024, 1); add_166 = None + getitem_1583 = split_160[0] + getitem_1584 = split_160[1] + getitem_1585 = split_160[2] + getitem_1586 = split_160[3] + getitem_1587 = split_160[4] + getitem_1588 = split_160[5] + getitem_1589 = split_160[6] + getitem_1590 = split_160[7]; split_160 = None + cat_152 = torch.ops.aten.cat.default([getitem_1583, getitem_1584, getitem_1585, getitem_1586, getitem_1587, getitem_1588, getitem_1589, getitem_1590]); getitem_1583 = getitem_1584 = getitem_1585 = getitem_1586 = getitem_1587 = getitem_1588 = getitem_1589 = getitem_1590 = None + reduce_scatter_tensor_126 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_152, 'sum', 8, '1'); cat_152 = None + wait_tensor_504 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_126); reduce_scatter_tensor_126 = None + convert_element_type_1362 = torch.ops.prims.convert_element_type.default(wait_tensor_504, torch.float32); wait_tensor_504 = None + convert_element_type_1364 = torch.ops.prims.convert_element_type.default(wait_tensor_347, torch.float32); wait_tensor_347 = None + mul_370 = torch.ops.aten.mul.Tensor(convert_element_type_1362, convert_element_type_1364); convert_element_type_1364 = None + mul_372 = torch.ops.aten.mul.Tensor(mul_212, mul_370) + sum_33 = torch.ops.aten.sum.dim_IntList(mul_372, [2], True); mul_372 = None + div_11 = torch.ops.aten.div.Tensor(mul_212, 4096) + mul_373 = torch.ops.aten.mul.Tensor(div_11, sum_33); div_11 = sum_33 = None + sub_18 = torch.ops.aten.sub.Tensor(mul_370, mul_373); mul_370 = mul_373 = None + mul_374 = torch.ops.aten.mul.Tensor(sub_18, rsqrt_53); sub_18 = rsqrt_53 = None + mul_375 = torch.ops.aten.mul.Tensor(convert_element_type_1362, mul_212); convert_element_type_1362 = mul_212 = None + sum_34 = torch.ops.aten.sum.dim_IntList(mul_375, [0, 1]); mul_375 = None + convert_element_type_1365 = torch.ops.prims.convert_element_type.default(mul_374, torch.bfloat16); mul_374 = None + convert_element_type_1366 = torch.ops.prims.convert_element_type.default(sum_34, torch.bfloat16); sum_34 = None + all_reduce_11 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1366, 'sum', '1'); convert_element_type_1366 = None + wait_tensor_505 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_11); all_reduce_11 = None + convert_element_type_1367 = torch.ops.prims.convert_element_type.default(wait_tensor_505, torch.float32); wait_tensor_505 = None + reduce_scatter_tensor_127 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1367, 'avg', 32, '0'); convert_element_type_1367 = None + wait_tensor_506 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_127); reduce_scatter_tensor_127 = None + add_167 = torch.ops.aten.add.Tensor(add_163, convert_element_type_1365); add_163 = convert_element_type_1365 = None + all_gather_into_tensor_367 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_167, 8, '1') + wait_tensor_507 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_367); all_gather_into_tensor_367 = None + split_161 = torch.ops.aten.split.Tensor(wait_tensor_507, 2); wait_tensor_507 = None + getitem_1591 = split_161[0] + getitem_1592 = split_161[1] + getitem_1593 = split_161[2] + getitem_1594 = split_161[3] + getitem_1595 = split_161[4] + getitem_1596 = split_161[5] + getitem_1597 = split_161[6] + getitem_1598 = split_161[7]; split_161 = None + cat_153 = torch.ops.aten.cat.default([getitem_1591, getitem_1592, getitem_1593, getitem_1594, getitem_1595, getitem_1596, getitem_1597, getitem_1598], 1); getitem_1591 = getitem_1592 = getitem_1593 = getitem_1594 = getitem_1595 = getitem_1596 = getitem_1597 = getitem_1598 = None + view_2449 = torch.ops.aten.view.default(cat_153, [16384, 4096]); cat_153 = None + permute_529 = torch.ops.aten.permute.default(view_2449, [1, 0]) + permute_292 = torch.ops.aten.permute.default(getitem_1146, [0, 2, 1, 3]) + view_1914 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + view_1920 = torch.ops.aten.view.default(view_1914, [16384, 512]); view_1914 = None + mm_303 = torch.ops.aten.mm.default(permute_529, view_1920); permute_529 = view_1920 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16); primals_242 = None + all_gather_into_tensor_292 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 32, '0'); convert_element_type_875 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_292); all_gather_into_tensor_292 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_345, [1, 0]); wait_tensor_345 = None + permute_531 = torch.ops.aten.permute.default(permute_293, [1, 0]); permute_293 = None + mm_304 = torch.ops.aten.mm.default(view_2449, permute_531); view_2449 = permute_531 = None + view_2450 = torch.ops.aten.view.default(mm_304, [2, 8192, 512]); mm_304 = None + convert_element_type_1372 = torch.ops.prims.convert_element_type.default(mm_303, torch.float32); mm_303 = None + reduce_scatter_tensor_128 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1372, 'avg', 32, '0'); convert_element_type_1372 = None + wait_tensor_508 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_128); reduce_scatter_tensor_128 = None + view_2451 = torch.ops.aten.view.default(view_2450, [2, 8192, 4, 128]); view_2450 = None + permute_533 = torch.ops.aten.permute.default(view_2451, [0, 2, 1, 3]); view_2451 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16); primals_238 = None + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 32, '0'); convert_element_type_859 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32); add_103 = None + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_340) + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_861, 8, '1'); convert_element_type_861 = None + wait_tensor_341 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + split_113 = torch.ops.aten.split.Tensor(wait_tensor_341, 2); wait_tensor_341 = None + getitem_1138 = split_113[0] + getitem_1139 = split_113[1] + getitem_1140 = split_113[2] + getitem_1141 = split_113[3] + getitem_1142 = split_113[4] + getitem_1143 = split_113[5] + getitem_1144 = split_113[6] + getitem_1145 = split_113[7]; split_113 = None + cat_105 = torch.ops.aten.cat.default([getitem_1138, getitem_1139, getitem_1140, getitem_1141, getitem_1142, getitem_1143, getitem_1144, getitem_1145], 1); getitem_1138 = getitem_1139 = getitem_1140 = getitem_1141 = getitem_1142 = getitem_1143 = getitem_1144 = getitem_1145 = None + view_1887 = torch.ops.aten.view.default(cat_105, [16384, 4096]); cat_105 = None + view_1888 = torch.ops.aten.view.default(mm_182, [2, 8192, 512]); mm_182 = None + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16); primals_240 = None + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 32, '0'); convert_element_type_865 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_343, [1, 0]); wait_tensor_343 = None + mm_183 = torch.ops.aten.mm.default(view_1887, permute_287) + view_1895 = torch.ops.aten.view.default(mm_183, [2, 8192, 128]); mm_183 = None + view_1902 = torch.ops.aten.view.default(mm_184, [2, 8192, 128]); mm_184 = None + view_1904 = torch.ops.aten.view.default(view_1888, [2, 8192, -1, 128]); view_1888 = None + view_1905 = torch.ops.aten.view.default(view_1895, [2, 8192, -1, 128]); view_1895 = None + view_1906 = torch.ops.aten.view.default(view_1902, [2, 8192, -1, 128]); view_1902 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_1904, torch.float32); view_1904 = None + view_1907 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 4, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_1907); view_1907 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_1905, torch.float32); view_1905 = None + view_1908 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 1, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_1908); view_1908 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_37); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_1910 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 4, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_37); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_1911 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 1, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_1910, torch.bfloat16); view_1910 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_1911, torch.bfloat16); view_1911 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 1, 4, 128]); unsqueeze_52 = None + view_1912 = torch.ops.aten.view.default(expand_52, [2, 8192, 4, 128]); expand_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_1906, 3); view_1906 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 1, 4, 128]); unsqueeze_53 = None + view_1913 = torch.ops.aten.view.default(expand_53, [2, 8192, 4, 128]); expand_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_1912, [0, 2, 1, 3]); view_1912 = None + permute_291 = torch.ops.aten.permute.default(view_1913, [0, 2, 1, 3]); view_1913 = None + _scaled_dot_product_cudnn_attention_backward_5 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_533, permute_289, permute_290, permute_291, getitem_1146, getitem_1147, getitem_1152, getitem_1153, None, None, None, 8192, 8192, 0.0, True); permute_533 = permute_289 = permute_290 = permute_291 = getitem_1146 = getitem_1147 = getitem_1152 = getitem_1153 = None + getitem_1599 = _scaled_dot_product_cudnn_attention_backward_5[0] + getitem_1600 = _scaled_dot_product_cudnn_attention_backward_5[1] + getitem_1601 = _scaled_dot_product_cudnn_attention_backward_5[2]; _scaled_dot_product_cudnn_attention_backward_5 = None + permute_534 = torch.ops.aten.permute.default(getitem_1601, [0, 2, 1, 3]); getitem_1601 = None + permute_535 = torch.ops.aten.permute.default(getitem_1600, [0, 2, 1, 3]); getitem_1600 = None + permute_536 = torch.ops.aten.permute.default(getitem_1599, [0, 2, 1, 3]); getitem_1599 = None + view_2452 = torch.ops.aten.view.default(permute_534, [2, 8192, 1, 4, 128]); permute_534 = None + sum_35 = torch.ops.aten.sum.dim_IntList(view_2452, [3], True); view_2452 = None + squeeze_10 = torch.ops.aten.squeeze.dim(sum_35, 3); sum_35 = None + view_2453 = torch.ops.aten.view.default(permute_535, [2, 8192, 1, 4, 128]); permute_535 = None + sum_36 = torch.ops.aten.sum.dim_IntList(view_2453, [3], True); view_2453 = None + squeeze_11 = torch.ops.aten.squeeze.dim(sum_36, 3); sum_36 = None + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(squeeze_11, torch.float32); squeeze_11 = None + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(permute_536, torch.float32); permute_536 = None + view_2454 = torch.ops.aten.view.default(convert_element_type_1373, [2, 8192, 1, 64, 2]); convert_element_type_1373 = None + view_as_complex_74 = torch.ops.aten.view_as_complex.default(view_2454); view_2454 = None + mul_376 = torch.ops.aten.mul.Tensor(view_as_complex_74, _conj); view_as_complex_74 = None + view_2455 = torch.ops.aten.view.default(convert_element_type_1374, [2, 8192, 4, 64, 2]); convert_element_type_1374 = None + view_as_complex_75 = torch.ops.aten.view_as_complex.default(view_2455); view_2455 = None + mul_377 = torch.ops.aten.mul.Tensor(view_as_complex_75, _conj); view_as_complex_75 = None + view_as_real_74 = torch.ops.aten.view_as_real.default(mul_376); mul_376 = None + view_2456 = torch.ops.aten.view.default(view_as_real_74, [2, 8192, 1, 128]); view_as_real_74 = None + convert_element_type_1375 = torch.ops.prims.convert_element_type.default(view_2456, torch.bfloat16); view_2456 = None + view_as_real_75 = torch.ops.aten.view_as_real.default(mul_377); mul_377 = None + view_2457 = torch.ops.aten.view.default(view_as_real_75, [2, 8192, 4, 128]); view_as_real_75 = None + convert_element_type_1376 = torch.ops.prims.convert_element_type.default(view_2457, torch.bfloat16); view_2457 = None + view_2458 = torch.ops.aten.view.default(squeeze_10, [2, 8192, 128]); squeeze_10 = None + view_2459 = torch.ops.aten.view.default(convert_element_type_1375, [2, 8192, 128]); convert_element_type_1375 = None + view_2460 = torch.ops.aten.view.default(convert_element_type_1376, [2, 8192, 512]); convert_element_type_1376 = None + view_2461 = torch.ops.aten.view.default(view_2458, [16384, 128]); view_2458 = None + permute_537 = torch.ops.aten.permute.default(view_2461, [1, 0]) + mm_305 = torch.ops.aten.mm.default(permute_537, view_1887); permute_537 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16); primals_241 = None + all_gather_into_tensor_291 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 32, '0'); convert_element_type_868 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_291); all_gather_into_tensor_291 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_344, [1, 0]); wait_tensor_344 = None + permute_539 = torch.ops.aten.permute.default(permute_288, [1, 0]); permute_288 = None + mm_306 = torch.ops.aten.mm.default(view_2461, permute_539); view_2461 = permute_539 = None + view_2462 = torch.ops.aten.view.default(mm_306, [2, 8192, 4096]); mm_306 = None + convert_element_type_1381 = torch.ops.prims.convert_element_type.default(mm_305, torch.float32); mm_305 = None + reduce_scatter_tensor_129 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1381, 'avg', 32, '0'); convert_element_type_1381 = None + wait_tensor_509 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_129); reduce_scatter_tensor_129 = None + view_2463 = torch.ops.aten.view.default(view_2459, [16384, 128]); view_2459 = None + permute_541 = torch.ops.aten.permute.default(view_2463, [1, 0]) + mm_307 = torch.ops.aten.mm.default(permute_541, view_1887); permute_541 = None + permute_543 = torch.ops.aten.permute.default(permute_287, [1, 0]); permute_287 = None + mm_308 = torch.ops.aten.mm.default(view_2463, permute_543); view_2463 = permute_543 = None + view_2464 = torch.ops.aten.view.default(mm_308, [2, 8192, 4096]); mm_308 = None + add_168 = torch.ops.aten.add.Tensor(view_2462, view_2464); view_2462 = view_2464 = None + convert_element_type_1386 = torch.ops.prims.convert_element_type.default(mm_307, torch.float32); mm_307 = None + reduce_scatter_tensor_130 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1386, 'avg', 32, '0'); convert_element_type_1386 = None + wait_tensor_510 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_130); reduce_scatter_tensor_130 = None + view_2465 = torch.ops.aten.view.default(view_2460, [16384, 512]); view_2460 = None + permute_545 = torch.ops.aten.permute.default(view_2465, [1, 0]) + mm_309 = torch.ops.aten.mm.default(permute_545, view_1887); permute_545 = view_1887 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16); primals_239 = None + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 32, '0'); convert_element_type_862 = None + wait_tensor_342 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_342, [1, 0]); wait_tensor_342 = None + permute_547 = torch.ops.aten.permute.default(permute_286, [1, 0]); permute_286 = None + mm_310 = torch.ops.aten.mm.default(view_2465, permute_547); view_2465 = permute_547 = None + view_2466 = torch.ops.aten.view.default(mm_310, [2, 8192, 4096]); mm_310 = None + add_169 = torch.ops.aten.add.Tensor(add_168, view_2466); add_168 = view_2466 = None + convert_element_type_1391 = torch.ops.prims.convert_element_type.default(mm_309, torch.float32); mm_309 = None + reduce_scatter_tensor_131 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1391, 'avg', 32, '0'); convert_element_type_1391 = None + wait_tensor_511 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_131); reduce_scatter_tensor_131 = None + split_162 = torch.ops.aten.split.Tensor(add_169, 1024, 1); add_169 = None + getitem_1602 = split_162[0] + getitem_1603 = split_162[1] + getitem_1604 = split_162[2] + getitem_1605 = split_162[3] + getitem_1606 = split_162[4] + getitem_1607 = split_162[5] + getitem_1608 = split_162[6] + getitem_1609 = split_162[7]; split_162 = None + cat_154 = torch.ops.aten.cat.default([getitem_1602, getitem_1603, getitem_1604, getitem_1605, getitem_1606, getitem_1607, getitem_1608, getitem_1609]); getitem_1602 = getitem_1603 = getitem_1604 = getitem_1605 = getitem_1606 = getitem_1607 = getitem_1608 = getitem_1609 = None + reduce_scatter_tensor_132 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_154, 'sum', 8, '1'); cat_154 = None + wait_tensor_512 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_132); reduce_scatter_tensor_132 = None + convert_element_type_1392 = torch.ops.prims.convert_element_type.default(wait_tensor_512, torch.float32); wait_tensor_512 = None + convert_element_type_1394 = torch.ops.prims.convert_element_type.default(wait_tensor_340, torch.float32); wait_tensor_340 = None + mul_378 = torch.ops.aten.mul.Tensor(convert_element_type_1392, convert_element_type_1394); convert_element_type_1394 = None + mul_380 = torch.ops.aten.mul.Tensor(mul_208, mul_378) + sum_37 = torch.ops.aten.sum.dim_IntList(mul_380, [2], True); mul_380 = None + div_12 = torch.ops.aten.div.Tensor(mul_208, 4096) + mul_381 = torch.ops.aten.mul.Tensor(div_12, sum_37); div_12 = sum_37 = None + sub_19 = torch.ops.aten.sub.Tensor(mul_378, mul_381); mul_378 = mul_381 = None + mul_382 = torch.ops.aten.mul.Tensor(sub_19, rsqrt_52); sub_19 = rsqrt_52 = None + mul_383 = torch.ops.aten.mul.Tensor(convert_element_type_1392, mul_208); convert_element_type_1392 = mul_208 = None + sum_38 = torch.ops.aten.sum.dim_IntList(mul_383, [0, 1]); mul_383 = None + convert_element_type_1395 = torch.ops.prims.convert_element_type.default(mul_382, torch.bfloat16); mul_382 = None + convert_element_type_1396 = torch.ops.prims.convert_element_type.default(sum_38, torch.bfloat16); sum_38 = None + all_reduce_12 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1396, 'sum', '1'); convert_element_type_1396 = None + wait_tensor_513 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_12); all_reduce_12 = None + convert_element_type_1397 = torch.ops.prims.convert_element_type.default(wait_tensor_513, torch.float32); wait_tensor_513 = None + reduce_scatter_tensor_133 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1397, 'avg', 32, '0'); convert_element_type_1397 = None + wait_tensor_514 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_133); reduce_scatter_tensor_133 = None + add_170 = torch.ops.aten.add.Tensor(add_167, convert_element_type_1395); add_167 = convert_element_type_1395 = None + all_gather_into_tensor_368 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_170, 8, '1') + wait_tensor_515 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_368); all_gather_into_tensor_368 = None + split_163 = torch.ops.aten.split.Tensor(wait_tensor_515, 2); wait_tensor_515 = None + getitem_1610 = split_163[0] + getitem_1611 = split_163[1] + getitem_1612 = split_163[2] + getitem_1613 = split_163[3] + getitem_1614 = split_163[4] + getitem_1615 = split_163[5] + getitem_1616 = split_163[6] + getitem_1617 = split_163[7]; split_163 = None + cat_155 = torch.ops.aten.cat.default([getitem_1610, getitem_1611, getitem_1612, getitem_1613, getitem_1614, getitem_1615, getitem_1616, getitem_1617], 1); getitem_1610 = getitem_1611 = getitem_1612 = getitem_1613 = getitem_1614 = getitem_1615 = getitem_1616 = getitem_1617 = None + view_2467 = torch.ops.aten.view.default(cat_155, [16384, 4096]); cat_155 = None + permute_549 = torch.ops.aten.permute.default(view_2467, [1, 0]) + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51); reduce_scatter_tensor_51 = None + add_101 = torch.ops.aten.add.Tensor(add_99, wait_tensor_333); wait_tensor_333 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16); primals_234 = None + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 32, '0'); convert_element_type_845 = None + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32); add_101 = None + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_334) + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_847, 8, '1'); convert_element_type_847 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + split_111 = torch.ops.aten.split.Tensor(wait_tensor_335, 2); wait_tensor_335 = None + getitem_1122 = split_111[0] + getitem_1123 = split_111[1] + getitem_1124 = split_111[2] + getitem_1125 = split_111[3] + getitem_1126 = split_111[4] + getitem_1127 = split_111[5] + getitem_1128 = split_111[6] + getitem_1129 = split_111[7]; split_111 = None + cat_103 = torch.ops.aten.cat.default([getitem_1122, getitem_1123, getitem_1124, getitem_1125, getitem_1126, getitem_1127, getitem_1128, getitem_1129], 1); getitem_1122 = getitem_1123 = getitem_1124 = getitem_1125 = getitem_1126 = getitem_1127 = getitem_1128 = getitem_1129 = None + view_1860 = torch.ops.aten.view.default(cat_103, [16384, 4096]); cat_103 = None + view_1861 = torch.ops.aten.view.default(mm_179, [2, 8192, 1792]); mm_179 = None + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_1861, torch.float32); view_1861 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16); primals_236 = None + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 32, '0'); convert_element_type_853 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_337, [1, 0]); wait_tensor_337 = None + mm_180 = torch.ops.aten.mm.default(view_1860, permute_284) + view_1868 = torch.ops.aten.view.default(mm_180, [2, 8192, 1792]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_1868) + view_1875 = torch.ops.aten.view.default(mul_207, [16384, 1792]); mul_207 = None + mm_311 = torch.ops.aten.mm.default(permute_549, view_1875); permute_549 = view_1875 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16); primals_237 = None + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 32, '0'); convert_element_type_856 = None + wait_tensor_338 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_338, [1, 0]); wait_tensor_338 = None + permute_551 = torch.ops.aten.permute.default(permute_285, [1, 0]); permute_285 = None + mm_312 = torch.ops.aten.mm.default(view_2467, permute_551); view_2467 = permute_551 = None + view_2468 = torch.ops.aten.view.default(mm_312, [2, 8192, 1792]); mm_312 = None + convert_element_type_1402 = torch.ops.prims.convert_element_type.default(mm_311, torch.float32); mm_311 = None + reduce_scatter_tensor_134 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1402, 'avg', 32, '0'); convert_element_type_1402 = None + wait_tensor_516 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_134); reduce_scatter_tensor_134 = None + mul_384 = torch.ops.aten.mul.Tensor(view_2468, convert_element_type_852); convert_element_type_852 = None + mul_385 = torch.ops.aten.mul.Tensor(view_2468, view_1868); view_2468 = view_1868 = None + view_2469 = torch.ops.aten.view.default(mul_384, [16384, 1792]); mul_384 = None + permute_553 = torch.ops.aten.permute.default(view_2469, [1, 0]) + mm_313 = torch.ops.aten.mm.default(permute_553, view_1860); permute_553 = None + permute_555 = torch.ops.aten.permute.default(permute_284, [1, 0]); permute_284 = None + mm_314 = torch.ops.aten.mm.default(view_2469, permute_555); view_2469 = permute_555 = None + view_2470 = torch.ops.aten.view.default(mm_314, [2, 8192, 4096]); mm_314 = None + convert_element_type_1407 = torch.ops.prims.convert_element_type.default(mm_313, torch.float32); mm_313 = None + reduce_scatter_tensor_135 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1407, 'avg', 32, '0'); convert_element_type_1407 = None + wait_tensor_517 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_135); reduce_scatter_tensor_135 = None + convert_element_type_1408 = torch.ops.prims.convert_element_type.default(mul_385, torch.float32); mul_385 = None + neg_6 = torch.ops.aten.neg.default(convert_element_type_851) + exp_6 = torch.ops.aten.exp.default(neg_6); neg_6 = None + add_171 = torch.ops.aten.add.Tensor(exp_6, 1); exp_6 = None + reciprocal_6 = torch.ops.aten.reciprocal.default(add_171); add_171 = None + mul_386 = torch.ops.aten.mul.Tensor(reciprocal_6, 1); reciprocal_6 = None + mul_387 = torch.ops.aten.mul.Tensor(convert_element_type_1408, mul_386); convert_element_type_1408 = None + sub_20 = torch.ops.aten.sub.Tensor(1, mul_386); mul_386 = None + mul_388 = torch.ops.aten.mul.Tensor(convert_element_type_851, sub_20); convert_element_type_851 = sub_20 = None + add_172 = torch.ops.aten.add.Tensor(mul_388, 1); mul_388 = None + mul_389 = torch.ops.aten.mul.Tensor(mul_387, add_172); mul_387 = add_172 = None + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(mul_389, torch.bfloat16); mul_389 = None + view_2471 = torch.ops.aten.view.default(convert_element_type_1410, [16384, 1792]); convert_element_type_1410 = None + permute_557 = torch.ops.aten.permute.default(view_2471, [1, 0]) + mm_315 = torch.ops.aten.mm.default(permute_557, view_1860); permute_557 = view_1860 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16); primals_235 = None + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 32, '0'); convert_element_type_848 = None + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_336, [1, 0]); wait_tensor_336 = None + permute_559 = torch.ops.aten.permute.default(permute_283, [1, 0]); permute_283 = None + mm_316 = torch.ops.aten.mm.default(view_2471, permute_559); view_2471 = permute_559 = None + view_2472 = torch.ops.aten.view.default(mm_316, [2, 8192, 4096]); mm_316 = None + add_173 = torch.ops.aten.add.Tensor(view_2470, view_2472); view_2470 = view_2472 = None + convert_element_type_1415 = torch.ops.prims.convert_element_type.default(mm_315, torch.float32); mm_315 = None + reduce_scatter_tensor_136 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1415, 'avg', 32, '0'); convert_element_type_1415 = None + wait_tensor_518 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_136); reduce_scatter_tensor_136 = None + split_164 = torch.ops.aten.split.Tensor(add_173, 1024, 1); add_173 = None + getitem_1618 = split_164[0] + getitem_1619 = split_164[1] + getitem_1620 = split_164[2] + getitem_1621 = split_164[3] + getitem_1622 = split_164[4] + getitem_1623 = split_164[5] + getitem_1624 = split_164[6] + getitem_1625 = split_164[7]; split_164 = None + cat_156 = torch.ops.aten.cat.default([getitem_1618, getitem_1619, getitem_1620, getitem_1621, getitem_1622, getitem_1623, getitem_1624, getitem_1625]); getitem_1618 = getitem_1619 = getitem_1620 = getitem_1621 = getitem_1622 = getitem_1623 = getitem_1624 = getitem_1625 = None + reduce_scatter_tensor_137 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_156, 'sum', 8, '1'); cat_156 = None + wait_tensor_519 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_137); reduce_scatter_tensor_137 = None + convert_element_type_1416 = torch.ops.prims.convert_element_type.default(wait_tensor_519, torch.float32); wait_tensor_519 = None + convert_element_type_1418 = torch.ops.prims.convert_element_type.default(wait_tensor_334, torch.float32); wait_tensor_334 = None + mul_390 = torch.ops.aten.mul.Tensor(convert_element_type_1416, convert_element_type_1418); convert_element_type_1418 = None + mul_392 = torch.ops.aten.mul.Tensor(mul_204, mul_390) + sum_39 = torch.ops.aten.sum.dim_IntList(mul_392, [2], True); mul_392 = None + div_13 = torch.ops.aten.div.Tensor(mul_204, 4096) + mul_393 = torch.ops.aten.mul.Tensor(div_13, sum_39); div_13 = sum_39 = None + sub_21 = torch.ops.aten.sub.Tensor(mul_390, mul_393); mul_390 = mul_393 = None + mul_394 = torch.ops.aten.mul.Tensor(sub_21, rsqrt_51); sub_21 = rsqrt_51 = None + mul_395 = torch.ops.aten.mul.Tensor(convert_element_type_1416, mul_204); convert_element_type_1416 = mul_204 = None + sum_40 = torch.ops.aten.sum.dim_IntList(mul_395, [0, 1]); mul_395 = None + convert_element_type_1419 = torch.ops.prims.convert_element_type.default(mul_394, torch.bfloat16); mul_394 = None + convert_element_type_1420 = torch.ops.prims.convert_element_type.default(sum_40, torch.bfloat16); sum_40 = None + all_reduce_13 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1420, 'sum', '1'); convert_element_type_1420 = None + wait_tensor_520 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_13); all_reduce_13 = None + convert_element_type_1421 = torch.ops.prims.convert_element_type.default(wait_tensor_520, torch.float32); wait_tensor_520 = None + reduce_scatter_tensor_138 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1421, 'avg', 32, '0'); convert_element_type_1421 = None + wait_tensor_521 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_138); reduce_scatter_tensor_138 = None + add_174 = torch.ops.aten.add.Tensor(add_170, convert_element_type_1419); add_170 = convert_element_type_1419 = None + all_gather_into_tensor_369 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_174, 8, '1') + wait_tensor_522 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_369); all_gather_into_tensor_369 = None + split_165 = torch.ops.aten.split.Tensor(wait_tensor_522, 2); wait_tensor_522 = None + getitem_1626 = split_165[0] + getitem_1627 = split_165[1] + getitem_1628 = split_165[2] + getitem_1629 = split_165[3] + getitem_1630 = split_165[4] + getitem_1631 = split_165[5] + getitem_1632 = split_165[6] + getitem_1633 = split_165[7]; split_165 = None + cat_157 = torch.ops.aten.cat.default([getitem_1626, getitem_1627, getitem_1628, getitem_1629, getitem_1630, getitem_1631, getitem_1632, getitem_1633], 1); getitem_1626 = getitem_1627 = getitem_1628 = getitem_1629 = getitem_1630 = getitem_1631 = getitem_1632 = getitem_1633 = None + view_2473 = torch.ops.aten.view.default(cat_157, [16384, 4096]); cat_157 = None + permute_561 = torch.ops.aten.permute.default(view_2473, [1, 0]) + permute_281 = torch.ops.aten.permute.default(getitem_1105, [0, 2, 1, 3]) + view_1842 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + view_1848 = torch.ops.aten.view.default(view_1842, [16384, 512]); view_1842 = None + mm_317 = torch.ops.aten.mm.default(permute_561, view_1848); permute_561 = view_1848 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16); primals_233 = None + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 32, '0'); convert_element_type_842 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_332, [1, 0]); wait_tensor_332 = None + permute_563 = torch.ops.aten.permute.default(permute_282, [1, 0]); permute_282 = None + mm_318 = torch.ops.aten.mm.default(view_2473, permute_563); view_2473 = permute_563 = None + view_2474 = torch.ops.aten.view.default(mm_318, [2, 8192, 512]); mm_318 = None + convert_element_type_1426 = torch.ops.prims.convert_element_type.default(mm_317, torch.float32); mm_317 = None + reduce_scatter_tensor_139 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1426, 'avg', 32, '0'); convert_element_type_1426 = None + wait_tensor_523 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_139); reduce_scatter_tensor_139 = None + view_2475 = torch.ops.aten.view.default(view_2474, [2, 8192, 4, 128]); view_2474 = None + permute_565 = torch.ops.aten.permute.default(view_2475, [0, 2, 1, 3]); view_2475 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16); primals_229 = None + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 32, '0'); convert_element_type_826 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32); add_99 = None + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_327) + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_828, 8, '1'); convert_element_type_828 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + split_109 = torch.ops.aten.split.Tensor(wait_tensor_328, 2); wait_tensor_328 = None + getitem_1097 = split_109[0] + getitem_1098 = split_109[1] + getitem_1099 = split_109[2] + getitem_1100 = split_109[3] + getitem_1101 = split_109[4] + getitem_1102 = split_109[5] + getitem_1103 = split_109[6] + getitem_1104 = split_109[7]; split_109 = None + cat_101 = torch.ops.aten.cat.default([getitem_1097, getitem_1098, getitem_1099, getitem_1100, getitem_1101, getitem_1102, getitem_1103, getitem_1104], 1); getitem_1097 = getitem_1098 = getitem_1099 = getitem_1100 = getitem_1101 = getitem_1102 = getitem_1103 = getitem_1104 = None + view_1815 = torch.ops.aten.view.default(cat_101, [16384, 4096]); cat_101 = None + view_1816 = torch.ops.aten.view.default(mm_175, [2, 8192, 512]); mm_175 = None + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16); primals_231 = None + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 32, '0'); convert_element_type_832 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_330, [1, 0]); wait_tensor_330 = None + mm_176 = torch.ops.aten.mm.default(view_1815, permute_276) + view_1823 = torch.ops.aten.view.default(mm_176, [2, 8192, 128]); mm_176 = None + view_1830 = torch.ops.aten.view.default(mm_177, [2, 8192, 128]); mm_177 = None + view_1832 = torch.ops.aten.view.default(view_1816, [2, 8192, -1, 128]); view_1816 = None + view_1833 = torch.ops.aten.view.default(view_1823, [2, 8192, -1, 128]); view_1823 = None + view_1834 = torch.ops.aten.view.default(view_1830, [2, 8192, -1, 128]); view_1830 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_1832, torch.float32); view_1832 = None + view_1835 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 4, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_1835); view_1835 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_1833, torch.float32); view_1833 = None + view_1836 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 1, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_1836); view_1836 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_37); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_1838 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 4, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_37); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_1839 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 1, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_1838, torch.bfloat16); view_1838 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_1839, torch.bfloat16); view_1839 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 1, 4, 128]); unsqueeze_50 = None + view_1840 = torch.ops.aten.view.default(expand_50, [2, 8192, 4, 128]); expand_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_1834, 3); view_1834 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 1, 4, 128]); unsqueeze_51 = None + view_1841 = torch.ops.aten.view.default(expand_51, [2, 8192, 4, 128]); expand_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_1840, [0, 2, 1, 3]); view_1840 = None + permute_280 = torch.ops.aten.permute.default(view_1841, [0, 2, 1, 3]); view_1841 = None + _scaled_dot_product_cudnn_attention_backward_6 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_565, permute_278, permute_279, permute_280, getitem_1105, getitem_1106, getitem_1111, getitem_1112, None, None, None, 8192, 8192, 0.0, True); permute_565 = permute_278 = permute_279 = permute_280 = getitem_1105 = getitem_1106 = getitem_1111 = getitem_1112 = None + getitem_1634 = _scaled_dot_product_cudnn_attention_backward_6[0] + getitem_1635 = _scaled_dot_product_cudnn_attention_backward_6[1] + getitem_1636 = _scaled_dot_product_cudnn_attention_backward_6[2]; _scaled_dot_product_cudnn_attention_backward_6 = None + permute_566 = torch.ops.aten.permute.default(getitem_1636, [0, 2, 1, 3]); getitem_1636 = None + permute_567 = torch.ops.aten.permute.default(getitem_1635, [0, 2, 1, 3]); getitem_1635 = None + permute_568 = torch.ops.aten.permute.default(getitem_1634, [0, 2, 1, 3]); getitem_1634 = None + view_2476 = torch.ops.aten.view.default(permute_566, [2, 8192, 1, 4, 128]); permute_566 = None + sum_41 = torch.ops.aten.sum.dim_IntList(view_2476, [3], True); view_2476 = None + squeeze_12 = torch.ops.aten.squeeze.dim(sum_41, 3); sum_41 = None + view_2477 = torch.ops.aten.view.default(permute_567, [2, 8192, 1, 4, 128]); permute_567 = None + sum_42 = torch.ops.aten.sum.dim_IntList(view_2477, [3], True); view_2477 = None + squeeze_13 = torch.ops.aten.squeeze.dim(sum_42, 3); sum_42 = None + convert_element_type_1427 = torch.ops.prims.convert_element_type.default(squeeze_13, torch.float32); squeeze_13 = None + convert_element_type_1428 = torch.ops.prims.convert_element_type.default(permute_568, torch.float32); permute_568 = None + view_2478 = torch.ops.aten.view.default(convert_element_type_1427, [2, 8192, 1, 64, 2]); convert_element_type_1427 = None + view_as_complex_76 = torch.ops.aten.view_as_complex.default(view_2478); view_2478 = None + mul_396 = torch.ops.aten.mul.Tensor(view_as_complex_76, _conj); view_as_complex_76 = None + view_2479 = torch.ops.aten.view.default(convert_element_type_1428, [2, 8192, 4, 64, 2]); convert_element_type_1428 = None + view_as_complex_77 = torch.ops.aten.view_as_complex.default(view_2479); view_2479 = None + mul_397 = torch.ops.aten.mul.Tensor(view_as_complex_77, _conj); view_as_complex_77 = None + view_as_real_76 = torch.ops.aten.view_as_real.default(mul_396); mul_396 = None + view_2480 = torch.ops.aten.view.default(view_as_real_76, [2, 8192, 1, 128]); view_as_real_76 = None + convert_element_type_1429 = torch.ops.prims.convert_element_type.default(view_2480, torch.bfloat16); view_2480 = None + view_as_real_77 = torch.ops.aten.view_as_real.default(mul_397); mul_397 = None + view_2481 = torch.ops.aten.view.default(view_as_real_77, [2, 8192, 4, 128]); view_as_real_77 = None + convert_element_type_1430 = torch.ops.prims.convert_element_type.default(view_2481, torch.bfloat16); view_2481 = None + view_2482 = torch.ops.aten.view.default(squeeze_12, [2, 8192, 128]); squeeze_12 = None + view_2483 = torch.ops.aten.view.default(convert_element_type_1429, [2, 8192, 128]); convert_element_type_1429 = None + view_2484 = torch.ops.aten.view.default(convert_element_type_1430, [2, 8192, 512]); convert_element_type_1430 = None + view_2485 = torch.ops.aten.view.default(view_2482, [16384, 128]); view_2482 = None + permute_569 = torch.ops.aten.permute.default(view_2485, [1, 0]) + mm_319 = torch.ops.aten.mm.default(permute_569, view_1815); permute_569 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16); primals_232 = None + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 32, '0'); convert_element_type_835 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_331, [1, 0]); wait_tensor_331 = None + permute_571 = torch.ops.aten.permute.default(permute_277, [1, 0]); permute_277 = None + mm_320 = torch.ops.aten.mm.default(view_2485, permute_571); view_2485 = permute_571 = None + view_2486 = torch.ops.aten.view.default(mm_320, [2, 8192, 4096]); mm_320 = None + convert_element_type_1435 = torch.ops.prims.convert_element_type.default(mm_319, torch.float32); mm_319 = None + reduce_scatter_tensor_140 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1435, 'avg', 32, '0'); convert_element_type_1435 = None + wait_tensor_524 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_140); reduce_scatter_tensor_140 = None + view_2487 = torch.ops.aten.view.default(view_2483, [16384, 128]); view_2483 = None + permute_573 = torch.ops.aten.permute.default(view_2487, [1, 0]) + mm_321 = torch.ops.aten.mm.default(permute_573, view_1815); permute_573 = None + permute_575 = torch.ops.aten.permute.default(permute_276, [1, 0]); permute_276 = None + mm_322 = torch.ops.aten.mm.default(view_2487, permute_575); view_2487 = permute_575 = None + view_2488 = torch.ops.aten.view.default(mm_322, [2, 8192, 4096]); mm_322 = None + add_175 = torch.ops.aten.add.Tensor(view_2486, view_2488); view_2486 = view_2488 = None + convert_element_type_1440 = torch.ops.prims.convert_element_type.default(mm_321, torch.float32); mm_321 = None + reduce_scatter_tensor_141 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1440, 'avg', 32, '0'); convert_element_type_1440 = None + wait_tensor_525 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_141); reduce_scatter_tensor_141 = None + view_2489 = torch.ops.aten.view.default(view_2484, [16384, 512]); view_2484 = None + permute_577 = torch.ops.aten.permute.default(view_2489, [1, 0]) + mm_323 = torch.ops.aten.mm.default(permute_577, view_1815); permute_577 = view_1815 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16); primals_230 = None + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 32, '0'); convert_element_type_829 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_329, [1, 0]); wait_tensor_329 = None + permute_579 = torch.ops.aten.permute.default(permute_275, [1, 0]); permute_275 = None + mm_324 = torch.ops.aten.mm.default(view_2489, permute_579); view_2489 = permute_579 = None + view_2490 = torch.ops.aten.view.default(mm_324, [2, 8192, 4096]); mm_324 = None + add_176 = torch.ops.aten.add.Tensor(add_175, view_2490); add_175 = view_2490 = None + convert_element_type_1445 = torch.ops.prims.convert_element_type.default(mm_323, torch.float32); mm_323 = None + reduce_scatter_tensor_142 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1445, 'avg', 32, '0'); convert_element_type_1445 = None + wait_tensor_526 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_142); reduce_scatter_tensor_142 = None + split_166 = torch.ops.aten.split.Tensor(add_176, 1024, 1); add_176 = None + getitem_1637 = split_166[0] + getitem_1638 = split_166[1] + getitem_1639 = split_166[2] + getitem_1640 = split_166[3] + getitem_1641 = split_166[4] + getitem_1642 = split_166[5] + getitem_1643 = split_166[6] + getitem_1644 = split_166[7]; split_166 = None + cat_158 = torch.ops.aten.cat.default([getitem_1637, getitem_1638, getitem_1639, getitem_1640, getitem_1641, getitem_1642, getitem_1643, getitem_1644]); getitem_1637 = getitem_1638 = getitem_1639 = getitem_1640 = getitem_1641 = getitem_1642 = getitem_1643 = getitem_1644 = None + reduce_scatter_tensor_143 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_158, 'sum', 8, '1'); cat_158 = None + wait_tensor_527 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_143); reduce_scatter_tensor_143 = None + convert_element_type_1446 = torch.ops.prims.convert_element_type.default(wait_tensor_527, torch.float32); wait_tensor_527 = None + convert_element_type_1448 = torch.ops.prims.convert_element_type.default(wait_tensor_327, torch.float32); wait_tensor_327 = None + mul_398 = torch.ops.aten.mul.Tensor(convert_element_type_1446, convert_element_type_1448); convert_element_type_1448 = None + mul_400 = torch.ops.aten.mul.Tensor(mul_200, mul_398) + sum_43 = torch.ops.aten.sum.dim_IntList(mul_400, [2], True); mul_400 = None + div_14 = torch.ops.aten.div.Tensor(mul_200, 4096) + mul_401 = torch.ops.aten.mul.Tensor(div_14, sum_43); div_14 = sum_43 = None + sub_22 = torch.ops.aten.sub.Tensor(mul_398, mul_401); mul_398 = mul_401 = None + mul_402 = torch.ops.aten.mul.Tensor(sub_22, rsqrt_50); sub_22 = rsqrt_50 = None + mul_403 = torch.ops.aten.mul.Tensor(convert_element_type_1446, mul_200); convert_element_type_1446 = mul_200 = None + sum_44 = torch.ops.aten.sum.dim_IntList(mul_403, [0, 1]); mul_403 = None + convert_element_type_1449 = torch.ops.prims.convert_element_type.default(mul_402, torch.bfloat16); mul_402 = None + convert_element_type_1450 = torch.ops.prims.convert_element_type.default(sum_44, torch.bfloat16); sum_44 = None + all_reduce_14 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1450, 'sum', '1'); convert_element_type_1450 = None + wait_tensor_528 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_14); all_reduce_14 = None + convert_element_type_1451 = torch.ops.prims.convert_element_type.default(wait_tensor_528, torch.float32); wait_tensor_528 = None + reduce_scatter_tensor_144 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1451, 'avg', 32, '0'); convert_element_type_1451 = None + wait_tensor_529 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_144); reduce_scatter_tensor_144 = None + add_177 = torch.ops.aten.add.Tensor(add_174, convert_element_type_1449); add_174 = convert_element_type_1449 = None + all_gather_into_tensor_370 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_177, 8, '1') + wait_tensor_530 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_370); all_gather_into_tensor_370 = None + split_167 = torch.ops.aten.split.Tensor(wait_tensor_530, 2); wait_tensor_530 = None + getitem_1645 = split_167[0] + getitem_1646 = split_167[1] + getitem_1647 = split_167[2] + getitem_1648 = split_167[3] + getitem_1649 = split_167[4] + getitem_1650 = split_167[5] + getitem_1651 = split_167[6] + getitem_1652 = split_167[7]; split_167 = None + cat_159 = torch.ops.aten.cat.default([getitem_1645, getitem_1646, getitem_1647, getitem_1648, getitem_1649, getitem_1650, getitem_1651, getitem_1652], 1); getitem_1645 = getitem_1646 = getitem_1647 = getitem_1648 = getitem_1649 = getitem_1650 = getitem_1651 = getitem_1652 = None + view_2491 = torch.ops.aten.view.default(cat_159, [16384, 4096]); cat_159 = None + permute_581 = torch.ops.aten.permute.default(view_2491, [1, 0]) + wait_tensor_320 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49); reduce_scatter_tensor_49 = None + add_97 = torch.ops.aten.add.Tensor(add_95, wait_tensor_320); wait_tensor_320 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16); primals_225 = None + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 32, '0'); convert_element_type_812 = None + wait_tensor_321 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32); add_97 = None + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_321) + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_814, 8, '1'); convert_element_type_814 = None + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + split_107 = torch.ops.aten.split.Tensor(wait_tensor_322, 2); wait_tensor_322 = None + getitem_1081 = split_107[0] + getitem_1082 = split_107[1] + getitem_1083 = split_107[2] + getitem_1084 = split_107[3] + getitem_1085 = split_107[4] + getitem_1086 = split_107[5] + getitem_1087 = split_107[6] + getitem_1088 = split_107[7]; split_107 = None + cat_99 = torch.ops.aten.cat.default([getitem_1081, getitem_1082, getitem_1083, getitem_1084, getitem_1085, getitem_1086, getitem_1087, getitem_1088], 1); getitem_1081 = getitem_1082 = getitem_1083 = getitem_1084 = getitem_1085 = getitem_1086 = getitem_1087 = getitem_1088 = None + view_1788 = torch.ops.aten.view.default(cat_99, [16384, 4096]); cat_99 = None + view_1789 = torch.ops.aten.view.default(mm_172, [2, 8192, 1792]); mm_172 = None + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_1789, torch.float32); view_1789 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16); primals_227 = None + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 32, '0'); convert_element_type_820 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_324, [1, 0]); wait_tensor_324 = None + mm_173 = torch.ops.aten.mm.default(view_1788, permute_273) + view_1796 = torch.ops.aten.view.default(mm_173, [2, 8192, 1792]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_1796) + view_1803 = torch.ops.aten.view.default(mul_199, [16384, 1792]); mul_199 = None + mm_325 = torch.ops.aten.mm.default(permute_581, view_1803); permute_581 = view_1803 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16); primals_228 = None + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 32, '0'); convert_element_type_823 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_325, [1, 0]); wait_tensor_325 = None + permute_583 = torch.ops.aten.permute.default(permute_274, [1, 0]); permute_274 = None + mm_326 = torch.ops.aten.mm.default(view_2491, permute_583); view_2491 = permute_583 = None + view_2492 = torch.ops.aten.view.default(mm_326, [2, 8192, 1792]); mm_326 = None + convert_element_type_1456 = torch.ops.prims.convert_element_type.default(mm_325, torch.float32); mm_325 = None + reduce_scatter_tensor_145 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1456, 'avg', 32, '0'); convert_element_type_1456 = None + wait_tensor_531 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_145); reduce_scatter_tensor_145 = None + mul_404 = torch.ops.aten.mul.Tensor(view_2492, convert_element_type_819); convert_element_type_819 = None + mul_405 = torch.ops.aten.mul.Tensor(view_2492, view_1796); view_2492 = view_1796 = None + view_2493 = torch.ops.aten.view.default(mul_404, [16384, 1792]); mul_404 = None + permute_585 = torch.ops.aten.permute.default(view_2493, [1, 0]) + mm_327 = torch.ops.aten.mm.default(permute_585, view_1788); permute_585 = None + permute_587 = torch.ops.aten.permute.default(permute_273, [1, 0]); permute_273 = None + mm_328 = torch.ops.aten.mm.default(view_2493, permute_587); view_2493 = permute_587 = None + view_2494 = torch.ops.aten.view.default(mm_328, [2, 8192, 4096]); mm_328 = None + convert_element_type_1461 = torch.ops.prims.convert_element_type.default(mm_327, torch.float32); mm_327 = None + reduce_scatter_tensor_146 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1461, 'avg', 32, '0'); convert_element_type_1461 = None + wait_tensor_532 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_146); reduce_scatter_tensor_146 = None + convert_element_type_1462 = torch.ops.prims.convert_element_type.default(mul_405, torch.float32); mul_405 = None + neg_7 = torch.ops.aten.neg.default(convert_element_type_818) + exp_7 = torch.ops.aten.exp.default(neg_7); neg_7 = None + add_178 = torch.ops.aten.add.Tensor(exp_7, 1); exp_7 = None + reciprocal_7 = torch.ops.aten.reciprocal.default(add_178); add_178 = None + mul_406 = torch.ops.aten.mul.Tensor(reciprocal_7, 1); reciprocal_7 = None + mul_407 = torch.ops.aten.mul.Tensor(convert_element_type_1462, mul_406); convert_element_type_1462 = None + sub_23 = torch.ops.aten.sub.Tensor(1, mul_406); mul_406 = None + mul_408 = torch.ops.aten.mul.Tensor(convert_element_type_818, sub_23); convert_element_type_818 = sub_23 = None + add_179 = torch.ops.aten.add.Tensor(mul_408, 1); mul_408 = None + mul_409 = torch.ops.aten.mul.Tensor(mul_407, add_179); mul_407 = add_179 = None + convert_element_type_1464 = torch.ops.prims.convert_element_type.default(mul_409, torch.bfloat16); mul_409 = None + view_2495 = torch.ops.aten.view.default(convert_element_type_1464, [16384, 1792]); convert_element_type_1464 = None + permute_589 = torch.ops.aten.permute.default(view_2495, [1, 0]) + mm_329 = torch.ops.aten.mm.default(permute_589, view_1788); permute_589 = view_1788 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16); primals_226 = None + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 32, '0'); convert_element_type_815 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_323, [1, 0]); wait_tensor_323 = None + permute_591 = torch.ops.aten.permute.default(permute_272, [1, 0]); permute_272 = None + mm_330 = torch.ops.aten.mm.default(view_2495, permute_591); view_2495 = permute_591 = None + view_2496 = torch.ops.aten.view.default(mm_330, [2, 8192, 4096]); mm_330 = None + add_180 = torch.ops.aten.add.Tensor(view_2494, view_2496); view_2494 = view_2496 = None + convert_element_type_1469 = torch.ops.prims.convert_element_type.default(mm_329, torch.float32); mm_329 = None + reduce_scatter_tensor_147 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1469, 'avg', 32, '0'); convert_element_type_1469 = None + wait_tensor_533 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_147); reduce_scatter_tensor_147 = None + split_168 = torch.ops.aten.split.Tensor(add_180, 1024, 1); add_180 = None + getitem_1653 = split_168[0] + getitem_1654 = split_168[1] + getitem_1655 = split_168[2] + getitem_1656 = split_168[3] + getitem_1657 = split_168[4] + getitem_1658 = split_168[5] + getitem_1659 = split_168[6] + getitem_1660 = split_168[7]; split_168 = None + cat_160 = torch.ops.aten.cat.default([getitem_1653, getitem_1654, getitem_1655, getitem_1656, getitem_1657, getitem_1658, getitem_1659, getitem_1660]); getitem_1653 = getitem_1654 = getitem_1655 = getitem_1656 = getitem_1657 = getitem_1658 = getitem_1659 = getitem_1660 = None + reduce_scatter_tensor_148 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_160, 'sum', 8, '1'); cat_160 = None + wait_tensor_534 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_148); reduce_scatter_tensor_148 = None + convert_element_type_1470 = torch.ops.prims.convert_element_type.default(wait_tensor_534, torch.float32); wait_tensor_534 = None + convert_element_type_1472 = torch.ops.prims.convert_element_type.default(wait_tensor_321, torch.float32); wait_tensor_321 = None + mul_410 = torch.ops.aten.mul.Tensor(convert_element_type_1470, convert_element_type_1472); convert_element_type_1472 = None + mul_412 = torch.ops.aten.mul.Tensor(mul_196, mul_410) + sum_45 = torch.ops.aten.sum.dim_IntList(mul_412, [2], True); mul_412 = None + div_15 = torch.ops.aten.div.Tensor(mul_196, 4096) + mul_413 = torch.ops.aten.mul.Tensor(div_15, sum_45); div_15 = sum_45 = None + sub_24 = torch.ops.aten.sub.Tensor(mul_410, mul_413); mul_410 = mul_413 = None + mul_414 = torch.ops.aten.mul.Tensor(sub_24, rsqrt_49); sub_24 = rsqrt_49 = None + mul_415 = torch.ops.aten.mul.Tensor(convert_element_type_1470, mul_196); convert_element_type_1470 = mul_196 = None + sum_46 = torch.ops.aten.sum.dim_IntList(mul_415, [0, 1]); mul_415 = None + convert_element_type_1473 = torch.ops.prims.convert_element_type.default(mul_414, torch.bfloat16); mul_414 = None + convert_element_type_1474 = torch.ops.prims.convert_element_type.default(sum_46, torch.bfloat16); sum_46 = None + all_reduce_15 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1474, 'sum', '1'); convert_element_type_1474 = None + wait_tensor_535 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_15); all_reduce_15 = None + convert_element_type_1475 = torch.ops.prims.convert_element_type.default(wait_tensor_535, torch.float32); wait_tensor_535 = None + reduce_scatter_tensor_149 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1475, 'avg', 32, '0'); convert_element_type_1475 = None + wait_tensor_536 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_149); reduce_scatter_tensor_149 = None + add_181 = torch.ops.aten.add.Tensor(add_177, convert_element_type_1473); add_177 = convert_element_type_1473 = None + all_gather_into_tensor_371 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_181, 8, '1') + wait_tensor_537 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_371); all_gather_into_tensor_371 = None + split_169 = torch.ops.aten.split.Tensor(wait_tensor_537, 2); wait_tensor_537 = None + getitem_1661 = split_169[0] + getitem_1662 = split_169[1] + getitem_1663 = split_169[2] + getitem_1664 = split_169[3] + getitem_1665 = split_169[4] + getitem_1666 = split_169[5] + getitem_1667 = split_169[6] + getitem_1668 = split_169[7]; split_169 = None + cat_161 = torch.ops.aten.cat.default([getitem_1661, getitem_1662, getitem_1663, getitem_1664, getitem_1665, getitem_1666, getitem_1667, getitem_1668], 1); getitem_1661 = getitem_1662 = getitem_1663 = getitem_1664 = getitem_1665 = getitem_1666 = getitem_1667 = getitem_1668 = None + view_2497 = torch.ops.aten.view.default(cat_161, [16384, 4096]); cat_161 = None + permute_593 = torch.ops.aten.permute.default(view_2497, [1, 0]) + permute_270 = torch.ops.aten.permute.default(getitem_1064, [0, 2, 1, 3]) + view_1770 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + view_1776 = torch.ops.aten.view.default(view_1770, [16384, 512]); view_1770 = None + mm_331 = torch.ops.aten.mm.default(permute_593, view_1776); permute_593 = view_1776 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16); primals_224 = None + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 32, '0'); convert_element_type_809 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_319, [1, 0]); wait_tensor_319 = None + permute_595 = torch.ops.aten.permute.default(permute_271, [1, 0]); permute_271 = None + mm_332 = torch.ops.aten.mm.default(view_2497, permute_595); view_2497 = permute_595 = None + view_2498 = torch.ops.aten.view.default(mm_332, [2, 8192, 512]); mm_332 = None + convert_element_type_1480 = torch.ops.prims.convert_element_type.default(mm_331, torch.float32); mm_331 = None + reduce_scatter_tensor_150 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1480, 'avg', 32, '0'); convert_element_type_1480 = None + wait_tensor_538 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_150); reduce_scatter_tensor_150 = None + view_2499 = torch.ops.aten.view.default(view_2498, [2, 8192, 4, 128]); view_2498 = None + permute_597 = torch.ops.aten.permute.default(view_2499, [0, 2, 1, 3]); view_2499 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16); primals_220 = None + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 32, '0'); convert_element_type_793 = None + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32); add_95 = None + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_314) + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_795, 8, '1'); convert_element_type_795 = None + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + split_105 = torch.ops.aten.split.Tensor(wait_tensor_315, 2); wait_tensor_315 = None + getitem_1056 = split_105[0] + getitem_1057 = split_105[1] + getitem_1058 = split_105[2] + getitem_1059 = split_105[3] + getitem_1060 = split_105[4] + getitem_1061 = split_105[5] + getitem_1062 = split_105[6] + getitem_1063 = split_105[7]; split_105 = None + cat_97 = torch.ops.aten.cat.default([getitem_1056, getitem_1057, getitem_1058, getitem_1059, getitem_1060, getitem_1061, getitem_1062, getitem_1063], 1); getitem_1056 = getitem_1057 = getitem_1058 = getitem_1059 = getitem_1060 = getitem_1061 = getitem_1062 = getitem_1063 = None + view_1743 = torch.ops.aten.view.default(cat_97, [16384, 4096]); cat_97 = None + view_1744 = torch.ops.aten.view.default(mm_168, [2, 8192, 512]); mm_168 = None + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16); primals_222 = None + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 32, '0'); convert_element_type_799 = None + wait_tensor_317 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_317, [1, 0]); wait_tensor_317 = None + mm_169 = torch.ops.aten.mm.default(view_1743, permute_265) + view_1751 = torch.ops.aten.view.default(mm_169, [2, 8192, 128]); mm_169 = None + view_1758 = torch.ops.aten.view.default(mm_170, [2, 8192, 128]); mm_170 = None + view_1760 = torch.ops.aten.view.default(view_1744, [2, 8192, -1, 128]); view_1744 = None + view_1761 = torch.ops.aten.view.default(view_1751, [2, 8192, -1, 128]); view_1751 = None + view_1762 = torch.ops.aten.view.default(view_1758, [2, 8192, -1, 128]); view_1758 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_1760, torch.float32); view_1760 = None + view_1763 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 4, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_1763); view_1763 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_1761, torch.float32); view_1761 = None + view_1764 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 1, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_1764); view_1764 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_37); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_1766 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 4, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_37); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_1767 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 1, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_1766, torch.bfloat16); view_1766 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_1767, torch.bfloat16); view_1767 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 1, 4, 128]); unsqueeze_48 = None + view_1768 = torch.ops.aten.view.default(expand_48, [2, 8192, 4, 128]); expand_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_1762, 3); view_1762 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 1, 4, 128]); unsqueeze_49 = None + view_1769 = torch.ops.aten.view.default(expand_49, [2, 8192, 4, 128]); expand_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_1768, [0, 2, 1, 3]); view_1768 = None + permute_269 = torch.ops.aten.permute.default(view_1769, [0, 2, 1, 3]); view_1769 = None + _scaled_dot_product_cudnn_attention_backward_7 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_597, permute_267, permute_268, permute_269, getitem_1064, getitem_1065, getitem_1070, getitem_1071, None, None, None, 8192, 8192, 0.0, True); permute_597 = permute_267 = permute_268 = permute_269 = getitem_1064 = getitem_1065 = getitem_1070 = getitem_1071 = None + getitem_1669 = _scaled_dot_product_cudnn_attention_backward_7[0] + getitem_1670 = _scaled_dot_product_cudnn_attention_backward_7[1] + getitem_1671 = _scaled_dot_product_cudnn_attention_backward_7[2]; _scaled_dot_product_cudnn_attention_backward_7 = None + permute_598 = torch.ops.aten.permute.default(getitem_1671, [0, 2, 1, 3]); getitem_1671 = None + permute_599 = torch.ops.aten.permute.default(getitem_1670, [0, 2, 1, 3]); getitem_1670 = None + permute_600 = torch.ops.aten.permute.default(getitem_1669, [0, 2, 1, 3]); getitem_1669 = None + view_2500 = torch.ops.aten.view.default(permute_598, [2, 8192, 1, 4, 128]); permute_598 = None + sum_47 = torch.ops.aten.sum.dim_IntList(view_2500, [3], True); view_2500 = None + squeeze_14 = torch.ops.aten.squeeze.dim(sum_47, 3); sum_47 = None + view_2501 = torch.ops.aten.view.default(permute_599, [2, 8192, 1, 4, 128]); permute_599 = None + sum_48 = torch.ops.aten.sum.dim_IntList(view_2501, [3], True); view_2501 = None + squeeze_15 = torch.ops.aten.squeeze.dim(sum_48, 3); sum_48 = None + convert_element_type_1481 = torch.ops.prims.convert_element_type.default(squeeze_15, torch.float32); squeeze_15 = None + convert_element_type_1482 = torch.ops.prims.convert_element_type.default(permute_600, torch.float32); permute_600 = None + view_2502 = torch.ops.aten.view.default(convert_element_type_1481, [2, 8192, 1, 64, 2]); convert_element_type_1481 = None + view_as_complex_78 = torch.ops.aten.view_as_complex.default(view_2502); view_2502 = None + mul_416 = torch.ops.aten.mul.Tensor(view_as_complex_78, _conj); view_as_complex_78 = None + view_2503 = torch.ops.aten.view.default(convert_element_type_1482, [2, 8192, 4, 64, 2]); convert_element_type_1482 = None + view_as_complex_79 = torch.ops.aten.view_as_complex.default(view_2503); view_2503 = None + mul_417 = torch.ops.aten.mul.Tensor(view_as_complex_79, _conj); view_as_complex_79 = None + view_as_real_78 = torch.ops.aten.view_as_real.default(mul_416); mul_416 = None + view_2504 = torch.ops.aten.view.default(view_as_real_78, [2, 8192, 1, 128]); view_as_real_78 = None + convert_element_type_1483 = torch.ops.prims.convert_element_type.default(view_2504, torch.bfloat16); view_2504 = None + view_as_real_79 = torch.ops.aten.view_as_real.default(mul_417); mul_417 = None + view_2505 = torch.ops.aten.view.default(view_as_real_79, [2, 8192, 4, 128]); view_as_real_79 = None + convert_element_type_1484 = torch.ops.prims.convert_element_type.default(view_2505, torch.bfloat16); view_2505 = None + view_2506 = torch.ops.aten.view.default(squeeze_14, [2, 8192, 128]); squeeze_14 = None + view_2507 = torch.ops.aten.view.default(convert_element_type_1483, [2, 8192, 128]); convert_element_type_1483 = None + view_2508 = torch.ops.aten.view.default(convert_element_type_1484, [2, 8192, 512]); convert_element_type_1484 = None + view_2509 = torch.ops.aten.view.default(view_2506, [16384, 128]); view_2506 = None + permute_601 = torch.ops.aten.permute.default(view_2509, [1, 0]) + mm_333 = torch.ops.aten.mm.default(permute_601, view_1743); permute_601 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16); primals_223 = None + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 32, '0'); convert_element_type_802 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_318, [1, 0]); wait_tensor_318 = None + permute_603 = torch.ops.aten.permute.default(permute_266, [1, 0]); permute_266 = None + mm_334 = torch.ops.aten.mm.default(view_2509, permute_603); view_2509 = permute_603 = None + view_2510 = torch.ops.aten.view.default(mm_334, [2, 8192, 4096]); mm_334 = None + convert_element_type_1489 = torch.ops.prims.convert_element_type.default(mm_333, torch.float32); mm_333 = None + reduce_scatter_tensor_151 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1489, 'avg', 32, '0'); convert_element_type_1489 = None + wait_tensor_539 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_151); reduce_scatter_tensor_151 = None + view_2511 = torch.ops.aten.view.default(view_2507, [16384, 128]); view_2507 = None + permute_605 = torch.ops.aten.permute.default(view_2511, [1, 0]) + mm_335 = torch.ops.aten.mm.default(permute_605, view_1743); permute_605 = None + permute_607 = torch.ops.aten.permute.default(permute_265, [1, 0]); permute_265 = None + mm_336 = torch.ops.aten.mm.default(view_2511, permute_607); view_2511 = permute_607 = None + view_2512 = torch.ops.aten.view.default(mm_336, [2, 8192, 4096]); mm_336 = None + add_182 = torch.ops.aten.add.Tensor(view_2510, view_2512); view_2510 = view_2512 = None + convert_element_type_1494 = torch.ops.prims.convert_element_type.default(mm_335, torch.float32); mm_335 = None + reduce_scatter_tensor_152 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1494, 'avg', 32, '0'); convert_element_type_1494 = None + wait_tensor_540 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_152); reduce_scatter_tensor_152 = None + view_2513 = torch.ops.aten.view.default(view_2508, [16384, 512]); view_2508 = None + permute_609 = torch.ops.aten.permute.default(view_2513, [1, 0]) + mm_337 = torch.ops.aten.mm.default(permute_609, view_1743); permute_609 = view_1743 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16); primals_221 = None + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 32, '0'); convert_element_type_796 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_316, [1, 0]); wait_tensor_316 = None + permute_611 = torch.ops.aten.permute.default(permute_264, [1, 0]); permute_264 = None + mm_338 = torch.ops.aten.mm.default(view_2513, permute_611); view_2513 = permute_611 = None + view_2514 = torch.ops.aten.view.default(mm_338, [2, 8192, 4096]); mm_338 = None + add_183 = torch.ops.aten.add.Tensor(add_182, view_2514); add_182 = view_2514 = None + convert_element_type_1499 = torch.ops.prims.convert_element_type.default(mm_337, torch.float32); mm_337 = None + reduce_scatter_tensor_153 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1499, 'avg', 32, '0'); convert_element_type_1499 = None + wait_tensor_541 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_153); reduce_scatter_tensor_153 = None + split_170 = torch.ops.aten.split.Tensor(add_183, 1024, 1); add_183 = None + getitem_1672 = split_170[0] + getitem_1673 = split_170[1] + getitem_1674 = split_170[2] + getitem_1675 = split_170[3] + getitem_1676 = split_170[4] + getitem_1677 = split_170[5] + getitem_1678 = split_170[6] + getitem_1679 = split_170[7]; split_170 = None + cat_162 = torch.ops.aten.cat.default([getitem_1672, getitem_1673, getitem_1674, getitem_1675, getitem_1676, getitem_1677, getitem_1678, getitem_1679]); getitem_1672 = getitem_1673 = getitem_1674 = getitem_1675 = getitem_1676 = getitem_1677 = getitem_1678 = getitem_1679 = None + reduce_scatter_tensor_154 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_162, 'sum', 8, '1'); cat_162 = None + wait_tensor_542 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_154); reduce_scatter_tensor_154 = None + convert_element_type_1500 = torch.ops.prims.convert_element_type.default(wait_tensor_542, torch.float32); wait_tensor_542 = None + convert_element_type_1502 = torch.ops.prims.convert_element_type.default(wait_tensor_314, torch.float32); wait_tensor_314 = None + mul_418 = torch.ops.aten.mul.Tensor(convert_element_type_1500, convert_element_type_1502); convert_element_type_1502 = None + mul_420 = torch.ops.aten.mul.Tensor(mul_192, mul_418) + sum_49 = torch.ops.aten.sum.dim_IntList(mul_420, [2], True); mul_420 = None + div_16 = torch.ops.aten.div.Tensor(mul_192, 4096) + mul_421 = torch.ops.aten.mul.Tensor(div_16, sum_49); div_16 = sum_49 = None + sub_25 = torch.ops.aten.sub.Tensor(mul_418, mul_421); mul_418 = mul_421 = None + mul_422 = torch.ops.aten.mul.Tensor(sub_25, rsqrt_48); sub_25 = rsqrt_48 = None + mul_423 = torch.ops.aten.mul.Tensor(convert_element_type_1500, mul_192); convert_element_type_1500 = mul_192 = None + sum_50 = torch.ops.aten.sum.dim_IntList(mul_423, [0, 1]); mul_423 = None + convert_element_type_1503 = torch.ops.prims.convert_element_type.default(mul_422, torch.bfloat16); mul_422 = None + convert_element_type_1504 = torch.ops.prims.convert_element_type.default(sum_50, torch.bfloat16); sum_50 = None + all_reduce_16 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1504, 'sum', '1'); convert_element_type_1504 = None + wait_tensor_543 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_16); all_reduce_16 = None + convert_element_type_1505 = torch.ops.prims.convert_element_type.default(wait_tensor_543, torch.float32); wait_tensor_543 = None + reduce_scatter_tensor_155 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1505, 'avg', 32, '0'); convert_element_type_1505 = None + wait_tensor_544 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_155); reduce_scatter_tensor_155 = None + add_184 = torch.ops.aten.add.Tensor(add_181, convert_element_type_1503); add_181 = convert_element_type_1503 = None + all_gather_into_tensor_372 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_184, 8, '1') + wait_tensor_545 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_372); all_gather_into_tensor_372 = None + split_171 = torch.ops.aten.split.Tensor(wait_tensor_545, 2); wait_tensor_545 = None + getitem_1680 = split_171[0] + getitem_1681 = split_171[1] + getitem_1682 = split_171[2] + getitem_1683 = split_171[3] + getitem_1684 = split_171[4] + getitem_1685 = split_171[5] + getitem_1686 = split_171[6] + getitem_1687 = split_171[7]; split_171 = None + cat_163 = torch.ops.aten.cat.default([getitem_1680, getitem_1681, getitem_1682, getitem_1683, getitem_1684, getitem_1685, getitem_1686, getitem_1687], 1); getitem_1680 = getitem_1681 = getitem_1682 = getitem_1683 = getitem_1684 = getitem_1685 = getitem_1686 = getitem_1687 = None + view_2515 = torch.ops.aten.view.default(cat_163, [16384, 4096]); cat_163 = None + permute_613 = torch.ops.aten.permute.default(view_2515, [1, 0]) + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47); reduce_scatter_tensor_47 = None + add_93 = torch.ops.aten.add.Tensor(add_91, wait_tensor_307); wait_tensor_307 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16); primals_216 = None + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 32, '0'); convert_element_type_779 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32); add_93 = None + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_308) + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_781, 8, '1'); convert_element_type_781 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + split_103 = torch.ops.aten.split.Tensor(wait_tensor_309, 2); wait_tensor_309 = None + getitem_1040 = split_103[0] + getitem_1041 = split_103[1] + getitem_1042 = split_103[2] + getitem_1043 = split_103[3] + getitem_1044 = split_103[4] + getitem_1045 = split_103[5] + getitem_1046 = split_103[6] + getitem_1047 = split_103[7]; split_103 = None + cat_95 = torch.ops.aten.cat.default([getitem_1040, getitem_1041, getitem_1042, getitem_1043, getitem_1044, getitem_1045, getitem_1046, getitem_1047], 1); getitem_1040 = getitem_1041 = getitem_1042 = getitem_1043 = getitem_1044 = getitem_1045 = getitem_1046 = getitem_1047 = None + view_1716 = torch.ops.aten.view.default(cat_95, [16384, 4096]); cat_95 = None + view_1717 = torch.ops.aten.view.default(mm_165, [2, 8192, 1792]); mm_165 = None + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_1717, torch.float32); view_1717 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16); primals_218 = None + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 32, '0'); convert_element_type_787 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_311, [1, 0]); wait_tensor_311 = None + mm_166 = torch.ops.aten.mm.default(view_1716, permute_262) + view_1724 = torch.ops.aten.view.default(mm_166, [2, 8192, 1792]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_1724) + view_1731 = torch.ops.aten.view.default(mul_191, [16384, 1792]); mul_191 = None + mm_339 = torch.ops.aten.mm.default(permute_613, view_1731); permute_613 = view_1731 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16); primals_219 = None + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 32, '0'); convert_element_type_790 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_312, [1, 0]); wait_tensor_312 = None + permute_615 = torch.ops.aten.permute.default(permute_263, [1, 0]); permute_263 = None + mm_340 = torch.ops.aten.mm.default(view_2515, permute_615); view_2515 = permute_615 = None + view_2516 = torch.ops.aten.view.default(mm_340, [2, 8192, 1792]); mm_340 = None + convert_element_type_1510 = torch.ops.prims.convert_element_type.default(mm_339, torch.float32); mm_339 = None + reduce_scatter_tensor_156 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1510, 'avg', 32, '0'); convert_element_type_1510 = None + wait_tensor_546 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_156); reduce_scatter_tensor_156 = None + mul_424 = torch.ops.aten.mul.Tensor(view_2516, convert_element_type_786); convert_element_type_786 = None + mul_425 = torch.ops.aten.mul.Tensor(view_2516, view_1724); view_2516 = view_1724 = None + view_2517 = torch.ops.aten.view.default(mul_424, [16384, 1792]); mul_424 = None + permute_617 = torch.ops.aten.permute.default(view_2517, [1, 0]) + mm_341 = torch.ops.aten.mm.default(permute_617, view_1716); permute_617 = None + permute_619 = torch.ops.aten.permute.default(permute_262, [1, 0]); permute_262 = None + mm_342 = torch.ops.aten.mm.default(view_2517, permute_619); view_2517 = permute_619 = None + view_2518 = torch.ops.aten.view.default(mm_342, [2, 8192, 4096]); mm_342 = None + convert_element_type_1515 = torch.ops.prims.convert_element_type.default(mm_341, torch.float32); mm_341 = None + reduce_scatter_tensor_157 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1515, 'avg', 32, '0'); convert_element_type_1515 = None + wait_tensor_547 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_157); reduce_scatter_tensor_157 = None + convert_element_type_1516 = torch.ops.prims.convert_element_type.default(mul_425, torch.float32); mul_425 = None + neg_8 = torch.ops.aten.neg.default(convert_element_type_785) + exp_8 = torch.ops.aten.exp.default(neg_8); neg_8 = None + add_185 = torch.ops.aten.add.Tensor(exp_8, 1); exp_8 = None + reciprocal_8 = torch.ops.aten.reciprocal.default(add_185); add_185 = None + mul_426 = torch.ops.aten.mul.Tensor(reciprocal_8, 1); reciprocal_8 = None + mul_427 = torch.ops.aten.mul.Tensor(convert_element_type_1516, mul_426); convert_element_type_1516 = None + sub_26 = torch.ops.aten.sub.Tensor(1, mul_426); mul_426 = None + mul_428 = torch.ops.aten.mul.Tensor(convert_element_type_785, sub_26); convert_element_type_785 = sub_26 = None + add_186 = torch.ops.aten.add.Tensor(mul_428, 1); mul_428 = None + mul_429 = torch.ops.aten.mul.Tensor(mul_427, add_186); mul_427 = add_186 = None + convert_element_type_1518 = torch.ops.prims.convert_element_type.default(mul_429, torch.bfloat16); mul_429 = None + view_2519 = torch.ops.aten.view.default(convert_element_type_1518, [16384, 1792]); convert_element_type_1518 = None + permute_621 = torch.ops.aten.permute.default(view_2519, [1, 0]) + mm_343 = torch.ops.aten.mm.default(permute_621, view_1716); permute_621 = view_1716 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16); primals_217 = None + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 32, '0'); convert_element_type_782 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_310, [1, 0]); wait_tensor_310 = None + permute_623 = torch.ops.aten.permute.default(permute_261, [1, 0]); permute_261 = None + mm_344 = torch.ops.aten.mm.default(view_2519, permute_623); view_2519 = permute_623 = None + view_2520 = torch.ops.aten.view.default(mm_344, [2, 8192, 4096]); mm_344 = None + add_187 = torch.ops.aten.add.Tensor(view_2518, view_2520); view_2518 = view_2520 = None + convert_element_type_1523 = torch.ops.prims.convert_element_type.default(mm_343, torch.float32); mm_343 = None + reduce_scatter_tensor_158 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1523, 'avg', 32, '0'); convert_element_type_1523 = None + wait_tensor_548 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_158); reduce_scatter_tensor_158 = None + split_172 = torch.ops.aten.split.Tensor(add_187, 1024, 1); add_187 = None + getitem_1688 = split_172[0] + getitem_1689 = split_172[1] + getitem_1690 = split_172[2] + getitem_1691 = split_172[3] + getitem_1692 = split_172[4] + getitem_1693 = split_172[5] + getitem_1694 = split_172[6] + getitem_1695 = split_172[7]; split_172 = None + cat_164 = torch.ops.aten.cat.default([getitem_1688, getitem_1689, getitem_1690, getitem_1691, getitem_1692, getitem_1693, getitem_1694, getitem_1695]); getitem_1688 = getitem_1689 = getitem_1690 = getitem_1691 = getitem_1692 = getitem_1693 = getitem_1694 = getitem_1695 = None + reduce_scatter_tensor_159 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_164, 'sum', 8, '1'); cat_164 = None + wait_tensor_549 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_159); reduce_scatter_tensor_159 = None + convert_element_type_1524 = torch.ops.prims.convert_element_type.default(wait_tensor_549, torch.float32); wait_tensor_549 = None + convert_element_type_1526 = torch.ops.prims.convert_element_type.default(wait_tensor_308, torch.float32); wait_tensor_308 = None + mul_430 = torch.ops.aten.mul.Tensor(convert_element_type_1524, convert_element_type_1526); convert_element_type_1526 = None + mul_432 = torch.ops.aten.mul.Tensor(mul_188, mul_430) + sum_51 = torch.ops.aten.sum.dim_IntList(mul_432, [2], True); mul_432 = None + div_17 = torch.ops.aten.div.Tensor(mul_188, 4096) + mul_433 = torch.ops.aten.mul.Tensor(div_17, sum_51); div_17 = sum_51 = None + sub_27 = torch.ops.aten.sub.Tensor(mul_430, mul_433); mul_430 = mul_433 = None + mul_434 = torch.ops.aten.mul.Tensor(sub_27, rsqrt_47); sub_27 = rsqrt_47 = None + mul_435 = torch.ops.aten.mul.Tensor(convert_element_type_1524, mul_188); convert_element_type_1524 = mul_188 = None + sum_52 = torch.ops.aten.sum.dim_IntList(mul_435, [0, 1]); mul_435 = None + convert_element_type_1527 = torch.ops.prims.convert_element_type.default(mul_434, torch.bfloat16); mul_434 = None + convert_element_type_1528 = torch.ops.prims.convert_element_type.default(sum_52, torch.bfloat16); sum_52 = None + all_reduce_17 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1528, 'sum', '1'); convert_element_type_1528 = None + wait_tensor_550 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_17); all_reduce_17 = None + convert_element_type_1529 = torch.ops.prims.convert_element_type.default(wait_tensor_550, torch.float32); wait_tensor_550 = None + reduce_scatter_tensor_160 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1529, 'avg', 32, '0'); convert_element_type_1529 = None + wait_tensor_551 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_160); reduce_scatter_tensor_160 = None + add_188 = torch.ops.aten.add.Tensor(add_184, convert_element_type_1527); add_184 = convert_element_type_1527 = None + all_gather_into_tensor_373 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_188, 8, '1') + wait_tensor_552 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_373); all_gather_into_tensor_373 = None + split_173 = torch.ops.aten.split.Tensor(wait_tensor_552, 2); wait_tensor_552 = None + getitem_1696 = split_173[0] + getitem_1697 = split_173[1] + getitem_1698 = split_173[2] + getitem_1699 = split_173[3] + getitem_1700 = split_173[4] + getitem_1701 = split_173[5] + getitem_1702 = split_173[6] + getitem_1703 = split_173[7]; split_173 = None + cat_165 = torch.ops.aten.cat.default([getitem_1696, getitem_1697, getitem_1698, getitem_1699, getitem_1700, getitem_1701, getitem_1702, getitem_1703], 1); getitem_1696 = getitem_1697 = getitem_1698 = getitem_1699 = getitem_1700 = getitem_1701 = getitem_1702 = getitem_1703 = None + view_2521 = torch.ops.aten.view.default(cat_165, [16384, 4096]); cat_165 = None + permute_625 = torch.ops.aten.permute.default(view_2521, [1, 0]) + permute_259 = torch.ops.aten.permute.default(getitem_1023, [0, 2, 1, 3]) + view_1698 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + view_1704 = torch.ops.aten.view.default(view_1698, [16384, 512]); view_1698 = None + mm_345 = torch.ops.aten.mm.default(permute_625, view_1704); permute_625 = view_1704 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16); primals_215 = None + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 32, '0'); convert_element_type_776 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_306, [1, 0]); wait_tensor_306 = None + permute_627 = torch.ops.aten.permute.default(permute_260, [1, 0]); permute_260 = None + mm_346 = torch.ops.aten.mm.default(view_2521, permute_627); view_2521 = permute_627 = None + view_2522 = torch.ops.aten.view.default(mm_346, [2, 8192, 512]); mm_346 = None + convert_element_type_1534 = torch.ops.prims.convert_element_type.default(mm_345, torch.float32); mm_345 = None + reduce_scatter_tensor_161 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1534, 'avg', 32, '0'); convert_element_type_1534 = None + wait_tensor_553 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_161); reduce_scatter_tensor_161 = None + view_2523 = torch.ops.aten.view.default(view_2522, [2, 8192, 4, 128]); view_2522 = None + permute_629 = torch.ops.aten.permute.default(view_2523, [0, 2, 1, 3]); view_2523 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16); primals_211 = None + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 32, '0'); convert_element_type_760 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32); add_91 = None + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_301) + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_762, 8, '1'); convert_element_type_762 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + split_101 = torch.ops.aten.split.Tensor(wait_tensor_302, 2); wait_tensor_302 = None + getitem_1015 = split_101[0] + getitem_1016 = split_101[1] + getitem_1017 = split_101[2] + getitem_1018 = split_101[3] + getitem_1019 = split_101[4] + getitem_1020 = split_101[5] + getitem_1021 = split_101[6] + getitem_1022 = split_101[7]; split_101 = None + cat_93 = torch.ops.aten.cat.default([getitem_1015, getitem_1016, getitem_1017, getitem_1018, getitem_1019, getitem_1020, getitem_1021, getitem_1022], 1); getitem_1015 = getitem_1016 = getitem_1017 = getitem_1018 = getitem_1019 = getitem_1020 = getitem_1021 = getitem_1022 = None + view_1671 = torch.ops.aten.view.default(cat_93, [16384, 4096]); cat_93 = None + view_1672 = torch.ops.aten.view.default(mm_161, [2, 8192, 512]); mm_161 = None + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16); primals_213 = None + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 32, '0'); convert_element_type_766 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_304, [1, 0]); wait_tensor_304 = None + mm_162 = torch.ops.aten.mm.default(view_1671, permute_254) + view_1679 = torch.ops.aten.view.default(mm_162, [2, 8192, 128]); mm_162 = None + view_1686 = torch.ops.aten.view.default(mm_163, [2, 8192, 128]); mm_163 = None + view_1688 = torch.ops.aten.view.default(view_1672, [2, 8192, -1, 128]); view_1672 = None + view_1689 = torch.ops.aten.view.default(view_1679, [2, 8192, -1, 128]); view_1679 = None + view_1690 = torch.ops.aten.view.default(view_1686, [2, 8192, -1, 128]); view_1686 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_1688, torch.float32); view_1688 = None + view_1691 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 4, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_1691); view_1691 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_1689, torch.float32); view_1689 = None + view_1692 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 1, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_1692); view_1692 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_37); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_1694 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 4, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_37); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_1695 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 1, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_1694, torch.bfloat16); view_1694 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_1695, torch.bfloat16); view_1695 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 1, 4, 128]); unsqueeze_46 = None + view_1696 = torch.ops.aten.view.default(expand_46, [2, 8192, 4, 128]); expand_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_1690, 3); view_1690 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 1, 4, 128]); unsqueeze_47 = None + view_1697 = torch.ops.aten.view.default(expand_47, [2, 8192, 4, 128]); expand_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_1696, [0, 2, 1, 3]); view_1696 = None + permute_258 = torch.ops.aten.permute.default(view_1697, [0, 2, 1, 3]); view_1697 = None + _scaled_dot_product_cudnn_attention_backward_8 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_629, permute_256, permute_257, permute_258, getitem_1023, getitem_1024, getitem_1029, getitem_1030, None, None, None, 8192, 8192, 0.0, True); permute_629 = permute_256 = permute_257 = permute_258 = getitem_1023 = getitem_1024 = getitem_1029 = getitem_1030 = None + getitem_1704 = _scaled_dot_product_cudnn_attention_backward_8[0] + getitem_1705 = _scaled_dot_product_cudnn_attention_backward_8[1] + getitem_1706 = _scaled_dot_product_cudnn_attention_backward_8[2]; _scaled_dot_product_cudnn_attention_backward_8 = None + permute_630 = torch.ops.aten.permute.default(getitem_1706, [0, 2, 1, 3]); getitem_1706 = None + permute_631 = torch.ops.aten.permute.default(getitem_1705, [0, 2, 1, 3]); getitem_1705 = None + permute_632 = torch.ops.aten.permute.default(getitem_1704, [0, 2, 1, 3]); getitem_1704 = None + view_2524 = torch.ops.aten.view.default(permute_630, [2, 8192, 1, 4, 128]); permute_630 = None + sum_53 = torch.ops.aten.sum.dim_IntList(view_2524, [3], True); view_2524 = None + squeeze_16 = torch.ops.aten.squeeze.dim(sum_53, 3); sum_53 = None + view_2525 = torch.ops.aten.view.default(permute_631, [2, 8192, 1, 4, 128]); permute_631 = None + sum_54 = torch.ops.aten.sum.dim_IntList(view_2525, [3], True); view_2525 = None + squeeze_17 = torch.ops.aten.squeeze.dim(sum_54, 3); sum_54 = None + convert_element_type_1535 = torch.ops.prims.convert_element_type.default(squeeze_17, torch.float32); squeeze_17 = None + convert_element_type_1536 = torch.ops.prims.convert_element_type.default(permute_632, torch.float32); permute_632 = None + view_2526 = torch.ops.aten.view.default(convert_element_type_1535, [2, 8192, 1, 64, 2]); convert_element_type_1535 = None + view_as_complex_80 = torch.ops.aten.view_as_complex.default(view_2526); view_2526 = None + mul_436 = torch.ops.aten.mul.Tensor(view_as_complex_80, _conj); view_as_complex_80 = None + view_2527 = torch.ops.aten.view.default(convert_element_type_1536, [2, 8192, 4, 64, 2]); convert_element_type_1536 = None + view_as_complex_81 = torch.ops.aten.view_as_complex.default(view_2527); view_2527 = None + mul_437 = torch.ops.aten.mul.Tensor(view_as_complex_81, _conj); view_as_complex_81 = None + view_as_real_80 = torch.ops.aten.view_as_real.default(mul_436); mul_436 = None + view_2528 = torch.ops.aten.view.default(view_as_real_80, [2, 8192, 1, 128]); view_as_real_80 = None + convert_element_type_1537 = torch.ops.prims.convert_element_type.default(view_2528, torch.bfloat16); view_2528 = None + view_as_real_81 = torch.ops.aten.view_as_real.default(mul_437); mul_437 = None + view_2529 = torch.ops.aten.view.default(view_as_real_81, [2, 8192, 4, 128]); view_as_real_81 = None + convert_element_type_1538 = torch.ops.prims.convert_element_type.default(view_2529, torch.bfloat16); view_2529 = None + view_2530 = torch.ops.aten.view.default(squeeze_16, [2, 8192, 128]); squeeze_16 = None + view_2531 = torch.ops.aten.view.default(convert_element_type_1537, [2, 8192, 128]); convert_element_type_1537 = None + view_2532 = torch.ops.aten.view.default(convert_element_type_1538, [2, 8192, 512]); convert_element_type_1538 = None + view_2533 = torch.ops.aten.view.default(view_2530, [16384, 128]); view_2530 = None + permute_633 = torch.ops.aten.permute.default(view_2533, [1, 0]) + mm_347 = torch.ops.aten.mm.default(permute_633, view_1671); permute_633 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16); primals_214 = None + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 32, '0'); convert_element_type_769 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_305, [1, 0]); wait_tensor_305 = None + permute_635 = torch.ops.aten.permute.default(permute_255, [1, 0]); permute_255 = None + mm_348 = torch.ops.aten.mm.default(view_2533, permute_635); view_2533 = permute_635 = None + view_2534 = torch.ops.aten.view.default(mm_348, [2, 8192, 4096]); mm_348 = None + convert_element_type_1543 = torch.ops.prims.convert_element_type.default(mm_347, torch.float32); mm_347 = None + reduce_scatter_tensor_162 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1543, 'avg', 32, '0'); convert_element_type_1543 = None + wait_tensor_554 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_162); reduce_scatter_tensor_162 = None + view_2535 = torch.ops.aten.view.default(view_2531, [16384, 128]); view_2531 = None + permute_637 = torch.ops.aten.permute.default(view_2535, [1, 0]) + mm_349 = torch.ops.aten.mm.default(permute_637, view_1671); permute_637 = None + permute_639 = torch.ops.aten.permute.default(permute_254, [1, 0]); permute_254 = None + mm_350 = torch.ops.aten.mm.default(view_2535, permute_639); view_2535 = permute_639 = None + view_2536 = torch.ops.aten.view.default(mm_350, [2, 8192, 4096]); mm_350 = None + add_189 = torch.ops.aten.add.Tensor(view_2534, view_2536); view_2534 = view_2536 = None + convert_element_type_1548 = torch.ops.prims.convert_element_type.default(mm_349, torch.float32); mm_349 = None + reduce_scatter_tensor_163 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1548, 'avg', 32, '0'); convert_element_type_1548 = None + wait_tensor_555 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_163); reduce_scatter_tensor_163 = None + view_2537 = torch.ops.aten.view.default(view_2532, [16384, 512]); view_2532 = None + permute_641 = torch.ops.aten.permute.default(view_2537, [1, 0]) + mm_351 = torch.ops.aten.mm.default(permute_641, view_1671); permute_641 = view_1671 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16); primals_212 = None + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 32, '0'); convert_element_type_763 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_303, [1, 0]); wait_tensor_303 = None + permute_643 = torch.ops.aten.permute.default(permute_253, [1, 0]); permute_253 = None + mm_352 = torch.ops.aten.mm.default(view_2537, permute_643); view_2537 = permute_643 = None + view_2538 = torch.ops.aten.view.default(mm_352, [2, 8192, 4096]); mm_352 = None + add_190 = torch.ops.aten.add.Tensor(add_189, view_2538); add_189 = view_2538 = None + convert_element_type_1553 = torch.ops.prims.convert_element_type.default(mm_351, torch.float32); mm_351 = None + reduce_scatter_tensor_164 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1553, 'avg', 32, '0'); convert_element_type_1553 = None + wait_tensor_556 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_164); reduce_scatter_tensor_164 = None + split_174 = torch.ops.aten.split.Tensor(add_190, 1024, 1); add_190 = None + getitem_1707 = split_174[0] + getitem_1708 = split_174[1] + getitem_1709 = split_174[2] + getitem_1710 = split_174[3] + getitem_1711 = split_174[4] + getitem_1712 = split_174[5] + getitem_1713 = split_174[6] + getitem_1714 = split_174[7]; split_174 = None + cat_166 = torch.ops.aten.cat.default([getitem_1707, getitem_1708, getitem_1709, getitem_1710, getitem_1711, getitem_1712, getitem_1713, getitem_1714]); getitem_1707 = getitem_1708 = getitem_1709 = getitem_1710 = getitem_1711 = getitem_1712 = getitem_1713 = getitem_1714 = None + reduce_scatter_tensor_165 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_166, 'sum', 8, '1'); cat_166 = None + wait_tensor_557 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_165); reduce_scatter_tensor_165 = None + convert_element_type_1554 = torch.ops.prims.convert_element_type.default(wait_tensor_557, torch.float32); wait_tensor_557 = None + convert_element_type_1556 = torch.ops.prims.convert_element_type.default(wait_tensor_301, torch.float32); wait_tensor_301 = None + mul_438 = torch.ops.aten.mul.Tensor(convert_element_type_1554, convert_element_type_1556); convert_element_type_1556 = None + mul_440 = torch.ops.aten.mul.Tensor(mul_184, mul_438) + sum_55 = torch.ops.aten.sum.dim_IntList(mul_440, [2], True); mul_440 = None + div_18 = torch.ops.aten.div.Tensor(mul_184, 4096) + mul_441 = torch.ops.aten.mul.Tensor(div_18, sum_55); div_18 = sum_55 = None + sub_28 = torch.ops.aten.sub.Tensor(mul_438, mul_441); mul_438 = mul_441 = None + mul_442 = torch.ops.aten.mul.Tensor(sub_28, rsqrt_46); sub_28 = rsqrt_46 = None + mul_443 = torch.ops.aten.mul.Tensor(convert_element_type_1554, mul_184); convert_element_type_1554 = mul_184 = None + sum_56 = torch.ops.aten.sum.dim_IntList(mul_443, [0, 1]); mul_443 = None + convert_element_type_1557 = torch.ops.prims.convert_element_type.default(mul_442, torch.bfloat16); mul_442 = None + convert_element_type_1558 = torch.ops.prims.convert_element_type.default(sum_56, torch.bfloat16); sum_56 = None + all_reduce_18 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1558, 'sum', '1'); convert_element_type_1558 = None + wait_tensor_558 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_18); all_reduce_18 = None + convert_element_type_1559 = torch.ops.prims.convert_element_type.default(wait_tensor_558, torch.float32); wait_tensor_558 = None + reduce_scatter_tensor_166 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1559, 'avg', 32, '0'); convert_element_type_1559 = None + wait_tensor_559 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_166); reduce_scatter_tensor_166 = None + add_191 = torch.ops.aten.add.Tensor(add_188, convert_element_type_1557); add_188 = convert_element_type_1557 = None + all_gather_into_tensor_374 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_191, 8, '1') + wait_tensor_560 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_374); all_gather_into_tensor_374 = None + split_175 = torch.ops.aten.split.Tensor(wait_tensor_560, 2); wait_tensor_560 = None + getitem_1715 = split_175[0] + getitem_1716 = split_175[1] + getitem_1717 = split_175[2] + getitem_1718 = split_175[3] + getitem_1719 = split_175[4] + getitem_1720 = split_175[5] + getitem_1721 = split_175[6] + getitem_1722 = split_175[7]; split_175 = None + cat_167 = torch.ops.aten.cat.default([getitem_1715, getitem_1716, getitem_1717, getitem_1718, getitem_1719, getitem_1720, getitem_1721, getitem_1722], 1); getitem_1715 = getitem_1716 = getitem_1717 = getitem_1718 = getitem_1719 = getitem_1720 = getitem_1721 = getitem_1722 = None + view_2539 = torch.ops.aten.view.default(cat_167, [16384, 4096]); cat_167 = None + permute_645 = torch.ops.aten.permute.default(view_2539, [1, 0]) + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45); reduce_scatter_tensor_45 = None + add_89 = torch.ops.aten.add.Tensor(add_87, wait_tensor_294); wait_tensor_294 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16); primals_207 = None + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 32, '0'); convert_element_type_746 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32); add_89 = None + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_295) + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_748, 8, '1'); convert_element_type_748 = None + wait_tensor_296 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + split_99 = torch.ops.aten.split.Tensor(wait_tensor_296, 2); wait_tensor_296 = None + getitem_999 = split_99[0] + getitem_1000 = split_99[1] + getitem_1001 = split_99[2] + getitem_1002 = split_99[3] + getitem_1003 = split_99[4] + getitem_1004 = split_99[5] + getitem_1005 = split_99[6] + getitem_1006 = split_99[7]; split_99 = None + cat_91 = torch.ops.aten.cat.default([getitem_999, getitem_1000, getitem_1001, getitem_1002, getitem_1003, getitem_1004, getitem_1005, getitem_1006], 1); getitem_999 = getitem_1000 = getitem_1001 = getitem_1002 = getitem_1003 = getitem_1004 = getitem_1005 = getitem_1006 = None + view_1644 = torch.ops.aten.view.default(cat_91, [16384, 4096]); cat_91 = None + view_1645 = torch.ops.aten.view.default(mm_158, [2, 8192, 1792]); mm_158 = None + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_1645, torch.float32); view_1645 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16); primals_209 = None + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 32, '0'); convert_element_type_754 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_298, [1, 0]); wait_tensor_298 = None + mm_159 = torch.ops.aten.mm.default(view_1644, permute_251) + view_1652 = torch.ops.aten.view.default(mm_159, [2, 8192, 1792]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_1652) + view_1659 = torch.ops.aten.view.default(mul_183, [16384, 1792]); mul_183 = None + mm_353 = torch.ops.aten.mm.default(permute_645, view_1659); permute_645 = view_1659 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16); primals_210 = None + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 32, '0'); convert_element_type_757 = None + wait_tensor_299 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_299, [1, 0]); wait_tensor_299 = None + permute_647 = torch.ops.aten.permute.default(permute_252, [1, 0]); permute_252 = None + mm_354 = torch.ops.aten.mm.default(view_2539, permute_647); view_2539 = permute_647 = None + view_2540 = torch.ops.aten.view.default(mm_354, [2, 8192, 1792]); mm_354 = None + convert_element_type_1564 = torch.ops.prims.convert_element_type.default(mm_353, torch.float32); mm_353 = None + reduce_scatter_tensor_167 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1564, 'avg', 32, '0'); convert_element_type_1564 = None + wait_tensor_561 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_167); reduce_scatter_tensor_167 = None + mul_444 = torch.ops.aten.mul.Tensor(view_2540, convert_element_type_753); convert_element_type_753 = None + mul_445 = torch.ops.aten.mul.Tensor(view_2540, view_1652); view_2540 = view_1652 = None + view_2541 = torch.ops.aten.view.default(mul_444, [16384, 1792]); mul_444 = None + permute_649 = torch.ops.aten.permute.default(view_2541, [1, 0]) + mm_355 = torch.ops.aten.mm.default(permute_649, view_1644); permute_649 = None + permute_651 = torch.ops.aten.permute.default(permute_251, [1, 0]); permute_251 = None + mm_356 = torch.ops.aten.mm.default(view_2541, permute_651); view_2541 = permute_651 = None + view_2542 = torch.ops.aten.view.default(mm_356, [2, 8192, 4096]); mm_356 = None + convert_element_type_1569 = torch.ops.prims.convert_element_type.default(mm_355, torch.float32); mm_355 = None + reduce_scatter_tensor_168 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1569, 'avg', 32, '0'); convert_element_type_1569 = None + wait_tensor_562 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_168); reduce_scatter_tensor_168 = None + convert_element_type_1570 = torch.ops.prims.convert_element_type.default(mul_445, torch.float32); mul_445 = None + neg_9 = torch.ops.aten.neg.default(convert_element_type_752) + exp_9 = torch.ops.aten.exp.default(neg_9); neg_9 = None + add_192 = torch.ops.aten.add.Tensor(exp_9, 1); exp_9 = None + reciprocal_9 = torch.ops.aten.reciprocal.default(add_192); add_192 = None + mul_446 = torch.ops.aten.mul.Tensor(reciprocal_9, 1); reciprocal_9 = None + mul_447 = torch.ops.aten.mul.Tensor(convert_element_type_1570, mul_446); convert_element_type_1570 = None + sub_29 = torch.ops.aten.sub.Tensor(1, mul_446); mul_446 = None + mul_448 = torch.ops.aten.mul.Tensor(convert_element_type_752, sub_29); convert_element_type_752 = sub_29 = None + add_193 = torch.ops.aten.add.Tensor(mul_448, 1); mul_448 = None + mul_449 = torch.ops.aten.mul.Tensor(mul_447, add_193); mul_447 = add_193 = None + convert_element_type_1572 = torch.ops.prims.convert_element_type.default(mul_449, torch.bfloat16); mul_449 = None + view_2543 = torch.ops.aten.view.default(convert_element_type_1572, [16384, 1792]); convert_element_type_1572 = None + permute_653 = torch.ops.aten.permute.default(view_2543, [1, 0]) + mm_357 = torch.ops.aten.mm.default(permute_653, view_1644); permute_653 = view_1644 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16); primals_208 = None + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 32, '0'); convert_element_type_749 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_297, [1, 0]); wait_tensor_297 = None + permute_655 = torch.ops.aten.permute.default(permute_250, [1, 0]); permute_250 = None + mm_358 = torch.ops.aten.mm.default(view_2543, permute_655); view_2543 = permute_655 = None + view_2544 = torch.ops.aten.view.default(mm_358, [2, 8192, 4096]); mm_358 = None + add_194 = torch.ops.aten.add.Tensor(view_2542, view_2544); view_2542 = view_2544 = None + convert_element_type_1577 = torch.ops.prims.convert_element_type.default(mm_357, torch.float32); mm_357 = None + reduce_scatter_tensor_169 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1577, 'avg', 32, '0'); convert_element_type_1577 = None + wait_tensor_563 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_169); reduce_scatter_tensor_169 = None + split_176 = torch.ops.aten.split.Tensor(add_194, 1024, 1); add_194 = None + getitem_1723 = split_176[0] + getitem_1724 = split_176[1] + getitem_1725 = split_176[2] + getitem_1726 = split_176[3] + getitem_1727 = split_176[4] + getitem_1728 = split_176[5] + getitem_1729 = split_176[6] + getitem_1730 = split_176[7]; split_176 = None + cat_168 = torch.ops.aten.cat.default([getitem_1723, getitem_1724, getitem_1725, getitem_1726, getitem_1727, getitem_1728, getitem_1729, getitem_1730]); getitem_1723 = getitem_1724 = getitem_1725 = getitem_1726 = getitem_1727 = getitem_1728 = getitem_1729 = getitem_1730 = None + reduce_scatter_tensor_170 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_168, 'sum', 8, '1'); cat_168 = None + wait_tensor_564 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_170); reduce_scatter_tensor_170 = None + convert_element_type_1578 = torch.ops.prims.convert_element_type.default(wait_tensor_564, torch.float32); wait_tensor_564 = None + convert_element_type_1580 = torch.ops.prims.convert_element_type.default(wait_tensor_295, torch.float32); wait_tensor_295 = None + mul_450 = torch.ops.aten.mul.Tensor(convert_element_type_1578, convert_element_type_1580); convert_element_type_1580 = None + mul_452 = torch.ops.aten.mul.Tensor(mul_180, mul_450) + sum_57 = torch.ops.aten.sum.dim_IntList(mul_452, [2], True); mul_452 = None + div_19 = torch.ops.aten.div.Tensor(mul_180, 4096) + mul_453 = torch.ops.aten.mul.Tensor(div_19, sum_57); div_19 = sum_57 = None + sub_30 = torch.ops.aten.sub.Tensor(mul_450, mul_453); mul_450 = mul_453 = None + mul_454 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_45); sub_30 = rsqrt_45 = None + mul_455 = torch.ops.aten.mul.Tensor(convert_element_type_1578, mul_180); convert_element_type_1578 = mul_180 = None + sum_58 = torch.ops.aten.sum.dim_IntList(mul_455, [0, 1]); mul_455 = None + convert_element_type_1581 = torch.ops.prims.convert_element_type.default(mul_454, torch.bfloat16); mul_454 = None + convert_element_type_1582 = torch.ops.prims.convert_element_type.default(sum_58, torch.bfloat16); sum_58 = None + all_reduce_19 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1582, 'sum', '1'); convert_element_type_1582 = None + wait_tensor_565 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_19); all_reduce_19 = None + convert_element_type_1583 = torch.ops.prims.convert_element_type.default(wait_tensor_565, torch.float32); wait_tensor_565 = None + reduce_scatter_tensor_171 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1583, 'avg', 32, '0'); convert_element_type_1583 = None + wait_tensor_566 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_171); reduce_scatter_tensor_171 = None + add_195 = torch.ops.aten.add.Tensor(add_191, convert_element_type_1581); add_191 = convert_element_type_1581 = None + all_gather_into_tensor_375 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_195, 8, '1') + wait_tensor_567 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_375); all_gather_into_tensor_375 = None + split_177 = torch.ops.aten.split.Tensor(wait_tensor_567, 2); wait_tensor_567 = None + getitem_1731 = split_177[0] + getitem_1732 = split_177[1] + getitem_1733 = split_177[2] + getitem_1734 = split_177[3] + getitem_1735 = split_177[4] + getitem_1736 = split_177[5] + getitem_1737 = split_177[6] + getitem_1738 = split_177[7]; split_177 = None + cat_169 = torch.ops.aten.cat.default([getitem_1731, getitem_1732, getitem_1733, getitem_1734, getitem_1735, getitem_1736, getitem_1737, getitem_1738], 1); getitem_1731 = getitem_1732 = getitem_1733 = getitem_1734 = getitem_1735 = getitem_1736 = getitem_1737 = getitem_1738 = None + view_2545 = torch.ops.aten.view.default(cat_169, [16384, 4096]); cat_169 = None + permute_657 = torch.ops.aten.permute.default(view_2545, [1, 0]) + permute_248 = torch.ops.aten.permute.default(getitem_982, [0, 2, 1, 3]) + view_1626 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + view_1632 = torch.ops.aten.view.default(view_1626, [16384, 512]); view_1626 = None + mm_359 = torch.ops.aten.mm.default(permute_657, view_1632); permute_657 = view_1632 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16); primals_206 = None + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 32, '0'); convert_element_type_743 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_293, [1, 0]); wait_tensor_293 = None + permute_659 = torch.ops.aten.permute.default(permute_249, [1, 0]); permute_249 = None + mm_360 = torch.ops.aten.mm.default(view_2545, permute_659); view_2545 = permute_659 = None + view_2546 = torch.ops.aten.view.default(mm_360, [2, 8192, 512]); mm_360 = None + convert_element_type_1588 = torch.ops.prims.convert_element_type.default(mm_359, torch.float32); mm_359 = None + reduce_scatter_tensor_172 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1588, 'avg', 32, '0'); convert_element_type_1588 = None + wait_tensor_568 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_172); reduce_scatter_tensor_172 = None + view_2547 = torch.ops.aten.view.default(view_2546, [2, 8192, 4, 128]); view_2546 = None + permute_661 = torch.ops.aten.permute.default(view_2547, [0, 2, 1, 3]); view_2547 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16); primals_202 = None + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 32, '0'); convert_element_type_727 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32); add_87 = None + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_288) + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_729, 8, '1'); convert_element_type_729 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + split_97 = torch.ops.aten.split.Tensor(wait_tensor_289, 2); wait_tensor_289 = None + getitem_974 = split_97[0] + getitem_975 = split_97[1] + getitem_976 = split_97[2] + getitem_977 = split_97[3] + getitem_978 = split_97[4] + getitem_979 = split_97[5] + getitem_980 = split_97[6] + getitem_981 = split_97[7]; split_97 = None + cat_89 = torch.ops.aten.cat.default([getitem_974, getitem_975, getitem_976, getitem_977, getitem_978, getitem_979, getitem_980, getitem_981], 1); getitem_974 = getitem_975 = getitem_976 = getitem_977 = getitem_978 = getitem_979 = getitem_980 = getitem_981 = None + view_1599 = torch.ops.aten.view.default(cat_89, [16384, 4096]); cat_89 = None + view_1600 = torch.ops.aten.view.default(mm_154, [2, 8192, 512]); mm_154 = None + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16); primals_204 = None + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 32, '0'); convert_element_type_733 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_291, [1, 0]); wait_tensor_291 = None + mm_155 = torch.ops.aten.mm.default(view_1599, permute_243) + view_1607 = torch.ops.aten.view.default(mm_155, [2, 8192, 128]); mm_155 = None + view_1614 = torch.ops.aten.view.default(mm_156, [2, 8192, 128]); mm_156 = None + view_1616 = torch.ops.aten.view.default(view_1600, [2, 8192, -1, 128]); view_1600 = None + view_1617 = torch.ops.aten.view.default(view_1607, [2, 8192, -1, 128]); view_1607 = None + view_1618 = torch.ops.aten.view.default(view_1614, [2, 8192, -1, 128]); view_1614 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_1616, torch.float32); view_1616 = None + view_1619 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 4, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_1619); view_1619 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_1617, torch.float32); view_1617 = None + view_1620 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 1, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_1620); view_1620 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_37); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_1622 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 4, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_37); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_1623 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 1, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_1622, torch.bfloat16); view_1622 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_1623, torch.bfloat16); view_1623 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 1, 4, 128]); unsqueeze_44 = None + view_1624 = torch.ops.aten.view.default(expand_44, [2, 8192, 4, 128]); expand_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_1618, 3); view_1618 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 1, 4, 128]); unsqueeze_45 = None + view_1625 = torch.ops.aten.view.default(expand_45, [2, 8192, 4, 128]); expand_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_1624, [0, 2, 1, 3]); view_1624 = None + permute_247 = torch.ops.aten.permute.default(view_1625, [0, 2, 1, 3]); view_1625 = None + _scaled_dot_product_cudnn_attention_backward_9 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_661, permute_245, permute_246, permute_247, getitem_982, getitem_983, getitem_988, getitem_989, None, None, None, 8192, 8192, 0.0, True); permute_661 = permute_245 = permute_246 = permute_247 = getitem_982 = getitem_983 = getitem_988 = getitem_989 = None + getitem_1739 = _scaled_dot_product_cudnn_attention_backward_9[0] + getitem_1740 = _scaled_dot_product_cudnn_attention_backward_9[1] + getitem_1741 = _scaled_dot_product_cudnn_attention_backward_9[2]; _scaled_dot_product_cudnn_attention_backward_9 = None + permute_662 = torch.ops.aten.permute.default(getitem_1741, [0, 2, 1, 3]); getitem_1741 = None + permute_663 = torch.ops.aten.permute.default(getitem_1740, [0, 2, 1, 3]); getitem_1740 = None + permute_664 = torch.ops.aten.permute.default(getitem_1739, [0, 2, 1, 3]); getitem_1739 = None + view_2548 = torch.ops.aten.view.default(permute_662, [2, 8192, 1, 4, 128]); permute_662 = None + sum_59 = torch.ops.aten.sum.dim_IntList(view_2548, [3], True); view_2548 = None + squeeze_18 = torch.ops.aten.squeeze.dim(sum_59, 3); sum_59 = None + view_2549 = torch.ops.aten.view.default(permute_663, [2, 8192, 1, 4, 128]); permute_663 = None + sum_60 = torch.ops.aten.sum.dim_IntList(view_2549, [3], True); view_2549 = None + squeeze_19 = torch.ops.aten.squeeze.dim(sum_60, 3); sum_60 = None + convert_element_type_1589 = torch.ops.prims.convert_element_type.default(squeeze_19, torch.float32); squeeze_19 = None + convert_element_type_1590 = torch.ops.prims.convert_element_type.default(permute_664, torch.float32); permute_664 = None + view_2550 = torch.ops.aten.view.default(convert_element_type_1589, [2, 8192, 1, 64, 2]); convert_element_type_1589 = None + view_as_complex_82 = torch.ops.aten.view_as_complex.default(view_2550); view_2550 = None + mul_456 = torch.ops.aten.mul.Tensor(view_as_complex_82, _conj); view_as_complex_82 = None + view_2551 = torch.ops.aten.view.default(convert_element_type_1590, [2, 8192, 4, 64, 2]); convert_element_type_1590 = None + view_as_complex_83 = torch.ops.aten.view_as_complex.default(view_2551); view_2551 = None + mul_457 = torch.ops.aten.mul.Tensor(view_as_complex_83, _conj); view_as_complex_83 = None + view_as_real_82 = torch.ops.aten.view_as_real.default(mul_456); mul_456 = None + view_2552 = torch.ops.aten.view.default(view_as_real_82, [2, 8192, 1, 128]); view_as_real_82 = None + convert_element_type_1591 = torch.ops.prims.convert_element_type.default(view_2552, torch.bfloat16); view_2552 = None + view_as_real_83 = torch.ops.aten.view_as_real.default(mul_457); mul_457 = None + view_2553 = torch.ops.aten.view.default(view_as_real_83, [2, 8192, 4, 128]); view_as_real_83 = None + convert_element_type_1592 = torch.ops.prims.convert_element_type.default(view_2553, torch.bfloat16); view_2553 = None + view_2554 = torch.ops.aten.view.default(squeeze_18, [2, 8192, 128]); squeeze_18 = None + view_2555 = torch.ops.aten.view.default(convert_element_type_1591, [2, 8192, 128]); convert_element_type_1591 = None + view_2556 = torch.ops.aten.view.default(convert_element_type_1592, [2, 8192, 512]); convert_element_type_1592 = None + view_2557 = torch.ops.aten.view.default(view_2554, [16384, 128]); view_2554 = None + permute_665 = torch.ops.aten.permute.default(view_2557, [1, 0]) + mm_361 = torch.ops.aten.mm.default(permute_665, view_1599); permute_665 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16); primals_205 = None + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 32, '0'); convert_element_type_736 = None + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_292, [1, 0]); wait_tensor_292 = None + permute_667 = torch.ops.aten.permute.default(permute_244, [1, 0]); permute_244 = None + mm_362 = torch.ops.aten.mm.default(view_2557, permute_667); view_2557 = permute_667 = None + view_2558 = torch.ops.aten.view.default(mm_362, [2, 8192, 4096]); mm_362 = None + convert_element_type_1597 = torch.ops.prims.convert_element_type.default(mm_361, torch.float32); mm_361 = None + reduce_scatter_tensor_173 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1597, 'avg', 32, '0'); convert_element_type_1597 = None + wait_tensor_569 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_173); reduce_scatter_tensor_173 = None + view_2559 = torch.ops.aten.view.default(view_2555, [16384, 128]); view_2555 = None + permute_669 = torch.ops.aten.permute.default(view_2559, [1, 0]) + mm_363 = torch.ops.aten.mm.default(permute_669, view_1599); permute_669 = None + permute_671 = torch.ops.aten.permute.default(permute_243, [1, 0]); permute_243 = None + mm_364 = torch.ops.aten.mm.default(view_2559, permute_671); view_2559 = permute_671 = None + view_2560 = torch.ops.aten.view.default(mm_364, [2, 8192, 4096]); mm_364 = None + add_196 = torch.ops.aten.add.Tensor(view_2558, view_2560); view_2558 = view_2560 = None + convert_element_type_1602 = torch.ops.prims.convert_element_type.default(mm_363, torch.float32); mm_363 = None + reduce_scatter_tensor_174 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1602, 'avg', 32, '0'); convert_element_type_1602 = None + wait_tensor_570 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_174); reduce_scatter_tensor_174 = None + view_2561 = torch.ops.aten.view.default(view_2556, [16384, 512]); view_2556 = None + permute_673 = torch.ops.aten.permute.default(view_2561, [1, 0]) + mm_365 = torch.ops.aten.mm.default(permute_673, view_1599); permute_673 = view_1599 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16); primals_203 = None + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 32, '0'); convert_element_type_730 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + permute_675 = torch.ops.aten.permute.default(permute_242, [1, 0]); permute_242 = None + mm_366 = torch.ops.aten.mm.default(view_2561, permute_675); view_2561 = permute_675 = None + view_2562 = torch.ops.aten.view.default(mm_366, [2, 8192, 4096]); mm_366 = None + add_197 = torch.ops.aten.add.Tensor(add_196, view_2562); add_196 = view_2562 = None + convert_element_type_1607 = torch.ops.prims.convert_element_type.default(mm_365, torch.float32); mm_365 = None + reduce_scatter_tensor_175 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1607, 'avg', 32, '0'); convert_element_type_1607 = None + wait_tensor_571 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_175); reduce_scatter_tensor_175 = None + split_178 = torch.ops.aten.split.Tensor(add_197, 1024, 1); add_197 = None + getitem_1742 = split_178[0] + getitem_1743 = split_178[1] + getitem_1744 = split_178[2] + getitem_1745 = split_178[3] + getitem_1746 = split_178[4] + getitem_1747 = split_178[5] + getitem_1748 = split_178[6] + getitem_1749 = split_178[7]; split_178 = None + cat_170 = torch.ops.aten.cat.default([getitem_1742, getitem_1743, getitem_1744, getitem_1745, getitem_1746, getitem_1747, getitem_1748, getitem_1749]); getitem_1742 = getitem_1743 = getitem_1744 = getitem_1745 = getitem_1746 = getitem_1747 = getitem_1748 = getitem_1749 = None + reduce_scatter_tensor_176 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_170, 'sum', 8, '1'); cat_170 = None + wait_tensor_572 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_176); reduce_scatter_tensor_176 = None + convert_element_type_1608 = torch.ops.prims.convert_element_type.default(wait_tensor_572, torch.float32); wait_tensor_572 = None + convert_element_type_1610 = torch.ops.prims.convert_element_type.default(wait_tensor_288, torch.float32); wait_tensor_288 = None + mul_458 = torch.ops.aten.mul.Tensor(convert_element_type_1608, convert_element_type_1610); convert_element_type_1610 = None + mul_460 = torch.ops.aten.mul.Tensor(mul_176, mul_458) + sum_61 = torch.ops.aten.sum.dim_IntList(mul_460, [2], True); mul_460 = None + div_20 = torch.ops.aten.div.Tensor(mul_176, 4096) + mul_461 = torch.ops.aten.mul.Tensor(div_20, sum_61); div_20 = sum_61 = None + sub_31 = torch.ops.aten.sub.Tensor(mul_458, mul_461); mul_458 = mul_461 = None + mul_462 = torch.ops.aten.mul.Tensor(sub_31, rsqrt_44); sub_31 = rsqrt_44 = None + mul_463 = torch.ops.aten.mul.Tensor(convert_element_type_1608, mul_176); convert_element_type_1608 = mul_176 = None + sum_62 = torch.ops.aten.sum.dim_IntList(mul_463, [0, 1]); mul_463 = None + convert_element_type_1611 = torch.ops.prims.convert_element_type.default(mul_462, torch.bfloat16); mul_462 = None + convert_element_type_1612 = torch.ops.prims.convert_element_type.default(sum_62, torch.bfloat16); sum_62 = None + all_reduce_20 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1612, 'sum', '1'); convert_element_type_1612 = None + wait_tensor_573 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_20); all_reduce_20 = None + convert_element_type_1613 = torch.ops.prims.convert_element_type.default(wait_tensor_573, torch.float32); wait_tensor_573 = None + reduce_scatter_tensor_177 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1613, 'avg', 32, '0'); convert_element_type_1613 = None + wait_tensor_574 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_177); reduce_scatter_tensor_177 = None + add_198 = torch.ops.aten.add.Tensor(add_195, convert_element_type_1611); add_195 = convert_element_type_1611 = None + all_gather_into_tensor_376 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_198, 8, '1') + wait_tensor_575 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_376); all_gather_into_tensor_376 = None + split_179 = torch.ops.aten.split.Tensor(wait_tensor_575, 2); wait_tensor_575 = None + getitem_1750 = split_179[0] + getitem_1751 = split_179[1] + getitem_1752 = split_179[2] + getitem_1753 = split_179[3] + getitem_1754 = split_179[4] + getitem_1755 = split_179[5] + getitem_1756 = split_179[6] + getitem_1757 = split_179[7]; split_179 = None + cat_171 = torch.ops.aten.cat.default([getitem_1750, getitem_1751, getitem_1752, getitem_1753, getitem_1754, getitem_1755, getitem_1756, getitem_1757], 1); getitem_1750 = getitem_1751 = getitem_1752 = getitem_1753 = getitem_1754 = getitem_1755 = getitem_1756 = getitem_1757 = None + view_2563 = torch.ops.aten.view.default(cat_171, [16384, 4096]); cat_171 = None + permute_677 = torch.ops.aten.permute.default(view_2563, [1, 0]) + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43); reduce_scatter_tensor_43 = None + add_85 = torch.ops.aten.add.Tensor(add_83, wait_tensor_281); wait_tensor_281 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16); primals_198 = None + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 32, '0'); convert_element_type_713 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32); add_85 = None + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_282) + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_715, 8, '1'); convert_element_type_715 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + split_95 = torch.ops.aten.split.Tensor(wait_tensor_283, 2); wait_tensor_283 = None + getitem_958 = split_95[0] + getitem_959 = split_95[1] + getitem_960 = split_95[2] + getitem_961 = split_95[3] + getitem_962 = split_95[4] + getitem_963 = split_95[5] + getitem_964 = split_95[6] + getitem_965 = split_95[7]; split_95 = None + cat_87 = torch.ops.aten.cat.default([getitem_958, getitem_959, getitem_960, getitem_961, getitem_962, getitem_963, getitem_964, getitem_965], 1); getitem_958 = getitem_959 = getitem_960 = getitem_961 = getitem_962 = getitem_963 = getitem_964 = getitem_965 = None + view_1572 = torch.ops.aten.view.default(cat_87, [16384, 4096]); cat_87 = None + view_1573 = torch.ops.aten.view.default(mm_151, [2, 8192, 1792]); mm_151 = None + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_1573, torch.float32); view_1573 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16); primals_200 = None + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 32, '0'); convert_element_type_721 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_285, [1, 0]); wait_tensor_285 = None + mm_152 = torch.ops.aten.mm.default(view_1572, permute_240) + view_1580 = torch.ops.aten.view.default(mm_152, [2, 8192, 1792]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_1580) + view_1587 = torch.ops.aten.view.default(mul_175, [16384, 1792]); mul_175 = None + mm_367 = torch.ops.aten.mm.default(permute_677, view_1587); permute_677 = view_1587 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16); primals_201 = None + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 32, '0'); convert_element_type_724 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + permute_679 = torch.ops.aten.permute.default(permute_241, [1, 0]); permute_241 = None + mm_368 = torch.ops.aten.mm.default(view_2563, permute_679); view_2563 = permute_679 = None + view_2564 = torch.ops.aten.view.default(mm_368, [2, 8192, 1792]); mm_368 = None + convert_element_type_1618 = torch.ops.prims.convert_element_type.default(mm_367, torch.float32); mm_367 = None + reduce_scatter_tensor_178 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1618, 'avg', 32, '0'); convert_element_type_1618 = None + wait_tensor_576 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_178); reduce_scatter_tensor_178 = None + mul_464 = torch.ops.aten.mul.Tensor(view_2564, convert_element_type_720); convert_element_type_720 = None + mul_465 = torch.ops.aten.mul.Tensor(view_2564, view_1580); view_2564 = view_1580 = None + view_2565 = torch.ops.aten.view.default(mul_464, [16384, 1792]); mul_464 = None + permute_681 = torch.ops.aten.permute.default(view_2565, [1, 0]) + mm_369 = torch.ops.aten.mm.default(permute_681, view_1572); permute_681 = None + permute_683 = torch.ops.aten.permute.default(permute_240, [1, 0]); permute_240 = None + mm_370 = torch.ops.aten.mm.default(view_2565, permute_683); view_2565 = permute_683 = None + view_2566 = torch.ops.aten.view.default(mm_370, [2, 8192, 4096]); mm_370 = None + convert_element_type_1623 = torch.ops.prims.convert_element_type.default(mm_369, torch.float32); mm_369 = None + reduce_scatter_tensor_179 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1623, 'avg', 32, '0'); convert_element_type_1623 = None + wait_tensor_577 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_179); reduce_scatter_tensor_179 = None + convert_element_type_1624 = torch.ops.prims.convert_element_type.default(mul_465, torch.float32); mul_465 = None + neg_10 = torch.ops.aten.neg.default(convert_element_type_719) + exp_10 = torch.ops.aten.exp.default(neg_10); neg_10 = None + add_199 = torch.ops.aten.add.Tensor(exp_10, 1); exp_10 = None + reciprocal_10 = torch.ops.aten.reciprocal.default(add_199); add_199 = None + mul_466 = torch.ops.aten.mul.Tensor(reciprocal_10, 1); reciprocal_10 = None + mul_467 = torch.ops.aten.mul.Tensor(convert_element_type_1624, mul_466); convert_element_type_1624 = None + sub_32 = torch.ops.aten.sub.Tensor(1, mul_466); mul_466 = None + mul_468 = torch.ops.aten.mul.Tensor(convert_element_type_719, sub_32); convert_element_type_719 = sub_32 = None + add_200 = torch.ops.aten.add.Tensor(mul_468, 1); mul_468 = None + mul_469 = torch.ops.aten.mul.Tensor(mul_467, add_200); mul_467 = add_200 = None + convert_element_type_1626 = torch.ops.prims.convert_element_type.default(mul_469, torch.bfloat16); mul_469 = None + view_2567 = torch.ops.aten.view.default(convert_element_type_1626, [16384, 1792]); convert_element_type_1626 = None + permute_685 = torch.ops.aten.permute.default(view_2567, [1, 0]) + mm_371 = torch.ops.aten.mm.default(permute_685, view_1572); permute_685 = view_1572 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16); primals_199 = None + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 32, '0'); convert_element_type_716 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + permute_687 = torch.ops.aten.permute.default(permute_239, [1, 0]); permute_239 = None + mm_372 = torch.ops.aten.mm.default(view_2567, permute_687); view_2567 = permute_687 = None + view_2568 = torch.ops.aten.view.default(mm_372, [2, 8192, 4096]); mm_372 = None + add_201 = torch.ops.aten.add.Tensor(view_2566, view_2568); view_2566 = view_2568 = None + convert_element_type_1631 = torch.ops.prims.convert_element_type.default(mm_371, torch.float32); mm_371 = None + reduce_scatter_tensor_180 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1631, 'avg', 32, '0'); convert_element_type_1631 = None + wait_tensor_578 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_180); reduce_scatter_tensor_180 = None + split_180 = torch.ops.aten.split.Tensor(add_201, 1024, 1); add_201 = None + getitem_1758 = split_180[0] + getitem_1759 = split_180[1] + getitem_1760 = split_180[2] + getitem_1761 = split_180[3] + getitem_1762 = split_180[4] + getitem_1763 = split_180[5] + getitem_1764 = split_180[6] + getitem_1765 = split_180[7]; split_180 = None + cat_172 = torch.ops.aten.cat.default([getitem_1758, getitem_1759, getitem_1760, getitem_1761, getitem_1762, getitem_1763, getitem_1764, getitem_1765]); getitem_1758 = getitem_1759 = getitem_1760 = getitem_1761 = getitem_1762 = getitem_1763 = getitem_1764 = getitem_1765 = None + reduce_scatter_tensor_181 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_172, 'sum', 8, '1'); cat_172 = None + wait_tensor_579 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_181); reduce_scatter_tensor_181 = None + convert_element_type_1632 = torch.ops.prims.convert_element_type.default(wait_tensor_579, torch.float32); wait_tensor_579 = None + convert_element_type_1634 = torch.ops.prims.convert_element_type.default(wait_tensor_282, torch.float32); wait_tensor_282 = None + mul_470 = torch.ops.aten.mul.Tensor(convert_element_type_1632, convert_element_type_1634); convert_element_type_1634 = None + mul_472 = torch.ops.aten.mul.Tensor(mul_172, mul_470) + sum_63 = torch.ops.aten.sum.dim_IntList(mul_472, [2], True); mul_472 = None + div_21 = torch.ops.aten.div.Tensor(mul_172, 4096) + mul_473 = torch.ops.aten.mul.Tensor(div_21, sum_63); div_21 = sum_63 = None + sub_33 = torch.ops.aten.sub.Tensor(mul_470, mul_473); mul_470 = mul_473 = None + mul_474 = torch.ops.aten.mul.Tensor(sub_33, rsqrt_43); sub_33 = rsqrt_43 = None + mul_475 = torch.ops.aten.mul.Tensor(convert_element_type_1632, mul_172); convert_element_type_1632 = mul_172 = None + sum_64 = torch.ops.aten.sum.dim_IntList(mul_475, [0, 1]); mul_475 = None + convert_element_type_1635 = torch.ops.prims.convert_element_type.default(mul_474, torch.bfloat16); mul_474 = None + convert_element_type_1636 = torch.ops.prims.convert_element_type.default(sum_64, torch.bfloat16); sum_64 = None + all_reduce_21 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1636, 'sum', '1'); convert_element_type_1636 = None + wait_tensor_580 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_21); all_reduce_21 = None + convert_element_type_1637 = torch.ops.prims.convert_element_type.default(wait_tensor_580, torch.float32); wait_tensor_580 = None + reduce_scatter_tensor_182 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1637, 'avg', 32, '0'); convert_element_type_1637 = None + wait_tensor_581 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_182); reduce_scatter_tensor_182 = None + add_202 = torch.ops.aten.add.Tensor(add_198, convert_element_type_1635); add_198 = convert_element_type_1635 = None + all_gather_into_tensor_377 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_202, 8, '1') + wait_tensor_582 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_377); all_gather_into_tensor_377 = None + split_181 = torch.ops.aten.split.Tensor(wait_tensor_582, 2); wait_tensor_582 = None + getitem_1766 = split_181[0] + getitem_1767 = split_181[1] + getitem_1768 = split_181[2] + getitem_1769 = split_181[3] + getitem_1770 = split_181[4] + getitem_1771 = split_181[5] + getitem_1772 = split_181[6] + getitem_1773 = split_181[7]; split_181 = None + cat_173 = torch.ops.aten.cat.default([getitem_1766, getitem_1767, getitem_1768, getitem_1769, getitem_1770, getitem_1771, getitem_1772, getitem_1773], 1); getitem_1766 = getitem_1767 = getitem_1768 = getitem_1769 = getitem_1770 = getitem_1771 = getitem_1772 = getitem_1773 = None + view_2569 = torch.ops.aten.view.default(cat_173, [16384, 4096]); cat_173 = None + permute_689 = torch.ops.aten.permute.default(view_2569, [1, 0]) + permute_237 = torch.ops.aten.permute.default(getitem_941, [0, 2, 1, 3]) + view_1554 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + view_1560 = torch.ops.aten.view.default(view_1554, [16384, 512]); view_1554 = None + mm_373 = torch.ops.aten.mm.default(permute_689, view_1560); permute_689 = view_1560 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16); primals_197 = None + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 32, '0'); convert_element_type_710 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_280, [1, 0]); wait_tensor_280 = None + permute_691 = torch.ops.aten.permute.default(permute_238, [1, 0]); permute_238 = None + mm_374 = torch.ops.aten.mm.default(view_2569, permute_691); view_2569 = permute_691 = None + view_2570 = torch.ops.aten.view.default(mm_374, [2, 8192, 512]); mm_374 = None + convert_element_type_1642 = torch.ops.prims.convert_element_type.default(mm_373, torch.float32); mm_373 = None + reduce_scatter_tensor_183 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1642, 'avg', 32, '0'); convert_element_type_1642 = None + wait_tensor_583 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_183); reduce_scatter_tensor_183 = None + view_2571 = torch.ops.aten.view.default(view_2570, [2, 8192, 4, 128]); view_2570 = None + permute_693 = torch.ops.aten.permute.default(view_2571, [0, 2, 1, 3]); view_2571 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16); primals_193 = None + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 32, '0'); convert_element_type_694 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32); add_83 = None + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_275) + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_696, 8, '1'); convert_element_type_696 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + split_93 = torch.ops.aten.split.Tensor(wait_tensor_276, 2); wait_tensor_276 = None + getitem_933 = split_93[0] + getitem_934 = split_93[1] + getitem_935 = split_93[2] + getitem_936 = split_93[3] + getitem_937 = split_93[4] + getitem_938 = split_93[5] + getitem_939 = split_93[6] + getitem_940 = split_93[7]; split_93 = None + cat_85 = torch.ops.aten.cat.default([getitem_933, getitem_934, getitem_935, getitem_936, getitem_937, getitem_938, getitem_939, getitem_940], 1); getitem_933 = getitem_934 = getitem_935 = getitem_936 = getitem_937 = getitem_938 = getitem_939 = getitem_940 = None + view_1527 = torch.ops.aten.view.default(cat_85, [16384, 4096]); cat_85 = None + view_1528 = torch.ops.aten.view.default(mm_147, [2, 8192, 512]); mm_147 = None + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16); primals_195 = None + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 32, '0'); convert_element_type_700 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_148 = torch.ops.aten.mm.default(view_1527, permute_232) + view_1535 = torch.ops.aten.view.default(mm_148, [2, 8192, 128]); mm_148 = None + view_1542 = torch.ops.aten.view.default(mm_149, [2, 8192, 128]); mm_149 = None + view_1544 = torch.ops.aten.view.default(view_1528, [2, 8192, -1, 128]); view_1528 = None + view_1545 = torch.ops.aten.view.default(view_1535, [2, 8192, -1, 128]); view_1535 = None + view_1546 = torch.ops.aten.view.default(view_1542, [2, 8192, -1, 128]); view_1542 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_1544, torch.float32); view_1544 = None + view_1547 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 4, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_1547); view_1547 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_1545, torch.float32); view_1545 = None + view_1548 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 1, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_1548); view_1548 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_37); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_1550 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 4, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_37); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_1551 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 1, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_1550, torch.bfloat16); view_1550 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_1551, torch.bfloat16); view_1551 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 1, 4, 128]); unsqueeze_42 = None + view_1552 = torch.ops.aten.view.default(expand_42, [2, 8192, 4, 128]); expand_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_1546, 3); view_1546 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 1, 4, 128]); unsqueeze_43 = None + view_1553 = torch.ops.aten.view.default(expand_43, [2, 8192, 4, 128]); expand_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_1552, [0, 2, 1, 3]); view_1552 = None + permute_236 = torch.ops.aten.permute.default(view_1553, [0, 2, 1, 3]); view_1553 = None + _scaled_dot_product_cudnn_attention_backward_10 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_693, permute_234, permute_235, permute_236, getitem_941, getitem_942, getitem_947, getitem_948, None, None, None, 8192, 8192, 0.0, True); permute_693 = permute_234 = permute_235 = permute_236 = getitem_941 = getitem_942 = getitem_947 = getitem_948 = None + getitem_1774 = _scaled_dot_product_cudnn_attention_backward_10[0] + getitem_1775 = _scaled_dot_product_cudnn_attention_backward_10[1] + getitem_1776 = _scaled_dot_product_cudnn_attention_backward_10[2]; _scaled_dot_product_cudnn_attention_backward_10 = None + permute_694 = torch.ops.aten.permute.default(getitem_1776, [0, 2, 1, 3]); getitem_1776 = None + permute_695 = torch.ops.aten.permute.default(getitem_1775, [0, 2, 1, 3]); getitem_1775 = None + permute_696 = torch.ops.aten.permute.default(getitem_1774, [0, 2, 1, 3]); getitem_1774 = None + view_2572 = torch.ops.aten.view.default(permute_694, [2, 8192, 1, 4, 128]); permute_694 = None + sum_65 = torch.ops.aten.sum.dim_IntList(view_2572, [3], True); view_2572 = None + squeeze_20 = torch.ops.aten.squeeze.dim(sum_65, 3); sum_65 = None + view_2573 = torch.ops.aten.view.default(permute_695, [2, 8192, 1, 4, 128]); permute_695 = None + sum_66 = torch.ops.aten.sum.dim_IntList(view_2573, [3], True); view_2573 = None + squeeze_21 = torch.ops.aten.squeeze.dim(sum_66, 3); sum_66 = None + convert_element_type_1643 = torch.ops.prims.convert_element_type.default(squeeze_21, torch.float32); squeeze_21 = None + convert_element_type_1644 = torch.ops.prims.convert_element_type.default(permute_696, torch.float32); permute_696 = None + view_2574 = torch.ops.aten.view.default(convert_element_type_1643, [2, 8192, 1, 64, 2]); convert_element_type_1643 = None + view_as_complex_84 = torch.ops.aten.view_as_complex.default(view_2574); view_2574 = None + mul_476 = torch.ops.aten.mul.Tensor(view_as_complex_84, _conj); view_as_complex_84 = None + view_2575 = torch.ops.aten.view.default(convert_element_type_1644, [2, 8192, 4, 64, 2]); convert_element_type_1644 = None + view_as_complex_85 = torch.ops.aten.view_as_complex.default(view_2575); view_2575 = None + mul_477 = torch.ops.aten.mul.Tensor(view_as_complex_85, _conj); view_as_complex_85 = None + view_as_real_84 = torch.ops.aten.view_as_real.default(mul_476); mul_476 = None + view_2576 = torch.ops.aten.view.default(view_as_real_84, [2, 8192, 1, 128]); view_as_real_84 = None + convert_element_type_1645 = torch.ops.prims.convert_element_type.default(view_2576, torch.bfloat16); view_2576 = None + view_as_real_85 = torch.ops.aten.view_as_real.default(mul_477); mul_477 = None + view_2577 = torch.ops.aten.view.default(view_as_real_85, [2, 8192, 4, 128]); view_as_real_85 = None + convert_element_type_1646 = torch.ops.prims.convert_element_type.default(view_2577, torch.bfloat16); view_2577 = None + view_2578 = torch.ops.aten.view.default(squeeze_20, [2, 8192, 128]); squeeze_20 = None + view_2579 = torch.ops.aten.view.default(convert_element_type_1645, [2, 8192, 128]); convert_element_type_1645 = None + view_2580 = torch.ops.aten.view.default(convert_element_type_1646, [2, 8192, 512]); convert_element_type_1646 = None + view_2581 = torch.ops.aten.view.default(view_2578, [16384, 128]); view_2578 = None + permute_697 = torch.ops.aten.permute.default(view_2581, [1, 0]) + mm_375 = torch.ops.aten.mm.default(permute_697, view_1527); permute_697 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16); primals_196 = None + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 32, '0'); convert_element_type_703 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + permute_699 = torch.ops.aten.permute.default(permute_233, [1, 0]); permute_233 = None + mm_376 = torch.ops.aten.mm.default(view_2581, permute_699); view_2581 = permute_699 = None + view_2582 = torch.ops.aten.view.default(mm_376, [2, 8192, 4096]); mm_376 = None + convert_element_type_1651 = torch.ops.prims.convert_element_type.default(mm_375, torch.float32); mm_375 = None + reduce_scatter_tensor_184 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1651, 'avg', 32, '0'); convert_element_type_1651 = None + wait_tensor_584 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_184); reduce_scatter_tensor_184 = None + view_2583 = torch.ops.aten.view.default(view_2579, [16384, 128]); view_2579 = None + permute_701 = torch.ops.aten.permute.default(view_2583, [1, 0]) + mm_377 = torch.ops.aten.mm.default(permute_701, view_1527); permute_701 = None + permute_703 = torch.ops.aten.permute.default(permute_232, [1, 0]); permute_232 = None + mm_378 = torch.ops.aten.mm.default(view_2583, permute_703); view_2583 = permute_703 = None + view_2584 = torch.ops.aten.view.default(mm_378, [2, 8192, 4096]); mm_378 = None + add_203 = torch.ops.aten.add.Tensor(view_2582, view_2584); view_2582 = view_2584 = None + convert_element_type_1656 = torch.ops.prims.convert_element_type.default(mm_377, torch.float32); mm_377 = None + reduce_scatter_tensor_185 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1656, 'avg', 32, '0'); convert_element_type_1656 = None + wait_tensor_585 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_185); reduce_scatter_tensor_185 = None + view_2585 = torch.ops.aten.view.default(view_2580, [16384, 512]); view_2580 = None + permute_705 = torch.ops.aten.permute.default(view_2585, [1, 0]) + mm_379 = torch.ops.aten.mm.default(permute_705, view_1527); permute_705 = view_1527 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16); primals_194 = None + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 32, '0'); convert_element_type_697 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + permute_707 = torch.ops.aten.permute.default(permute_231, [1, 0]); permute_231 = None + mm_380 = torch.ops.aten.mm.default(view_2585, permute_707); view_2585 = permute_707 = None + view_2586 = torch.ops.aten.view.default(mm_380, [2, 8192, 4096]); mm_380 = None + add_204 = torch.ops.aten.add.Tensor(add_203, view_2586); add_203 = view_2586 = None + convert_element_type_1661 = torch.ops.prims.convert_element_type.default(mm_379, torch.float32); mm_379 = None + reduce_scatter_tensor_186 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1661, 'avg', 32, '0'); convert_element_type_1661 = None + wait_tensor_586 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_186); reduce_scatter_tensor_186 = None + split_182 = torch.ops.aten.split.Tensor(add_204, 1024, 1); add_204 = None + getitem_1777 = split_182[0] + getitem_1778 = split_182[1] + getitem_1779 = split_182[2] + getitem_1780 = split_182[3] + getitem_1781 = split_182[4] + getitem_1782 = split_182[5] + getitem_1783 = split_182[6] + getitem_1784 = split_182[7]; split_182 = None + cat_174 = torch.ops.aten.cat.default([getitem_1777, getitem_1778, getitem_1779, getitem_1780, getitem_1781, getitem_1782, getitem_1783, getitem_1784]); getitem_1777 = getitem_1778 = getitem_1779 = getitem_1780 = getitem_1781 = getitem_1782 = getitem_1783 = getitem_1784 = None + reduce_scatter_tensor_187 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_174, 'sum', 8, '1'); cat_174 = None + wait_tensor_587 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_187); reduce_scatter_tensor_187 = None + convert_element_type_1662 = torch.ops.prims.convert_element_type.default(wait_tensor_587, torch.float32); wait_tensor_587 = None + convert_element_type_1664 = torch.ops.prims.convert_element_type.default(wait_tensor_275, torch.float32); wait_tensor_275 = None + mul_478 = torch.ops.aten.mul.Tensor(convert_element_type_1662, convert_element_type_1664); convert_element_type_1664 = None + mul_480 = torch.ops.aten.mul.Tensor(mul_168, mul_478) + sum_67 = torch.ops.aten.sum.dim_IntList(mul_480, [2], True); mul_480 = None + div_22 = torch.ops.aten.div.Tensor(mul_168, 4096) + mul_481 = torch.ops.aten.mul.Tensor(div_22, sum_67); div_22 = sum_67 = None + sub_34 = torch.ops.aten.sub.Tensor(mul_478, mul_481); mul_478 = mul_481 = None + mul_482 = torch.ops.aten.mul.Tensor(sub_34, rsqrt_42); sub_34 = rsqrt_42 = None + mul_483 = torch.ops.aten.mul.Tensor(convert_element_type_1662, mul_168); convert_element_type_1662 = mul_168 = None + sum_68 = torch.ops.aten.sum.dim_IntList(mul_483, [0, 1]); mul_483 = None + convert_element_type_1665 = torch.ops.prims.convert_element_type.default(mul_482, torch.bfloat16); mul_482 = None + convert_element_type_1666 = torch.ops.prims.convert_element_type.default(sum_68, torch.bfloat16); sum_68 = None + all_reduce_22 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1666, 'sum', '1'); convert_element_type_1666 = None + wait_tensor_588 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_22); all_reduce_22 = None + convert_element_type_1667 = torch.ops.prims.convert_element_type.default(wait_tensor_588, torch.float32); wait_tensor_588 = None + reduce_scatter_tensor_188 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1667, 'avg', 32, '0'); convert_element_type_1667 = None + wait_tensor_589 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_188); reduce_scatter_tensor_188 = None + add_205 = torch.ops.aten.add.Tensor(add_202, convert_element_type_1665); add_202 = convert_element_type_1665 = None + all_gather_into_tensor_378 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_205, 8, '1') + wait_tensor_590 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_378); all_gather_into_tensor_378 = None + split_183 = torch.ops.aten.split.Tensor(wait_tensor_590, 2); wait_tensor_590 = None + getitem_1785 = split_183[0] + getitem_1786 = split_183[1] + getitem_1787 = split_183[2] + getitem_1788 = split_183[3] + getitem_1789 = split_183[4] + getitem_1790 = split_183[5] + getitem_1791 = split_183[6] + getitem_1792 = split_183[7]; split_183 = None + cat_175 = torch.ops.aten.cat.default([getitem_1785, getitem_1786, getitem_1787, getitem_1788, getitem_1789, getitem_1790, getitem_1791, getitem_1792], 1); getitem_1785 = getitem_1786 = getitem_1787 = getitem_1788 = getitem_1789 = getitem_1790 = getitem_1791 = getitem_1792 = None + view_2587 = torch.ops.aten.view.default(cat_175, [16384, 4096]); cat_175 = None + permute_709 = torch.ops.aten.permute.default(view_2587, [1, 0]) + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41); reduce_scatter_tensor_41 = None + add_81 = torch.ops.aten.add.Tensor(add_79, wait_tensor_268); wait_tensor_268 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16); primals_189 = None + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 32, '0'); convert_element_type_680 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32); add_81 = None + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_269) + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_682, 8, '1'); convert_element_type_682 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + split_91 = torch.ops.aten.split.Tensor(wait_tensor_270, 2); wait_tensor_270 = None + getitem_917 = split_91[0] + getitem_918 = split_91[1] + getitem_919 = split_91[2] + getitem_920 = split_91[3] + getitem_921 = split_91[4] + getitem_922 = split_91[5] + getitem_923 = split_91[6] + getitem_924 = split_91[7]; split_91 = None + cat_83 = torch.ops.aten.cat.default([getitem_917, getitem_918, getitem_919, getitem_920, getitem_921, getitem_922, getitem_923, getitem_924], 1); getitem_917 = getitem_918 = getitem_919 = getitem_920 = getitem_921 = getitem_922 = getitem_923 = getitem_924 = None + view_1500 = torch.ops.aten.view.default(cat_83, [16384, 4096]); cat_83 = None + view_1501 = torch.ops.aten.view.default(mm_144, [2, 8192, 1792]); mm_144 = None + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_1501, torch.float32); view_1501 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16); primals_191 = None + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 32, '0'); convert_element_type_688 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + mm_145 = torch.ops.aten.mm.default(view_1500, permute_229) + view_1508 = torch.ops.aten.view.default(mm_145, [2, 8192, 1792]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_1508) + view_1515 = torch.ops.aten.view.default(mul_167, [16384, 1792]); mul_167 = None + mm_381 = torch.ops.aten.mm.default(permute_709, view_1515); permute_709 = view_1515 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16); primals_192 = None + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 32, '0'); convert_element_type_691 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + permute_711 = torch.ops.aten.permute.default(permute_230, [1, 0]); permute_230 = None + mm_382 = torch.ops.aten.mm.default(view_2587, permute_711); view_2587 = permute_711 = None + view_2588 = torch.ops.aten.view.default(mm_382, [2, 8192, 1792]); mm_382 = None + convert_element_type_1672 = torch.ops.prims.convert_element_type.default(mm_381, torch.float32); mm_381 = None + reduce_scatter_tensor_189 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1672, 'avg', 32, '0'); convert_element_type_1672 = None + wait_tensor_591 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_189); reduce_scatter_tensor_189 = None + mul_484 = torch.ops.aten.mul.Tensor(view_2588, convert_element_type_687); convert_element_type_687 = None + mul_485 = torch.ops.aten.mul.Tensor(view_2588, view_1508); view_2588 = view_1508 = None + view_2589 = torch.ops.aten.view.default(mul_484, [16384, 1792]); mul_484 = None + permute_713 = torch.ops.aten.permute.default(view_2589, [1, 0]) + mm_383 = torch.ops.aten.mm.default(permute_713, view_1500); permute_713 = None + permute_715 = torch.ops.aten.permute.default(permute_229, [1, 0]); permute_229 = None + mm_384 = torch.ops.aten.mm.default(view_2589, permute_715); view_2589 = permute_715 = None + view_2590 = torch.ops.aten.view.default(mm_384, [2, 8192, 4096]); mm_384 = None + convert_element_type_1677 = torch.ops.prims.convert_element_type.default(mm_383, torch.float32); mm_383 = None + reduce_scatter_tensor_190 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1677, 'avg', 32, '0'); convert_element_type_1677 = None + wait_tensor_592 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_190); reduce_scatter_tensor_190 = None + convert_element_type_1678 = torch.ops.prims.convert_element_type.default(mul_485, torch.float32); mul_485 = None + neg_11 = torch.ops.aten.neg.default(convert_element_type_686) + exp_11 = torch.ops.aten.exp.default(neg_11); neg_11 = None + add_206 = torch.ops.aten.add.Tensor(exp_11, 1); exp_11 = None + reciprocal_11 = torch.ops.aten.reciprocal.default(add_206); add_206 = None + mul_486 = torch.ops.aten.mul.Tensor(reciprocal_11, 1); reciprocal_11 = None + mul_487 = torch.ops.aten.mul.Tensor(convert_element_type_1678, mul_486); convert_element_type_1678 = None + sub_35 = torch.ops.aten.sub.Tensor(1, mul_486); mul_486 = None + mul_488 = torch.ops.aten.mul.Tensor(convert_element_type_686, sub_35); convert_element_type_686 = sub_35 = None + add_207 = torch.ops.aten.add.Tensor(mul_488, 1); mul_488 = None + mul_489 = torch.ops.aten.mul.Tensor(mul_487, add_207); mul_487 = add_207 = None + convert_element_type_1680 = torch.ops.prims.convert_element_type.default(mul_489, torch.bfloat16); mul_489 = None + view_2591 = torch.ops.aten.view.default(convert_element_type_1680, [16384, 1792]); convert_element_type_1680 = None + permute_717 = torch.ops.aten.permute.default(view_2591, [1, 0]) + mm_385 = torch.ops.aten.mm.default(permute_717, view_1500); permute_717 = view_1500 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16); primals_190 = None + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 32, '0'); convert_element_type_683 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_271, [1, 0]); wait_tensor_271 = None + permute_719 = torch.ops.aten.permute.default(permute_228, [1, 0]); permute_228 = None + mm_386 = torch.ops.aten.mm.default(view_2591, permute_719); view_2591 = permute_719 = None + view_2592 = torch.ops.aten.view.default(mm_386, [2, 8192, 4096]); mm_386 = None + add_208 = torch.ops.aten.add.Tensor(view_2590, view_2592); view_2590 = view_2592 = None + convert_element_type_1685 = torch.ops.prims.convert_element_type.default(mm_385, torch.float32); mm_385 = None + reduce_scatter_tensor_191 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1685, 'avg', 32, '0'); convert_element_type_1685 = None + wait_tensor_593 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_191); reduce_scatter_tensor_191 = None + split_184 = torch.ops.aten.split.Tensor(add_208, 1024, 1); add_208 = None + getitem_1793 = split_184[0] + getitem_1794 = split_184[1] + getitem_1795 = split_184[2] + getitem_1796 = split_184[3] + getitem_1797 = split_184[4] + getitem_1798 = split_184[5] + getitem_1799 = split_184[6] + getitem_1800 = split_184[7]; split_184 = None + cat_176 = torch.ops.aten.cat.default([getitem_1793, getitem_1794, getitem_1795, getitem_1796, getitem_1797, getitem_1798, getitem_1799, getitem_1800]); getitem_1793 = getitem_1794 = getitem_1795 = getitem_1796 = getitem_1797 = getitem_1798 = getitem_1799 = getitem_1800 = None + reduce_scatter_tensor_192 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_176, 'sum', 8, '1'); cat_176 = None + wait_tensor_594 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_192); reduce_scatter_tensor_192 = None + convert_element_type_1686 = torch.ops.prims.convert_element_type.default(wait_tensor_594, torch.float32); wait_tensor_594 = None + convert_element_type_1688 = torch.ops.prims.convert_element_type.default(wait_tensor_269, torch.float32); wait_tensor_269 = None + mul_490 = torch.ops.aten.mul.Tensor(convert_element_type_1686, convert_element_type_1688); convert_element_type_1688 = None + mul_492 = torch.ops.aten.mul.Tensor(mul_164, mul_490) + sum_69 = torch.ops.aten.sum.dim_IntList(mul_492, [2], True); mul_492 = None + div_23 = torch.ops.aten.div.Tensor(mul_164, 4096) + mul_493 = torch.ops.aten.mul.Tensor(div_23, sum_69); div_23 = sum_69 = None + sub_36 = torch.ops.aten.sub.Tensor(mul_490, mul_493); mul_490 = mul_493 = None + mul_494 = torch.ops.aten.mul.Tensor(sub_36, rsqrt_41); sub_36 = rsqrt_41 = None + mul_495 = torch.ops.aten.mul.Tensor(convert_element_type_1686, mul_164); convert_element_type_1686 = mul_164 = None + sum_70 = torch.ops.aten.sum.dim_IntList(mul_495, [0, 1]); mul_495 = None + convert_element_type_1689 = torch.ops.prims.convert_element_type.default(mul_494, torch.bfloat16); mul_494 = None + convert_element_type_1690 = torch.ops.prims.convert_element_type.default(sum_70, torch.bfloat16); sum_70 = None + all_reduce_23 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1690, 'sum', '1'); convert_element_type_1690 = None + wait_tensor_595 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_23); all_reduce_23 = None + convert_element_type_1691 = torch.ops.prims.convert_element_type.default(wait_tensor_595, torch.float32); wait_tensor_595 = None + reduce_scatter_tensor_193 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1691, 'avg', 32, '0'); convert_element_type_1691 = None + wait_tensor_596 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_193); reduce_scatter_tensor_193 = None + add_209 = torch.ops.aten.add.Tensor(add_205, convert_element_type_1689); add_205 = convert_element_type_1689 = None + all_gather_into_tensor_379 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_209, 8, '1') + wait_tensor_597 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_379); all_gather_into_tensor_379 = None + split_185 = torch.ops.aten.split.Tensor(wait_tensor_597, 2); wait_tensor_597 = None + getitem_1801 = split_185[0] + getitem_1802 = split_185[1] + getitem_1803 = split_185[2] + getitem_1804 = split_185[3] + getitem_1805 = split_185[4] + getitem_1806 = split_185[5] + getitem_1807 = split_185[6] + getitem_1808 = split_185[7]; split_185 = None + cat_177 = torch.ops.aten.cat.default([getitem_1801, getitem_1802, getitem_1803, getitem_1804, getitem_1805, getitem_1806, getitem_1807, getitem_1808], 1); getitem_1801 = getitem_1802 = getitem_1803 = getitem_1804 = getitem_1805 = getitem_1806 = getitem_1807 = getitem_1808 = None + view_2593 = torch.ops.aten.view.default(cat_177, [16384, 4096]); cat_177 = None + permute_721 = torch.ops.aten.permute.default(view_2593, [1, 0]) + permute_226 = torch.ops.aten.permute.default(getitem_900, [0, 2, 1, 3]) + view_1482 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + view_1488 = torch.ops.aten.view.default(view_1482, [16384, 512]); view_1482 = None + mm_387 = torch.ops.aten.mm.default(permute_721, view_1488); permute_721 = view_1488 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16); primals_188 = None + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 32, '0'); convert_element_type_677 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_267, [1, 0]); wait_tensor_267 = None + permute_723 = torch.ops.aten.permute.default(permute_227, [1, 0]); permute_227 = None + mm_388 = torch.ops.aten.mm.default(view_2593, permute_723); view_2593 = permute_723 = None + view_2594 = torch.ops.aten.view.default(mm_388, [2, 8192, 512]); mm_388 = None + convert_element_type_1696 = torch.ops.prims.convert_element_type.default(mm_387, torch.float32); mm_387 = None + reduce_scatter_tensor_194 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1696, 'avg', 32, '0'); convert_element_type_1696 = None + wait_tensor_598 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_194); reduce_scatter_tensor_194 = None + view_2595 = torch.ops.aten.view.default(view_2594, [2, 8192, 4, 128]); view_2594 = None + permute_725 = torch.ops.aten.permute.default(view_2595, [0, 2, 1, 3]); view_2595 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16); primals_184 = None + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 32, '0'); convert_element_type_661 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32); add_79 = None + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_262) + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_663, 8, '1'); convert_element_type_663 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + split_89 = torch.ops.aten.split.Tensor(wait_tensor_263, 2); wait_tensor_263 = None + getitem_892 = split_89[0] + getitem_893 = split_89[1] + getitem_894 = split_89[2] + getitem_895 = split_89[3] + getitem_896 = split_89[4] + getitem_897 = split_89[5] + getitem_898 = split_89[6] + getitem_899 = split_89[7]; split_89 = None + cat_81 = torch.ops.aten.cat.default([getitem_892, getitem_893, getitem_894, getitem_895, getitem_896, getitem_897, getitem_898, getitem_899], 1); getitem_892 = getitem_893 = getitem_894 = getitem_895 = getitem_896 = getitem_897 = getitem_898 = getitem_899 = None + view_1455 = torch.ops.aten.view.default(cat_81, [16384, 4096]); cat_81 = None + view_1456 = torch.ops.aten.view.default(mm_140, [2, 8192, 512]); mm_140 = None + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16); primals_186 = None + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 32, '0'); convert_element_type_667 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + mm_141 = torch.ops.aten.mm.default(view_1455, permute_221) + view_1463 = torch.ops.aten.view.default(mm_141, [2, 8192, 128]); mm_141 = None + view_1470 = torch.ops.aten.view.default(mm_142, [2, 8192, 128]); mm_142 = None + view_1472 = torch.ops.aten.view.default(view_1456, [2, 8192, -1, 128]); view_1456 = None + view_1473 = torch.ops.aten.view.default(view_1463, [2, 8192, -1, 128]); view_1463 = None + view_1474 = torch.ops.aten.view.default(view_1470, [2, 8192, -1, 128]); view_1470 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_1472, torch.float32); view_1472 = None + view_1475 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 4, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_1475); view_1475 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_1473, torch.float32); view_1473 = None + view_1476 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 1, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_1476); view_1476 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_37); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_1478 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 4, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_37); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_1479 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 1, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_1478, torch.bfloat16); view_1478 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_1479, torch.bfloat16); view_1479 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 1, 4, 128]); unsqueeze_40 = None + view_1480 = torch.ops.aten.view.default(expand_40, [2, 8192, 4, 128]); expand_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_1474, 3); view_1474 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 1, 4, 128]); unsqueeze_41 = None + view_1481 = torch.ops.aten.view.default(expand_41, [2, 8192, 4, 128]); expand_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_1480, [0, 2, 1, 3]); view_1480 = None + permute_225 = torch.ops.aten.permute.default(view_1481, [0, 2, 1, 3]); view_1481 = None + _scaled_dot_product_cudnn_attention_backward_11 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_725, permute_223, permute_224, permute_225, getitem_900, getitem_901, getitem_906, getitem_907, None, None, None, 8192, 8192, 0.0, True); permute_725 = permute_223 = permute_224 = permute_225 = getitem_900 = getitem_901 = getitem_906 = getitem_907 = None + getitem_1809 = _scaled_dot_product_cudnn_attention_backward_11[0] + getitem_1810 = _scaled_dot_product_cudnn_attention_backward_11[1] + getitem_1811 = _scaled_dot_product_cudnn_attention_backward_11[2]; _scaled_dot_product_cudnn_attention_backward_11 = None + permute_726 = torch.ops.aten.permute.default(getitem_1811, [0, 2, 1, 3]); getitem_1811 = None + permute_727 = torch.ops.aten.permute.default(getitem_1810, [0, 2, 1, 3]); getitem_1810 = None + permute_728 = torch.ops.aten.permute.default(getitem_1809, [0, 2, 1, 3]); getitem_1809 = None + view_2596 = torch.ops.aten.view.default(permute_726, [2, 8192, 1, 4, 128]); permute_726 = None + sum_71 = torch.ops.aten.sum.dim_IntList(view_2596, [3], True); view_2596 = None + squeeze_22 = torch.ops.aten.squeeze.dim(sum_71, 3); sum_71 = None + view_2597 = torch.ops.aten.view.default(permute_727, [2, 8192, 1, 4, 128]); permute_727 = None + sum_72 = torch.ops.aten.sum.dim_IntList(view_2597, [3], True); view_2597 = None + squeeze_23 = torch.ops.aten.squeeze.dim(sum_72, 3); sum_72 = None + convert_element_type_1697 = torch.ops.prims.convert_element_type.default(squeeze_23, torch.float32); squeeze_23 = None + convert_element_type_1698 = torch.ops.prims.convert_element_type.default(permute_728, torch.float32); permute_728 = None + view_2598 = torch.ops.aten.view.default(convert_element_type_1697, [2, 8192, 1, 64, 2]); convert_element_type_1697 = None + view_as_complex_86 = torch.ops.aten.view_as_complex.default(view_2598); view_2598 = None + mul_496 = torch.ops.aten.mul.Tensor(view_as_complex_86, _conj); view_as_complex_86 = None + view_2599 = torch.ops.aten.view.default(convert_element_type_1698, [2, 8192, 4, 64, 2]); convert_element_type_1698 = None + view_as_complex_87 = torch.ops.aten.view_as_complex.default(view_2599); view_2599 = None + mul_497 = torch.ops.aten.mul.Tensor(view_as_complex_87, _conj); view_as_complex_87 = None + view_as_real_86 = torch.ops.aten.view_as_real.default(mul_496); mul_496 = None + view_2600 = torch.ops.aten.view.default(view_as_real_86, [2, 8192, 1, 128]); view_as_real_86 = None + convert_element_type_1699 = torch.ops.prims.convert_element_type.default(view_2600, torch.bfloat16); view_2600 = None + view_as_real_87 = torch.ops.aten.view_as_real.default(mul_497); mul_497 = None + view_2601 = torch.ops.aten.view.default(view_as_real_87, [2, 8192, 4, 128]); view_as_real_87 = None + convert_element_type_1700 = torch.ops.prims.convert_element_type.default(view_2601, torch.bfloat16); view_2601 = None + view_2602 = torch.ops.aten.view.default(squeeze_22, [2, 8192, 128]); squeeze_22 = None + view_2603 = torch.ops.aten.view.default(convert_element_type_1699, [2, 8192, 128]); convert_element_type_1699 = None + view_2604 = torch.ops.aten.view.default(convert_element_type_1700, [2, 8192, 512]); convert_element_type_1700 = None + view_2605 = torch.ops.aten.view.default(view_2602, [16384, 128]); view_2602 = None + permute_729 = torch.ops.aten.permute.default(view_2605, [1, 0]) + mm_389 = torch.ops.aten.mm.default(permute_729, view_1455); permute_729 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16); primals_187 = None + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 32, '0'); convert_element_type_670 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + permute_731 = torch.ops.aten.permute.default(permute_222, [1, 0]); permute_222 = None + mm_390 = torch.ops.aten.mm.default(view_2605, permute_731); view_2605 = permute_731 = None + view_2606 = torch.ops.aten.view.default(mm_390, [2, 8192, 4096]); mm_390 = None + convert_element_type_1705 = torch.ops.prims.convert_element_type.default(mm_389, torch.float32); mm_389 = None + reduce_scatter_tensor_195 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1705, 'avg', 32, '0'); convert_element_type_1705 = None + wait_tensor_599 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_195); reduce_scatter_tensor_195 = None + view_2607 = torch.ops.aten.view.default(view_2603, [16384, 128]); view_2603 = None + permute_733 = torch.ops.aten.permute.default(view_2607, [1, 0]) + mm_391 = torch.ops.aten.mm.default(permute_733, view_1455); permute_733 = None + permute_735 = torch.ops.aten.permute.default(permute_221, [1, 0]); permute_221 = None + mm_392 = torch.ops.aten.mm.default(view_2607, permute_735); view_2607 = permute_735 = None + view_2608 = torch.ops.aten.view.default(mm_392, [2, 8192, 4096]); mm_392 = None + add_210 = torch.ops.aten.add.Tensor(view_2606, view_2608); view_2606 = view_2608 = None + convert_element_type_1710 = torch.ops.prims.convert_element_type.default(mm_391, torch.float32); mm_391 = None + reduce_scatter_tensor_196 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1710, 'avg', 32, '0'); convert_element_type_1710 = None + wait_tensor_600 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_196); reduce_scatter_tensor_196 = None + view_2609 = torch.ops.aten.view.default(view_2604, [16384, 512]); view_2604 = None + permute_737 = torch.ops.aten.permute.default(view_2609, [1, 0]) + mm_393 = torch.ops.aten.mm.default(permute_737, view_1455); permute_737 = view_1455 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16); primals_185 = None + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 32, '0'); convert_element_type_664 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + permute_739 = torch.ops.aten.permute.default(permute_220, [1, 0]); permute_220 = None + mm_394 = torch.ops.aten.mm.default(view_2609, permute_739); view_2609 = permute_739 = None + view_2610 = torch.ops.aten.view.default(mm_394, [2, 8192, 4096]); mm_394 = None + add_211 = torch.ops.aten.add.Tensor(add_210, view_2610); add_210 = view_2610 = None + convert_element_type_1715 = torch.ops.prims.convert_element_type.default(mm_393, torch.float32); mm_393 = None + reduce_scatter_tensor_197 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1715, 'avg', 32, '0'); convert_element_type_1715 = None + wait_tensor_601 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_197); reduce_scatter_tensor_197 = None + split_186 = torch.ops.aten.split.Tensor(add_211, 1024, 1); add_211 = None + getitem_1812 = split_186[0] + getitem_1813 = split_186[1] + getitem_1814 = split_186[2] + getitem_1815 = split_186[3] + getitem_1816 = split_186[4] + getitem_1817 = split_186[5] + getitem_1818 = split_186[6] + getitem_1819 = split_186[7]; split_186 = None + cat_178 = torch.ops.aten.cat.default([getitem_1812, getitem_1813, getitem_1814, getitem_1815, getitem_1816, getitem_1817, getitem_1818, getitem_1819]); getitem_1812 = getitem_1813 = getitem_1814 = getitem_1815 = getitem_1816 = getitem_1817 = getitem_1818 = getitem_1819 = None + reduce_scatter_tensor_198 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_178, 'sum', 8, '1'); cat_178 = None + wait_tensor_602 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_198); reduce_scatter_tensor_198 = None + convert_element_type_1716 = torch.ops.prims.convert_element_type.default(wait_tensor_602, torch.float32); wait_tensor_602 = None + convert_element_type_1718 = torch.ops.prims.convert_element_type.default(wait_tensor_262, torch.float32); wait_tensor_262 = None + mul_498 = torch.ops.aten.mul.Tensor(convert_element_type_1716, convert_element_type_1718); convert_element_type_1718 = None + mul_500 = torch.ops.aten.mul.Tensor(mul_160, mul_498) + sum_73 = torch.ops.aten.sum.dim_IntList(mul_500, [2], True); mul_500 = None + div_24 = torch.ops.aten.div.Tensor(mul_160, 4096) + mul_501 = torch.ops.aten.mul.Tensor(div_24, sum_73); div_24 = sum_73 = None + sub_37 = torch.ops.aten.sub.Tensor(mul_498, mul_501); mul_498 = mul_501 = None + mul_502 = torch.ops.aten.mul.Tensor(sub_37, rsqrt_40); sub_37 = rsqrt_40 = None + mul_503 = torch.ops.aten.mul.Tensor(convert_element_type_1716, mul_160); convert_element_type_1716 = mul_160 = None + sum_74 = torch.ops.aten.sum.dim_IntList(mul_503, [0, 1]); mul_503 = None + convert_element_type_1719 = torch.ops.prims.convert_element_type.default(mul_502, torch.bfloat16); mul_502 = None + convert_element_type_1720 = torch.ops.prims.convert_element_type.default(sum_74, torch.bfloat16); sum_74 = None + all_reduce_24 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1720, 'sum', '1'); convert_element_type_1720 = None + wait_tensor_603 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_24); all_reduce_24 = None + convert_element_type_1721 = torch.ops.prims.convert_element_type.default(wait_tensor_603, torch.float32); wait_tensor_603 = None + reduce_scatter_tensor_199 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1721, 'avg', 32, '0'); convert_element_type_1721 = None + wait_tensor_604 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_199); reduce_scatter_tensor_199 = None + add_212 = torch.ops.aten.add.Tensor(add_209, convert_element_type_1719); add_209 = convert_element_type_1719 = None + all_gather_into_tensor_380 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_212, 8, '1') + wait_tensor_605 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_380); all_gather_into_tensor_380 = None + split_187 = torch.ops.aten.split.Tensor(wait_tensor_605, 2); wait_tensor_605 = None + getitem_1820 = split_187[0] + getitem_1821 = split_187[1] + getitem_1822 = split_187[2] + getitem_1823 = split_187[3] + getitem_1824 = split_187[4] + getitem_1825 = split_187[5] + getitem_1826 = split_187[6] + getitem_1827 = split_187[7]; split_187 = None + cat_179 = torch.ops.aten.cat.default([getitem_1820, getitem_1821, getitem_1822, getitem_1823, getitem_1824, getitem_1825, getitem_1826, getitem_1827], 1); getitem_1820 = getitem_1821 = getitem_1822 = getitem_1823 = getitem_1824 = getitem_1825 = getitem_1826 = getitem_1827 = None + view_2611 = torch.ops.aten.view.default(cat_179, [16384, 4096]); cat_179 = None + permute_741 = torch.ops.aten.permute.default(view_2611, [1, 0]) + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39); reduce_scatter_tensor_39 = None + add_77 = torch.ops.aten.add.Tensor(add_75, wait_tensor_255); wait_tensor_255 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16); primals_180 = None + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 32, '0'); convert_element_type_647 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32); add_77 = None + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_256) + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_649, 8, '1'); convert_element_type_649 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + split_87 = torch.ops.aten.split.Tensor(wait_tensor_257, 2); wait_tensor_257 = None + getitem_876 = split_87[0] + getitem_877 = split_87[1] + getitem_878 = split_87[2] + getitem_879 = split_87[3] + getitem_880 = split_87[4] + getitem_881 = split_87[5] + getitem_882 = split_87[6] + getitem_883 = split_87[7]; split_87 = None + cat_79 = torch.ops.aten.cat.default([getitem_876, getitem_877, getitem_878, getitem_879, getitem_880, getitem_881, getitem_882, getitem_883], 1); getitem_876 = getitem_877 = getitem_878 = getitem_879 = getitem_880 = getitem_881 = getitem_882 = getitem_883 = None + view_1428 = torch.ops.aten.view.default(cat_79, [16384, 4096]); cat_79 = None + view_1429 = torch.ops.aten.view.default(mm_137, [2, 8192, 1792]); mm_137 = None + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_1429, torch.float32); view_1429 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16); primals_182 = None + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 32, '0'); convert_element_type_655 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + mm_138 = torch.ops.aten.mm.default(view_1428, permute_218) + view_1436 = torch.ops.aten.view.default(mm_138, [2, 8192, 1792]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_1436) + view_1443 = torch.ops.aten.view.default(mul_159, [16384, 1792]); mul_159 = None + mm_395 = torch.ops.aten.mm.default(permute_741, view_1443); permute_741 = view_1443 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16); primals_183 = None + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 32, '0'); convert_element_type_658 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + permute_743 = torch.ops.aten.permute.default(permute_219, [1, 0]); permute_219 = None + mm_396 = torch.ops.aten.mm.default(view_2611, permute_743); view_2611 = permute_743 = None + view_2612 = torch.ops.aten.view.default(mm_396, [2, 8192, 1792]); mm_396 = None + convert_element_type_1726 = torch.ops.prims.convert_element_type.default(mm_395, torch.float32); mm_395 = None + reduce_scatter_tensor_200 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1726, 'avg', 32, '0'); convert_element_type_1726 = None + wait_tensor_606 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_200); reduce_scatter_tensor_200 = None + mul_504 = torch.ops.aten.mul.Tensor(view_2612, convert_element_type_654); convert_element_type_654 = None + mul_505 = torch.ops.aten.mul.Tensor(view_2612, view_1436); view_2612 = view_1436 = None + view_2613 = torch.ops.aten.view.default(mul_504, [16384, 1792]); mul_504 = None + permute_745 = torch.ops.aten.permute.default(view_2613, [1, 0]) + mm_397 = torch.ops.aten.mm.default(permute_745, view_1428); permute_745 = None + permute_747 = torch.ops.aten.permute.default(permute_218, [1, 0]); permute_218 = None + mm_398 = torch.ops.aten.mm.default(view_2613, permute_747); view_2613 = permute_747 = None + view_2614 = torch.ops.aten.view.default(mm_398, [2, 8192, 4096]); mm_398 = None + convert_element_type_1731 = torch.ops.prims.convert_element_type.default(mm_397, torch.float32); mm_397 = None + reduce_scatter_tensor_201 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1731, 'avg', 32, '0'); convert_element_type_1731 = None + wait_tensor_607 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_201); reduce_scatter_tensor_201 = None + convert_element_type_1732 = torch.ops.prims.convert_element_type.default(mul_505, torch.float32); mul_505 = None + neg_12 = torch.ops.aten.neg.default(convert_element_type_653) + exp_12 = torch.ops.aten.exp.default(neg_12); neg_12 = None + add_213 = torch.ops.aten.add.Tensor(exp_12, 1); exp_12 = None + reciprocal_12 = torch.ops.aten.reciprocal.default(add_213); add_213 = None + mul_506 = torch.ops.aten.mul.Tensor(reciprocal_12, 1); reciprocal_12 = None + mul_507 = torch.ops.aten.mul.Tensor(convert_element_type_1732, mul_506); convert_element_type_1732 = None + sub_38 = torch.ops.aten.sub.Tensor(1, mul_506); mul_506 = None + mul_508 = torch.ops.aten.mul.Tensor(convert_element_type_653, sub_38); convert_element_type_653 = sub_38 = None + add_214 = torch.ops.aten.add.Tensor(mul_508, 1); mul_508 = None + mul_509 = torch.ops.aten.mul.Tensor(mul_507, add_214); mul_507 = add_214 = None + convert_element_type_1734 = torch.ops.prims.convert_element_type.default(mul_509, torch.bfloat16); mul_509 = None + view_2615 = torch.ops.aten.view.default(convert_element_type_1734, [16384, 1792]); convert_element_type_1734 = None + permute_749 = torch.ops.aten.permute.default(view_2615, [1, 0]) + mm_399 = torch.ops.aten.mm.default(permute_749, view_1428); permute_749 = view_1428 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16); primals_181 = None + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 32, '0'); convert_element_type_650 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_258, [1, 0]); wait_tensor_258 = None + permute_751 = torch.ops.aten.permute.default(permute_217, [1, 0]); permute_217 = None + mm_400 = torch.ops.aten.mm.default(view_2615, permute_751); view_2615 = permute_751 = None + view_2616 = torch.ops.aten.view.default(mm_400, [2, 8192, 4096]); mm_400 = None + add_215 = torch.ops.aten.add.Tensor(view_2614, view_2616); view_2614 = view_2616 = None + convert_element_type_1739 = torch.ops.prims.convert_element_type.default(mm_399, torch.float32); mm_399 = None + reduce_scatter_tensor_202 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1739, 'avg', 32, '0'); convert_element_type_1739 = None + wait_tensor_608 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_202); reduce_scatter_tensor_202 = None + split_188 = torch.ops.aten.split.Tensor(add_215, 1024, 1); add_215 = None + getitem_1828 = split_188[0] + getitem_1829 = split_188[1] + getitem_1830 = split_188[2] + getitem_1831 = split_188[3] + getitem_1832 = split_188[4] + getitem_1833 = split_188[5] + getitem_1834 = split_188[6] + getitem_1835 = split_188[7]; split_188 = None + cat_180 = torch.ops.aten.cat.default([getitem_1828, getitem_1829, getitem_1830, getitem_1831, getitem_1832, getitem_1833, getitem_1834, getitem_1835]); getitem_1828 = getitem_1829 = getitem_1830 = getitem_1831 = getitem_1832 = getitem_1833 = getitem_1834 = getitem_1835 = None + reduce_scatter_tensor_203 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_180, 'sum', 8, '1'); cat_180 = None + wait_tensor_609 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_203); reduce_scatter_tensor_203 = None + convert_element_type_1740 = torch.ops.prims.convert_element_type.default(wait_tensor_609, torch.float32); wait_tensor_609 = None + convert_element_type_1742 = torch.ops.prims.convert_element_type.default(wait_tensor_256, torch.float32); wait_tensor_256 = None + mul_510 = torch.ops.aten.mul.Tensor(convert_element_type_1740, convert_element_type_1742); convert_element_type_1742 = None + mul_512 = torch.ops.aten.mul.Tensor(mul_156, mul_510) + sum_75 = torch.ops.aten.sum.dim_IntList(mul_512, [2], True); mul_512 = None + div_25 = torch.ops.aten.div.Tensor(mul_156, 4096) + mul_513 = torch.ops.aten.mul.Tensor(div_25, sum_75); div_25 = sum_75 = None + sub_39 = torch.ops.aten.sub.Tensor(mul_510, mul_513); mul_510 = mul_513 = None + mul_514 = torch.ops.aten.mul.Tensor(sub_39, rsqrt_39); sub_39 = rsqrt_39 = None + mul_515 = torch.ops.aten.mul.Tensor(convert_element_type_1740, mul_156); convert_element_type_1740 = mul_156 = None + sum_76 = torch.ops.aten.sum.dim_IntList(mul_515, [0, 1]); mul_515 = None + convert_element_type_1743 = torch.ops.prims.convert_element_type.default(mul_514, torch.bfloat16); mul_514 = None + convert_element_type_1744 = torch.ops.prims.convert_element_type.default(sum_76, torch.bfloat16); sum_76 = None + all_reduce_25 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1744, 'sum', '1'); convert_element_type_1744 = None + wait_tensor_610 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_25); all_reduce_25 = None + convert_element_type_1745 = torch.ops.prims.convert_element_type.default(wait_tensor_610, torch.float32); wait_tensor_610 = None + reduce_scatter_tensor_204 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1745, 'avg', 32, '0'); convert_element_type_1745 = None + wait_tensor_611 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_204); reduce_scatter_tensor_204 = None + add_216 = torch.ops.aten.add.Tensor(add_212, convert_element_type_1743); add_212 = convert_element_type_1743 = None + all_gather_into_tensor_381 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_216, 8, '1') + wait_tensor_612 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_381); all_gather_into_tensor_381 = None + split_189 = torch.ops.aten.split.Tensor(wait_tensor_612, 2); wait_tensor_612 = None + getitem_1836 = split_189[0] + getitem_1837 = split_189[1] + getitem_1838 = split_189[2] + getitem_1839 = split_189[3] + getitem_1840 = split_189[4] + getitem_1841 = split_189[5] + getitem_1842 = split_189[6] + getitem_1843 = split_189[7]; split_189 = None + cat_181 = torch.ops.aten.cat.default([getitem_1836, getitem_1837, getitem_1838, getitem_1839, getitem_1840, getitem_1841, getitem_1842, getitem_1843], 1); getitem_1836 = getitem_1837 = getitem_1838 = getitem_1839 = getitem_1840 = getitem_1841 = getitem_1842 = getitem_1843 = None + view_2617 = torch.ops.aten.view.default(cat_181, [16384, 4096]); cat_181 = None + permute_753 = torch.ops.aten.permute.default(view_2617, [1, 0]) + permute_215 = torch.ops.aten.permute.default(getitem_859, [0, 2, 1, 3]) + view_1410 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + view_1416 = torch.ops.aten.view.default(view_1410, [16384, 512]); view_1410 = None + mm_401 = torch.ops.aten.mm.default(permute_753, view_1416); permute_753 = view_1416 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16); primals_179 = None + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 32, '0'); convert_element_type_644 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + permute_755 = torch.ops.aten.permute.default(permute_216, [1, 0]); permute_216 = None + mm_402 = torch.ops.aten.mm.default(view_2617, permute_755); view_2617 = permute_755 = None + view_2618 = torch.ops.aten.view.default(mm_402, [2, 8192, 512]); mm_402 = None + convert_element_type_1750 = torch.ops.prims.convert_element_type.default(mm_401, torch.float32); mm_401 = None + reduce_scatter_tensor_205 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1750, 'avg', 32, '0'); convert_element_type_1750 = None + wait_tensor_613 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_205); reduce_scatter_tensor_205 = None + view_2619 = torch.ops.aten.view.default(view_2618, [2, 8192, 4, 128]); view_2618 = None + permute_757 = torch.ops.aten.permute.default(view_2619, [0, 2, 1, 3]); view_2619 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16); primals_175 = None + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 32, '0'); convert_element_type_628 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32); add_75 = None + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_249) + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_630, 8, '1'); convert_element_type_630 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + split_85 = torch.ops.aten.split.Tensor(wait_tensor_250, 2); wait_tensor_250 = None + getitem_851 = split_85[0] + getitem_852 = split_85[1] + getitem_853 = split_85[2] + getitem_854 = split_85[3] + getitem_855 = split_85[4] + getitem_856 = split_85[5] + getitem_857 = split_85[6] + getitem_858 = split_85[7]; split_85 = None + cat_77 = torch.ops.aten.cat.default([getitem_851, getitem_852, getitem_853, getitem_854, getitem_855, getitem_856, getitem_857, getitem_858], 1); getitem_851 = getitem_852 = getitem_853 = getitem_854 = getitem_855 = getitem_856 = getitem_857 = getitem_858 = None + view_1383 = torch.ops.aten.view.default(cat_77, [16384, 4096]); cat_77 = None + view_1384 = torch.ops.aten.view.default(mm_133, [2, 8192, 512]); mm_133 = None + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16); primals_177 = None + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 32, '0'); convert_element_type_634 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + mm_134 = torch.ops.aten.mm.default(view_1383, permute_210) + view_1391 = torch.ops.aten.view.default(mm_134, [2, 8192, 128]); mm_134 = None + view_1398 = torch.ops.aten.view.default(mm_135, [2, 8192, 128]); mm_135 = None + view_1400 = torch.ops.aten.view.default(view_1384, [2, 8192, -1, 128]); view_1384 = None + view_1401 = torch.ops.aten.view.default(view_1391, [2, 8192, -1, 128]); view_1391 = None + view_1402 = torch.ops.aten.view.default(view_1398, [2, 8192, -1, 128]); view_1398 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_1400, torch.float32); view_1400 = None + view_1403 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 4, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_1403); view_1403 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_1401, torch.float32); view_1401 = None + view_1404 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 1, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_1404); view_1404 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_37); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_1406 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 4, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_37); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_1407 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 1, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_1406, torch.bfloat16); view_1406 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_1407, torch.bfloat16); view_1407 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 1, 4, 128]); unsqueeze_38 = None + view_1408 = torch.ops.aten.view.default(expand_38, [2, 8192, 4, 128]); expand_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_1402, 3); view_1402 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 1, 4, 128]); unsqueeze_39 = None + view_1409 = torch.ops.aten.view.default(expand_39, [2, 8192, 4, 128]); expand_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_1408, [0, 2, 1, 3]); view_1408 = None + permute_214 = torch.ops.aten.permute.default(view_1409, [0, 2, 1, 3]); view_1409 = None + _scaled_dot_product_cudnn_attention_backward_12 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_757, permute_212, permute_213, permute_214, getitem_859, getitem_860, getitem_865, getitem_866, None, None, None, 8192, 8192, 0.0, True); permute_757 = permute_212 = permute_213 = permute_214 = getitem_859 = getitem_860 = getitem_865 = getitem_866 = None + getitem_1844 = _scaled_dot_product_cudnn_attention_backward_12[0] + getitem_1845 = _scaled_dot_product_cudnn_attention_backward_12[1] + getitem_1846 = _scaled_dot_product_cudnn_attention_backward_12[2]; _scaled_dot_product_cudnn_attention_backward_12 = None + permute_758 = torch.ops.aten.permute.default(getitem_1846, [0, 2, 1, 3]); getitem_1846 = None + permute_759 = torch.ops.aten.permute.default(getitem_1845, [0, 2, 1, 3]); getitem_1845 = None + permute_760 = torch.ops.aten.permute.default(getitem_1844, [0, 2, 1, 3]); getitem_1844 = None + view_2620 = torch.ops.aten.view.default(permute_758, [2, 8192, 1, 4, 128]); permute_758 = None + sum_77 = torch.ops.aten.sum.dim_IntList(view_2620, [3], True); view_2620 = None + squeeze_24 = torch.ops.aten.squeeze.dim(sum_77, 3); sum_77 = None + view_2621 = torch.ops.aten.view.default(permute_759, [2, 8192, 1, 4, 128]); permute_759 = None + sum_78 = torch.ops.aten.sum.dim_IntList(view_2621, [3], True); view_2621 = None + squeeze_25 = torch.ops.aten.squeeze.dim(sum_78, 3); sum_78 = None + convert_element_type_1751 = torch.ops.prims.convert_element_type.default(squeeze_25, torch.float32); squeeze_25 = None + convert_element_type_1752 = torch.ops.prims.convert_element_type.default(permute_760, torch.float32); permute_760 = None + view_2622 = torch.ops.aten.view.default(convert_element_type_1751, [2, 8192, 1, 64, 2]); convert_element_type_1751 = None + view_as_complex_88 = torch.ops.aten.view_as_complex.default(view_2622); view_2622 = None + mul_516 = torch.ops.aten.mul.Tensor(view_as_complex_88, _conj); view_as_complex_88 = None + view_2623 = torch.ops.aten.view.default(convert_element_type_1752, [2, 8192, 4, 64, 2]); convert_element_type_1752 = None + view_as_complex_89 = torch.ops.aten.view_as_complex.default(view_2623); view_2623 = None + mul_517 = torch.ops.aten.mul.Tensor(view_as_complex_89, _conj); view_as_complex_89 = None + view_as_real_88 = torch.ops.aten.view_as_real.default(mul_516); mul_516 = None + view_2624 = torch.ops.aten.view.default(view_as_real_88, [2, 8192, 1, 128]); view_as_real_88 = None + convert_element_type_1753 = torch.ops.prims.convert_element_type.default(view_2624, torch.bfloat16); view_2624 = None + view_as_real_89 = torch.ops.aten.view_as_real.default(mul_517); mul_517 = None + view_2625 = torch.ops.aten.view.default(view_as_real_89, [2, 8192, 4, 128]); view_as_real_89 = None + convert_element_type_1754 = torch.ops.prims.convert_element_type.default(view_2625, torch.bfloat16); view_2625 = None + view_2626 = torch.ops.aten.view.default(squeeze_24, [2, 8192, 128]); squeeze_24 = None + view_2627 = torch.ops.aten.view.default(convert_element_type_1753, [2, 8192, 128]); convert_element_type_1753 = None + view_2628 = torch.ops.aten.view.default(convert_element_type_1754, [2, 8192, 512]); convert_element_type_1754 = None + view_2629 = torch.ops.aten.view.default(view_2626, [16384, 128]); view_2626 = None + permute_761 = torch.ops.aten.permute.default(view_2629, [1, 0]) + mm_403 = torch.ops.aten.mm.default(permute_761, view_1383); permute_761 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16); primals_178 = None + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 32, '0'); convert_element_type_637 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_253, [1, 0]); wait_tensor_253 = None + permute_763 = torch.ops.aten.permute.default(permute_211, [1, 0]); permute_211 = None + mm_404 = torch.ops.aten.mm.default(view_2629, permute_763); view_2629 = permute_763 = None + view_2630 = torch.ops.aten.view.default(mm_404, [2, 8192, 4096]); mm_404 = None + convert_element_type_1759 = torch.ops.prims.convert_element_type.default(mm_403, torch.float32); mm_403 = None + reduce_scatter_tensor_206 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1759, 'avg', 32, '0'); convert_element_type_1759 = None + wait_tensor_614 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_206); reduce_scatter_tensor_206 = None + view_2631 = torch.ops.aten.view.default(view_2627, [16384, 128]); view_2627 = None + permute_765 = torch.ops.aten.permute.default(view_2631, [1, 0]) + mm_405 = torch.ops.aten.mm.default(permute_765, view_1383); permute_765 = None + permute_767 = torch.ops.aten.permute.default(permute_210, [1, 0]); permute_210 = None + mm_406 = torch.ops.aten.mm.default(view_2631, permute_767); view_2631 = permute_767 = None + view_2632 = torch.ops.aten.view.default(mm_406, [2, 8192, 4096]); mm_406 = None + add_217 = torch.ops.aten.add.Tensor(view_2630, view_2632); view_2630 = view_2632 = None + convert_element_type_1764 = torch.ops.prims.convert_element_type.default(mm_405, torch.float32); mm_405 = None + reduce_scatter_tensor_207 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1764, 'avg', 32, '0'); convert_element_type_1764 = None + wait_tensor_615 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_207); reduce_scatter_tensor_207 = None + view_2633 = torch.ops.aten.view.default(view_2628, [16384, 512]); view_2628 = None + permute_769 = torch.ops.aten.permute.default(view_2633, [1, 0]) + mm_407 = torch.ops.aten.mm.default(permute_769, view_1383); permute_769 = view_1383 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16); primals_176 = None + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 32, '0'); convert_element_type_631 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + permute_771 = torch.ops.aten.permute.default(permute_209, [1, 0]); permute_209 = None + mm_408 = torch.ops.aten.mm.default(view_2633, permute_771); view_2633 = permute_771 = None + view_2634 = torch.ops.aten.view.default(mm_408, [2, 8192, 4096]); mm_408 = None + add_218 = torch.ops.aten.add.Tensor(add_217, view_2634); add_217 = view_2634 = None + convert_element_type_1769 = torch.ops.prims.convert_element_type.default(mm_407, torch.float32); mm_407 = None + reduce_scatter_tensor_208 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1769, 'avg', 32, '0'); convert_element_type_1769 = None + wait_tensor_616 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_208); reduce_scatter_tensor_208 = None + split_190 = torch.ops.aten.split.Tensor(add_218, 1024, 1); add_218 = None + getitem_1847 = split_190[0] + getitem_1848 = split_190[1] + getitem_1849 = split_190[2] + getitem_1850 = split_190[3] + getitem_1851 = split_190[4] + getitem_1852 = split_190[5] + getitem_1853 = split_190[6] + getitem_1854 = split_190[7]; split_190 = None + cat_182 = torch.ops.aten.cat.default([getitem_1847, getitem_1848, getitem_1849, getitem_1850, getitem_1851, getitem_1852, getitem_1853, getitem_1854]); getitem_1847 = getitem_1848 = getitem_1849 = getitem_1850 = getitem_1851 = getitem_1852 = getitem_1853 = getitem_1854 = None + reduce_scatter_tensor_209 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_182, 'sum', 8, '1'); cat_182 = None + wait_tensor_617 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_209); reduce_scatter_tensor_209 = None + convert_element_type_1770 = torch.ops.prims.convert_element_type.default(wait_tensor_617, torch.float32); wait_tensor_617 = None + convert_element_type_1772 = torch.ops.prims.convert_element_type.default(wait_tensor_249, torch.float32); wait_tensor_249 = None + mul_518 = torch.ops.aten.mul.Tensor(convert_element_type_1770, convert_element_type_1772); convert_element_type_1772 = None + mul_520 = torch.ops.aten.mul.Tensor(mul_152, mul_518) + sum_79 = torch.ops.aten.sum.dim_IntList(mul_520, [2], True); mul_520 = None + div_26 = torch.ops.aten.div.Tensor(mul_152, 4096) + mul_521 = torch.ops.aten.mul.Tensor(div_26, sum_79); div_26 = sum_79 = None + sub_40 = torch.ops.aten.sub.Tensor(mul_518, mul_521); mul_518 = mul_521 = None + mul_522 = torch.ops.aten.mul.Tensor(sub_40, rsqrt_38); sub_40 = rsqrt_38 = None + mul_523 = torch.ops.aten.mul.Tensor(convert_element_type_1770, mul_152); convert_element_type_1770 = mul_152 = None + sum_80 = torch.ops.aten.sum.dim_IntList(mul_523, [0, 1]); mul_523 = None + convert_element_type_1773 = torch.ops.prims.convert_element_type.default(mul_522, torch.bfloat16); mul_522 = None + convert_element_type_1774 = torch.ops.prims.convert_element_type.default(sum_80, torch.bfloat16); sum_80 = None + all_reduce_26 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1774, 'sum', '1'); convert_element_type_1774 = None + wait_tensor_618 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_26); all_reduce_26 = None + convert_element_type_1775 = torch.ops.prims.convert_element_type.default(wait_tensor_618, torch.float32); wait_tensor_618 = None + reduce_scatter_tensor_210 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1775, 'avg', 32, '0'); convert_element_type_1775 = None + wait_tensor_619 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_210); reduce_scatter_tensor_210 = None + add_219 = torch.ops.aten.add.Tensor(add_216, convert_element_type_1773); add_216 = convert_element_type_1773 = None + all_gather_into_tensor_382 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_219, 8, '1') + wait_tensor_620 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_382); all_gather_into_tensor_382 = None + split_191 = torch.ops.aten.split.Tensor(wait_tensor_620, 2); wait_tensor_620 = None + getitem_1855 = split_191[0] + getitem_1856 = split_191[1] + getitem_1857 = split_191[2] + getitem_1858 = split_191[3] + getitem_1859 = split_191[4] + getitem_1860 = split_191[5] + getitem_1861 = split_191[6] + getitem_1862 = split_191[7]; split_191 = None + cat_183 = torch.ops.aten.cat.default([getitem_1855, getitem_1856, getitem_1857, getitem_1858, getitem_1859, getitem_1860, getitem_1861, getitem_1862], 1); getitem_1855 = getitem_1856 = getitem_1857 = getitem_1858 = getitem_1859 = getitem_1860 = getitem_1861 = getitem_1862 = None + view_2635 = torch.ops.aten.view.default(cat_183, [16384, 4096]); cat_183 = None + permute_773 = torch.ops.aten.permute.default(view_2635, [1, 0]) + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37); reduce_scatter_tensor_37 = None + add_73 = torch.ops.aten.add.Tensor(add_71, wait_tensor_242); wait_tensor_242 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16); primals_171 = None + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 32, '0'); convert_element_type_614 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32); add_73 = None + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_243) + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_616, 8, '1'); convert_element_type_616 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + split_83 = torch.ops.aten.split.Tensor(wait_tensor_244, 2); wait_tensor_244 = None + getitem_835 = split_83[0] + getitem_836 = split_83[1] + getitem_837 = split_83[2] + getitem_838 = split_83[3] + getitem_839 = split_83[4] + getitem_840 = split_83[5] + getitem_841 = split_83[6] + getitem_842 = split_83[7]; split_83 = None + cat_75 = torch.ops.aten.cat.default([getitem_835, getitem_836, getitem_837, getitem_838, getitem_839, getitem_840, getitem_841, getitem_842], 1); getitem_835 = getitem_836 = getitem_837 = getitem_838 = getitem_839 = getitem_840 = getitem_841 = getitem_842 = None + view_1356 = torch.ops.aten.view.default(cat_75, [16384, 4096]); cat_75 = None + view_1357 = torch.ops.aten.view.default(mm_130, [2, 8192, 1792]); mm_130 = None + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_1357, torch.float32); view_1357 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16); primals_173 = None + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 32, '0'); convert_element_type_622 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_131 = torch.ops.aten.mm.default(view_1356, permute_207) + view_1364 = torch.ops.aten.view.default(mm_131, [2, 8192, 1792]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_1364) + view_1371 = torch.ops.aten.view.default(mul_151, [16384, 1792]); mul_151 = None + mm_409 = torch.ops.aten.mm.default(permute_773, view_1371); permute_773 = view_1371 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16); primals_174 = None + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 32, '0'); convert_element_type_625 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + permute_775 = torch.ops.aten.permute.default(permute_208, [1, 0]); permute_208 = None + mm_410 = torch.ops.aten.mm.default(view_2635, permute_775); view_2635 = permute_775 = None + view_2636 = torch.ops.aten.view.default(mm_410, [2, 8192, 1792]); mm_410 = None + convert_element_type_1780 = torch.ops.prims.convert_element_type.default(mm_409, torch.float32); mm_409 = None + reduce_scatter_tensor_211 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1780, 'avg', 32, '0'); convert_element_type_1780 = None + wait_tensor_621 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_211); reduce_scatter_tensor_211 = None + mul_524 = torch.ops.aten.mul.Tensor(view_2636, convert_element_type_621); convert_element_type_621 = None + mul_525 = torch.ops.aten.mul.Tensor(view_2636, view_1364); view_2636 = view_1364 = None + view_2637 = torch.ops.aten.view.default(mul_524, [16384, 1792]); mul_524 = None + permute_777 = torch.ops.aten.permute.default(view_2637, [1, 0]) + mm_411 = torch.ops.aten.mm.default(permute_777, view_1356); permute_777 = None + permute_779 = torch.ops.aten.permute.default(permute_207, [1, 0]); permute_207 = None + mm_412 = torch.ops.aten.mm.default(view_2637, permute_779); view_2637 = permute_779 = None + view_2638 = torch.ops.aten.view.default(mm_412, [2, 8192, 4096]); mm_412 = None + convert_element_type_1785 = torch.ops.prims.convert_element_type.default(mm_411, torch.float32); mm_411 = None + reduce_scatter_tensor_212 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1785, 'avg', 32, '0'); convert_element_type_1785 = None + wait_tensor_622 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_212); reduce_scatter_tensor_212 = None + convert_element_type_1786 = torch.ops.prims.convert_element_type.default(mul_525, torch.float32); mul_525 = None + neg_13 = torch.ops.aten.neg.default(convert_element_type_620) + exp_13 = torch.ops.aten.exp.default(neg_13); neg_13 = None + add_220 = torch.ops.aten.add.Tensor(exp_13, 1); exp_13 = None + reciprocal_13 = torch.ops.aten.reciprocal.default(add_220); add_220 = None + mul_526 = torch.ops.aten.mul.Tensor(reciprocal_13, 1); reciprocal_13 = None + mul_527 = torch.ops.aten.mul.Tensor(convert_element_type_1786, mul_526); convert_element_type_1786 = None + sub_41 = torch.ops.aten.sub.Tensor(1, mul_526); mul_526 = None + mul_528 = torch.ops.aten.mul.Tensor(convert_element_type_620, sub_41); convert_element_type_620 = sub_41 = None + add_221 = torch.ops.aten.add.Tensor(mul_528, 1); mul_528 = None + mul_529 = torch.ops.aten.mul.Tensor(mul_527, add_221); mul_527 = add_221 = None + convert_element_type_1788 = torch.ops.prims.convert_element_type.default(mul_529, torch.bfloat16); mul_529 = None + view_2639 = torch.ops.aten.view.default(convert_element_type_1788, [16384, 1792]); convert_element_type_1788 = None + permute_781 = torch.ops.aten.permute.default(view_2639, [1, 0]) + mm_413 = torch.ops.aten.mm.default(permute_781, view_1356); permute_781 = view_1356 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16); primals_172 = None + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 32, '0'); convert_element_type_617 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + permute_783 = torch.ops.aten.permute.default(permute_206, [1, 0]); permute_206 = None + mm_414 = torch.ops.aten.mm.default(view_2639, permute_783); view_2639 = permute_783 = None + view_2640 = torch.ops.aten.view.default(mm_414, [2, 8192, 4096]); mm_414 = None + add_222 = torch.ops.aten.add.Tensor(view_2638, view_2640); view_2638 = view_2640 = None + convert_element_type_1793 = torch.ops.prims.convert_element_type.default(mm_413, torch.float32); mm_413 = None + reduce_scatter_tensor_213 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1793, 'avg', 32, '0'); convert_element_type_1793 = None + wait_tensor_623 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_213); reduce_scatter_tensor_213 = None + split_192 = torch.ops.aten.split.Tensor(add_222, 1024, 1); add_222 = None + getitem_1863 = split_192[0] + getitem_1864 = split_192[1] + getitem_1865 = split_192[2] + getitem_1866 = split_192[3] + getitem_1867 = split_192[4] + getitem_1868 = split_192[5] + getitem_1869 = split_192[6] + getitem_1870 = split_192[7]; split_192 = None + cat_184 = torch.ops.aten.cat.default([getitem_1863, getitem_1864, getitem_1865, getitem_1866, getitem_1867, getitem_1868, getitem_1869, getitem_1870]); getitem_1863 = getitem_1864 = getitem_1865 = getitem_1866 = getitem_1867 = getitem_1868 = getitem_1869 = getitem_1870 = None + reduce_scatter_tensor_214 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_184, 'sum', 8, '1'); cat_184 = None + wait_tensor_624 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_214); reduce_scatter_tensor_214 = None + convert_element_type_1794 = torch.ops.prims.convert_element_type.default(wait_tensor_624, torch.float32); wait_tensor_624 = None + convert_element_type_1796 = torch.ops.prims.convert_element_type.default(wait_tensor_243, torch.float32); wait_tensor_243 = None + mul_530 = torch.ops.aten.mul.Tensor(convert_element_type_1794, convert_element_type_1796); convert_element_type_1796 = None + mul_532 = torch.ops.aten.mul.Tensor(mul_148, mul_530) + sum_81 = torch.ops.aten.sum.dim_IntList(mul_532, [2], True); mul_532 = None + div_27 = torch.ops.aten.div.Tensor(mul_148, 4096) + mul_533 = torch.ops.aten.mul.Tensor(div_27, sum_81); div_27 = sum_81 = None + sub_42 = torch.ops.aten.sub.Tensor(mul_530, mul_533); mul_530 = mul_533 = None + mul_534 = torch.ops.aten.mul.Tensor(sub_42, rsqrt_37); sub_42 = rsqrt_37 = None + mul_535 = torch.ops.aten.mul.Tensor(convert_element_type_1794, mul_148); convert_element_type_1794 = mul_148 = None + sum_82 = torch.ops.aten.sum.dim_IntList(mul_535, [0, 1]); mul_535 = None + convert_element_type_1797 = torch.ops.prims.convert_element_type.default(mul_534, torch.bfloat16); mul_534 = None + convert_element_type_1798 = torch.ops.prims.convert_element_type.default(sum_82, torch.bfloat16); sum_82 = None + all_reduce_27 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1798, 'sum', '1'); convert_element_type_1798 = None + wait_tensor_625 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_27); all_reduce_27 = None + convert_element_type_1799 = torch.ops.prims.convert_element_type.default(wait_tensor_625, torch.float32); wait_tensor_625 = None + reduce_scatter_tensor_215 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1799, 'avg', 32, '0'); convert_element_type_1799 = None + wait_tensor_626 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_215); reduce_scatter_tensor_215 = None + add_223 = torch.ops.aten.add.Tensor(add_219, convert_element_type_1797); add_219 = convert_element_type_1797 = None + all_gather_into_tensor_383 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_223, 8, '1') + wait_tensor_627 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_383); all_gather_into_tensor_383 = None + split_193 = torch.ops.aten.split.Tensor(wait_tensor_627, 2); wait_tensor_627 = None + getitem_1871 = split_193[0] + getitem_1872 = split_193[1] + getitem_1873 = split_193[2] + getitem_1874 = split_193[3] + getitem_1875 = split_193[4] + getitem_1876 = split_193[5] + getitem_1877 = split_193[6] + getitem_1878 = split_193[7]; split_193 = None + cat_185 = torch.ops.aten.cat.default([getitem_1871, getitem_1872, getitem_1873, getitem_1874, getitem_1875, getitem_1876, getitem_1877, getitem_1878], 1); getitem_1871 = getitem_1872 = getitem_1873 = getitem_1874 = getitem_1875 = getitem_1876 = getitem_1877 = getitem_1878 = None + view_2641 = torch.ops.aten.view.default(cat_185, [16384, 4096]); cat_185 = None + permute_785 = torch.ops.aten.permute.default(view_2641, [1, 0]) + permute_204 = torch.ops.aten.permute.default(getitem_818, [0, 2, 1, 3]) + view_1338 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + view_1344 = torch.ops.aten.view.default(view_1338, [16384, 512]); view_1338 = None + mm_415 = torch.ops.aten.mm.default(permute_785, view_1344); permute_785 = view_1344 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16); primals_170 = None + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 32, '0'); convert_element_type_611 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + permute_787 = torch.ops.aten.permute.default(permute_205, [1, 0]); permute_205 = None + mm_416 = torch.ops.aten.mm.default(view_2641, permute_787); view_2641 = permute_787 = None + view_2642 = torch.ops.aten.view.default(mm_416, [2, 8192, 512]); mm_416 = None + convert_element_type_1804 = torch.ops.prims.convert_element_type.default(mm_415, torch.float32); mm_415 = None + reduce_scatter_tensor_216 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1804, 'avg', 32, '0'); convert_element_type_1804 = None + wait_tensor_628 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_216); reduce_scatter_tensor_216 = None + view_2643 = torch.ops.aten.view.default(view_2642, [2, 8192, 4, 128]); view_2642 = None + permute_789 = torch.ops.aten.permute.default(view_2643, [0, 2, 1, 3]); view_2643 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16); primals_166 = None + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 32, '0'); convert_element_type_595 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32); add_71 = None + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_236) + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_597, 8, '1'); convert_element_type_597 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + split_81 = torch.ops.aten.split.Tensor(wait_tensor_237, 2); wait_tensor_237 = None + getitem_810 = split_81[0] + getitem_811 = split_81[1] + getitem_812 = split_81[2] + getitem_813 = split_81[3] + getitem_814 = split_81[4] + getitem_815 = split_81[5] + getitem_816 = split_81[6] + getitem_817 = split_81[7]; split_81 = None + cat_73 = torch.ops.aten.cat.default([getitem_810, getitem_811, getitem_812, getitem_813, getitem_814, getitem_815, getitem_816, getitem_817], 1); getitem_810 = getitem_811 = getitem_812 = getitem_813 = getitem_814 = getitem_815 = getitem_816 = getitem_817 = None + view_1311 = torch.ops.aten.view.default(cat_73, [16384, 4096]); cat_73 = None + view_1312 = torch.ops.aten.view.default(mm_126, [2, 8192, 512]); mm_126 = None + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16); primals_168 = None + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 32, '0'); convert_element_type_601 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + mm_127 = torch.ops.aten.mm.default(view_1311, permute_199) + view_1319 = torch.ops.aten.view.default(mm_127, [2, 8192, 128]); mm_127 = None + view_1326 = torch.ops.aten.view.default(mm_128, [2, 8192, 128]); mm_128 = None + view_1328 = torch.ops.aten.view.default(view_1312, [2, 8192, -1, 128]); view_1312 = None + view_1329 = torch.ops.aten.view.default(view_1319, [2, 8192, -1, 128]); view_1319 = None + view_1330 = torch.ops.aten.view.default(view_1326, [2, 8192, -1, 128]); view_1326 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_1328, torch.float32); view_1328 = None + view_1331 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 4, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_1331); view_1331 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_1329, torch.float32); view_1329 = None + view_1332 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 1, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_1332); view_1332 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_37); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_1334 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 4, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_37); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_1335 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 1, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_1334, torch.bfloat16); view_1334 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_1335, torch.bfloat16); view_1335 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 1, 4, 128]); unsqueeze_36 = None + view_1336 = torch.ops.aten.view.default(expand_36, [2, 8192, 4, 128]); expand_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_1330, 3); view_1330 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 1, 4, 128]); unsqueeze_37 = None + view_1337 = torch.ops.aten.view.default(expand_37, [2, 8192, 4, 128]); expand_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_1336, [0, 2, 1, 3]); view_1336 = None + permute_203 = torch.ops.aten.permute.default(view_1337, [0, 2, 1, 3]); view_1337 = None + _scaled_dot_product_cudnn_attention_backward_13 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_789, permute_201, permute_202, permute_203, getitem_818, getitem_819, getitem_824, getitem_825, None, None, None, 8192, 8192, 0.0, True); permute_789 = permute_201 = permute_202 = permute_203 = getitem_818 = getitem_819 = getitem_824 = getitem_825 = None + getitem_1879 = _scaled_dot_product_cudnn_attention_backward_13[0] + getitem_1880 = _scaled_dot_product_cudnn_attention_backward_13[1] + getitem_1881 = _scaled_dot_product_cudnn_attention_backward_13[2]; _scaled_dot_product_cudnn_attention_backward_13 = None + permute_790 = torch.ops.aten.permute.default(getitem_1881, [0, 2, 1, 3]); getitem_1881 = None + permute_791 = torch.ops.aten.permute.default(getitem_1880, [0, 2, 1, 3]); getitem_1880 = None + permute_792 = torch.ops.aten.permute.default(getitem_1879, [0, 2, 1, 3]); getitem_1879 = None + view_2644 = torch.ops.aten.view.default(permute_790, [2, 8192, 1, 4, 128]); permute_790 = None + sum_83 = torch.ops.aten.sum.dim_IntList(view_2644, [3], True); view_2644 = None + squeeze_26 = torch.ops.aten.squeeze.dim(sum_83, 3); sum_83 = None + view_2645 = torch.ops.aten.view.default(permute_791, [2, 8192, 1, 4, 128]); permute_791 = None + sum_84 = torch.ops.aten.sum.dim_IntList(view_2645, [3], True); view_2645 = None + squeeze_27 = torch.ops.aten.squeeze.dim(sum_84, 3); sum_84 = None + convert_element_type_1805 = torch.ops.prims.convert_element_type.default(squeeze_27, torch.float32); squeeze_27 = None + convert_element_type_1806 = torch.ops.prims.convert_element_type.default(permute_792, torch.float32); permute_792 = None + view_2646 = torch.ops.aten.view.default(convert_element_type_1805, [2, 8192, 1, 64, 2]); convert_element_type_1805 = None + view_as_complex_90 = torch.ops.aten.view_as_complex.default(view_2646); view_2646 = None + mul_536 = torch.ops.aten.mul.Tensor(view_as_complex_90, _conj); view_as_complex_90 = None + view_2647 = torch.ops.aten.view.default(convert_element_type_1806, [2, 8192, 4, 64, 2]); convert_element_type_1806 = None + view_as_complex_91 = torch.ops.aten.view_as_complex.default(view_2647); view_2647 = None + mul_537 = torch.ops.aten.mul.Tensor(view_as_complex_91, _conj); view_as_complex_91 = None + view_as_real_90 = torch.ops.aten.view_as_real.default(mul_536); mul_536 = None + view_2648 = torch.ops.aten.view.default(view_as_real_90, [2, 8192, 1, 128]); view_as_real_90 = None + convert_element_type_1807 = torch.ops.prims.convert_element_type.default(view_2648, torch.bfloat16); view_2648 = None + view_as_real_91 = torch.ops.aten.view_as_real.default(mul_537); mul_537 = None + view_2649 = torch.ops.aten.view.default(view_as_real_91, [2, 8192, 4, 128]); view_as_real_91 = None + convert_element_type_1808 = torch.ops.prims.convert_element_type.default(view_2649, torch.bfloat16); view_2649 = None + view_2650 = torch.ops.aten.view.default(squeeze_26, [2, 8192, 128]); squeeze_26 = None + view_2651 = torch.ops.aten.view.default(convert_element_type_1807, [2, 8192, 128]); convert_element_type_1807 = None + view_2652 = torch.ops.aten.view.default(convert_element_type_1808, [2, 8192, 512]); convert_element_type_1808 = None + view_2653 = torch.ops.aten.view.default(view_2650, [16384, 128]); view_2650 = None + permute_793 = torch.ops.aten.permute.default(view_2653, [1, 0]) + mm_417 = torch.ops.aten.mm.default(permute_793, view_1311); permute_793 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16); primals_169 = None + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 32, '0'); convert_element_type_604 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_240, [1, 0]); wait_tensor_240 = None + permute_795 = torch.ops.aten.permute.default(permute_200, [1, 0]); permute_200 = None + mm_418 = torch.ops.aten.mm.default(view_2653, permute_795); view_2653 = permute_795 = None + view_2654 = torch.ops.aten.view.default(mm_418, [2, 8192, 4096]); mm_418 = None + convert_element_type_1813 = torch.ops.prims.convert_element_type.default(mm_417, torch.float32); mm_417 = None + reduce_scatter_tensor_217 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1813, 'avg', 32, '0'); convert_element_type_1813 = None + wait_tensor_629 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_217); reduce_scatter_tensor_217 = None + view_2655 = torch.ops.aten.view.default(view_2651, [16384, 128]); view_2651 = None + permute_797 = torch.ops.aten.permute.default(view_2655, [1, 0]) + mm_419 = torch.ops.aten.mm.default(permute_797, view_1311); permute_797 = None + permute_799 = torch.ops.aten.permute.default(permute_199, [1, 0]); permute_199 = None + mm_420 = torch.ops.aten.mm.default(view_2655, permute_799); view_2655 = permute_799 = None + view_2656 = torch.ops.aten.view.default(mm_420, [2, 8192, 4096]); mm_420 = None + add_224 = torch.ops.aten.add.Tensor(view_2654, view_2656); view_2654 = view_2656 = None + convert_element_type_1818 = torch.ops.prims.convert_element_type.default(mm_419, torch.float32); mm_419 = None + reduce_scatter_tensor_218 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1818, 'avg', 32, '0'); convert_element_type_1818 = None + wait_tensor_630 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_218); reduce_scatter_tensor_218 = None + view_2657 = torch.ops.aten.view.default(view_2652, [16384, 512]); view_2652 = None + permute_801 = torch.ops.aten.permute.default(view_2657, [1, 0]) + mm_421 = torch.ops.aten.mm.default(permute_801, view_1311); permute_801 = view_1311 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16); primals_167 = None + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 32, '0'); convert_element_type_598 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + permute_803 = torch.ops.aten.permute.default(permute_198, [1, 0]); permute_198 = None + mm_422 = torch.ops.aten.mm.default(view_2657, permute_803); view_2657 = permute_803 = None + view_2658 = torch.ops.aten.view.default(mm_422, [2, 8192, 4096]); mm_422 = None + add_225 = torch.ops.aten.add.Tensor(add_224, view_2658); add_224 = view_2658 = None + convert_element_type_1823 = torch.ops.prims.convert_element_type.default(mm_421, torch.float32); mm_421 = None + reduce_scatter_tensor_219 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1823, 'avg', 32, '0'); convert_element_type_1823 = None + wait_tensor_631 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_219); reduce_scatter_tensor_219 = None + split_194 = torch.ops.aten.split.Tensor(add_225, 1024, 1); add_225 = None + getitem_1882 = split_194[0] + getitem_1883 = split_194[1] + getitem_1884 = split_194[2] + getitem_1885 = split_194[3] + getitem_1886 = split_194[4] + getitem_1887 = split_194[5] + getitem_1888 = split_194[6] + getitem_1889 = split_194[7]; split_194 = None + cat_186 = torch.ops.aten.cat.default([getitem_1882, getitem_1883, getitem_1884, getitem_1885, getitem_1886, getitem_1887, getitem_1888, getitem_1889]); getitem_1882 = getitem_1883 = getitem_1884 = getitem_1885 = getitem_1886 = getitem_1887 = getitem_1888 = getitem_1889 = None + reduce_scatter_tensor_220 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_186, 'sum', 8, '1'); cat_186 = None + wait_tensor_632 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_220); reduce_scatter_tensor_220 = None + convert_element_type_1824 = torch.ops.prims.convert_element_type.default(wait_tensor_632, torch.float32); wait_tensor_632 = None + convert_element_type_1826 = torch.ops.prims.convert_element_type.default(wait_tensor_236, torch.float32); wait_tensor_236 = None + mul_538 = torch.ops.aten.mul.Tensor(convert_element_type_1824, convert_element_type_1826); convert_element_type_1826 = None + mul_540 = torch.ops.aten.mul.Tensor(mul_144, mul_538) + sum_85 = torch.ops.aten.sum.dim_IntList(mul_540, [2], True); mul_540 = None + div_28 = torch.ops.aten.div.Tensor(mul_144, 4096) + mul_541 = torch.ops.aten.mul.Tensor(div_28, sum_85); div_28 = sum_85 = None + sub_43 = torch.ops.aten.sub.Tensor(mul_538, mul_541); mul_538 = mul_541 = None + mul_542 = torch.ops.aten.mul.Tensor(sub_43, rsqrt_36); sub_43 = rsqrt_36 = None + mul_543 = torch.ops.aten.mul.Tensor(convert_element_type_1824, mul_144); convert_element_type_1824 = mul_144 = None + sum_86 = torch.ops.aten.sum.dim_IntList(mul_543, [0, 1]); mul_543 = None + convert_element_type_1827 = torch.ops.prims.convert_element_type.default(mul_542, torch.bfloat16); mul_542 = None + convert_element_type_1828 = torch.ops.prims.convert_element_type.default(sum_86, torch.bfloat16); sum_86 = None + all_reduce_28 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1828, 'sum', '1'); convert_element_type_1828 = None + wait_tensor_633 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_28); all_reduce_28 = None + convert_element_type_1829 = torch.ops.prims.convert_element_type.default(wait_tensor_633, torch.float32); wait_tensor_633 = None + reduce_scatter_tensor_221 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1829, 'avg', 32, '0'); convert_element_type_1829 = None + wait_tensor_634 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_221); reduce_scatter_tensor_221 = None + add_226 = torch.ops.aten.add.Tensor(add_223, convert_element_type_1827); add_223 = convert_element_type_1827 = None + all_gather_into_tensor_384 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_226, 8, '1') + wait_tensor_635 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_384); all_gather_into_tensor_384 = None + split_195 = torch.ops.aten.split.Tensor(wait_tensor_635, 2); wait_tensor_635 = None + getitem_1890 = split_195[0] + getitem_1891 = split_195[1] + getitem_1892 = split_195[2] + getitem_1893 = split_195[3] + getitem_1894 = split_195[4] + getitem_1895 = split_195[5] + getitem_1896 = split_195[6] + getitem_1897 = split_195[7]; split_195 = None + cat_187 = torch.ops.aten.cat.default([getitem_1890, getitem_1891, getitem_1892, getitem_1893, getitem_1894, getitem_1895, getitem_1896, getitem_1897], 1); getitem_1890 = getitem_1891 = getitem_1892 = getitem_1893 = getitem_1894 = getitem_1895 = getitem_1896 = getitem_1897 = None + view_2659 = torch.ops.aten.view.default(cat_187, [16384, 4096]); cat_187 = None + permute_805 = torch.ops.aten.permute.default(view_2659, [1, 0]) + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35); reduce_scatter_tensor_35 = None + add_69 = torch.ops.aten.add.Tensor(add_67, wait_tensor_229); wait_tensor_229 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16); primals_162 = None + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 32, '0'); convert_element_type_581 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32); add_69 = None + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_230) + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_583, 8, '1'); convert_element_type_583 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + split_79 = torch.ops.aten.split.Tensor(wait_tensor_231, 2); wait_tensor_231 = None + getitem_794 = split_79[0] + getitem_795 = split_79[1] + getitem_796 = split_79[2] + getitem_797 = split_79[3] + getitem_798 = split_79[4] + getitem_799 = split_79[5] + getitem_800 = split_79[6] + getitem_801 = split_79[7]; split_79 = None + cat_71 = torch.ops.aten.cat.default([getitem_794, getitem_795, getitem_796, getitem_797, getitem_798, getitem_799, getitem_800, getitem_801], 1); getitem_794 = getitem_795 = getitem_796 = getitem_797 = getitem_798 = getitem_799 = getitem_800 = getitem_801 = None + view_1284 = torch.ops.aten.view.default(cat_71, [16384, 4096]); cat_71 = None + view_1285 = torch.ops.aten.view.default(mm_123, [2, 8192, 1792]); mm_123 = None + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_1285, torch.float32); view_1285 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16); primals_164 = None + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 32, '0'); convert_element_type_589 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_124 = torch.ops.aten.mm.default(view_1284, permute_196) + view_1292 = torch.ops.aten.view.default(mm_124, [2, 8192, 1792]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_1292) + view_1299 = torch.ops.aten.view.default(mul_143, [16384, 1792]); mul_143 = None + mm_423 = torch.ops.aten.mm.default(permute_805, view_1299); permute_805 = view_1299 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16); primals_165 = None + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 32, '0'); convert_element_type_592 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + permute_807 = torch.ops.aten.permute.default(permute_197, [1, 0]); permute_197 = None + mm_424 = torch.ops.aten.mm.default(view_2659, permute_807); view_2659 = permute_807 = None + view_2660 = torch.ops.aten.view.default(mm_424, [2, 8192, 1792]); mm_424 = None + convert_element_type_1834 = torch.ops.prims.convert_element_type.default(mm_423, torch.float32); mm_423 = None + reduce_scatter_tensor_222 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1834, 'avg', 32, '0'); convert_element_type_1834 = None + wait_tensor_636 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_222); reduce_scatter_tensor_222 = None + mul_544 = torch.ops.aten.mul.Tensor(view_2660, convert_element_type_588); convert_element_type_588 = None + mul_545 = torch.ops.aten.mul.Tensor(view_2660, view_1292); view_2660 = view_1292 = None + view_2661 = torch.ops.aten.view.default(mul_544, [16384, 1792]); mul_544 = None + permute_809 = torch.ops.aten.permute.default(view_2661, [1, 0]) + mm_425 = torch.ops.aten.mm.default(permute_809, view_1284); permute_809 = None + permute_811 = torch.ops.aten.permute.default(permute_196, [1, 0]); permute_196 = None + mm_426 = torch.ops.aten.mm.default(view_2661, permute_811); view_2661 = permute_811 = None + view_2662 = torch.ops.aten.view.default(mm_426, [2, 8192, 4096]); mm_426 = None + convert_element_type_1839 = torch.ops.prims.convert_element_type.default(mm_425, torch.float32); mm_425 = None + reduce_scatter_tensor_223 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1839, 'avg', 32, '0'); convert_element_type_1839 = None + wait_tensor_637 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_223); reduce_scatter_tensor_223 = None + convert_element_type_1840 = torch.ops.prims.convert_element_type.default(mul_545, torch.float32); mul_545 = None + neg_14 = torch.ops.aten.neg.default(convert_element_type_587) + exp_14 = torch.ops.aten.exp.default(neg_14); neg_14 = None + add_227 = torch.ops.aten.add.Tensor(exp_14, 1); exp_14 = None + reciprocal_14 = torch.ops.aten.reciprocal.default(add_227); add_227 = None + mul_546 = torch.ops.aten.mul.Tensor(reciprocal_14, 1); reciprocal_14 = None + mul_547 = torch.ops.aten.mul.Tensor(convert_element_type_1840, mul_546); convert_element_type_1840 = None + sub_44 = torch.ops.aten.sub.Tensor(1, mul_546); mul_546 = None + mul_548 = torch.ops.aten.mul.Tensor(convert_element_type_587, sub_44); convert_element_type_587 = sub_44 = None + add_228 = torch.ops.aten.add.Tensor(mul_548, 1); mul_548 = None + mul_549 = torch.ops.aten.mul.Tensor(mul_547, add_228); mul_547 = add_228 = None + convert_element_type_1842 = torch.ops.prims.convert_element_type.default(mul_549, torch.bfloat16); mul_549 = None + view_2663 = torch.ops.aten.view.default(convert_element_type_1842, [16384, 1792]); convert_element_type_1842 = None + permute_813 = torch.ops.aten.permute.default(view_2663, [1, 0]) + mm_427 = torch.ops.aten.mm.default(permute_813, view_1284); permute_813 = view_1284 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16); primals_163 = None + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 32, '0'); convert_element_type_584 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + permute_815 = torch.ops.aten.permute.default(permute_195, [1, 0]); permute_195 = None + mm_428 = torch.ops.aten.mm.default(view_2663, permute_815); view_2663 = permute_815 = None + view_2664 = torch.ops.aten.view.default(mm_428, [2, 8192, 4096]); mm_428 = None + add_229 = torch.ops.aten.add.Tensor(view_2662, view_2664); view_2662 = view_2664 = None + convert_element_type_1847 = torch.ops.prims.convert_element_type.default(mm_427, torch.float32); mm_427 = None + reduce_scatter_tensor_224 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1847, 'avg', 32, '0'); convert_element_type_1847 = None + wait_tensor_638 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_224); reduce_scatter_tensor_224 = None + split_196 = torch.ops.aten.split.Tensor(add_229, 1024, 1); add_229 = None + getitem_1898 = split_196[0] + getitem_1899 = split_196[1] + getitem_1900 = split_196[2] + getitem_1901 = split_196[3] + getitem_1902 = split_196[4] + getitem_1903 = split_196[5] + getitem_1904 = split_196[6] + getitem_1905 = split_196[7]; split_196 = None + cat_188 = torch.ops.aten.cat.default([getitem_1898, getitem_1899, getitem_1900, getitem_1901, getitem_1902, getitem_1903, getitem_1904, getitem_1905]); getitem_1898 = getitem_1899 = getitem_1900 = getitem_1901 = getitem_1902 = getitem_1903 = getitem_1904 = getitem_1905 = None + reduce_scatter_tensor_225 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_188, 'sum', 8, '1'); cat_188 = None + wait_tensor_639 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_225); reduce_scatter_tensor_225 = None + convert_element_type_1848 = torch.ops.prims.convert_element_type.default(wait_tensor_639, torch.float32); wait_tensor_639 = None + convert_element_type_1850 = torch.ops.prims.convert_element_type.default(wait_tensor_230, torch.float32); wait_tensor_230 = None + mul_550 = torch.ops.aten.mul.Tensor(convert_element_type_1848, convert_element_type_1850); convert_element_type_1850 = None + mul_552 = torch.ops.aten.mul.Tensor(mul_140, mul_550) + sum_87 = torch.ops.aten.sum.dim_IntList(mul_552, [2], True); mul_552 = None + div_29 = torch.ops.aten.div.Tensor(mul_140, 4096) + mul_553 = torch.ops.aten.mul.Tensor(div_29, sum_87); div_29 = sum_87 = None + sub_45 = torch.ops.aten.sub.Tensor(mul_550, mul_553); mul_550 = mul_553 = None + mul_554 = torch.ops.aten.mul.Tensor(sub_45, rsqrt_35); sub_45 = rsqrt_35 = None + mul_555 = torch.ops.aten.mul.Tensor(convert_element_type_1848, mul_140); convert_element_type_1848 = mul_140 = None + sum_88 = torch.ops.aten.sum.dim_IntList(mul_555, [0, 1]); mul_555 = None + convert_element_type_1851 = torch.ops.prims.convert_element_type.default(mul_554, torch.bfloat16); mul_554 = None + convert_element_type_1852 = torch.ops.prims.convert_element_type.default(sum_88, torch.bfloat16); sum_88 = None + all_reduce_29 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1852, 'sum', '1'); convert_element_type_1852 = None + wait_tensor_640 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_29); all_reduce_29 = None + convert_element_type_1853 = torch.ops.prims.convert_element_type.default(wait_tensor_640, torch.float32); wait_tensor_640 = None + reduce_scatter_tensor_226 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1853, 'avg', 32, '0'); convert_element_type_1853 = None + wait_tensor_641 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_226); reduce_scatter_tensor_226 = None + add_230 = torch.ops.aten.add.Tensor(add_226, convert_element_type_1851); add_226 = convert_element_type_1851 = None + all_gather_into_tensor_385 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_230, 8, '1') + wait_tensor_642 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_385); all_gather_into_tensor_385 = None + split_197 = torch.ops.aten.split.Tensor(wait_tensor_642, 2); wait_tensor_642 = None + getitem_1906 = split_197[0] + getitem_1907 = split_197[1] + getitem_1908 = split_197[2] + getitem_1909 = split_197[3] + getitem_1910 = split_197[4] + getitem_1911 = split_197[5] + getitem_1912 = split_197[6] + getitem_1913 = split_197[7]; split_197 = None + cat_189 = torch.ops.aten.cat.default([getitem_1906, getitem_1907, getitem_1908, getitem_1909, getitem_1910, getitem_1911, getitem_1912, getitem_1913], 1); getitem_1906 = getitem_1907 = getitem_1908 = getitem_1909 = getitem_1910 = getitem_1911 = getitem_1912 = getitem_1913 = None + view_2665 = torch.ops.aten.view.default(cat_189, [16384, 4096]); cat_189 = None + permute_817 = torch.ops.aten.permute.default(view_2665, [1, 0]) + permute_193 = torch.ops.aten.permute.default(getitem_777, [0, 2, 1, 3]) + view_1266 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + view_1272 = torch.ops.aten.view.default(view_1266, [16384, 512]); view_1266 = None + mm_429 = torch.ops.aten.mm.default(permute_817, view_1272); permute_817 = view_1272 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16); primals_161 = None + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 32, '0'); convert_element_type_578 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + permute_819 = torch.ops.aten.permute.default(permute_194, [1, 0]); permute_194 = None + mm_430 = torch.ops.aten.mm.default(view_2665, permute_819); view_2665 = permute_819 = None + view_2666 = torch.ops.aten.view.default(mm_430, [2, 8192, 512]); mm_430 = None + convert_element_type_1858 = torch.ops.prims.convert_element_type.default(mm_429, torch.float32); mm_429 = None + reduce_scatter_tensor_227 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1858, 'avg', 32, '0'); convert_element_type_1858 = None + wait_tensor_643 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_227); reduce_scatter_tensor_227 = None + view_2667 = torch.ops.aten.view.default(view_2666, [2, 8192, 4, 128]); view_2666 = None + permute_821 = torch.ops.aten.permute.default(view_2667, [0, 2, 1, 3]); view_2667 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16); primals_157 = None + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 32, '0'); convert_element_type_562 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32); add_67 = None + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_223) + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_564, 8, '1'); convert_element_type_564 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + split_77 = torch.ops.aten.split.Tensor(wait_tensor_224, 2); wait_tensor_224 = None + getitem_769 = split_77[0] + getitem_770 = split_77[1] + getitem_771 = split_77[2] + getitem_772 = split_77[3] + getitem_773 = split_77[4] + getitem_774 = split_77[5] + getitem_775 = split_77[6] + getitem_776 = split_77[7]; split_77 = None + cat_69 = torch.ops.aten.cat.default([getitem_769, getitem_770, getitem_771, getitem_772, getitem_773, getitem_774, getitem_775, getitem_776], 1); getitem_769 = getitem_770 = getitem_771 = getitem_772 = getitem_773 = getitem_774 = getitem_775 = getitem_776 = None + view_1239 = torch.ops.aten.view.default(cat_69, [16384, 4096]); cat_69 = None + view_1240 = torch.ops.aten.view.default(mm_119, [2, 8192, 512]); mm_119 = None + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16); primals_159 = None + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 32, '0'); convert_element_type_568 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_226, [1, 0]); wait_tensor_226 = None + mm_120 = torch.ops.aten.mm.default(view_1239, permute_188) + view_1247 = torch.ops.aten.view.default(mm_120, [2, 8192, 128]); mm_120 = None + view_1254 = torch.ops.aten.view.default(mm_121, [2, 8192, 128]); mm_121 = None + view_1256 = torch.ops.aten.view.default(view_1240, [2, 8192, -1, 128]); view_1240 = None + view_1257 = torch.ops.aten.view.default(view_1247, [2, 8192, -1, 128]); view_1247 = None + view_1258 = torch.ops.aten.view.default(view_1254, [2, 8192, -1, 128]); view_1254 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_1256, torch.float32); view_1256 = None + view_1259 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 4, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_1259); view_1259 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_1257, torch.float32); view_1257 = None + view_1260 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 1, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_1260); view_1260 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_37); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_1262 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 4, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_37); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_1263 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 1, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_1262, torch.bfloat16); view_1262 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_1263, torch.bfloat16); view_1263 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 1, 4, 128]); unsqueeze_34 = None + view_1264 = torch.ops.aten.view.default(expand_34, [2, 8192, 4, 128]); expand_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_1258, 3); view_1258 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 1, 4, 128]); unsqueeze_35 = None + view_1265 = torch.ops.aten.view.default(expand_35, [2, 8192, 4, 128]); expand_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_1264, [0, 2, 1, 3]); view_1264 = None + permute_192 = torch.ops.aten.permute.default(view_1265, [0, 2, 1, 3]); view_1265 = None + _scaled_dot_product_cudnn_attention_backward_14 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_821, permute_190, permute_191, permute_192, getitem_777, getitem_778, getitem_783, getitem_784, None, None, None, 8192, 8192, 0.0, True); permute_821 = permute_190 = permute_191 = permute_192 = getitem_777 = getitem_778 = getitem_783 = getitem_784 = None + getitem_1914 = _scaled_dot_product_cudnn_attention_backward_14[0] + getitem_1915 = _scaled_dot_product_cudnn_attention_backward_14[1] + getitem_1916 = _scaled_dot_product_cudnn_attention_backward_14[2]; _scaled_dot_product_cudnn_attention_backward_14 = None + permute_822 = torch.ops.aten.permute.default(getitem_1916, [0, 2, 1, 3]); getitem_1916 = None + permute_823 = torch.ops.aten.permute.default(getitem_1915, [0, 2, 1, 3]); getitem_1915 = None + permute_824 = torch.ops.aten.permute.default(getitem_1914, [0, 2, 1, 3]); getitem_1914 = None + view_2668 = torch.ops.aten.view.default(permute_822, [2, 8192, 1, 4, 128]); permute_822 = None + sum_89 = torch.ops.aten.sum.dim_IntList(view_2668, [3], True); view_2668 = None + squeeze_28 = torch.ops.aten.squeeze.dim(sum_89, 3); sum_89 = None + view_2669 = torch.ops.aten.view.default(permute_823, [2, 8192, 1, 4, 128]); permute_823 = None + sum_90 = torch.ops.aten.sum.dim_IntList(view_2669, [3], True); view_2669 = None + squeeze_29 = torch.ops.aten.squeeze.dim(sum_90, 3); sum_90 = None + convert_element_type_1859 = torch.ops.prims.convert_element_type.default(squeeze_29, torch.float32); squeeze_29 = None + convert_element_type_1860 = torch.ops.prims.convert_element_type.default(permute_824, torch.float32); permute_824 = None + view_2670 = torch.ops.aten.view.default(convert_element_type_1859, [2, 8192, 1, 64, 2]); convert_element_type_1859 = None + view_as_complex_92 = torch.ops.aten.view_as_complex.default(view_2670); view_2670 = None + mul_556 = torch.ops.aten.mul.Tensor(view_as_complex_92, _conj); view_as_complex_92 = None + view_2671 = torch.ops.aten.view.default(convert_element_type_1860, [2, 8192, 4, 64, 2]); convert_element_type_1860 = None + view_as_complex_93 = torch.ops.aten.view_as_complex.default(view_2671); view_2671 = None + mul_557 = torch.ops.aten.mul.Tensor(view_as_complex_93, _conj); view_as_complex_93 = None + view_as_real_92 = torch.ops.aten.view_as_real.default(mul_556); mul_556 = None + view_2672 = torch.ops.aten.view.default(view_as_real_92, [2, 8192, 1, 128]); view_as_real_92 = None + convert_element_type_1861 = torch.ops.prims.convert_element_type.default(view_2672, torch.bfloat16); view_2672 = None + view_as_real_93 = torch.ops.aten.view_as_real.default(mul_557); mul_557 = None + view_2673 = torch.ops.aten.view.default(view_as_real_93, [2, 8192, 4, 128]); view_as_real_93 = None + convert_element_type_1862 = torch.ops.prims.convert_element_type.default(view_2673, torch.bfloat16); view_2673 = None + view_2674 = torch.ops.aten.view.default(squeeze_28, [2, 8192, 128]); squeeze_28 = None + view_2675 = torch.ops.aten.view.default(convert_element_type_1861, [2, 8192, 128]); convert_element_type_1861 = None + view_2676 = torch.ops.aten.view.default(convert_element_type_1862, [2, 8192, 512]); convert_element_type_1862 = None + view_2677 = torch.ops.aten.view.default(view_2674, [16384, 128]); view_2674 = None + permute_825 = torch.ops.aten.permute.default(view_2677, [1, 0]) + mm_431 = torch.ops.aten.mm.default(permute_825, view_1239); permute_825 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16); primals_160 = None + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 32, '0'); convert_element_type_571 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + permute_827 = torch.ops.aten.permute.default(permute_189, [1, 0]); permute_189 = None + mm_432 = torch.ops.aten.mm.default(view_2677, permute_827); view_2677 = permute_827 = None + view_2678 = torch.ops.aten.view.default(mm_432, [2, 8192, 4096]); mm_432 = None + convert_element_type_1867 = torch.ops.prims.convert_element_type.default(mm_431, torch.float32); mm_431 = None + reduce_scatter_tensor_228 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1867, 'avg', 32, '0'); convert_element_type_1867 = None + wait_tensor_644 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_228); reduce_scatter_tensor_228 = None + view_2679 = torch.ops.aten.view.default(view_2675, [16384, 128]); view_2675 = None + permute_829 = torch.ops.aten.permute.default(view_2679, [1, 0]) + mm_433 = torch.ops.aten.mm.default(permute_829, view_1239); permute_829 = None + permute_831 = torch.ops.aten.permute.default(permute_188, [1, 0]); permute_188 = None + mm_434 = torch.ops.aten.mm.default(view_2679, permute_831); view_2679 = permute_831 = None + view_2680 = torch.ops.aten.view.default(mm_434, [2, 8192, 4096]); mm_434 = None + add_231 = torch.ops.aten.add.Tensor(view_2678, view_2680); view_2678 = view_2680 = None + convert_element_type_1872 = torch.ops.prims.convert_element_type.default(mm_433, torch.float32); mm_433 = None + reduce_scatter_tensor_229 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1872, 'avg', 32, '0'); convert_element_type_1872 = None + wait_tensor_645 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_229); reduce_scatter_tensor_229 = None + view_2681 = torch.ops.aten.view.default(view_2676, [16384, 512]); view_2676 = None + permute_833 = torch.ops.aten.permute.default(view_2681, [1, 0]) + mm_435 = torch.ops.aten.mm.default(permute_833, view_1239); permute_833 = view_1239 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16); primals_158 = None + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 32, '0'); convert_element_type_565 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + permute_835 = torch.ops.aten.permute.default(permute_187, [1, 0]); permute_187 = None + mm_436 = torch.ops.aten.mm.default(view_2681, permute_835); view_2681 = permute_835 = None + view_2682 = torch.ops.aten.view.default(mm_436, [2, 8192, 4096]); mm_436 = None + add_232 = torch.ops.aten.add.Tensor(add_231, view_2682); add_231 = view_2682 = None + convert_element_type_1877 = torch.ops.prims.convert_element_type.default(mm_435, torch.float32); mm_435 = None + reduce_scatter_tensor_230 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1877, 'avg', 32, '0'); convert_element_type_1877 = None + wait_tensor_646 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_230); reduce_scatter_tensor_230 = None + split_198 = torch.ops.aten.split.Tensor(add_232, 1024, 1); add_232 = None + getitem_1917 = split_198[0] + getitem_1918 = split_198[1] + getitem_1919 = split_198[2] + getitem_1920 = split_198[3] + getitem_1921 = split_198[4] + getitem_1922 = split_198[5] + getitem_1923 = split_198[6] + getitem_1924 = split_198[7]; split_198 = None + cat_190 = torch.ops.aten.cat.default([getitem_1917, getitem_1918, getitem_1919, getitem_1920, getitem_1921, getitem_1922, getitem_1923, getitem_1924]); getitem_1917 = getitem_1918 = getitem_1919 = getitem_1920 = getitem_1921 = getitem_1922 = getitem_1923 = getitem_1924 = None + reduce_scatter_tensor_231 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_190, 'sum', 8, '1'); cat_190 = None + wait_tensor_647 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_231); reduce_scatter_tensor_231 = None + convert_element_type_1878 = torch.ops.prims.convert_element_type.default(wait_tensor_647, torch.float32); wait_tensor_647 = None + convert_element_type_1880 = torch.ops.prims.convert_element_type.default(wait_tensor_223, torch.float32); wait_tensor_223 = None + mul_558 = torch.ops.aten.mul.Tensor(convert_element_type_1878, convert_element_type_1880); convert_element_type_1880 = None + mul_560 = torch.ops.aten.mul.Tensor(mul_136, mul_558) + sum_91 = torch.ops.aten.sum.dim_IntList(mul_560, [2], True); mul_560 = None + div_30 = torch.ops.aten.div.Tensor(mul_136, 4096) + mul_561 = torch.ops.aten.mul.Tensor(div_30, sum_91); div_30 = sum_91 = None + sub_46 = torch.ops.aten.sub.Tensor(mul_558, mul_561); mul_558 = mul_561 = None + mul_562 = torch.ops.aten.mul.Tensor(sub_46, rsqrt_34); sub_46 = rsqrt_34 = None + mul_563 = torch.ops.aten.mul.Tensor(convert_element_type_1878, mul_136); convert_element_type_1878 = mul_136 = None + sum_92 = torch.ops.aten.sum.dim_IntList(mul_563, [0, 1]); mul_563 = None + convert_element_type_1881 = torch.ops.prims.convert_element_type.default(mul_562, torch.bfloat16); mul_562 = None + convert_element_type_1882 = torch.ops.prims.convert_element_type.default(sum_92, torch.bfloat16); sum_92 = None + all_reduce_30 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1882, 'sum', '1'); convert_element_type_1882 = None + wait_tensor_648 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_30); all_reduce_30 = None + convert_element_type_1883 = torch.ops.prims.convert_element_type.default(wait_tensor_648, torch.float32); wait_tensor_648 = None + reduce_scatter_tensor_232 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1883, 'avg', 32, '0'); convert_element_type_1883 = None + wait_tensor_649 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_232); reduce_scatter_tensor_232 = None + add_233 = torch.ops.aten.add.Tensor(add_230, convert_element_type_1881); add_230 = convert_element_type_1881 = None + all_gather_into_tensor_386 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_233, 8, '1') + wait_tensor_650 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_386); all_gather_into_tensor_386 = None + split_199 = torch.ops.aten.split.Tensor(wait_tensor_650, 2); wait_tensor_650 = None + getitem_1925 = split_199[0] + getitem_1926 = split_199[1] + getitem_1927 = split_199[2] + getitem_1928 = split_199[3] + getitem_1929 = split_199[4] + getitem_1930 = split_199[5] + getitem_1931 = split_199[6] + getitem_1932 = split_199[7]; split_199 = None + cat_191 = torch.ops.aten.cat.default([getitem_1925, getitem_1926, getitem_1927, getitem_1928, getitem_1929, getitem_1930, getitem_1931, getitem_1932], 1); getitem_1925 = getitem_1926 = getitem_1927 = getitem_1928 = getitem_1929 = getitem_1930 = getitem_1931 = getitem_1932 = None + view_2683 = torch.ops.aten.view.default(cat_191, [16384, 4096]); cat_191 = None + permute_837 = torch.ops.aten.permute.default(view_2683, [1, 0]) + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33); reduce_scatter_tensor_33 = None + add_65 = torch.ops.aten.add.Tensor(add_63, wait_tensor_216); wait_tensor_216 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16); primals_153 = None + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 32, '0'); convert_element_type_548 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32); add_65 = None + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_217) + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_550, 8, '1'); convert_element_type_550 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + split_75 = torch.ops.aten.split.Tensor(wait_tensor_218, 2); wait_tensor_218 = None + getitem_753 = split_75[0] + getitem_754 = split_75[1] + getitem_755 = split_75[2] + getitem_756 = split_75[3] + getitem_757 = split_75[4] + getitem_758 = split_75[5] + getitem_759 = split_75[6] + getitem_760 = split_75[7]; split_75 = None + cat_67 = torch.ops.aten.cat.default([getitem_753, getitem_754, getitem_755, getitem_756, getitem_757, getitem_758, getitem_759, getitem_760], 1); getitem_753 = getitem_754 = getitem_755 = getitem_756 = getitem_757 = getitem_758 = getitem_759 = getitem_760 = None + view_1212 = torch.ops.aten.view.default(cat_67, [16384, 4096]); cat_67 = None + view_1213 = torch.ops.aten.view.default(mm_116, [2, 8192, 1792]); mm_116 = None + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_1213, torch.float32); view_1213 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16); primals_155 = None + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 32, '0'); convert_element_type_556 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + mm_117 = torch.ops.aten.mm.default(view_1212, permute_185) + view_1220 = torch.ops.aten.view.default(mm_117, [2, 8192, 1792]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_1220) + view_1227 = torch.ops.aten.view.default(mul_135, [16384, 1792]); mul_135 = None + mm_437 = torch.ops.aten.mm.default(permute_837, view_1227); permute_837 = view_1227 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16); primals_156 = None + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 32, '0'); convert_element_type_559 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + permute_839 = torch.ops.aten.permute.default(permute_186, [1, 0]); permute_186 = None + mm_438 = torch.ops.aten.mm.default(view_2683, permute_839); view_2683 = permute_839 = None + view_2684 = torch.ops.aten.view.default(mm_438, [2, 8192, 1792]); mm_438 = None + convert_element_type_1888 = torch.ops.prims.convert_element_type.default(mm_437, torch.float32); mm_437 = None + reduce_scatter_tensor_233 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1888, 'avg', 32, '0'); convert_element_type_1888 = None + wait_tensor_651 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_233); reduce_scatter_tensor_233 = None + mul_564 = torch.ops.aten.mul.Tensor(view_2684, convert_element_type_555); convert_element_type_555 = None + mul_565 = torch.ops.aten.mul.Tensor(view_2684, view_1220); view_2684 = view_1220 = None + view_2685 = torch.ops.aten.view.default(mul_564, [16384, 1792]); mul_564 = None + permute_841 = torch.ops.aten.permute.default(view_2685, [1, 0]) + mm_439 = torch.ops.aten.mm.default(permute_841, view_1212); permute_841 = None + permute_843 = torch.ops.aten.permute.default(permute_185, [1, 0]); permute_185 = None + mm_440 = torch.ops.aten.mm.default(view_2685, permute_843); view_2685 = permute_843 = None + view_2686 = torch.ops.aten.view.default(mm_440, [2, 8192, 4096]); mm_440 = None + convert_element_type_1893 = torch.ops.prims.convert_element_type.default(mm_439, torch.float32); mm_439 = None + reduce_scatter_tensor_234 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1893, 'avg', 32, '0'); convert_element_type_1893 = None + wait_tensor_652 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_234); reduce_scatter_tensor_234 = None + convert_element_type_1894 = torch.ops.prims.convert_element_type.default(mul_565, torch.float32); mul_565 = None + neg_15 = torch.ops.aten.neg.default(convert_element_type_554) + exp_15 = torch.ops.aten.exp.default(neg_15); neg_15 = None + add_234 = torch.ops.aten.add.Tensor(exp_15, 1); exp_15 = None + reciprocal_15 = torch.ops.aten.reciprocal.default(add_234); add_234 = None + mul_566 = torch.ops.aten.mul.Tensor(reciprocal_15, 1); reciprocal_15 = None + mul_567 = torch.ops.aten.mul.Tensor(convert_element_type_1894, mul_566); convert_element_type_1894 = None + sub_47 = torch.ops.aten.sub.Tensor(1, mul_566); mul_566 = None + mul_568 = torch.ops.aten.mul.Tensor(convert_element_type_554, sub_47); convert_element_type_554 = sub_47 = None + add_235 = torch.ops.aten.add.Tensor(mul_568, 1); mul_568 = None + mul_569 = torch.ops.aten.mul.Tensor(mul_567, add_235); mul_567 = add_235 = None + convert_element_type_1896 = torch.ops.prims.convert_element_type.default(mul_569, torch.bfloat16); mul_569 = None + view_2687 = torch.ops.aten.view.default(convert_element_type_1896, [16384, 1792]); convert_element_type_1896 = None + permute_845 = torch.ops.aten.permute.default(view_2687, [1, 0]) + mm_441 = torch.ops.aten.mm.default(permute_845, view_1212); permute_845 = view_1212 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16); primals_154 = None + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 32, '0'); convert_element_type_551 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + permute_847 = torch.ops.aten.permute.default(permute_184, [1, 0]); permute_184 = None + mm_442 = torch.ops.aten.mm.default(view_2687, permute_847); view_2687 = permute_847 = None + view_2688 = torch.ops.aten.view.default(mm_442, [2, 8192, 4096]); mm_442 = None + add_236 = torch.ops.aten.add.Tensor(view_2686, view_2688); view_2686 = view_2688 = None + convert_element_type_1901 = torch.ops.prims.convert_element_type.default(mm_441, torch.float32); mm_441 = None + reduce_scatter_tensor_235 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1901, 'avg', 32, '0'); convert_element_type_1901 = None + wait_tensor_653 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_235); reduce_scatter_tensor_235 = None + split_200 = torch.ops.aten.split.Tensor(add_236, 1024, 1); add_236 = None + getitem_1933 = split_200[0] + getitem_1934 = split_200[1] + getitem_1935 = split_200[2] + getitem_1936 = split_200[3] + getitem_1937 = split_200[4] + getitem_1938 = split_200[5] + getitem_1939 = split_200[6] + getitem_1940 = split_200[7]; split_200 = None + cat_192 = torch.ops.aten.cat.default([getitem_1933, getitem_1934, getitem_1935, getitem_1936, getitem_1937, getitem_1938, getitem_1939, getitem_1940]); getitem_1933 = getitem_1934 = getitem_1935 = getitem_1936 = getitem_1937 = getitem_1938 = getitem_1939 = getitem_1940 = None + reduce_scatter_tensor_236 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_192, 'sum', 8, '1'); cat_192 = None + wait_tensor_654 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_236); reduce_scatter_tensor_236 = None + convert_element_type_1902 = torch.ops.prims.convert_element_type.default(wait_tensor_654, torch.float32); wait_tensor_654 = None + convert_element_type_1904 = torch.ops.prims.convert_element_type.default(wait_tensor_217, torch.float32); wait_tensor_217 = None + mul_570 = torch.ops.aten.mul.Tensor(convert_element_type_1902, convert_element_type_1904); convert_element_type_1904 = None + mul_572 = torch.ops.aten.mul.Tensor(mul_132, mul_570) + sum_93 = torch.ops.aten.sum.dim_IntList(mul_572, [2], True); mul_572 = None + div_31 = torch.ops.aten.div.Tensor(mul_132, 4096) + mul_573 = torch.ops.aten.mul.Tensor(div_31, sum_93); div_31 = sum_93 = None + sub_48 = torch.ops.aten.sub.Tensor(mul_570, mul_573); mul_570 = mul_573 = None + mul_574 = torch.ops.aten.mul.Tensor(sub_48, rsqrt_33); sub_48 = rsqrt_33 = None + mul_575 = torch.ops.aten.mul.Tensor(convert_element_type_1902, mul_132); convert_element_type_1902 = mul_132 = None + sum_94 = torch.ops.aten.sum.dim_IntList(mul_575, [0, 1]); mul_575 = None + convert_element_type_1905 = torch.ops.prims.convert_element_type.default(mul_574, torch.bfloat16); mul_574 = None + convert_element_type_1906 = torch.ops.prims.convert_element_type.default(sum_94, torch.bfloat16); sum_94 = None + all_reduce_31 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1906, 'sum', '1'); convert_element_type_1906 = None + wait_tensor_655 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_31); all_reduce_31 = None + convert_element_type_1907 = torch.ops.prims.convert_element_type.default(wait_tensor_655, torch.float32); wait_tensor_655 = None + reduce_scatter_tensor_237 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1907, 'avg', 32, '0'); convert_element_type_1907 = None + wait_tensor_656 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_237); reduce_scatter_tensor_237 = None + add_237 = torch.ops.aten.add.Tensor(add_233, convert_element_type_1905); add_233 = convert_element_type_1905 = None + all_gather_into_tensor_387 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_237, 8, '1') + wait_tensor_657 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_387); all_gather_into_tensor_387 = None + split_201 = torch.ops.aten.split.Tensor(wait_tensor_657, 2); wait_tensor_657 = None + getitem_1941 = split_201[0] + getitem_1942 = split_201[1] + getitem_1943 = split_201[2] + getitem_1944 = split_201[3] + getitem_1945 = split_201[4] + getitem_1946 = split_201[5] + getitem_1947 = split_201[6] + getitem_1948 = split_201[7]; split_201 = None + cat_193 = torch.ops.aten.cat.default([getitem_1941, getitem_1942, getitem_1943, getitem_1944, getitem_1945, getitem_1946, getitem_1947, getitem_1948], 1); getitem_1941 = getitem_1942 = getitem_1943 = getitem_1944 = getitem_1945 = getitem_1946 = getitem_1947 = getitem_1948 = None + view_2689 = torch.ops.aten.view.default(cat_193, [16384, 4096]); cat_193 = None + permute_849 = torch.ops.aten.permute.default(view_2689, [1, 0]) + permute_182 = torch.ops.aten.permute.default(getitem_736, [0, 2, 1, 3]) + view_1194 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + view_1200 = torch.ops.aten.view.default(view_1194, [16384, 512]); view_1194 = None + mm_443 = torch.ops.aten.mm.default(permute_849, view_1200); permute_849 = view_1200 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16); primals_152 = None + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 32, '0'); convert_element_type_545 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + permute_851 = torch.ops.aten.permute.default(permute_183, [1, 0]); permute_183 = None + mm_444 = torch.ops.aten.mm.default(view_2689, permute_851); view_2689 = permute_851 = None + view_2690 = torch.ops.aten.view.default(mm_444, [2, 8192, 512]); mm_444 = None + convert_element_type_1912 = torch.ops.prims.convert_element_type.default(mm_443, torch.float32); mm_443 = None + reduce_scatter_tensor_238 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1912, 'avg', 32, '0'); convert_element_type_1912 = None + wait_tensor_658 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_238); reduce_scatter_tensor_238 = None + view_2691 = torch.ops.aten.view.default(view_2690, [2, 8192, 4, 128]); view_2690 = None + permute_853 = torch.ops.aten.permute.default(view_2691, [0, 2, 1, 3]); view_2691 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16); primals_148 = None + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 32, '0'); convert_element_type_529 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32); add_63 = None + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_210) + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_531, 8, '1'); convert_element_type_531 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + split_73 = torch.ops.aten.split.Tensor(wait_tensor_211, 2); wait_tensor_211 = None + getitem_728 = split_73[0] + getitem_729 = split_73[1] + getitem_730 = split_73[2] + getitem_731 = split_73[3] + getitem_732 = split_73[4] + getitem_733 = split_73[5] + getitem_734 = split_73[6] + getitem_735 = split_73[7]; split_73 = None + cat_65 = torch.ops.aten.cat.default([getitem_728, getitem_729, getitem_730, getitem_731, getitem_732, getitem_733, getitem_734, getitem_735], 1); getitem_728 = getitem_729 = getitem_730 = getitem_731 = getitem_732 = getitem_733 = getitem_734 = getitem_735 = None + view_1167 = torch.ops.aten.view.default(cat_65, [16384, 4096]); cat_65 = None + view_1168 = torch.ops.aten.view.default(mm_112, [2, 8192, 512]); mm_112 = None + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16); primals_150 = None + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 32, '0'); convert_element_type_535 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_213, [1, 0]); wait_tensor_213 = None + mm_113 = torch.ops.aten.mm.default(view_1167, permute_177) + view_1175 = torch.ops.aten.view.default(mm_113, [2, 8192, 128]); mm_113 = None + view_1182 = torch.ops.aten.view.default(mm_114, [2, 8192, 128]); mm_114 = None + view_1184 = torch.ops.aten.view.default(view_1168, [2, 8192, -1, 128]); view_1168 = None + view_1185 = torch.ops.aten.view.default(view_1175, [2, 8192, -1, 128]); view_1175 = None + view_1186 = torch.ops.aten.view.default(view_1182, [2, 8192, -1, 128]); view_1182 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_1184, torch.float32); view_1184 = None + view_1187 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 4, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_1187); view_1187 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_1185, torch.float32); view_1185 = None + view_1188 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 1, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_1188); view_1188 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_37); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_1190 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 4, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_37); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_1191 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 1, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_1190, torch.bfloat16); view_1190 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_1191, torch.bfloat16); view_1191 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 1, 4, 128]); unsqueeze_32 = None + view_1192 = torch.ops.aten.view.default(expand_32, [2, 8192, 4, 128]); expand_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_1186, 3); view_1186 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 1, 4, 128]); unsqueeze_33 = None + view_1193 = torch.ops.aten.view.default(expand_33, [2, 8192, 4, 128]); expand_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_1192, [0, 2, 1, 3]); view_1192 = None + permute_181 = torch.ops.aten.permute.default(view_1193, [0, 2, 1, 3]); view_1193 = None + _scaled_dot_product_cudnn_attention_backward_15 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_853, permute_179, permute_180, permute_181, getitem_736, getitem_737, getitem_742, getitem_743, None, None, None, 8192, 8192, 0.0, True); permute_853 = permute_179 = permute_180 = permute_181 = getitem_736 = getitem_737 = getitem_742 = getitem_743 = None + getitem_1949 = _scaled_dot_product_cudnn_attention_backward_15[0] + getitem_1950 = _scaled_dot_product_cudnn_attention_backward_15[1] + getitem_1951 = _scaled_dot_product_cudnn_attention_backward_15[2]; _scaled_dot_product_cudnn_attention_backward_15 = None + permute_854 = torch.ops.aten.permute.default(getitem_1951, [0, 2, 1, 3]); getitem_1951 = None + permute_855 = torch.ops.aten.permute.default(getitem_1950, [0, 2, 1, 3]); getitem_1950 = None + permute_856 = torch.ops.aten.permute.default(getitem_1949, [0, 2, 1, 3]); getitem_1949 = None + view_2692 = torch.ops.aten.view.default(permute_854, [2, 8192, 1, 4, 128]); permute_854 = None + sum_95 = torch.ops.aten.sum.dim_IntList(view_2692, [3], True); view_2692 = None + squeeze_30 = torch.ops.aten.squeeze.dim(sum_95, 3); sum_95 = None + view_2693 = torch.ops.aten.view.default(permute_855, [2, 8192, 1, 4, 128]); permute_855 = None + sum_96 = torch.ops.aten.sum.dim_IntList(view_2693, [3], True); view_2693 = None + squeeze_31 = torch.ops.aten.squeeze.dim(sum_96, 3); sum_96 = None + convert_element_type_1913 = torch.ops.prims.convert_element_type.default(squeeze_31, torch.float32); squeeze_31 = None + convert_element_type_1914 = torch.ops.prims.convert_element_type.default(permute_856, torch.float32); permute_856 = None + view_2694 = torch.ops.aten.view.default(convert_element_type_1913, [2, 8192, 1, 64, 2]); convert_element_type_1913 = None + view_as_complex_94 = torch.ops.aten.view_as_complex.default(view_2694); view_2694 = None + mul_576 = torch.ops.aten.mul.Tensor(view_as_complex_94, _conj); view_as_complex_94 = None + view_2695 = torch.ops.aten.view.default(convert_element_type_1914, [2, 8192, 4, 64, 2]); convert_element_type_1914 = None + view_as_complex_95 = torch.ops.aten.view_as_complex.default(view_2695); view_2695 = None + mul_577 = torch.ops.aten.mul.Tensor(view_as_complex_95, _conj); view_as_complex_95 = None + view_as_real_94 = torch.ops.aten.view_as_real.default(mul_576); mul_576 = None + view_2696 = torch.ops.aten.view.default(view_as_real_94, [2, 8192, 1, 128]); view_as_real_94 = None + convert_element_type_1915 = torch.ops.prims.convert_element_type.default(view_2696, torch.bfloat16); view_2696 = None + view_as_real_95 = torch.ops.aten.view_as_real.default(mul_577); mul_577 = None + view_2697 = torch.ops.aten.view.default(view_as_real_95, [2, 8192, 4, 128]); view_as_real_95 = None + convert_element_type_1916 = torch.ops.prims.convert_element_type.default(view_2697, torch.bfloat16); view_2697 = None + view_2698 = torch.ops.aten.view.default(squeeze_30, [2, 8192, 128]); squeeze_30 = None + view_2699 = torch.ops.aten.view.default(convert_element_type_1915, [2, 8192, 128]); convert_element_type_1915 = None + view_2700 = torch.ops.aten.view.default(convert_element_type_1916, [2, 8192, 512]); convert_element_type_1916 = None + view_2701 = torch.ops.aten.view.default(view_2698, [16384, 128]); view_2698 = None + permute_857 = torch.ops.aten.permute.default(view_2701, [1, 0]) + mm_445 = torch.ops.aten.mm.default(permute_857, view_1167); permute_857 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16); primals_151 = None + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 32, '0'); convert_element_type_538 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + permute_859 = torch.ops.aten.permute.default(permute_178, [1, 0]); permute_178 = None + mm_446 = torch.ops.aten.mm.default(view_2701, permute_859); view_2701 = permute_859 = None + view_2702 = torch.ops.aten.view.default(mm_446, [2, 8192, 4096]); mm_446 = None + convert_element_type_1921 = torch.ops.prims.convert_element_type.default(mm_445, torch.float32); mm_445 = None + reduce_scatter_tensor_239 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1921, 'avg', 32, '0'); convert_element_type_1921 = None + wait_tensor_659 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_239); reduce_scatter_tensor_239 = None + view_2703 = torch.ops.aten.view.default(view_2699, [16384, 128]); view_2699 = None + permute_861 = torch.ops.aten.permute.default(view_2703, [1, 0]) + mm_447 = torch.ops.aten.mm.default(permute_861, view_1167); permute_861 = None + permute_863 = torch.ops.aten.permute.default(permute_177, [1, 0]); permute_177 = None + mm_448 = torch.ops.aten.mm.default(view_2703, permute_863); view_2703 = permute_863 = None + view_2704 = torch.ops.aten.view.default(mm_448, [2, 8192, 4096]); mm_448 = None + add_238 = torch.ops.aten.add.Tensor(view_2702, view_2704); view_2702 = view_2704 = None + convert_element_type_1926 = torch.ops.prims.convert_element_type.default(mm_447, torch.float32); mm_447 = None + reduce_scatter_tensor_240 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1926, 'avg', 32, '0'); convert_element_type_1926 = None + wait_tensor_660 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_240); reduce_scatter_tensor_240 = None + view_2705 = torch.ops.aten.view.default(view_2700, [16384, 512]); view_2700 = None + permute_865 = torch.ops.aten.permute.default(view_2705, [1, 0]) + mm_449 = torch.ops.aten.mm.default(permute_865, view_1167); permute_865 = view_1167 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16); primals_149 = None + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 32, '0'); convert_element_type_532 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + permute_867 = torch.ops.aten.permute.default(permute_176, [1, 0]); permute_176 = None + mm_450 = torch.ops.aten.mm.default(view_2705, permute_867); view_2705 = permute_867 = None + view_2706 = torch.ops.aten.view.default(mm_450, [2, 8192, 4096]); mm_450 = None + add_239 = torch.ops.aten.add.Tensor(add_238, view_2706); add_238 = view_2706 = None + convert_element_type_1931 = torch.ops.prims.convert_element_type.default(mm_449, torch.float32); mm_449 = None + reduce_scatter_tensor_241 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1931, 'avg', 32, '0'); convert_element_type_1931 = None + wait_tensor_661 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_241); reduce_scatter_tensor_241 = None + split_202 = torch.ops.aten.split.Tensor(add_239, 1024, 1); add_239 = None + getitem_1952 = split_202[0] + getitem_1953 = split_202[1] + getitem_1954 = split_202[2] + getitem_1955 = split_202[3] + getitem_1956 = split_202[4] + getitem_1957 = split_202[5] + getitem_1958 = split_202[6] + getitem_1959 = split_202[7]; split_202 = None + cat_194 = torch.ops.aten.cat.default([getitem_1952, getitem_1953, getitem_1954, getitem_1955, getitem_1956, getitem_1957, getitem_1958, getitem_1959]); getitem_1952 = getitem_1953 = getitem_1954 = getitem_1955 = getitem_1956 = getitem_1957 = getitem_1958 = getitem_1959 = None + reduce_scatter_tensor_242 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_194, 'sum', 8, '1'); cat_194 = None + wait_tensor_662 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_242); reduce_scatter_tensor_242 = None + convert_element_type_1932 = torch.ops.prims.convert_element_type.default(wait_tensor_662, torch.float32); wait_tensor_662 = None + convert_element_type_1934 = torch.ops.prims.convert_element_type.default(wait_tensor_210, torch.float32); wait_tensor_210 = None + mul_578 = torch.ops.aten.mul.Tensor(convert_element_type_1932, convert_element_type_1934); convert_element_type_1934 = None + mul_580 = torch.ops.aten.mul.Tensor(mul_128, mul_578) + sum_97 = torch.ops.aten.sum.dim_IntList(mul_580, [2], True); mul_580 = None + div_32 = torch.ops.aten.div.Tensor(mul_128, 4096) + mul_581 = torch.ops.aten.mul.Tensor(div_32, sum_97); div_32 = sum_97 = None + sub_49 = torch.ops.aten.sub.Tensor(mul_578, mul_581); mul_578 = mul_581 = None + mul_582 = torch.ops.aten.mul.Tensor(sub_49, rsqrt_32); sub_49 = rsqrt_32 = None + mul_583 = torch.ops.aten.mul.Tensor(convert_element_type_1932, mul_128); convert_element_type_1932 = mul_128 = None + sum_98 = torch.ops.aten.sum.dim_IntList(mul_583, [0, 1]); mul_583 = None + convert_element_type_1935 = torch.ops.prims.convert_element_type.default(mul_582, torch.bfloat16); mul_582 = None + convert_element_type_1936 = torch.ops.prims.convert_element_type.default(sum_98, torch.bfloat16); sum_98 = None + all_reduce_32 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1936, 'sum', '1'); convert_element_type_1936 = None + wait_tensor_663 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_32); all_reduce_32 = None + convert_element_type_1937 = torch.ops.prims.convert_element_type.default(wait_tensor_663, torch.float32); wait_tensor_663 = None + reduce_scatter_tensor_243 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1937, 'avg', 32, '0'); convert_element_type_1937 = None + wait_tensor_664 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_243); reduce_scatter_tensor_243 = None + add_240 = torch.ops.aten.add.Tensor(add_237, convert_element_type_1935); add_237 = convert_element_type_1935 = None + all_gather_into_tensor_388 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_240, 8, '1') + wait_tensor_665 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_388); all_gather_into_tensor_388 = None + split_203 = torch.ops.aten.split.Tensor(wait_tensor_665, 2); wait_tensor_665 = None + getitem_1960 = split_203[0] + getitem_1961 = split_203[1] + getitem_1962 = split_203[2] + getitem_1963 = split_203[3] + getitem_1964 = split_203[4] + getitem_1965 = split_203[5] + getitem_1966 = split_203[6] + getitem_1967 = split_203[7]; split_203 = None + cat_195 = torch.ops.aten.cat.default([getitem_1960, getitem_1961, getitem_1962, getitem_1963, getitem_1964, getitem_1965, getitem_1966, getitem_1967], 1); getitem_1960 = getitem_1961 = getitem_1962 = getitem_1963 = getitem_1964 = getitem_1965 = getitem_1966 = getitem_1967 = None + view_2707 = torch.ops.aten.view.default(cat_195, [16384, 4096]); cat_195 = None + permute_869 = torch.ops.aten.permute.default(view_2707, [1, 0]) + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31); reduce_scatter_tensor_31 = None + add_61 = torch.ops.aten.add.Tensor(add_59, wait_tensor_203); wait_tensor_203 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16); primals_144 = None + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 32, '0'); convert_element_type_515 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32); add_61 = None + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_204) + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_517, 8, '1'); convert_element_type_517 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + split_71 = torch.ops.aten.split.Tensor(wait_tensor_205, 2); wait_tensor_205 = None + getitem_712 = split_71[0] + getitem_713 = split_71[1] + getitem_714 = split_71[2] + getitem_715 = split_71[3] + getitem_716 = split_71[4] + getitem_717 = split_71[5] + getitem_718 = split_71[6] + getitem_719 = split_71[7]; split_71 = None + cat_63 = torch.ops.aten.cat.default([getitem_712, getitem_713, getitem_714, getitem_715, getitem_716, getitem_717, getitem_718, getitem_719], 1); getitem_712 = getitem_713 = getitem_714 = getitem_715 = getitem_716 = getitem_717 = getitem_718 = getitem_719 = None + view_1140 = torch.ops.aten.view.default(cat_63, [16384, 4096]); cat_63 = None + view_1141 = torch.ops.aten.view.default(mm_109, [2, 8192, 1792]); mm_109 = None + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_1141, torch.float32); view_1141 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16); primals_146 = None + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 32, '0'); convert_element_type_523 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + mm_110 = torch.ops.aten.mm.default(view_1140, permute_174) + view_1148 = torch.ops.aten.view.default(mm_110, [2, 8192, 1792]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_1148) + view_1155 = torch.ops.aten.view.default(mul_127, [16384, 1792]); mul_127 = None + mm_451 = torch.ops.aten.mm.default(permute_869, view_1155); permute_869 = view_1155 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16); primals_147 = None + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 32, '0'); convert_element_type_526 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_208, [1, 0]); wait_tensor_208 = None + permute_871 = torch.ops.aten.permute.default(permute_175, [1, 0]); permute_175 = None + mm_452 = torch.ops.aten.mm.default(view_2707, permute_871); view_2707 = permute_871 = None + view_2708 = torch.ops.aten.view.default(mm_452, [2, 8192, 1792]); mm_452 = None + convert_element_type_1942 = torch.ops.prims.convert_element_type.default(mm_451, torch.float32); mm_451 = None + reduce_scatter_tensor_244 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1942, 'avg', 32, '0'); convert_element_type_1942 = None + wait_tensor_666 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_244); reduce_scatter_tensor_244 = None + mul_584 = torch.ops.aten.mul.Tensor(view_2708, convert_element_type_522); convert_element_type_522 = None + mul_585 = torch.ops.aten.mul.Tensor(view_2708, view_1148); view_2708 = view_1148 = None + view_2709 = torch.ops.aten.view.default(mul_584, [16384, 1792]); mul_584 = None + permute_873 = torch.ops.aten.permute.default(view_2709, [1, 0]) + mm_453 = torch.ops.aten.mm.default(permute_873, view_1140); permute_873 = None + permute_875 = torch.ops.aten.permute.default(permute_174, [1, 0]); permute_174 = None + mm_454 = torch.ops.aten.mm.default(view_2709, permute_875); view_2709 = permute_875 = None + view_2710 = torch.ops.aten.view.default(mm_454, [2, 8192, 4096]); mm_454 = None + convert_element_type_1947 = torch.ops.prims.convert_element_type.default(mm_453, torch.float32); mm_453 = None + reduce_scatter_tensor_245 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1947, 'avg', 32, '0'); convert_element_type_1947 = None + wait_tensor_667 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_245); reduce_scatter_tensor_245 = None + convert_element_type_1948 = torch.ops.prims.convert_element_type.default(mul_585, torch.float32); mul_585 = None + neg_16 = torch.ops.aten.neg.default(convert_element_type_521) + exp_16 = torch.ops.aten.exp.default(neg_16); neg_16 = None + add_241 = torch.ops.aten.add.Tensor(exp_16, 1); exp_16 = None + reciprocal_16 = torch.ops.aten.reciprocal.default(add_241); add_241 = None + mul_586 = torch.ops.aten.mul.Tensor(reciprocal_16, 1); reciprocal_16 = None + mul_587 = torch.ops.aten.mul.Tensor(convert_element_type_1948, mul_586); convert_element_type_1948 = None + sub_50 = torch.ops.aten.sub.Tensor(1, mul_586); mul_586 = None + mul_588 = torch.ops.aten.mul.Tensor(convert_element_type_521, sub_50); convert_element_type_521 = sub_50 = None + add_242 = torch.ops.aten.add.Tensor(mul_588, 1); mul_588 = None + mul_589 = torch.ops.aten.mul.Tensor(mul_587, add_242); mul_587 = add_242 = None + convert_element_type_1950 = torch.ops.prims.convert_element_type.default(mul_589, torch.bfloat16); mul_589 = None + view_2711 = torch.ops.aten.view.default(convert_element_type_1950, [16384, 1792]); convert_element_type_1950 = None + permute_877 = torch.ops.aten.permute.default(view_2711, [1, 0]) + mm_455 = torch.ops.aten.mm.default(permute_877, view_1140); permute_877 = view_1140 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16); primals_145 = None + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 32, '0'); convert_element_type_518 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + permute_879 = torch.ops.aten.permute.default(permute_173, [1, 0]); permute_173 = None + mm_456 = torch.ops.aten.mm.default(view_2711, permute_879); view_2711 = permute_879 = None + view_2712 = torch.ops.aten.view.default(mm_456, [2, 8192, 4096]); mm_456 = None + add_243 = torch.ops.aten.add.Tensor(view_2710, view_2712); view_2710 = view_2712 = None + convert_element_type_1955 = torch.ops.prims.convert_element_type.default(mm_455, torch.float32); mm_455 = None + reduce_scatter_tensor_246 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1955, 'avg', 32, '0'); convert_element_type_1955 = None + wait_tensor_668 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_246); reduce_scatter_tensor_246 = None + split_204 = torch.ops.aten.split.Tensor(add_243, 1024, 1); add_243 = None + getitem_1968 = split_204[0] + getitem_1969 = split_204[1] + getitem_1970 = split_204[2] + getitem_1971 = split_204[3] + getitem_1972 = split_204[4] + getitem_1973 = split_204[5] + getitem_1974 = split_204[6] + getitem_1975 = split_204[7]; split_204 = None + cat_196 = torch.ops.aten.cat.default([getitem_1968, getitem_1969, getitem_1970, getitem_1971, getitem_1972, getitem_1973, getitem_1974, getitem_1975]); getitem_1968 = getitem_1969 = getitem_1970 = getitem_1971 = getitem_1972 = getitem_1973 = getitem_1974 = getitem_1975 = None + reduce_scatter_tensor_247 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_196, 'sum', 8, '1'); cat_196 = None + wait_tensor_669 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_247); reduce_scatter_tensor_247 = None + convert_element_type_1956 = torch.ops.prims.convert_element_type.default(wait_tensor_669, torch.float32); wait_tensor_669 = None + convert_element_type_1958 = torch.ops.prims.convert_element_type.default(wait_tensor_204, torch.float32); wait_tensor_204 = None + mul_590 = torch.ops.aten.mul.Tensor(convert_element_type_1956, convert_element_type_1958); convert_element_type_1958 = None + mul_592 = torch.ops.aten.mul.Tensor(mul_124, mul_590) + sum_99 = torch.ops.aten.sum.dim_IntList(mul_592, [2], True); mul_592 = None + div_33 = torch.ops.aten.div.Tensor(mul_124, 4096) + mul_593 = torch.ops.aten.mul.Tensor(div_33, sum_99); div_33 = sum_99 = None + sub_51 = torch.ops.aten.sub.Tensor(mul_590, mul_593); mul_590 = mul_593 = None + mul_594 = torch.ops.aten.mul.Tensor(sub_51, rsqrt_31); sub_51 = rsqrt_31 = None + mul_595 = torch.ops.aten.mul.Tensor(convert_element_type_1956, mul_124); convert_element_type_1956 = mul_124 = None + sum_100 = torch.ops.aten.sum.dim_IntList(mul_595, [0, 1]); mul_595 = None + convert_element_type_1959 = torch.ops.prims.convert_element_type.default(mul_594, torch.bfloat16); mul_594 = None + convert_element_type_1960 = torch.ops.prims.convert_element_type.default(sum_100, torch.bfloat16); sum_100 = None + all_reduce_33 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1960, 'sum', '1'); convert_element_type_1960 = None + wait_tensor_670 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_33); all_reduce_33 = None + convert_element_type_1961 = torch.ops.prims.convert_element_type.default(wait_tensor_670, torch.float32); wait_tensor_670 = None + reduce_scatter_tensor_248 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1961, 'avg', 32, '0'); convert_element_type_1961 = None + wait_tensor_671 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_248); reduce_scatter_tensor_248 = None + add_244 = torch.ops.aten.add.Tensor(add_240, convert_element_type_1959); add_240 = convert_element_type_1959 = None + all_gather_into_tensor_389 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_244, 8, '1') + wait_tensor_672 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_389); all_gather_into_tensor_389 = None + split_205 = torch.ops.aten.split.Tensor(wait_tensor_672, 2); wait_tensor_672 = None + getitem_1976 = split_205[0] + getitem_1977 = split_205[1] + getitem_1978 = split_205[2] + getitem_1979 = split_205[3] + getitem_1980 = split_205[4] + getitem_1981 = split_205[5] + getitem_1982 = split_205[6] + getitem_1983 = split_205[7]; split_205 = None + cat_197 = torch.ops.aten.cat.default([getitem_1976, getitem_1977, getitem_1978, getitem_1979, getitem_1980, getitem_1981, getitem_1982, getitem_1983], 1); getitem_1976 = getitem_1977 = getitem_1978 = getitem_1979 = getitem_1980 = getitem_1981 = getitem_1982 = getitem_1983 = None + view_2713 = torch.ops.aten.view.default(cat_197, [16384, 4096]); cat_197 = None + permute_881 = torch.ops.aten.permute.default(view_2713, [1, 0]) + permute_171 = torch.ops.aten.permute.default(getitem_695, [0, 2, 1, 3]) + view_1122 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + view_1128 = torch.ops.aten.view.default(view_1122, [16384, 512]); view_1122 = None + mm_457 = torch.ops.aten.mm.default(permute_881, view_1128); permute_881 = view_1128 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16); primals_143 = None + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 32, '0'); convert_element_type_512 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + permute_883 = torch.ops.aten.permute.default(permute_172, [1, 0]); permute_172 = None + mm_458 = torch.ops.aten.mm.default(view_2713, permute_883); view_2713 = permute_883 = None + view_2714 = torch.ops.aten.view.default(mm_458, [2, 8192, 512]); mm_458 = None + convert_element_type_1966 = torch.ops.prims.convert_element_type.default(mm_457, torch.float32); mm_457 = None + reduce_scatter_tensor_249 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1966, 'avg', 32, '0'); convert_element_type_1966 = None + wait_tensor_673 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_249); reduce_scatter_tensor_249 = None + view_2715 = torch.ops.aten.view.default(view_2714, [2, 8192, 4, 128]); view_2714 = None + permute_885 = torch.ops.aten.permute.default(view_2715, [0, 2, 1, 3]); view_2715 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16); primals_139 = None + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 32, '0'); convert_element_type_496 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32); add_59 = None + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_197) + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_498, 8, '1'); convert_element_type_498 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + split_69 = torch.ops.aten.split.Tensor(wait_tensor_198, 2); wait_tensor_198 = None + getitem_687 = split_69[0] + getitem_688 = split_69[1] + getitem_689 = split_69[2] + getitem_690 = split_69[3] + getitem_691 = split_69[4] + getitem_692 = split_69[5] + getitem_693 = split_69[6] + getitem_694 = split_69[7]; split_69 = None + cat_61 = torch.ops.aten.cat.default([getitem_687, getitem_688, getitem_689, getitem_690, getitem_691, getitem_692, getitem_693, getitem_694], 1); getitem_687 = getitem_688 = getitem_689 = getitem_690 = getitem_691 = getitem_692 = getitem_693 = getitem_694 = None + view_1095 = torch.ops.aten.view.default(cat_61, [16384, 4096]); cat_61 = None + view_1096 = torch.ops.aten.view.default(mm_105, [2, 8192, 512]); mm_105 = None + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16); primals_141 = None + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 32, '0'); convert_element_type_502 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + mm_106 = torch.ops.aten.mm.default(view_1095, permute_166) + view_1103 = torch.ops.aten.view.default(mm_106, [2, 8192, 128]); mm_106 = None + view_1110 = torch.ops.aten.view.default(mm_107, [2, 8192, 128]); mm_107 = None + view_1112 = torch.ops.aten.view.default(view_1096, [2, 8192, -1, 128]); view_1096 = None + view_1113 = torch.ops.aten.view.default(view_1103, [2, 8192, -1, 128]); view_1103 = None + view_1114 = torch.ops.aten.view.default(view_1110, [2, 8192, -1, 128]); view_1110 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_1112, torch.float32); view_1112 = None + view_1115 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 4, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_1115); view_1115 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_1113, torch.float32); view_1113 = None + view_1116 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 1, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_1116); view_1116 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_37); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_1118 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 4, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_37); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_1119 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 1, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_1118, torch.bfloat16); view_1118 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_1119, torch.bfloat16); view_1119 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 1, 4, 128]); unsqueeze_30 = None + view_1120 = torch.ops.aten.view.default(expand_30, [2, 8192, 4, 128]); expand_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_1114, 3); view_1114 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 1, 4, 128]); unsqueeze_31 = None + view_1121 = torch.ops.aten.view.default(expand_31, [2, 8192, 4, 128]); expand_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_1120, [0, 2, 1, 3]); view_1120 = None + permute_170 = torch.ops.aten.permute.default(view_1121, [0, 2, 1, 3]); view_1121 = None + _scaled_dot_product_cudnn_attention_backward_16 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_885, permute_168, permute_169, permute_170, getitem_695, getitem_696, getitem_701, getitem_702, None, None, None, 8192, 8192, 0.0, True); permute_885 = permute_168 = permute_169 = permute_170 = getitem_695 = getitem_696 = getitem_701 = getitem_702 = None + getitem_1984 = _scaled_dot_product_cudnn_attention_backward_16[0] + getitem_1985 = _scaled_dot_product_cudnn_attention_backward_16[1] + getitem_1986 = _scaled_dot_product_cudnn_attention_backward_16[2]; _scaled_dot_product_cudnn_attention_backward_16 = None + permute_886 = torch.ops.aten.permute.default(getitem_1986, [0, 2, 1, 3]); getitem_1986 = None + permute_887 = torch.ops.aten.permute.default(getitem_1985, [0, 2, 1, 3]); getitem_1985 = None + permute_888 = torch.ops.aten.permute.default(getitem_1984, [0, 2, 1, 3]); getitem_1984 = None + view_2716 = torch.ops.aten.view.default(permute_886, [2, 8192, 1, 4, 128]); permute_886 = None + sum_101 = torch.ops.aten.sum.dim_IntList(view_2716, [3], True); view_2716 = None + squeeze_32 = torch.ops.aten.squeeze.dim(sum_101, 3); sum_101 = None + view_2717 = torch.ops.aten.view.default(permute_887, [2, 8192, 1, 4, 128]); permute_887 = None + sum_102 = torch.ops.aten.sum.dim_IntList(view_2717, [3], True); view_2717 = None + squeeze_33 = torch.ops.aten.squeeze.dim(sum_102, 3); sum_102 = None + convert_element_type_1967 = torch.ops.prims.convert_element_type.default(squeeze_33, torch.float32); squeeze_33 = None + convert_element_type_1968 = torch.ops.prims.convert_element_type.default(permute_888, torch.float32); permute_888 = None + view_2718 = torch.ops.aten.view.default(convert_element_type_1967, [2, 8192, 1, 64, 2]); convert_element_type_1967 = None + view_as_complex_96 = torch.ops.aten.view_as_complex.default(view_2718); view_2718 = None + mul_596 = torch.ops.aten.mul.Tensor(view_as_complex_96, _conj); view_as_complex_96 = None + view_2719 = torch.ops.aten.view.default(convert_element_type_1968, [2, 8192, 4, 64, 2]); convert_element_type_1968 = None + view_as_complex_97 = torch.ops.aten.view_as_complex.default(view_2719); view_2719 = None + mul_597 = torch.ops.aten.mul.Tensor(view_as_complex_97, _conj); view_as_complex_97 = None + view_as_real_96 = torch.ops.aten.view_as_real.default(mul_596); mul_596 = None + view_2720 = torch.ops.aten.view.default(view_as_real_96, [2, 8192, 1, 128]); view_as_real_96 = None + convert_element_type_1969 = torch.ops.prims.convert_element_type.default(view_2720, torch.bfloat16); view_2720 = None + view_as_real_97 = torch.ops.aten.view_as_real.default(mul_597); mul_597 = None + view_2721 = torch.ops.aten.view.default(view_as_real_97, [2, 8192, 4, 128]); view_as_real_97 = None + convert_element_type_1970 = torch.ops.prims.convert_element_type.default(view_2721, torch.bfloat16); view_2721 = None + view_2722 = torch.ops.aten.view.default(squeeze_32, [2, 8192, 128]); squeeze_32 = None + view_2723 = torch.ops.aten.view.default(convert_element_type_1969, [2, 8192, 128]); convert_element_type_1969 = None + view_2724 = torch.ops.aten.view.default(convert_element_type_1970, [2, 8192, 512]); convert_element_type_1970 = None + view_2725 = torch.ops.aten.view.default(view_2722, [16384, 128]); view_2722 = None + permute_889 = torch.ops.aten.permute.default(view_2725, [1, 0]) + mm_459 = torch.ops.aten.mm.default(permute_889, view_1095); permute_889 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16); primals_142 = None + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 32, '0'); convert_element_type_505 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + permute_891 = torch.ops.aten.permute.default(permute_167, [1, 0]); permute_167 = None + mm_460 = torch.ops.aten.mm.default(view_2725, permute_891); view_2725 = permute_891 = None + view_2726 = torch.ops.aten.view.default(mm_460, [2, 8192, 4096]); mm_460 = None + convert_element_type_1975 = torch.ops.prims.convert_element_type.default(mm_459, torch.float32); mm_459 = None + reduce_scatter_tensor_250 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1975, 'avg', 32, '0'); convert_element_type_1975 = None + wait_tensor_674 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_250); reduce_scatter_tensor_250 = None + view_2727 = torch.ops.aten.view.default(view_2723, [16384, 128]); view_2723 = None + permute_893 = torch.ops.aten.permute.default(view_2727, [1, 0]) + mm_461 = torch.ops.aten.mm.default(permute_893, view_1095); permute_893 = None + permute_895 = torch.ops.aten.permute.default(permute_166, [1, 0]); permute_166 = None + mm_462 = torch.ops.aten.mm.default(view_2727, permute_895); view_2727 = permute_895 = None + view_2728 = torch.ops.aten.view.default(mm_462, [2, 8192, 4096]); mm_462 = None + add_245 = torch.ops.aten.add.Tensor(view_2726, view_2728); view_2726 = view_2728 = None + convert_element_type_1980 = torch.ops.prims.convert_element_type.default(mm_461, torch.float32); mm_461 = None + reduce_scatter_tensor_251 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1980, 'avg', 32, '0'); convert_element_type_1980 = None + wait_tensor_675 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_251); reduce_scatter_tensor_251 = None + view_2729 = torch.ops.aten.view.default(view_2724, [16384, 512]); view_2724 = None + permute_897 = torch.ops.aten.permute.default(view_2729, [1, 0]) + mm_463 = torch.ops.aten.mm.default(permute_897, view_1095); permute_897 = view_1095 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16); primals_140 = None + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 32, '0'); convert_element_type_499 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_199, [1, 0]); wait_tensor_199 = None + permute_899 = torch.ops.aten.permute.default(permute_165, [1, 0]); permute_165 = None + mm_464 = torch.ops.aten.mm.default(view_2729, permute_899); view_2729 = permute_899 = None + view_2730 = torch.ops.aten.view.default(mm_464, [2, 8192, 4096]); mm_464 = None + add_246 = torch.ops.aten.add.Tensor(add_245, view_2730); add_245 = view_2730 = None + convert_element_type_1985 = torch.ops.prims.convert_element_type.default(mm_463, torch.float32); mm_463 = None + reduce_scatter_tensor_252 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1985, 'avg', 32, '0'); convert_element_type_1985 = None + wait_tensor_676 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_252); reduce_scatter_tensor_252 = None + split_206 = torch.ops.aten.split.Tensor(add_246, 1024, 1); add_246 = None + getitem_1987 = split_206[0] + getitem_1988 = split_206[1] + getitem_1989 = split_206[2] + getitem_1990 = split_206[3] + getitem_1991 = split_206[4] + getitem_1992 = split_206[5] + getitem_1993 = split_206[6] + getitem_1994 = split_206[7]; split_206 = None + cat_198 = torch.ops.aten.cat.default([getitem_1987, getitem_1988, getitem_1989, getitem_1990, getitem_1991, getitem_1992, getitem_1993, getitem_1994]); getitem_1987 = getitem_1988 = getitem_1989 = getitem_1990 = getitem_1991 = getitem_1992 = getitem_1993 = getitem_1994 = None + reduce_scatter_tensor_253 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_198, 'sum', 8, '1'); cat_198 = None + wait_tensor_677 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_253); reduce_scatter_tensor_253 = None + convert_element_type_1986 = torch.ops.prims.convert_element_type.default(wait_tensor_677, torch.float32); wait_tensor_677 = None + convert_element_type_1988 = torch.ops.prims.convert_element_type.default(wait_tensor_197, torch.float32); wait_tensor_197 = None + mul_598 = torch.ops.aten.mul.Tensor(convert_element_type_1986, convert_element_type_1988); convert_element_type_1988 = None + mul_600 = torch.ops.aten.mul.Tensor(mul_120, mul_598) + sum_103 = torch.ops.aten.sum.dim_IntList(mul_600, [2], True); mul_600 = None + div_34 = torch.ops.aten.div.Tensor(mul_120, 4096) + mul_601 = torch.ops.aten.mul.Tensor(div_34, sum_103); div_34 = sum_103 = None + sub_52 = torch.ops.aten.sub.Tensor(mul_598, mul_601); mul_598 = mul_601 = None + mul_602 = torch.ops.aten.mul.Tensor(sub_52, rsqrt_30); sub_52 = rsqrt_30 = None + mul_603 = torch.ops.aten.mul.Tensor(convert_element_type_1986, mul_120); convert_element_type_1986 = mul_120 = None + sum_104 = torch.ops.aten.sum.dim_IntList(mul_603, [0, 1]); mul_603 = None + convert_element_type_1989 = torch.ops.prims.convert_element_type.default(mul_602, torch.bfloat16); mul_602 = None + convert_element_type_1990 = torch.ops.prims.convert_element_type.default(sum_104, torch.bfloat16); sum_104 = None + all_reduce_34 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1990, 'sum', '1'); convert_element_type_1990 = None + wait_tensor_678 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_34); all_reduce_34 = None + convert_element_type_1991 = torch.ops.prims.convert_element_type.default(wait_tensor_678, torch.float32); wait_tensor_678 = None + reduce_scatter_tensor_254 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1991, 'avg', 32, '0'); convert_element_type_1991 = None + wait_tensor_679 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_254); reduce_scatter_tensor_254 = None + add_247 = torch.ops.aten.add.Tensor(add_244, convert_element_type_1989); add_244 = convert_element_type_1989 = None + all_gather_into_tensor_390 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_247, 8, '1') + wait_tensor_680 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_390); all_gather_into_tensor_390 = None + split_207 = torch.ops.aten.split.Tensor(wait_tensor_680, 2); wait_tensor_680 = None + getitem_1995 = split_207[0] + getitem_1996 = split_207[1] + getitem_1997 = split_207[2] + getitem_1998 = split_207[3] + getitem_1999 = split_207[4] + getitem_2000 = split_207[5] + getitem_2001 = split_207[6] + getitem_2002 = split_207[7]; split_207 = None + cat_199 = torch.ops.aten.cat.default([getitem_1995, getitem_1996, getitem_1997, getitem_1998, getitem_1999, getitem_2000, getitem_2001, getitem_2002], 1); getitem_1995 = getitem_1996 = getitem_1997 = getitem_1998 = getitem_1999 = getitem_2000 = getitem_2001 = getitem_2002 = None + view_2731 = torch.ops.aten.view.default(cat_199, [16384, 4096]); cat_199 = None + permute_901 = torch.ops.aten.permute.default(view_2731, [1, 0]) + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29); reduce_scatter_tensor_29 = None + add_57 = torch.ops.aten.add.Tensor(add_55, wait_tensor_190); wait_tensor_190 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16); primals_135 = None + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 32, '0'); convert_element_type_482 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32); add_57 = None + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_191) + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_484, 8, '1'); convert_element_type_484 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + split_67 = torch.ops.aten.split.Tensor(wait_tensor_192, 2); wait_tensor_192 = None + getitem_671 = split_67[0] + getitem_672 = split_67[1] + getitem_673 = split_67[2] + getitem_674 = split_67[3] + getitem_675 = split_67[4] + getitem_676 = split_67[5] + getitem_677 = split_67[6] + getitem_678 = split_67[7]; split_67 = None + cat_59 = torch.ops.aten.cat.default([getitem_671, getitem_672, getitem_673, getitem_674, getitem_675, getitem_676, getitem_677, getitem_678], 1); getitem_671 = getitem_672 = getitem_673 = getitem_674 = getitem_675 = getitem_676 = getitem_677 = getitem_678 = None + view_1068 = torch.ops.aten.view.default(cat_59, [16384, 4096]); cat_59 = None + view_1069 = torch.ops.aten.view.default(mm_102, [2, 8192, 1792]); mm_102 = None + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_1069, torch.float32); view_1069 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16); primals_137 = None + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 32, '0'); convert_element_type_490 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + mm_103 = torch.ops.aten.mm.default(view_1068, permute_163) + view_1076 = torch.ops.aten.view.default(mm_103, [2, 8192, 1792]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_1076) + view_1083 = torch.ops.aten.view.default(mul_119, [16384, 1792]); mul_119 = None + mm_465 = torch.ops.aten.mm.default(permute_901, view_1083); permute_901 = view_1083 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16); primals_138 = None + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 32, '0'); convert_element_type_493 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_195, [1, 0]); wait_tensor_195 = None + permute_903 = torch.ops.aten.permute.default(permute_164, [1, 0]); permute_164 = None + mm_466 = torch.ops.aten.mm.default(view_2731, permute_903); view_2731 = permute_903 = None + view_2732 = torch.ops.aten.view.default(mm_466, [2, 8192, 1792]); mm_466 = None + convert_element_type_1996 = torch.ops.prims.convert_element_type.default(mm_465, torch.float32); mm_465 = None + reduce_scatter_tensor_255 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1996, 'avg', 32, '0'); convert_element_type_1996 = None + wait_tensor_681 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_255); reduce_scatter_tensor_255 = None + mul_604 = torch.ops.aten.mul.Tensor(view_2732, convert_element_type_489); convert_element_type_489 = None + mul_605 = torch.ops.aten.mul.Tensor(view_2732, view_1076); view_2732 = view_1076 = None + view_2733 = torch.ops.aten.view.default(mul_604, [16384, 1792]); mul_604 = None + permute_905 = torch.ops.aten.permute.default(view_2733, [1, 0]) + mm_467 = torch.ops.aten.mm.default(permute_905, view_1068); permute_905 = None + permute_907 = torch.ops.aten.permute.default(permute_163, [1, 0]); permute_163 = None + mm_468 = torch.ops.aten.mm.default(view_2733, permute_907); view_2733 = permute_907 = None + view_2734 = torch.ops.aten.view.default(mm_468, [2, 8192, 4096]); mm_468 = None + convert_element_type_2001 = torch.ops.prims.convert_element_type.default(mm_467, torch.float32); mm_467 = None + reduce_scatter_tensor_256 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2001, 'avg', 32, '0'); convert_element_type_2001 = None + wait_tensor_682 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_256); reduce_scatter_tensor_256 = None + convert_element_type_2002 = torch.ops.prims.convert_element_type.default(mul_605, torch.float32); mul_605 = None + neg_17 = torch.ops.aten.neg.default(convert_element_type_488) + exp_17 = torch.ops.aten.exp.default(neg_17); neg_17 = None + add_248 = torch.ops.aten.add.Tensor(exp_17, 1); exp_17 = None + reciprocal_17 = torch.ops.aten.reciprocal.default(add_248); add_248 = None + mul_606 = torch.ops.aten.mul.Tensor(reciprocal_17, 1); reciprocal_17 = None + mul_607 = torch.ops.aten.mul.Tensor(convert_element_type_2002, mul_606); convert_element_type_2002 = None + sub_53 = torch.ops.aten.sub.Tensor(1, mul_606); mul_606 = None + mul_608 = torch.ops.aten.mul.Tensor(convert_element_type_488, sub_53); convert_element_type_488 = sub_53 = None + add_249 = torch.ops.aten.add.Tensor(mul_608, 1); mul_608 = None + mul_609 = torch.ops.aten.mul.Tensor(mul_607, add_249); mul_607 = add_249 = None + convert_element_type_2004 = torch.ops.prims.convert_element_type.default(mul_609, torch.bfloat16); mul_609 = None + view_2735 = torch.ops.aten.view.default(convert_element_type_2004, [16384, 1792]); convert_element_type_2004 = None + permute_909 = torch.ops.aten.permute.default(view_2735, [1, 0]) + mm_469 = torch.ops.aten.mm.default(permute_909, view_1068); permute_909 = view_1068 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16); primals_136 = None + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 32, '0'); convert_element_type_485 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + permute_911 = torch.ops.aten.permute.default(permute_162, [1, 0]); permute_162 = None + mm_470 = torch.ops.aten.mm.default(view_2735, permute_911); view_2735 = permute_911 = None + view_2736 = torch.ops.aten.view.default(mm_470, [2, 8192, 4096]); mm_470 = None + add_250 = torch.ops.aten.add.Tensor(view_2734, view_2736); view_2734 = view_2736 = None + convert_element_type_2009 = torch.ops.prims.convert_element_type.default(mm_469, torch.float32); mm_469 = None + reduce_scatter_tensor_257 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2009, 'avg', 32, '0'); convert_element_type_2009 = None + wait_tensor_683 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_257); reduce_scatter_tensor_257 = None + split_208 = torch.ops.aten.split.Tensor(add_250, 1024, 1); add_250 = None + getitem_2003 = split_208[0] + getitem_2004 = split_208[1] + getitem_2005 = split_208[2] + getitem_2006 = split_208[3] + getitem_2007 = split_208[4] + getitem_2008 = split_208[5] + getitem_2009 = split_208[6] + getitem_2010 = split_208[7]; split_208 = None + cat_200 = torch.ops.aten.cat.default([getitem_2003, getitem_2004, getitem_2005, getitem_2006, getitem_2007, getitem_2008, getitem_2009, getitem_2010]); getitem_2003 = getitem_2004 = getitem_2005 = getitem_2006 = getitem_2007 = getitem_2008 = getitem_2009 = getitem_2010 = None + reduce_scatter_tensor_258 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_200, 'sum', 8, '1'); cat_200 = None + wait_tensor_684 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_258); reduce_scatter_tensor_258 = None + convert_element_type_2010 = torch.ops.prims.convert_element_type.default(wait_tensor_684, torch.float32); wait_tensor_684 = None + convert_element_type_2012 = torch.ops.prims.convert_element_type.default(wait_tensor_191, torch.float32); wait_tensor_191 = None + mul_610 = torch.ops.aten.mul.Tensor(convert_element_type_2010, convert_element_type_2012); convert_element_type_2012 = None + mul_612 = torch.ops.aten.mul.Tensor(mul_116, mul_610) + sum_105 = torch.ops.aten.sum.dim_IntList(mul_612, [2], True); mul_612 = None + div_35 = torch.ops.aten.div.Tensor(mul_116, 4096) + mul_613 = torch.ops.aten.mul.Tensor(div_35, sum_105); div_35 = sum_105 = None + sub_54 = torch.ops.aten.sub.Tensor(mul_610, mul_613); mul_610 = mul_613 = None + mul_614 = torch.ops.aten.mul.Tensor(sub_54, rsqrt_29); sub_54 = rsqrt_29 = None + mul_615 = torch.ops.aten.mul.Tensor(convert_element_type_2010, mul_116); convert_element_type_2010 = mul_116 = None + sum_106 = torch.ops.aten.sum.dim_IntList(mul_615, [0, 1]); mul_615 = None + convert_element_type_2013 = torch.ops.prims.convert_element_type.default(mul_614, torch.bfloat16); mul_614 = None + convert_element_type_2014 = torch.ops.prims.convert_element_type.default(sum_106, torch.bfloat16); sum_106 = None + all_reduce_35 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2014, 'sum', '1'); convert_element_type_2014 = None + wait_tensor_685 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_35); all_reduce_35 = None + convert_element_type_2015 = torch.ops.prims.convert_element_type.default(wait_tensor_685, torch.float32); wait_tensor_685 = None + reduce_scatter_tensor_259 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2015, 'avg', 32, '0'); convert_element_type_2015 = None + wait_tensor_686 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_259); reduce_scatter_tensor_259 = None + add_251 = torch.ops.aten.add.Tensor(add_247, convert_element_type_2013); add_247 = convert_element_type_2013 = None + all_gather_into_tensor_391 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_251, 8, '1') + wait_tensor_687 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_391); all_gather_into_tensor_391 = None + split_209 = torch.ops.aten.split.Tensor(wait_tensor_687, 2); wait_tensor_687 = None + getitem_2011 = split_209[0] + getitem_2012 = split_209[1] + getitem_2013 = split_209[2] + getitem_2014 = split_209[3] + getitem_2015 = split_209[4] + getitem_2016 = split_209[5] + getitem_2017 = split_209[6] + getitem_2018 = split_209[7]; split_209 = None + cat_201 = torch.ops.aten.cat.default([getitem_2011, getitem_2012, getitem_2013, getitem_2014, getitem_2015, getitem_2016, getitem_2017, getitem_2018], 1); getitem_2011 = getitem_2012 = getitem_2013 = getitem_2014 = getitem_2015 = getitem_2016 = getitem_2017 = getitem_2018 = None + view_2737 = torch.ops.aten.view.default(cat_201, [16384, 4096]); cat_201 = None + permute_913 = torch.ops.aten.permute.default(view_2737, [1, 0]) + permute_160 = torch.ops.aten.permute.default(getitem_654, [0, 2, 1, 3]) + view_1050 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + view_1056 = torch.ops.aten.view.default(view_1050, [16384, 512]); view_1050 = None + mm_471 = torch.ops.aten.mm.default(permute_913, view_1056); permute_913 = view_1056 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16); primals_134 = None + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 32, '0'); convert_element_type_479 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + permute_915 = torch.ops.aten.permute.default(permute_161, [1, 0]); permute_161 = None + mm_472 = torch.ops.aten.mm.default(view_2737, permute_915); view_2737 = permute_915 = None + view_2738 = torch.ops.aten.view.default(mm_472, [2, 8192, 512]); mm_472 = None + convert_element_type_2020 = torch.ops.prims.convert_element_type.default(mm_471, torch.float32); mm_471 = None + reduce_scatter_tensor_260 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2020, 'avg', 32, '0'); convert_element_type_2020 = None + wait_tensor_688 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_260); reduce_scatter_tensor_260 = None + view_2739 = torch.ops.aten.view.default(view_2738, [2, 8192, 4, 128]); view_2738 = None + permute_917 = torch.ops.aten.permute.default(view_2739, [0, 2, 1, 3]); view_2739 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16); primals_130 = None + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 32, '0'); convert_element_type_463 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32); add_55 = None + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_184) + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_465, 8, '1'); convert_element_type_465 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + split_65 = torch.ops.aten.split.Tensor(wait_tensor_185, 2); wait_tensor_185 = None + getitem_646 = split_65[0] + getitem_647 = split_65[1] + getitem_648 = split_65[2] + getitem_649 = split_65[3] + getitem_650 = split_65[4] + getitem_651 = split_65[5] + getitem_652 = split_65[6] + getitem_653 = split_65[7]; split_65 = None + cat_57 = torch.ops.aten.cat.default([getitem_646, getitem_647, getitem_648, getitem_649, getitem_650, getitem_651, getitem_652, getitem_653], 1); getitem_646 = getitem_647 = getitem_648 = getitem_649 = getitem_650 = getitem_651 = getitem_652 = getitem_653 = None + view_1023 = torch.ops.aten.view.default(cat_57, [16384, 4096]); cat_57 = None + view_1024 = torch.ops.aten.view.default(mm_98, [2, 8192, 512]); mm_98 = None + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16); primals_132 = None + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 32, '0'); convert_element_type_469 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + mm_99 = torch.ops.aten.mm.default(view_1023, permute_155) + view_1031 = torch.ops.aten.view.default(mm_99, [2, 8192, 128]); mm_99 = None + view_1038 = torch.ops.aten.view.default(mm_100, [2, 8192, 128]); mm_100 = None + view_1040 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1041 = torch.ops.aten.view.default(view_1031, [2, 8192, -1, 128]); view_1031 = None + view_1042 = torch.ops.aten.view.default(view_1038, [2, 8192, -1, 128]); view_1038 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_1040, torch.float32); view_1040 = None + view_1043 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 4, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_1043); view_1043 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_1041, torch.float32); view_1041 = None + view_1044 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 1, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_1044); view_1044 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_37); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_1046 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 4, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_37); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_1047 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 1, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_1046, torch.bfloat16); view_1046 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_1047, torch.bfloat16); view_1047 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 1, 4, 128]); unsqueeze_28 = None + view_1048 = torch.ops.aten.view.default(expand_28, [2, 8192, 4, 128]); expand_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_1042, 3); view_1042 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 1, 4, 128]); unsqueeze_29 = None + view_1049 = torch.ops.aten.view.default(expand_29, [2, 8192, 4, 128]); expand_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_1048, [0, 2, 1, 3]); view_1048 = None + permute_159 = torch.ops.aten.permute.default(view_1049, [0, 2, 1, 3]); view_1049 = None + _scaled_dot_product_cudnn_attention_backward_17 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_917, permute_157, permute_158, permute_159, getitem_654, getitem_655, getitem_660, getitem_661, None, None, None, 8192, 8192, 0.0, True); permute_917 = permute_157 = permute_158 = permute_159 = getitem_654 = getitem_655 = getitem_660 = getitem_661 = None + getitem_2019 = _scaled_dot_product_cudnn_attention_backward_17[0] + getitem_2020 = _scaled_dot_product_cudnn_attention_backward_17[1] + getitem_2021 = _scaled_dot_product_cudnn_attention_backward_17[2]; _scaled_dot_product_cudnn_attention_backward_17 = None + permute_918 = torch.ops.aten.permute.default(getitem_2021, [0, 2, 1, 3]); getitem_2021 = None + permute_919 = torch.ops.aten.permute.default(getitem_2020, [0, 2, 1, 3]); getitem_2020 = None + permute_920 = torch.ops.aten.permute.default(getitem_2019, [0, 2, 1, 3]); getitem_2019 = None + view_2740 = torch.ops.aten.view.default(permute_918, [2, 8192, 1, 4, 128]); permute_918 = None + sum_107 = torch.ops.aten.sum.dim_IntList(view_2740, [3], True); view_2740 = None + squeeze_34 = torch.ops.aten.squeeze.dim(sum_107, 3); sum_107 = None + view_2741 = torch.ops.aten.view.default(permute_919, [2, 8192, 1, 4, 128]); permute_919 = None + sum_108 = torch.ops.aten.sum.dim_IntList(view_2741, [3], True); view_2741 = None + squeeze_35 = torch.ops.aten.squeeze.dim(sum_108, 3); sum_108 = None + convert_element_type_2021 = torch.ops.prims.convert_element_type.default(squeeze_35, torch.float32); squeeze_35 = None + convert_element_type_2022 = torch.ops.prims.convert_element_type.default(permute_920, torch.float32); permute_920 = None + view_2742 = torch.ops.aten.view.default(convert_element_type_2021, [2, 8192, 1, 64, 2]); convert_element_type_2021 = None + view_as_complex_98 = torch.ops.aten.view_as_complex.default(view_2742); view_2742 = None + mul_616 = torch.ops.aten.mul.Tensor(view_as_complex_98, _conj); view_as_complex_98 = None + view_2743 = torch.ops.aten.view.default(convert_element_type_2022, [2, 8192, 4, 64, 2]); convert_element_type_2022 = None + view_as_complex_99 = torch.ops.aten.view_as_complex.default(view_2743); view_2743 = None + mul_617 = torch.ops.aten.mul.Tensor(view_as_complex_99, _conj); view_as_complex_99 = None + view_as_real_98 = torch.ops.aten.view_as_real.default(mul_616); mul_616 = None + view_2744 = torch.ops.aten.view.default(view_as_real_98, [2, 8192, 1, 128]); view_as_real_98 = None + convert_element_type_2023 = torch.ops.prims.convert_element_type.default(view_2744, torch.bfloat16); view_2744 = None + view_as_real_99 = torch.ops.aten.view_as_real.default(mul_617); mul_617 = None + view_2745 = torch.ops.aten.view.default(view_as_real_99, [2, 8192, 4, 128]); view_as_real_99 = None + convert_element_type_2024 = torch.ops.prims.convert_element_type.default(view_2745, torch.bfloat16); view_2745 = None + view_2746 = torch.ops.aten.view.default(squeeze_34, [2, 8192, 128]); squeeze_34 = None + view_2747 = torch.ops.aten.view.default(convert_element_type_2023, [2, 8192, 128]); convert_element_type_2023 = None + view_2748 = torch.ops.aten.view.default(convert_element_type_2024, [2, 8192, 512]); convert_element_type_2024 = None + view_2749 = torch.ops.aten.view.default(view_2746, [16384, 128]); view_2746 = None + permute_921 = torch.ops.aten.permute.default(view_2749, [1, 0]) + mm_473 = torch.ops.aten.mm.default(permute_921, view_1023); permute_921 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16); primals_133 = None + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 32, '0'); convert_element_type_472 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + permute_923 = torch.ops.aten.permute.default(permute_156, [1, 0]); permute_156 = None + mm_474 = torch.ops.aten.mm.default(view_2749, permute_923); view_2749 = permute_923 = None + view_2750 = torch.ops.aten.view.default(mm_474, [2, 8192, 4096]); mm_474 = None + convert_element_type_2029 = torch.ops.prims.convert_element_type.default(mm_473, torch.float32); mm_473 = None + reduce_scatter_tensor_261 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2029, 'avg', 32, '0'); convert_element_type_2029 = None + wait_tensor_689 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_261); reduce_scatter_tensor_261 = None + view_2751 = torch.ops.aten.view.default(view_2747, [16384, 128]); view_2747 = None + permute_925 = torch.ops.aten.permute.default(view_2751, [1, 0]) + mm_475 = torch.ops.aten.mm.default(permute_925, view_1023); permute_925 = None + permute_927 = torch.ops.aten.permute.default(permute_155, [1, 0]); permute_155 = None + mm_476 = torch.ops.aten.mm.default(view_2751, permute_927); view_2751 = permute_927 = None + view_2752 = torch.ops.aten.view.default(mm_476, [2, 8192, 4096]); mm_476 = None + add_252 = torch.ops.aten.add.Tensor(view_2750, view_2752); view_2750 = view_2752 = None + convert_element_type_2034 = torch.ops.prims.convert_element_type.default(mm_475, torch.float32); mm_475 = None + reduce_scatter_tensor_262 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2034, 'avg', 32, '0'); convert_element_type_2034 = None + wait_tensor_690 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_262); reduce_scatter_tensor_262 = None + view_2753 = torch.ops.aten.view.default(view_2748, [16384, 512]); view_2748 = None + permute_929 = torch.ops.aten.permute.default(view_2753, [1, 0]) + mm_477 = torch.ops.aten.mm.default(permute_929, view_1023); permute_929 = view_1023 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16); primals_131 = None + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 32, '0'); convert_element_type_466 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_186, [1, 0]); wait_tensor_186 = None + permute_931 = torch.ops.aten.permute.default(permute_154, [1, 0]); permute_154 = None + mm_478 = torch.ops.aten.mm.default(view_2753, permute_931); view_2753 = permute_931 = None + view_2754 = torch.ops.aten.view.default(mm_478, [2, 8192, 4096]); mm_478 = None + add_253 = torch.ops.aten.add.Tensor(add_252, view_2754); add_252 = view_2754 = None + convert_element_type_2039 = torch.ops.prims.convert_element_type.default(mm_477, torch.float32); mm_477 = None + reduce_scatter_tensor_263 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2039, 'avg', 32, '0'); convert_element_type_2039 = None + wait_tensor_691 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_263); reduce_scatter_tensor_263 = None + split_210 = torch.ops.aten.split.Tensor(add_253, 1024, 1); add_253 = None + getitem_2022 = split_210[0] + getitem_2023 = split_210[1] + getitem_2024 = split_210[2] + getitem_2025 = split_210[3] + getitem_2026 = split_210[4] + getitem_2027 = split_210[5] + getitem_2028 = split_210[6] + getitem_2029 = split_210[7]; split_210 = None + cat_202 = torch.ops.aten.cat.default([getitem_2022, getitem_2023, getitem_2024, getitem_2025, getitem_2026, getitem_2027, getitem_2028, getitem_2029]); getitem_2022 = getitem_2023 = getitem_2024 = getitem_2025 = getitem_2026 = getitem_2027 = getitem_2028 = getitem_2029 = None + reduce_scatter_tensor_264 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_202, 'sum', 8, '1'); cat_202 = None + wait_tensor_692 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_264); reduce_scatter_tensor_264 = None + convert_element_type_2040 = torch.ops.prims.convert_element_type.default(wait_tensor_692, torch.float32); wait_tensor_692 = None + convert_element_type_2042 = torch.ops.prims.convert_element_type.default(wait_tensor_184, torch.float32); wait_tensor_184 = None + mul_618 = torch.ops.aten.mul.Tensor(convert_element_type_2040, convert_element_type_2042); convert_element_type_2042 = None + mul_620 = torch.ops.aten.mul.Tensor(mul_112, mul_618) + sum_109 = torch.ops.aten.sum.dim_IntList(mul_620, [2], True); mul_620 = None + div_36 = torch.ops.aten.div.Tensor(mul_112, 4096) + mul_621 = torch.ops.aten.mul.Tensor(div_36, sum_109); div_36 = sum_109 = None + sub_55 = torch.ops.aten.sub.Tensor(mul_618, mul_621); mul_618 = mul_621 = None + mul_622 = torch.ops.aten.mul.Tensor(sub_55, rsqrt_28); sub_55 = rsqrt_28 = None + mul_623 = torch.ops.aten.mul.Tensor(convert_element_type_2040, mul_112); convert_element_type_2040 = mul_112 = None + sum_110 = torch.ops.aten.sum.dim_IntList(mul_623, [0, 1]); mul_623 = None + convert_element_type_2043 = torch.ops.prims.convert_element_type.default(mul_622, torch.bfloat16); mul_622 = None + convert_element_type_2044 = torch.ops.prims.convert_element_type.default(sum_110, torch.bfloat16); sum_110 = None + all_reduce_36 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2044, 'sum', '1'); convert_element_type_2044 = None + wait_tensor_693 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_36); all_reduce_36 = None + convert_element_type_2045 = torch.ops.prims.convert_element_type.default(wait_tensor_693, torch.float32); wait_tensor_693 = None + reduce_scatter_tensor_265 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2045, 'avg', 32, '0'); convert_element_type_2045 = None + wait_tensor_694 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_265); reduce_scatter_tensor_265 = None + add_254 = torch.ops.aten.add.Tensor(add_251, convert_element_type_2043); add_251 = convert_element_type_2043 = None + all_gather_into_tensor_392 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_254, 8, '1') + wait_tensor_695 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_392); all_gather_into_tensor_392 = None + split_211 = torch.ops.aten.split.Tensor(wait_tensor_695, 2); wait_tensor_695 = None + getitem_2030 = split_211[0] + getitem_2031 = split_211[1] + getitem_2032 = split_211[2] + getitem_2033 = split_211[3] + getitem_2034 = split_211[4] + getitem_2035 = split_211[5] + getitem_2036 = split_211[6] + getitem_2037 = split_211[7]; split_211 = None + cat_203 = torch.ops.aten.cat.default([getitem_2030, getitem_2031, getitem_2032, getitem_2033, getitem_2034, getitem_2035, getitem_2036, getitem_2037], 1); getitem_2030 = getitem_2031 = getitem_2032 = getitem_2033 = getitem_2034 = getitem_2035 = getitem_2036 = getitem_2037 = None + view_2755 = torch.ops.aten.view.default(cat_203, [16384, 4096]); cat_203 = None + permute_933 = torch.ops.aten.permute.default(view_2755, [1, 0]) + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27); reduce_scatter_tensor_27 = None + add_53 = torch.ops.aten.add.Tensor(add_51, wait_tensor_177); wait_tensor_177 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16); primals_126 = None + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 32, '0'); convert_element_type_449 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32); add_53 = None + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_178) + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_451, 8, '1'); convert_element_type_451 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + split_63 = torch.ops.aten.split.Tensor(wait_tensor_179, 2); wait_tensor_179 = None + getitem_630 = split_63[0] + getitem_631 = split_63[1] + getitem_632 = split_63[2] + getitem_633 = split_63[3] + getitem_634 = split_63[4] + getitem_635 = split_63[5] + getitem_636 = split_63[6] + getitem_637 = split_63[7]; split_63 = None + cat_55 = torch.ops.aten.cat.default([getitem_630, getitem_631, getitem_632, getitem_633, getitem_634, getitem_635, getitem_636, getitem_637], 1); getitem_630 = getitem_631 = getitem_632 = getitem_633 = getitem_634 = getitem_635 = getitem_636 = getitem_637 = None + view_996 = torch.ops.aten.view.default(cat_55, [16384, 4096]); cat_55 = None + view_997 = torch.ops.aten.view.default(mm_95, [2, 8192, 1792]); mm_95 = None + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16); primals_128 = None + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 32, '0'); convert_element_type_457 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_181, [1, 0]); wait_tensor_181 = None + mm_96 = torch.ops.aten.mm.default(view_996, permute_152) + view_1004 = torch.ops.aten.view.default(mm_96, [2, 8192, 1792]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_1004) + view_1011 = torch.ops.aten.view.default(mul_111, [16384, 1792]); mul_111 = None + mm_479 = torch.ops.aten.mm.default(permute_933, view_1011); permute_933 = view_1011 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16); primals_129 = None + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 32, '0'); convert_element_type_460 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + permute_935 = torch.ops.aten.permute.default(permute_153, [1, 0]); permute_153 = None + mm_480 = torch.ops.aten.mm.default(view_2755, permute_935); view_2755 = permute_935 = None + view_2756 = torch.ops.aten.view.default(mm_480, [2, 8192, 1792]); mm_480 = None + convert_element_type_2050 = torch.ops.prims.convert_element_type.default(mm_479, torch.float32); mm_479 = None + reduce_scatter_tensor_266 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2050, 'avg', 32, '0'); convert_element_type_2050 = None + wait_tensor_696 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_266); reduce_scatter_tensor_266 = None + mul_624 = torch.ops.aten.mul.Tensor(view_2756, convert_element_type_456); convert_element_type_456 = None + mul_625 = torch.ops.aten.mul.Tensor(view_2756, view_1004); view_2756 = view_1004 = None + view_2757 = torch.ops.aten.view.default(mul_624, [16384, 1792]); mul_624 = None + permute_937 = torch.ops.aten.permute.default(view_2757, [1, 0]) + mm_481 = torch.ops.aten.mm.default(permute_937, view_996); permute_937 = None + permute_939 = torch.ops.aten.permute.default(permute_152, [1, 0]); permute_152 = None + mm_482 = torch.ops.aten.mm.default(view_2757, permute_939); view_2757 = permute_939 = None + view_2758 = torch.ops.aten.view.default(mm_482, [2, 8192, 4096]); mm_482 = None + convert_element_type_2055 = torch.ops.prims.convert_element_type.default(mm_481, torch.float32); mm_481 = None + reduce_scatter_tensor_267 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2055, 'avg', 32, '0'); convert_element_type_2055 = None + wait_tensor_697 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_267); reduce_scatter_tensor_267 = None + convert_element_type_2056 = torch.ops.prims.convert_element_type.default(mul_625, torch.float32); mul_625 = None + neg_18 = torch.ops.aten.neg.default(convert_element_type_455) + exp_18 = torch.ops.aten.exp.default(neg_18); neg_18 = None + add_255 = torch.ops.aten.add.Tensor(exp_18, 1); exp_18 = None + reciprocal_18 = torch.ops.aten.reciprocal.default(add_255); add_255 = None + mul_626 = torch.ops.aten.mul.Tensor(reciprocal_18, 1); reciprocal_18 = None + mul_627 = torch.ops.aten.mul.Tensor(convert_element_type_2056, mul_626); convert_element_type_2056 = None + sub_56 = torch.ops.aten.sub.Tensor(1, mul_626); mul_626 = None + mul_628 = torch.ops.aten.mul.Tensor(convert_element_type_455, sub_56); convert_element_type_455 = sub_56 = None + add_256 = torch.ops.aten.add.Tensor(mul_628, 1); mul_628 = None + mul_629 = torch.ops.aten.mul.Tensor(mul_627, add_256); mul_627 = add_256 = None + convert_element_type_2058 = torch.ops.prims.convert_element_type.default(mul_629, torch.bfloat16); mul_629 = None + view_2759 = torch.ops.aten.view.default(convert_element_type_2058, [16384, 1792]); convert_element_type_2058 = None + permute_941 = torch.ops.aten.permute.default(view_2759, [1, 0]) + mm_483 = torch.ops.aten.mm.default(permute_941, view_996); permute_941 = view_996 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16); primals_127 = None + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 32, '0'); convert_element_type_452 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + permute_943 = torch.ops.aten.permute.default(permute_151, [1, 0]); permute_151 = None + mm_484 = torch.ops.aten.mm.default(view_2759, permute_943); view_2759 = permute_943 = None + view_2760 = torch.ops.aten.view.default(mm_484, [2, 8192, 4096]); mm_484 = None + add_257 = torch.ops.aten.add.Tensor(view_2758, view_2760); view_2758 = view_2760 = None + convert_element_type_2063 = torch.ops.prims.convert_element_type.default(mm_483, torch.float32); mm_483 = None + reduce_scatter_tensor_268 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2063, 'avg', 32, '0'); convert_element_type_2063 = None + wait_tensor_698 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_268); reduce_scatter_tensor_268 = None + split_212 = torch.ops.aten.split.Tensor(add_257, 1024, 1); add_257 = None + getitem_2038 = split_212[0] + getitem_2039 = split_212[1] + getitem_2040 = split_212[2] + getitem_2041 = split_212[3] + getitem_2042 = split_212[4] + getitem_2043 = split_212[5] + getitem_2044 = split_212[6] + getitem_2045 = split_212[7]; split_212 = None + cat_204 = torch.ops.aten.cat.default([getitem_2038, getitem_2039, getitem_2040, getitem_2041, getitem_2042, getitem_2043, getitem_2044, getitem_2045]); getitem_2038 = getitem_2039 = getitem_2040 = getitem_2041 = getitem_2042 = getitem_2043 = getitem_2044 = getitem_2045 = None + reduce_scatter_tensor_269 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_204, 'sum', 8, '1'); cat_204 = None + wait_tensor_699 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_269); reduce_scatter_tensor_269 = None + convert_element_type_2064 = torch.ops.prims.convert_element_type.default(wait_tensor_699, torch.float32); wait_tensor_699 = None + convert_element_type_2066 = torch.ops.prims.convert_element_type.default(wait_tensor_178, torch.float32); wait_tensor_178 = None + mul_630 = torch.ops.aten.mul.Tensor(convert_element_type_2064, convert_element_type_2066); convert_element_type_2066 = None + mul_632 = torch.ops.aten.mul.Tensor(mul_108, mul_630) + sum_111 = torch.ops.aten.sum.dim_IntList(mul_632, [2], True); mul_632 = None + div_37 = torch.ops.aten.div.Tensor(mul_108, 4096) + mul_633 = torch.ops.aten.mul.Tensor(div_37, sum_111); div_37 = sum_111 = None + sub_57 = torch.ops.aten.sub.Tensor(mul_630, mul_633); mul_630 = mul_633 = None + mul_634 = torch.ops.aten.mul.Tensor(sub_57, rsqrt_27); sub_57 = rsqrt_27 = None + mul_635 = torch.ops.aten.mul.Tensor(convert_element_type_2064, mul_108); convert_element_type_2064 = mul_108 = None + sum_112 = torch.ops.aten.sum.dim_IntList(mul_635, [0, 1]); mul_635 = None + convert_element_type_2067 = torch.ops.prims.convert_element_type.default(mul_634, torch.bfloat16); mul_634 = None + convert_element_type_2068 = torch.ops.prims.convert_element_type.default(sum_112, torch.bfloat16); sum_112 = None + all_reduce_37 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2068, 'sum', '1'); convert_element_type_2068 = None + wait_tensor_700 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_37); all_reduce_37 = None + convert_element_type_2069 = torch.ops.prims.convert_element_type.default(wait_tensor_700, torch.float32); wait_tensor_700 = None + reduce_scatter_tensor_270 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2069, 'avg', 32, '0'); convert_element_type_2069 = None + wait_tensor_701 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_270); reduce_scatter_tensor_270 = None + add_258 = torch.ops.aten.add.Tensor(add_254, convert_element_type_2067); add_254 = convert_element_type_2067 = None + all_gather_into_tensor_393 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_258, 8, '1') + wait_tensor_702 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_393); all_gather_into_tensor_393 = None + split_213 = torch.ops.aten.split.Tensor(wait_tensor_702, 2); wait_tensor_702 = None + getitem_2046 = split_213[0] + getitem_2047 = split_213[1] + getitem_2048 = split_213[2] + getitem_2049 = split_213[3] + getitem_2050 = split_213[4] + getitem_2051 = split_213[5] + getitem_2052 = split_213[6] + getitem_2053 = split_213[7]; split_213 = None + cat_205 = torch.ops.aten.cat.default([getitem_2046, getitem_2047, getitem_2048, getitem_2049, getitem_2050, getitem_2051, getitem_2052, getitem_2053], 1); getitem_2046 = getitem_2047 = getitem_2048 = getitem_2049 = getitem_2050 = getitem_2051 = getitem_2052 = getitem_2053 = None + view_2761 = torch.ops.aten.view.default(cat_205, [16384, 4096]); cat_205 = None + permute_945 = torch.ops.aten.permute.default(view_2761, [1, 0]) + permute_149 = torch.ops.aten.permute.default(getitem_613, [0, 2, 1, 3]) + view_978 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + view_984 = torch.ops.aten.view.default(view_978, [16384, 512]); view_978 = None + mm_485 = torch.ops.aten.mm.default(permute_945, view_984); permute_945 = view_984 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16); primals_125 = None + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 32, '0'); convert_element_type_446 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + permute_947 = torch.ops.aten.permute.default(permute_150, [1, 0]); permute_150 = None + mm_486 = torch.ops.aten.mm.default(view_2761, permute_947); view_2761 = permute_947 = None + view_2762 = torch.ops.aten.view.default(mm_486, [2, 8192, 512]); mm_486 = None + convert_element_type_2074 = torch.ops.prims.convert_element_type.default(mm_485, torch.float32); mm_485 = None + reduce_scatter_tensor_271 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2074, 'avg', 32, '0'); convert_element_type_2074 = None + wait_tensor_703 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_271); reduce_scatter_tensor_271 = None + view_2763 = torch.ops.aten.view.default(view_2762, [2, 8192, 4, 128]); view_2762 = None + permute_949 = torch.ops.aten.permute.default(view_2763, [0, 2, 1, 3]); view_2763 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16); primals_121 = None + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 32, '0'); convert_element_type_430 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32); add_51 = None + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_171) + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_432, 8, '1'); convert_element_type_432 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + split_61 = torch.ops.aten.split.Tensor(wait_tensor_172, 2); wait_tensor_172 = None + getitem_605 = split_61[0] + getitem_606 = split_61[1] + getitem_607 = split_61[2] + getitem_608 = split_61[3] + getitem_609 = split_61[4] + getitem_610 = split_61[5] + getitem_611 = split_61[6] + getitem_612 = split_61[7]; split_61 = None + cat_53 = torch.ops.aten.cat.default([getitem_605, getitem_606, getitem_607, getitem_608, getitem_609, getitem_610, getitem_611, getitem_612], 1); getitem_605 = getitem_606 = getitem_607 = getitem_608 = getitem_609 = getitem_610 = getitem_611 = getitem_612 = None + view_951 = torch.ops.aten.view.default(cat_53, [16384, 4096]); cat_53 = None + view_952 = torch.ops.aten.view.default(mm_91, [2, 8192, 512]); mm_91 = None + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16); primals_123 = None + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 32, '0'); convert_element_type_436 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_92 = torch.ops.aten.mm.default(view_951, permute_144) + view_959 = torch.ops.aten.view.default(mm_92, [2, 8192, 128]); mm_92 = None + view_966 = torch.ops.aten.view.default(mm_93, [2, 8192, 128]); mm_93 = None + view_968 = torch.ops.aten.view.default(view_952, [2, 8192, -1, 128]); view_952 = None + view_969 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_970 = torch.ops.aten.view.default(view_966, [2, 8192, -1, 128]); view_966 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_968, torch.float32); view_968 = None + view_971 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 4, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_971); view_971 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_969, torch.float32); view_969 = None + view_972 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 1, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_972); view_972 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_37); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_974 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 4, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_37); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_975 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 1, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_974, torch.bfloat16); view_974 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_975, torch.bfloat16); view_975 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 1, 4, 128]); unsqueeze_26 = None + view_976 = torch.ops.aten.view.default(expand_26, [2, 8192, 4, 128]); expand_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_970, 3); view_970 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 1, 4, 128]); unsqueeze_27 = None + view_977 = torch.ops.aten.view.default(expand_27, [2, 8192, 4, 128]); expand_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_976, [0, 2, 1, 3]); view_976 = None + permute_148 = torch.ops.aten.permute.default(view_977, [0, 2, 1, 3]); view_977 = None + _scaled_dot_product_cudnn_attention_backward_18 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_949, permute_146, permute_147, permute_148, getitem_613, getitem_614, getitem_619, getitem_620, None, None, None, 8192, 8192, 0.0, True); permute_949 = permute_146 = permute_147 = permute_148 = getitem_613 = getitem_614 = getitem_619 = getitem_620 = None + getitem_2054 = _scaled_dot_product_cudnn_attention_backward_18[0] + getitem_2055 = _scaled_dot_product_cudnn_attention_backward_18[1] + getitem_2056 = _scaled_dot_product_cudnn_attention_backward_18[2]; _scaled_dot_product_cudnn_attention_backward_18 = None + permute_950 = torch.ops.aten.permute.default(getitem_2056, [0, 2, 1, 3]); getitem_2056 = None + permute_951 = torch.ops.aten.permute.default(getitem_2055, [0, 2, 1, 3]); getitem_2055 = None + permute_952 = torch.ops.aten.permute.default(getitem_2054, [0, 2, 1, 3]); getitem_2054 = None + view_2764 = torch.ops.aten.view.default(permute_950, [2, 8192, 1, 4, 128]); permute_950 = None + sum_113 = torch.ops.aten.sum.dim_IntList(view_2764, [3], True); view_2764 = None + squeeze_36 = torch.ops.aten.squeeze.dim(sum_113, 3); sum_113 = None + view_2765 = torch.ops.aten.view.default(permute_951, [2, 8192, 1, 4, 128]); permute_951 = None + sum_114 = torch.ops.aten.sum.dim_IntList(view_2765, [3], True); view_2765 = None + squeeze_37 = torch.ops.aten.squeeze.dim(sum_114, 3); sum_114 = None + convert_element_type_2075 = torch.ops.prims.convert_element_type.default(squeeze_37, torch.float32); squeeze_37 = None + convert_element_type_2076 = torch.ops.prims.convert_element_type.default(permute_952, torch.float32); permute_952 = None + view_2766 = torch.ops.aten.view.default(convert_element_type_2075, [2, 8192, 1, 64, 2]); convert_element_type_2075 = None + view_as_complex_100 = torch.ops.aten.view_as_complex.default(view_2766); view_2766 = None + mul_636 = torch.ops.aten.mul.Tensor(view_as_complex_100, _conj); view_as_complex_100 = None + view_2767 = torch.ops.aten.view.default(convert_element_type_2076, [2, 8192, 4, 64, 2]); convert_element_type_2076 = None + view_as_complex_101 = torch.ops.aten.view_as_complex.default(view_2767); view_2767 = None + mul_637 = torch.ops.aten.mul.Tensor(view_as_complex_101, _conj); view_as_complex_101 = None + view_as_real_100 = torch.ops.aten.view_as_real.default(mul_636); mul_636 = None + view_2768 = torch.ops.aten.view.default(view_as_real_100, [2, 8192, 1, 128]); view_as_real_100 = None + convert_element_type_2077 = torch.ops.prims.convert_element_type.default(view_2768, torch.bfloat16); view_2768 = None + view_as_real_101 = torch.ops.aten.view_as_real.default(mul_637); mul_637 = None + view_2769 = torch.ops.aten.view.default(view_as_real_101, [2, 8192, 4, 128]); view_as_real_101 = None + convert_element_type_2078 = torch.ops.prims.convert_element_type.default(view_2769, torch.bfloat16); view_2769 = None + view_2770 = torch.ops.aten.view.default(squeeze_36, [2, 8192, 128]); squeeze_36 = None + view_2771 = torch.ops.aten.view.default(convert_element_type_2077, [2, 8192, 128]); convert_element_type_2077 = None + view_2772 = torch.ops.aten.view.default(convert_element_type_2078, [2, 8192, 512]); convert_element_type_2078 = None + view_2773 = torch.ops.aten.view.default(view_2770, [16384, 128]); view_2770 = None + permute_953 = torch.ops.aten.permute.default(view_2773, [1, 0]) + mm_487 = torch.ops.aten.mm.default(permute_953, view_951); permute_953 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16); primals_124 = None + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 32, '0'); convert_element_type_439 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + permute_955 = torch.ops.aten.permute.default(permute_145, [1, 0]); permute_145 = None + mm_488 = torch.ops.aten.mm.default(view_2773, permute_955); view_2773 = permute_955 = None + view_2774 = torch.ops.aten.view.default(mm_488, [2, 8192, 4096]); mm_488 = None + convert_element_type_2083 = torch.ops.prims.convert_element_type.default(mm_487, torch.float32); mm_487 = None + reduce_scatter_tensor_272 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2083, 'avg', 32, '0'); convert_element_type_2083 = None + wait_tensor_704 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_272); reduce_scatter_tensor_272 = None + view_2775 = torch.ops.aten.view.default(view_2771, [16384, 128]); view_2771 = None + permute_957 = torch.ops.aten.permute.default(view_2775, [1, 0]) + mm_489 = torch.ops.aten.mm.default(permute_957, view_951); permute_957 = None + permute_959 = torch.ops.aten.permute.default(permute_144, [1, 0]); permute_144 = None + mm_490 = torch.ops.aten.mm.default(view_2775, permute_959); view_2775 = permute_959 = None + view_2776 = torch.ops.aten.view.default(mm_490, [2, 8192, 4096]); mm_490 = None + add_259 = torch.ops.aten.add.Tensor(view_2774, view_2776); view_2774 = view_2776 = None + convert_element_type_2088 = torch.ops.prims.convert_element_type.default(mm_489, torch.float32); mm_489 = None + reduce_scatter_tensor_273 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2088, 'avg', 32, '0'); convert_element_type_2088 = None + wait_tensor_705 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_273); reduce_scatter_tensor_273 = None + view_2777 = torch.ops.aten.view.default(view_2772, [16384, 512]); view_2772 = None + permute_961 = torch.ops.aten.permute.default(view_2777, [1, 0]) + mm_491 = torch.ops.aten.mm.default(permute_961, view_951); permute_961 = view_951 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16); primals_122 = None + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 32, '0'); convert_element_type_433 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + permute_963 = torch.ops.aten.permute.default(permute_143, [1, 0]); permute_143 = None + mm_492 = torch.ops.aten.mm.default(view_2777, permute_963); view_2777 = permute_963 = None + view_2778 = torch.ops.aten.view.default(mm_492, [2, 8192, 4096]); mm_492 = None + add_260 = torch.ops.aten.add.Tensor(add_259, view_2778); add_259 = view_2778 = None + convert_element_type_2093 = torch.ops.prims.convert_element_type.default(mm_491, torch.float32); mm_491 = None + reduce_scatter_tensor_274 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2093, 'avg', 32, '0'); convert_element_type_2093 = None + wait_tensor_706 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_274); reduce_scatter_tensor_274 = None + split_214 = torch.ops.aten.split.Tensor(add_260, 1024, 1); add_260 = None + getitem_2057 = split_214[0] + getitem_2058 = split_214[1] + getitem_2059 = split_214[2] + getitem_2060 = split_214[3] + getitem_2061 = split_214[4] + getitem_2062 = split_214[5] + getitem_2063 = split_214[6] + getitem_2064 = split_214[7]; split_214 = None + cat_206 = torch.ops.aten.cat.default([getitem_2057, getitem_2058, getitem_2059, getitem_2060, getitem_2061, getitem_2062, getitem_2063, getitem_2064]); getitem_2057 = getitem_2058 = getitem_2059 = getitem_2060 = getitem_2061 = getitem_2062 = getitem_2063 = getitem_2064 = None + reduce_scatter_tensor_275 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_206, 'sum', 8, '1'); cat_206 = None + wait_tensor_707 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_275); reduce_scatter_tensor_275 = None + convert_element_type_2094 = torch.ops.prims.convert_element_type.default(wait_tensor_707, torch.float32); wait_tensor_707 = None + convert_element_type_2096 = torch.ops.prims.convert_element_type.default(wait_tensor_171, torch.float32); wait_tensor_171 = None + mul_638 = torch.ops.aten.mul.Tensor(convert_element_type_2094, convert_element_type_2096); convert_element_type_2096 = None + mul_640 = torch.ops.aten.mul.Tensor(mul_104, mul_638) + sum_115 = torch.ops.aten.sum.dim_IntList(mul_640, [2], True); mul_640 = None + div_38 = torch.ops.aten.div.Tensor(mul_104, 4096) + mul_641 = torch.ops.aten.mul.Tensor(div_38, sum_115); div_38 = sum_115 = None + sub_58 = torch.ops.aten.sub.Tensor(mul_638, mul_641); mul_638 = mul_641 = None + mul_642 = torch.ops.aten.mul.Tensor(sub_58, rsqrt_26); sub_58 = rsqrt_26 = None + mul_643 = torch.ops.aten.mul.Tensor(convert_element_type_2094, mul_104); convert_element_type_2094 = mul_104 = None + sum_116 = torch.ops.aten.sum.dim_IntList(mul_643, [0, 1]); mul_643 = None + convert_element_type_2097 = torch.ops.prims.convert_element_type.default(mul_642, torch.bfloat16); mul_642 = None + convert_element_type_2098 = torch.ops.prims.convert_element_type.default(sum_116, torch.bfloat16); sum_116 = None + all_reduce_38 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2098, 'sum', '1'); convert_element_type_2098 = None + wait_tensor_708 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_38); all_reduce_38 = None + convert_element_type_2099 = torch.ops.prims.convert_element_type.default(wait_tensor_708, torch.float32); wait_tensor_708 = None + reduce_scatter_tensor_276 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2099, 'avg', 32, '0'); convert_element_type_2099 = None + wait_tensor_709 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_276); reduce_scatter_tensor_276 = None + add_261 = torch.ops.aten.add.Tensor(add_258, convert_element_type_2097); add_258 = convert_element_type_2097 = None + all_gather_into_tensor_394 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_261, 8, '1') + wait_tensor_710 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_394); all_gather_into_tensor_394 = None + split_215 = torch.ops.aten.split.Tensor(wait_tensor_710, 2); wait_tensor_710 = None + getitem_2065 = split_215[0] + getitem_2066 = split_215[1] + getitem_2067 = split_215[2] + getitem_2068 = split_215[3] + getitem_2069 = split_215[4] + getitem_2070 = split_215[5] + getitem_2071 = split_215[6] + getitem_2072 = split_215[7]; split_215 = None + cat_207 = torch.ops.aten.cat.default([getitem_2065, getitem_2066, getitem_2067, getitem_2068, getitem_2069, getitem_2070, getitem_2071, getitem_2072], 1); getitem_2065 = getitem_2066 = getitem_2067 = getitem_2068 = getitem_2069 = getitem_2070 = getitem_2071 = getitem_2072 = None + view_2779 = torch.ops.aten.view.default(cat_207, [16384, 4096]); cat_207 = None + permute_965 = torch.ops.aten.permute.default(view_2779, [1, 0]) + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25); reduce_scatter_tensor_25 = None + add_49 = torch.ops.aten.add.Tensor(add_47, wait_tensor_164); wait_tensor_164 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16); primals_117 = None + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 32, '0'); convert_element_type_416 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32); add_49 = None + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_165) + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_418, 8, '1'); convert_element_type_418 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + split_59 = torch.ops.aten.split.Tensor(wait_tensor_166, 2); wait_tensor_166 = None + getitem_589 = split_59[0] + getitem_590 = split_59[1] + getitem_591 = split_59[2] + getitem_592 = split_59[3] + getitem_593 = split_59[4] + getitem_594 = split_59[5] + getitem_595 = split_59[6] + getitem_596 = split_59[7]; split_59 = None + cat_51 = torch.ops.aten.cat.default([getitem_589, getitem_590, getitem_591, getitem_592, getitem_593, getitem_594, getitem_595, getitem_596], 1); getitem_589 = getitem_590 = getitem_591 = getitem_592 = getitem_593 = getitem_594 = getitem_595 = getitem_596 = None + view_924 = torch.ops.aten.view.default(cat_51, [16384, 4096]); cat_51 = None + view_925 = torch.ops.aten.view.default(mm_88, [2, 8192, 1792]); mm_88 = None + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_925, torch.float32); view_925 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16); primals_119 = None + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 32, '0'); convert_element_type_424 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_168, [1, 0]); wait_tensor_168 = None + mm_89 = torch.ops.aten.mm.default(view_924, permute_141) + view_932 = torch.ops.aten.view.default(mm_89, [2, 8192, 1792]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_932) + view_939 = torch.ops.aten.view.default(mul_103, [16384, 1792]); mul_103 = None + mm_493 = torch.ops.aten.mm.default(permute_965, view_939); permute_965 = view_939 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16); primals_120 = None + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 32, '0'); convert_element_type_427 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + permute_967 = torch.ops.aten.permute.default(permute_142, [1, 0]); permute_142 = None + mm_494 = torch.ops.aten.mm.default(view_2779, permute_967); view_2779 = permute_967 = None + view_2780 = torch.ops.aten.view.default(mm_494, [2, 8192, 1792]); mm_494 = None + convert_element_type_2104 = torch.ops.prims.convert_element_type.default(mm_493, torch.float32); mm_493 = None + reduce_scatter_tensor_277 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2104, 'avg', 32, '0'); convert_element_type_2104 = None + wait_tensor_711 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_277); reduce_scatter_tensor_277 = None + mul_644 = torch.ops.aten.mul.Tensor(view_2780, convert_element_type_423); convert_element_type_423 = None + mul_645 = torch.ops.aten.mul.Tensor(view_2780, view_932); view_2780 = view_932 = None + view_2781 = torch.ops.aten.view.default(mul_644, [16384, 1792]); mul_644 = None + permute_969 = torch.ops.aten.permute.default(view_2781, [1, 0]) + mm_495 = torch.ops.aten.mm.default(permute_969, view_924); permute_969 = None + permute_971 = torch.ops.aten.permute.default(permute_141, [1, 0]); permute_141 = None + mm_496 = torch.ops.aten.mm.default(view_2781, permute_971); view_2781 = permute_971 = None + view_2782 = torch.ops.aten.view.default(mm_496, [2, 8192, 4096]); mm_496 = None + convert_element_type_2109 = torch.ops.prims.convert_element_type.default(mm_495, torch.float32); mm_495 = None + reduce_scatter_tensor_278 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2109, 'avg', 32, '0'); convert_element_type_2109 = None + wait_tensor_712 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_278); reduce_scatter_tensor_278 = None + convert_element_type_2110 = torch.ops.prims.convert_element_type.default(mul_645, torch.float32); mul_645 = None + neg_19 = torch.ops.aten.neg.default(convert_element_type_422) + exp_19 = torch.ops.aten.exp.default(neg_19); neg_19 = None + add_262 = torch.ops.aten.add.Tensor(exp_19, 1); exp_19 = None + reciprocal_19 = torch.ops.aten.reciprocal.default(add_262); add_262 = None + mul_646 = torch.ops.aten.mul.Tensor(reciprocal_19, 1); reciprocal_19 = None + mul_647 = torch.ops.aten.mul.Tensor(convert_element_type_2110, mul_646); convert_element_type_2110 = None + sub_59 = torch.ops.aten.sub.Tensor(1, mul_646); mul_646 = None + mul_648 = torch.ops.aten.mul.Tensor(convert_element_type_422, sub_59); convert_element_type_422 = sub_59 = None + add_263 = torch.ops.aten.add.Tensor(mul_648, 1); mul_648 = None + mul_649 = torch.ops.aten.mul.Tensor(mul_647, add_263); mul_647 = add_263 = None + convert_element_type_2112 = torch.ops.prims.convert_element_type.default(mul_649, torch.bfloat16); mul_649 = None + view_2783 = torch.ops.aten.view.default(convert_element_type_2112, [16384, 1792]); convert_element_type_2112 = None + permute_973 = torch.ops.aten.permute.default(view_2783, [1, 0]) + mm_497 = torch.ops.aten.mm.default(permute_973, view_924); permute_973 = view_924 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16); primals_118 = None + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 32, '0'); convert_element_type_419 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + permute_975 = torch.ops.aten.permute.default(permute_140, [1, 0]); permute_140 = None + mm_498 = torch.ops.aten.mm.default(view_2783, permute_975); view_2783 = permute_975 = None + view_2784 = torch.ops.aten.view.default(mm_498, [2, 8192, 4096]); mm_498 = None + add_264 = torch.ops.aten.add.Tensor(view_2782, view_2784); view_2782 = view_2784 = None + convert_element_type_2117 = torch.ops.prims.convert_element_type.default(mm_497, torch.float32); mm_497 = None + reduce_scatter_tensor_279 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2117, 'avg', 32, '0'); convert_element_type_2117 = None + wait_tensor_713 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_279); reduce_scatter_tensor_279 = None + split_216 = torch.ops.aten.split.Tensor(add_264, 1024, 1); add_264 = None + getitem_2073 = split_216[0] + getitem_2074 = split_216[1] + getitem_2075 = split_216[2] + getitem_2076 = split_216[3] + getitem_2077 = split_216[4] + getitem_2078 = split_216[5] + getitem_2079 = split_216[6] + getitem_2080 = split_216[7]; split_216 = None + cat_208 = torch.ops.aten.cat.default([getitem_2073, getitem_2074, getitem_2075, getitem_2076, getitem_2077, getitem_2078, getitem_2079, getitem_2080]); getitem_2073 = getitem_2074 = getitem_2075 = getitem_2076 = getitem_2077 = getitem_2078 = getitem_2079 = getitem_2080 = None + reduce_scatter_tensor_280 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_208, 'sum', 8, '1'); cat_208 = None + wait_tensor_714 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_280); reduce_scatter_tensor_280 = None + convert_element_type_2118 = torch.ops.prims.convert_element_type.default(wait_tensor_714, torch.float32); wait_tensor_714 = None + convert_element_type_2120 = torch.ops.prims.convert_element_type.default(wait_tensor_165, torch.float32); wait_tensor_165 = None + mul_650 = torch.ops.aten.mul.Tensor(convert_element_type_2118, convert_element_type_2120); convert_element_type_2120 = None + mul_652 = torch.ops.aten.mul.Tensor(mul_100, mul_650) + sum_117 = torch.ops.aten.sum.dim_IntList(mul_652, [2], True); mul_652 = None + div_39 = torch.ops.aten.div.Tensor(mul_100, 4096) + mul_653 = torch.ops.aten.mul.Tensor(div_39, sum_117); div_39 = sum_117 = None + sub_60 = torch.ops.aten.sub.Tensor(mul_650, mul_653); mul_650 = mul_653 = None + mul_654 = torch.ops.aten.mul.Tensor(sub_60, rsqrt_25); sub_60 = rsqrt_25 = None + mul_655 = torch.ops.aten.mul.Tensor(convert_element_type_2118, mul_100); convert_element_type_2118 = mul_100 = None + sum_118 = torch.ops.aten.sum.dim_IntList(mul_655, [0, 1]); mul_655 = None + convert_element_type_2121 = torch.ops.prims.convert_element_type.default(mul_654, torch.bfloat16); mul_654 = None + convert_element_type_2122 = torch.ops.prims.convert_element_type.default(sum_118, torch.bfloat16); sum_118 = None + all_reduce_39 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2122, 'sum', '1'); convert_element_type_2122 = None + wait_tensor_715 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_39); all_reduce_39 = None + convert_element_type_2123 = torch.ops.prims.convert_element_type.default(wait_tensor_715, torch.float32); wait_tensor_715 = None + reduce_scatter_tensor_281 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2123, 'avg', 32, '0'); convert_element_type_2123 = None + wait_tensor_716 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_281); reduce_scatter_tensor_281 = None + add_265 = torch.ops.aten.add.Tensor(add_261, convert_element_type_2121); add_261 = convert_element_type_2121 = None + all_gather_into_tensor_395 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_265, 8, '1') + wait_tensor_717 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_395); all_gather_into_tensor_395 = None + split_217 = torch.ops.aten.split.Tensor(wait_tensor_717, 2); wait_tensor_717 = None + getitem_2081 = split_217[0] + getitem_2082 = split_217[1] + getitem_2083 = split_217[2] + getitem_2084 = split_217[3] + getitem_2085 = split_217[4] + getitem_2086 = split_217[5] + getitem_2087 = split_217[6] + getitem_2088 = split_217[7]; split_217 = None + cat_209 = torch.ops.aten.cat.default([getitem_2081, getitem_2082, getitem_2083, getitem_2084, getitem_2085, getitem_2086, getitem_2087, getitem_2088], 1); getitem_2081 = getitem_2082 = getitem_2083 = getitem_2084 = getitem_2085 = getitem_2086 = getitem_2087 = getitem_2088 = None + view_2785 = torch.ops.aten.view.default(cat_209, [16384, 4096]); cat_209 = None + permute_977 = torch.ops.aten.permute.default(view_2785, [1, 0]) + permute_138 = torch.ops.aten.permute.default(getitem_572, [0, 2, 1, 3]) + view_906 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + view_912 = torch.ops.aten.view.default(view_906, [16384, 512]); view_906 = None + mm_499 = torch.ops.aten.mm.default(permute_977, view_912); permute_977 = view_912 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16); primals_116 = None + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 32, '0'); convert_element_type_413 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_163, [1, 0]); wait_tensor_163 = None + permute_979 = torch.ops.aten.permute.default(permute_139, [1, 0]); permute_139 = None + mm_500 = torch.ops.aten.mm.default(view_2785, permute_979); view_2785 = permute_979 = None + view_2786 = torch.ops.aten.view.default(mm_500, [2, 8192, 512]); mm_500 = None + convert_element_type_2128 = torch.ops.prims.convert_element_type.default(mm_499, torch.float32); mm_499 = None + reduce_scatter_tensor_282 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2128, 'avg', 32, '0'); convert_element_type_2128 = None + wait_tensor_718 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_282); reduce_scatter_tensor_282 = None + view_2787 = torch.ops.aten.view.default(view_2786, [2, 8192, 4, 128]); view_2786 = None + permute_981 = torch.ops.aten.permute.default(view_2787, [0, 2, 1, 3]); view_2787 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16); primals_112 = None + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 32, '0'); convert_element_type_397 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32); add_47 = None + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_158) + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_399, 8, '1'); convert_element_type_399 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + split_57 = torch.ops.aten.split.Tensor(wait_tensor_159, 2); wait_tensor_159 = None + getitem_564 = split_57[0] + getitem_565 = split_57[1] + getitem_566 = split_57[2] + getitem_567 = split_57[3] + getitem_568 = split_57[4] + getitem_569 = split_57[5] + getitem_570 = split_57[6] + getitem_571 = split_57[7]; split_57 = None + cat_49 = torch.ops.aten.cat.default([getitem_564, getitem_565, getitem_566, getitem_567, getitem_568, getitem_569, getitem_570, getitem_571], 1); getitem_564 = getitem_565 = getitem_566 = getitem_567 = getitem_568 = getitem_569 = getitem_570 = getitem_571 = None + view_879 = torch.ops.aten.view.default(cat_49, [16384, 4096]); cat_49 = None + view_880 = torch.ops.aten.view.default(mm_84, [2, 8192, 512]); mm_84 = None + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16); primals_114 = None + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 32, '0'); convert_element_type_403 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_85 = torch.ops.aten.mm.default(view_879, permute_133) + view_887 = torch.ops.aten.view.default(mm_85, [2, 8192, 128]); mm_85 = None + view_894 = torch.ops.aten.view.default(mm_86, [2, 8192, 128]); mm_86 = None + view_896 = torch.ops.aten.view.default(view_880, [2, 8192, -1, 128]); view_880 = None + view_897 = torch.ops.aten.view.default(view_887, [2, 8192, -1, 128]); view_887 = None + view_898 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 4, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_897, torch.float32); view_897 = None + view_900 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 1, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_900); view_900 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_37); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_902 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 4, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_37); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_903 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 1, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_903, torch.bfloat16); view_903 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 1, 4, 128]); unsqueeze_24 = None + view_904 = torch.ops.aten.view.default(expand_24, [2, 8192, 4, 128]); expand_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_898, 3); view_898 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 1, 4, 128]); unsqueeze_25 = None + view_905 = torch.ops.aten.view.default(expand_25, [2, 8192, 4, 128]); expand_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + permute_137 = torch.ops.aten.permute.default(view_905, [0, 2, 1, 3]); view_905 = None + _scaled_dot_product_cudnn_attention_backward_19 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_981, permute_135, permute_136, permute_137, getitem_572, getitem_573, getitem_578, getitem_579, None, None, None, 8192, 8192, 0.0, True); permute_981 = permute_135 = permute_136 = permute_137 = getitem_572 = getitem_573 = getitem_578 = getitem_579 = None + getitem_2089 = _scaled_dot_product_cudnn_attention_backward_19[0] + getitem_2090 = _scaled_dot_product_cudnn_attention_backward_19[1] + getitem_2091 = _scaled_dot_product_cudnn_attention_backward_19[2]; _scaled_dot_product_cudnn_attention_backward_19 = None + permute_982 = torch.ops.aten.permute.default(getitem_2091, [0, 2, 1, 3]); getitem_2091 = None + permute_983 = torch.ops.aten.permute.default(getitem_2090, [0, 2, 1, 3]); getitem_2090 = None + permute_984 = torch.ops.aten.permute.default(getitem_2089, [0, 2, 1, 3]); getitem_2089 = None + view_2788 = torch.ops.aten.view.default(permute_982, [2, 8192, 1, 4, 128]); permute_982 = None + sum_119 = torch.ops.aten.sum.dim_IntList(view_2788, [3], True); view_2788 = None + squeeze_38 = torch.ops.aten.squeeze.dim(sum_119, 3); sum_119 = None + view_2789 = torch.ops.aten.view.default(permute_983, [2, 8192, 1, 4, 128]); permute_983 = None + sum_120 = torch.ops.aten.sum.dim_IntList(view_2789, [3], True); view_2789 = None + squeeze_39 = torch.ops.aten.squeeze.dim(sum_120, 3); sum_120 = None + convert_element_type_2129 = torch.ops.prims.convert_element_type.default(squeeze_39, torch.float32); squeeze_39 = None + convert_element_type_2130 = torch.ops.prims.convert_element_type.default(permute_984, torch.float32); permute_984 = None + view_2790 = torch.ops.aten.view.default(convert_element_type_2129, [2, 8192, 1, 64, 2]); convert_element_type_2129 = None + view_as_complex_102 = torch.ops.aten.view_as_complex.default(view_2790); view_2790 = None + mul_656 = torch.ops.aten.mul.Tensor(view_as_complex_102, _conj); view_as_complex_102 = None + view_2791 = torch.ops.aten.view.default(convert_element_type_2130, [2, 8192, 4, 64, 2]); convert_element_type_2130 = None + view_as_complex_103 = torch.ops.aten.view_as_complex.default(view_2791); view_2791 = None + mul_657 = torch.ops.aten.mul.Tensor(view_as_complex_103, _conj); view_as_complex_103 = None + view_as_real_102 = torch.ops.aten.view_as_real.default(mul_656); mul_656 = None + view_2792 = torch.ops.aten.view.default(view_as_real_102, [2, 8192, 1, 128]); view_as_real_102 = None + convert_element_type_2131 = torch.ops.prims.convert_element_type.default(view_2792, torch.bfloat16); view_2792 = None + view_as_real_103 = torch.ops.aten.view_as_real.default(mul_657); mul_657 = None + view_2793 = torch.ops.aten.view.default(view_as_real_103, [2, 8192, 4, 128]); view_as_real_103 = None + convert_element_type_2132 = torch.ops.prims.convert_element_type.default(view_2793, torch.bfloat16); view_2793 = None + view_2794 = torch.ops.aten.view.default(squeeze_38, [2, 8192, 128]); squeeze_38 = None + view_2795 = torch.ops.aten.view.default(convert_element_type_2131, [2, 8192, 128]); convert_element_type_2131 = None + view_2796 = torch.ops.aten.view.default(convert_element_type_2132, [2, 8192, 512]); convert_element_type_2132 = None + view_2797 = torch.ops.aten.view.default(view_2794, [16384, 128]); view_2794 = None + permute_985 = torch.ops.aten.permute.default(view_2797, [1, 0]) + mm_501 = torch.ops.aten.mm.default(permute_985, view_879); permute_985 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16); primals_115 = None + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 32, '0'); convert_element_type_406 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + permute_987 = torch.ops.aten.permute.default(permute_134, [1, 0]); permute_134 = None + mm_502 = torch.ops.aten.mm.default(view_2797, permute_987); view_2797 = permute_987 = None + view_2798 = torch.ops.aten.view.default(mm_502, [2, 8192, 4096]); mm_502 = None + convert_element_type_2137 = torch.ops.prims.convert_element_type.default(mm_501, torch.float32); mm_501 = None + reduce_scatter_tensor_283 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2137, 'avg', 32, '0'); convert_element_type_2137 = None + wait_tensor_719 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_283); reduce_scatter_tensor_283 = None + view_2799 = torch.ops.aten.view.default(view_2795, [16384, 128]); view_2795 = None + permute_989 = torch.ops.aten.permute.default(view_2799, [1, 0]) + mm_503 = torch.ops.aten.mm.default(permute_989, view_879); permute_989 = None + permute_991 = torch.ops.aten.permute.default(permute_133, [1, 0]); permute_133 = None + mm_504 = torch.ops.aten.mm.default(view_2799, permute_991); view_2799 = permute_991 = None + view_2800 = torch.ops.aten.view.default(mm_504, [2, 8192, 4096]); mm_504 = None + add_266 = torch.ops.aten.add.Tensor(view_2798, view_2800); view_2798 = view_2800 = None + convert_element_type_2142 = torch.ops.prims.convert_element_type.default(mm_503, torch.float32); mm_503 = None + reduce_scatter_tensor_284 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2142, 'avg', 32, '0'); convert_element_type_2142 = None + wait_tensor_720 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_284); reduce_scatter_tensor_284 = None + view_2801 = torch.ops.aten.view.default(view_2796, [16384, 512]); view_2796 = None + permute_993 = torch.ops.aten.permute.default(view_2801, [1, 0]) + mm_505 = torch.ops.aten.mm.default(permute_993, view_879); permute_993 = view_879 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16); primals_113 = None + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 32, '0'); convert_element_type_400 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + permute_995 = torch.ops.aten.permute.default(permute_132, [1, 0]); permute_132 = None + mm_506 = torch.ops.aten.mm.default(view_2801, permute_995); view_2801 = permute_995 = None + view_2802 = torch.ops.aten.view.default(mm_506, [2, 8192, 4096]); mm_506 = None + add_267 = torch.ops.aten.add.Tensor(add_266, view_2802); add_266 = view_2802 = None + convert_element_type_2147 = torch.ops.prims.convert_element_type.default(mm_505, torch.float32); mm_505 = None + reduce_scatter_tensor_285 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2147, 'avg', 32, '0'); convert_element_type_2147 = None + wait_tensor_721 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_285); reduce_scatter_tensor_285 = None + split_218 = torch.ops.aten.split.Tensor(add_267, 1024, 1); add_267 = None + getitem_2092 = split_218[0] + getitem_2093 = split_218[1] + getitem_2094 = split_218[2] + getitem_2095 = split_218[3] + getitem_2096 = split_218[4] + getitem_2097 = split_218[5] + getitem_2098 = split_218[6] + getitem_2099 = split_218[7]; split_218 = None + cat_210 = torch.ops.aten.cat.default([getitem_2092, getitem_2093, getitem_2094, getitem_2095, getitem_2096, getitem_2097, getitem_2098, getitem_2099]); getitem_2092 = getitem_2093 = getitem_2094 = getitem_2095 = getitem_2096 = getitem_2097 = getitem_2098 = getitem_2099 = None + reduce_scatter_tensor_286 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_210, 'sum', 8, '1'); cat_210 = None + wait_tensor_722 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_286); reduce_scatter_tensor_286 = None + convert_element_type_2148 = torch.ops.prims.convert_element_type.default(wait_tensor_722, torch.float32); wait_tensor_722 = None + convert_element_type_2150 = torch.ops.prims.convert_element_type.default(wait_tensor_158, torch.float32); wait_tensor_158 = None + mul_658 = torch.ops.aten.mul.Tensor(convert_element_type_2148, convert_element_type_2150); convert_element_type_2150 = None + mul_660 = torch.ops.aten.mul.Tensor(mul_96, mul_658) + sum_121 = torch.ops.aten.sum.dim_IntList(mul_660, [2], True); mul_660 = None + div_40 = torch.ops.aten.div.Tensor(mul_96, 4096) + mul_661 = torch.ops.aten.mul.Tensor(div_40, sum_121); div_40 = sum_121 = None + sub_61 = torch.ops.aten.sub.Tensor(mul_658, mul_661); mul_658 = mul_661 = None + mul_662 = torch.ops.aten.mul.Tensor(sub_61, rsqrt_24); sub_61 = rsqrt_24 = None + mul_663 = torch.ops.aten.mul.Tensor(convert_element_type_2148, mul_96); convert_element_type_2148 = mul_96 = None + sum_122 = torch.ops.aten.sum.dim_IntList(mul_663, [0, 1]); mul_663 = None + convert_element_type_2151 = torch.ops.prims.convert_element_type.default(mul_662, torch.bfloat16); mul_662 = None + convert_element_type_2152 = torch.ops.prims.convert_element_type.default(sum_122, torch.bfloat16); sum_122 = None + all_reduce_40 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2152, 'sum', '1'); convert_element_type_2152 = None + wait_tensor_723 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_40); all_reduce_40 = None + convert_element_type_2153 = torch.ops.prims.convert_element_type.default(wait_tensor_723, torch.float32); wait_tensor_723 = None + reduce_scatter_tensor_287 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2153, 'avg', 32, '0'); convert_element_type_2153 = None + wait_tensor_724 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_287); reduce_scatter_tensor_287 = None + add_268 = torch.ops.aten.add.Tensor(add_265, convert_element_type_2151); add_265 = convert_element_type_2151 = None + all_gather_into_tensor_396 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_268, 8, '1') + wait_tensor_725 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_396); all_gather_into_tensor_396 = None + split_219 = torch.ops.aten.split.Tensor(wait_tensor_725, 2); wait_tensor_725 = None + getitem_2100 = split_219[0] + getitem_2101 = split_219[1] + getitem_2102 = split_219[2] + getitem_2103 = split_219[3] + getitem_2104 = split_219[4] + getitem_2105 = split_219[5] + getitem_2106 = split_219[6] + getitem_2107 = split_219[7]; split_219 = None + cat_211 = torch.ops.aten.cat.default([getitem_2100, getitem_2101, getitem_2102, getitem_2103, getitem_2104, getitem_2105, getitem_2106, getitem_2107], 1); getitem_2100 = getitem_2101 = getitem_2102 = getitem_2103 = getitem_2104 = getitem_2105 = getitem_2106 = getitem_2107 = None + view_2803 = torch.ops.aten.view.default(cat_211, [16384, 4096]); cat_211 = None + permute_997 = torch.ops.aten.permute.default(view_2803, [1, 0]) + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23); reduce_scatter_tensor_23 = None + add_45 = torch.ops.aten.add.Tensor(add_43, wait_tensor_151); wait_tensor_151 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16); primals_108 = None + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 32, '0'); convert_element_type_383 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32); add_45 = None + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_152) + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_385, 8, '1'); convert_element_type_385 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + split_55 = torch.ops.aten.split.Tensor(wait_tensor_153, 2); wait_tensor_153 = None + getitem_548 = split_55[0] + getitem_549 = split_55[1] + getitem_550 = split_55[2] + getitem_551 = split_55[3] + getitem_552 = split_55[4] + getitem_553 = split_55[5] + getitem_554 = split_55[6] + getitem_555 = split_55[7]; split_55 = None + cat_47 = torch.ops.aten.cat.default([getitem_548, getitem_549, getitem_550, getitem_551, getitem_552, getitem_553, getitem_554, getitem_555], 1); getitem_548 = getitem_549 = getitem_550 = getitem_551 = getitem_552 = getitem_553 = getitem_554 = getitem_555 = None + view_852 = torch.ops.aten.view.default(cat_47, [16384, 4096]); cat_47 = None + view_853 = torch.ops.aten.view.default(mm_81, [2, 8192, 1792]); mm_81 = None + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_853, torch.float32); view_853 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16); primals_110 = None + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 32, '0'); convert_element_type_391 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + mm_82 = torch.ops.aten.mm.default(view_852, permute_130) + view_860 = torch.ops.aten.view.default(mm_82, [2, 8192, 1792]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_860) + view_867 = torch.ops.aten.view.default(mul_95, [16384, 1792]); mul_95 = None + mm_507 = torch.ops.aten.mm.default(permute_997, view_867); permute_997 = view_867 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16); primals_111 = None + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 32, '0'); convert_element_type_394 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + permute_999 = torch.ops.aten.permute.default(permute_131, [1, 0]); permute_131 = None + mm_508 = torch.ops.aten.mm.default(view_2803, permute_999); view_2803 = permute_999 = None + view_2804 = torch.ops.aten.view.default(mm_508, [2, 8192, 1792]); mm_508 = None + convert_element_type_2158 = torch.ops.prims.convert_element_type.default(mm_507, torch.float32); mm_507 = None + reduce_scatter_tensor_288 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2158, 'avg', 32, '0'); convert_element_type_2158 = None + wait_tensor_726 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_288); reduce_scatter_tensor_288 = None + mul_664 = torch.ops.aten.mul.Tensor(view_2804, convert_element_type_390); convert_element_type_390 = None + mul_665 = torch.ops.aten.mul.Tensor(view_2804, view_860); view_2804 = view_860 = None + view_2805 = torch.ops.aten.view.default(mul_664, [16384, 1792]); mul_664 = None + permute_1001 = torch.ops.aten.permute.default(view_2805, [1, 0]) + mm_509 = torch.ops.aten.mm.default(permute_1001, view_852); permute_1001 = None + permute_1003 = torch.ops.aten.permute.default(permute_130, [1, 0]); permute_130 = None + mm_510 = torch.ops.aten.mm.default(view_2805, permute_1003); view_2805 = permute_1003 = None + view_2806 = torch.ops.aten.view.default(mm_510, [2, 8192, 4096]); mm_510 = None + convert_element_type_2163 = torch.ops.prims.convert_element_type.default(mm_509, torch.float32); mm_509 = None + reduce_scatter_tensor_289 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2163, 'avg', 32, '0'); convert_element_type_2163 = None + wait_tensor_727 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_289); reduce_scatter_tensor_289 = None + convert_element_type_2164 = torch.ops.prims.convert_element_type.default(mul_665, torch.float32); mul_665 = None + neg_20 = torch.ops.aten.neg.default(convert_element_type_389) + exp_20 = torch.ops.aten.exp.default(neg_20); neg_20 = None + add_269 = torch.ops.aten.add.Tensor(exp_20, 1); exp_20 = None + reciprocal_20 = torch.ops.aten.reciprocal.default(add_269); add_269 = None + mul_666 = torch.ops.aten.mul.Tensor(reciprocal_20, 1); reciprocal_20 = None + mul_667 = torch.ops.aten.mul.Tensor(convert_element_type_2164, mul_666); convert_element_type_2164 = None + sub_62 = torch.ops.aten.sub.Tensor(1, mul_666); mul_666 = None + mul_668 = torch.ops.aten.mul.Tensor(convert_element_type_389, sub_62); convert_element_type_389 = sub_62 = None + add_270 = torch.ops.aten.add.Tensor(mul_668, 1); mul_668 = None + mul_669 = torch.ops.aten.mul.Tensor(mul_667, add_270); mul_667 = add_270 = None + convert_element_type_2166 = torch.ops.prims.convert_element_type.default(mul_669, torch.bfloat16); mul_669 = None + view_2807 = torch.ops.aten.view.default(convert_element_type_2166, [16384, 1792]); convert_element_type_2166 = None + permute_1005 = torch.ops.aten.permute.default(view_2807, [1, 0]) + mm_511 = torch.ops.aten.mm.default(permute_1005, view_852); permute_1005 = view_852 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16); primals_109 = None + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 32, '0'); convert_element_type_386 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_154, [1, 0]); wait_tensor_154 = None + permute_1007 = torch.ops.aten.permute.default(permute_129, [1, 0]); permute_129 = None + mm_512 = torch.ops.aten.mm.default(view_2807, permute_1007); view_2807 = permute_1007 = None + view_2808 = torch.ops.aten.view.default(mm_512, [2, 8192, 4096]); mm_512 = None + add_271 = torch.ops.aten.add.Tensor(view_2806, view_2808); view_2806 = view_2808 = None + convert_element_type_2171 = torch.ops.prims.convert_element_type.default(mm_511, torch.float32); mm_511 = None + reduce_scatter_tensor_290 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2171, 'avg', 32, '0'); convert_element_type_2171 = None + wait_tensor_728 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_290); reduce_scatter_tensor_290 = None + split_220 = torch.ops.aten.split.Tensor(add_271, 1024, 1); add_271 = None + getitem_2108 = split_220[0] + getitem_2109 = split_220[1] + getitem_2110 = split_220[2] + getitem_2111 = split_220[3] + getitem_2112 = split_220[4] + getitem_2113 = split_220[5] + getitem_2114 = split_220[6] + getitem_2115 = split_220[7]; split_220 = None + cat_212 = torch.ops.aten.cat.default([getitem_2108, getitem_2109, getitem_2110, getitem_2111, getitem_2112, getitem_2113, getitem_2114, getitem_2115]); getitem_2108 = getitem_2109 = getitem_2110 = getitem_2111 = getitem_2112 = getitem_2113 = getitem_2114 = getitem_2115 = None + reduce_scatter_tensor_291 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_212, 'sum', 8, '1'); cat_212 = None + wait_tensor_729 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_291); reduce_scatter_tensor_291 = None + convert_element_type_2172 = torch.ops.prims.convert_element_type.default(wait_tensor_729, torch.float32); wait_tensor_729 = None + convert_element_type_2174 = torch.ops.prims.convert_element_type.default(wait_tensor_152, torch.float32); wait_tensor_152 = None + mul_670 = torch.ops.aten.mul.Tensor(convert_element_type_2172, convert_element_type_2174); convert_element_type_2174 = None + mul_672 = torch.ops.aten.mul.Tensor(mul_92, mul_670) + sum_123 = torch.ops.aten.sum.dim_IntList(mul_672, [2], True); mul_672 = None + div_41 = torch.ops.aten.div.Tensor(mul_92, 4096) + mul_673 = torch.ops.aten.mul.Tensor(div_41, sum_123); div_41 = sum_123 = None + sub_63 = torch.ops.aten.sub.Tensor(mul_670, mul_673); mul_670 = mul_673 = None + mul_674 = torch.ops.aten.mul.Tensor(sub_63, rsqrt_23); sub_63 = rsqrt_23 = None + mul_675 = torch.ops.aten.mul.Tensor(convert_element_type_2172, mul_92); convert_element_type_2172 = mul_92 = None + sum_124 = torch.ops.aten.sum.dim_IntList(mul_675, [0, 1]); mul_675 = None + convert_element_type_2175 = torch.ops.prims.convert_element_type.default(mul_674, torch.bfloat16); mul_674 = None + convert_element_type_2176 = torch.ops.prims.convert_element_type.default(sum_124, torch.bfloat16); sum_124 = None + all_reduce_41 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2176, 'sum', '1'); convert_element_type_2176 = None + wait_tensor_730 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_41); all_reduce_41 = None + convert_element_type_2177 = torch.ops.prims.convert_element_type.default(wait_tensor_730, torch.float32); wait_tensor_730 = None + reduce_scatter_tensor_292 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2177, 'avg', 32, '0'); convert_element_type_2177 = None + wait_tensor_731 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_292); reduce_scatter_tensor_292 = None + add_272 = torch.ops.aten.add.Tensor(add_268, convert_element_type_2175); add_268 = convert_element_type_2175 = None + all_gather_into_tensor_397 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_272, 8, '1') + wait_tensor_732 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_397); all_gather_into_tensor_397 = None + split_221 = torch.ops.aten.split.Tensor(wait_tensor_732, 2); wait_tensor_732 = None + getitem_2116 = split_221[0] + getitem_2117 = split_221[1] + getitem_2118 = split_221[2] + getitem_2119 = split_221[3] + getitem_2120 = split_221[4] + getitem_2121 = split_221[5] + getitem_2122 = split_221[6] + getitem_2123 = split_221[7]; split_221 = None + cat_213 = torch.ops.aten.cat.default([getitem_2116, getitem_2117, getitem_2118, getitem_2119, getitem_2120, getitem_2121, getitem_2122, getitem_2123], 1); getitem_2116 = getitem_2117 = getitem_2118 = getitem_2119 = getitem_2120 = getitem_2121 = getitem_2122 = getitem_2123 = None + view_2809 = torch.ops.aten.view.default(cat_213, [16384, 4096]); cat_213 = None + permute_1009 = torch.ops.aten.permute.default(view_2809, [1, 0]) + permute_127 = torch.ops.aten.permute.default(getitem_531, [0, 2, 1, 3]) + view_834 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + view_840 = torch.ops.aten.view.default(view_834, [16384, 512]); view_834 = None + mm_513 = torch.ops.aten.mm.default(permute_1009, view_840); permute_1009 = view_840 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16); primals_107 = None + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 32, '0'); convert_element_type_380 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_150, [1, 0]); wait_tensor_150 = None + permute_1011 = torch.ops.aten.permute.default(permute_128, [1, 0]); permute_128 = None + mm_514 = torch.ops.aten.mm.default(view_2809, permute_1011); view_2809 = permute_1011 = None + view_2810 = torch.ops.aten.view.default(mm_514, [2, 8192, 512]); mm_514 = None + convert_element_type_2182 = torch.ops.prims.convert_element_type.default(mm_513, torch.float32); mm_513 = None + reduce_scatter_tensor_293 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2182, 'avg', 32, '0'); convert_element_type_2182 = None + wait_tensor_733 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_293); reduce_scatter_tensor_293 = None + view_2811 = torch.ops.aten.view.default(view_2810, [2, 8192, 4, 128]); view_2810 = None + permute_1013 = torch.ops.aten.permute.default(view_2811, [0, 2, 1, 3]); view_2811 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16); primals_103 = None + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 32, '0'); convert_element_type_364 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32); add_43 = None + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_145) + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_366, 8, '1'); convert_element_type_366 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + split_53 = torch.ops.aten.split.Tensor(wait_tensor_146, 2); wait_tensor_146 = None + getitem_523 = split_53[0] + getitem_524 = split_53[1] + getitem_525 = split_53[2] + getitem_526 = split_53[3] + getitem_527 = split_53[4] + getitem_528 = split_53[5] + getitem_529 = split_53[6] + getitem_530 = split_53[7]; split_53 = None + cat_45 = torch.ops.aten.cat.default([getitem_523, getitem_524, getitem_525, getitem_526, getitem_527, getitem_528, getitem_529, getitem_530], 1); getitem_523 = getitem_524 = getitem_525 = getitem_526 = getitem_527 = getitem_528 = getitem_529 = getitem_530 = None + view_807 = torch.ops.aten.view.default(cat_45, [16384, 4096]); cat_45 = None + view_808 = torch.ops.aten.view.default(mm_77, [2, 8192, 512]); mm_77 = None + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16); primals_105 = None + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 32, '0'); convert_element_type_370 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + mm_78 = torch.ops.aten.mm.default(view_807, permute_122) + view_815 = torch.ops.aten.view.default(mm_78, [2, 8192, 128]); mm_78 = None + view_822 = torch.ops.aten.view.default(mm_79, [2, 8192, 128]); mm_79 = None + view_824 = torch.ops.aten.view.default(view_808, [2, 8192, -1, 128]); view_808 = None + view_825 = torch.ops.aten.view.default(view_815, [2, 8192, -1, 128]); view_815 = None + view_826 = torch.ops.aten.view.default(view_822, [2, 8192, -1, 128]); view_822 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_824, torch.float32); view_824 = None + view_827 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 4, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_827); view_827 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_825, torch.float32); view_825 = None + view_828 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 1, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_828); view_828 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_37); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_830 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 4, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_37); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_831 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 1, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_830, torch.bfloat16); view_830 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_831, torch.bfloat16); view_831 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 1, 4, 128]); unsqueeze_22 = None + view_832 = torch.ops.aten.view.default(expand_22, [2, 8192, 4, 128]); expand_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_826, 3); view_826 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 1, 4, 128]); unsqueeze_23 = None + view_833 = torch.ops.aten.view.default(expand_23, [2, 8192, 4, 128]); expand_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_832, [0, 2, 1, 3]); view_832 = None + permute_126 = torch.ops.aten.permute.default(view_833, [0, 2, 1, 3]); view_833 = None + _scaled_dot_product_cudnn_attention_backward_20 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1013, permute_124, permute_125, permute_126, getitem_531, getitem_532, getitem_537, getitem_538, None, None, None, 8192, 8192, 0.0, True); permute_1013 = permute_124 = permute_125 = permute_126 = getitem_531 = getitem_532 = getitem_537 = getitem_538 = None + getitem_2124 = _scaled_dot_product_cudnn_attention_backward_20[0] + getitem_2125 = _scaled_dot_product_cudnn_attention_backward_20[1] + getitem_2126 = _scaled_dot_product_cudnn_attention_backward_20[2]; _scaled_dot_product_cudnn_attention_backward_20 = None + permute_1014 = torch.ops.aten.permute.default(getitem_2126, [0, 2, 1, 3]); getitem_2126 = None + permute_1015 = torch.ops.aten.permute.default(getitem_2125, [0, 2, 1, 3]); getitem_2125 = None + permute_1016 = torch.ops.aten.permute.default(getitem_2124, [0, 2, 1, 3]); getitem_2124 = None + view_2812 = torch.ops.aten.view.default(permute_1014, [2, 8192, 1, 4, 128]); permute_1014 = None + sum_125 = torch.ops.aten.sum.dim_IntList(view_2812, [3], True); view_2812 = None + squeeze_40 = torch.ops.aten.squeeze.dim(sum_125, 3); sum_125 = None + view_2813 = torch.ops.aten.view.default(permute_1015, [2, 8192, 1, 4, 128]); permute_1015 = None + sum_126 = torch.ops.aten.sum.dim_IntList(view_2813, [3], True); view_2813 = None + squeeze_41 = torch.ops.aten.squeeze.dim(sum_126, 3); sum_126 = None + convert_element_type_2183 = torch.ops.prims.convert_element_type.default(squeeze_41, torch.float32); squeeze_41 = None + convert_element_type_2184 = torch.ops.prims.convert_element_type.default(permute_1016, torch.float32); permute_1016 = None + view_2814 = torch.ops.aten.view.default(convert_element_type_2183, [2, 8192, 1, 64, 2]); convert_element_type_2183 = None + view_as_complex_104 = torch.ops.aten.view_as_complex.default(view_2814); view_2814 = None + mul_676 = torch.ops.aten.mul.Tensor(view_as_complex_104, _conj); view_as_complex_104 = None + view_2815 = torch.ops.aten.view.default(convert_element_type_2184, [2, 8192, 4, 64, 2]); convert_element_type_2184 = None + view_as_complex_105 = torch.ops.aten.view_as_complex.default(view_2815); view_2815 = None + mul_677 = torch.ops.aten.mul.Tensor(view_as_complex_105, _conj); view_as_complex_105 = None + view_as_real_104 = torch.ops.aten.view_as_real.default(mul_676); mul_676 = None + view_2816 = torch.ops.aten.view.default(view_as_real_104, [2, 8192, 1, 128]); view_as_real_104 = None + convert_element_type_2185 = torch.ops.prims.convert_element_type.default(view_2816, torch.bfloat16); view_2816 = None + view_as_real_105 = torch.ops.aten.view_as_real.default(mul_677); mul_677 = None + view_2817 = torch.ops.aten.view.default(view_as_real_105, [2, 8192, 4, 128]); view_as_real_105 = None + convert_element_type_2186 = torch.ops.prims.convert_element_type.default(view_2817, torch.bfloat16); view_2817 = None + view_2818 = torch.ops.aten.view.default(squeeze_40, [2, 8192, 128]); squeeze_40 = None + view_2819 = torch.ops.aten.view.default(convert_element_type_2185, [2, 8192, 128]); convert_element_type_2185 = None + view_2820 = torch.ops.aten.view.default(convert_element_type_2186, [2, 8192, 512]); convert_element_type_2186 = None + view_2821 = torch.ops.aten.view.default(view_2818, [16384, 128]); view_2818 = None + permute_1017 = torch.ops.aten.permute.default(view_2821, [1, 0]) + mm_515 = torch.ops.aten.mm.default(permute_1017, view_807); permute_1017 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16); primals_106 = None + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 32, '0'); convert_element_type_373 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + permute_1019 = torch.ops.aten.permute.default(permute_123, [1, 0]); permute_123 = None + mm_516 = torch.ops.aten.mm.default(view_2821, permute_1019); view_2821 = permute_1019 = None + view_2822 = torch.ops.aten.view.default(mm_516, [2, 8192, 4096]); mm_516 = None + convert_element_type_2191 = torch.ops.prims.convert_element_type.default(mm_515, torch.float32); mm_515 = None + reduce_scatter_tensor_294 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2191, 'avg', 32, '0'); convert_element_type_2191 = None + wait_tensor_734 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_294); reduce_scatter_tensor_294 = None + view_2823 = torch.ops.aten.view.default(view_2819, [16384, 128]); view_2819 = None + permute_1021 = torch.ops.aten.permute.default(view_2823, [1, 0]) + mm_517 = torch.ops.aten.mm.default(permute_1021, view_807); permute_1021 = None + permute_1023 = torch.ops.aten.permute.default(permute_122, [1, 0]); permute_122 = None + mm_518 = torch.ops.aten.mm.default(view_2823, permute_1023); view_2823 = permute_1023 = None + view_2824 = torch.ops.aten.view.default(mm_518, [2, 8192, 4096]); mm_518 = None + add_273 = torch.ops.aten.add.Tensor(view_2822, view_2824); view_2822 = view_2824 = None + convert_element_type_2196 = torch.ops.prims.convert_element_type.default(mm_517, torch.float32); mm_517 = None + reduce_scatter_tensor_295 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2196, 'avg', 32, '0'); convert_element_type_2196 = None + wait_tensor_735 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_295); reduce_scatter_tensor_295 = None + view_2825 = torch.ops.aten.view.default(view_2820, [16384, 512]); view_2820 = None + permute_1025 = torch.ops.aten.permute.default(view_2825, [1, 0]) + mm_519 = torch.ops.aten.mm.default(permute_1025, view_807); permute_1025 = view_807 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16); primals_104 = None + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 32, '0'); convert_element_type_367 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + permute_1027 = torch.ops.aten.permute.default(permute_121, [1, 0]); permute_121 = None + mm_520 = torch.ops.aten.mm.default(view_2825, permute_1027); view_2825 = permute_1027 = None + view_2826 = torch.ops.aten.view.default(mm_520, [2, 8192, 4096]); mm_520 = None + add_274 = torch.ops.aten.add.Tensor(add_273, view_2826); add_273 = view_2826 = None + convert_element_type_2201 = torch.ops.prims.convert_element_type.default(mm_519, torch.float32); mm_519 = None + reduce_scatter_tensor_296 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2201, 'avg', 32, '0'); convert_element_type_2201 = None + wait_tensor_736 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_296); reduce_scatter_tensor_296 = None + split_222 = torch.ops.aten.split.Tensor(add_274, 1024, 1); add_274 = None + getitem_2127 = split_222[0] + getitem_2128 = split_222[1] + getitem_2129 = split_222[2] + getitem_2130 = split_222[3] + getitem_2131 = split_222[4] + getitem_2132 = split_222[5] + getitem_2133 = split_222[6] + getitem_2134 = split_222[7]; split_222 = None + cat_214 = torch.ops.aten.cat.default([getitem_2127, getitem_2128, getitem_2129, getitem_2130, getitem_2131, getitem_2132, getitem_2133, getitem_2134]); getitem_2127 = getitem_2128 = getitem_2129 = getitem_2130 = getitem_2131 = getitem_2132 = getitem_2133 = getitem_2134 = None + reduce_scatter_tensor_297 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_214, 'sum', 8, '1'); cat_214 = None + wait_tensor_737 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_297); reduce_scatter_tensor_297 = None + convert_element_type_2202 = torch.ops.prims.convert_element_type.default(wait_tensor_737, torch.float32); wait_tensor_737 = None + convert_element_type_2204 = torch.ops.prims.convert_element_type.default(wait_tensor_145, torch.float32); wait_tensor_145 = None + mul_678 = torch.ops.aten.mul.Tensor(convert_element_type_2202, convert_element_type_2204); convert_element_type_2204 = None + mul_680 = torch.ops.aten.mul.Tensor(mul_88, mul_678) + sum_127 = torch.ops.aten.sum.dim_IntList(mul_680, [2], True); mul_680 = None + div_42 = torch.ops.aten.div.Tensor(mul_88, 4096) + mul_681 = torch.ops.aten.mul.Tensor(div_42, sum_127); div_42 = sum_127 = None + sub_64 = torch.ops.aten.sub.Tensor(mul_678, mul_681); mul_678 = mul_681 = None + mul_682 = torch.ops.aten.mul.Tensor(sub_64, rsqrt_22); sub_64 = rsqrt_22 = None + mul_683 = torch.ops.aten.mul.Tensor(convert_element_type_2202, mul_88); convert_element_type_2202 = mul_88 = None + sum_128 = torch.ops.aten.sum.dim_IntList(mul_683, [0, 1]); mul_683 = None + convert_element_type_2205 = torch.ops.prims.convert_element_type.default(mul_682, torch.bfloat16); mul_682 = None + convert_element_type_2206 = torch.ops.prims.convert_element_type.default(sum_128, torch.bfloat16); sum_128 = None + all_reduce_42 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2206, 'sum', '1'); convert_element_type_2206 = None + wait_tensor_738 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_42); all_reduce_42 = None + convert_element_type_2207 = torch.ops.prims.convert_element_type.default(wait_tensor_738, torch.float32); wait_tensor_738 = None + reduce_scatter_tensor_298 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2207, 'avg', 32, '0'); convert_element_type_2207 = None + wait_tensor_739 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_298); reduce_scatter_tensor_298 = None + add_275 = torch.ops.aten.add.Tensor(add_272, convert_element_type_2205); add_272 = convert_element_type_2205 = None + all_gather_into_tensor_398 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_275, 8, '1') + wait_tensor_740 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_398); all_gather_into_tensor_398 = None + split_223 = torch.ops.aten.split.Tensor(wait_tensor_740, 2); wait_tensor_740 = None + getitem_2135 = split_223[0] + getitem_2136 = split_223[1] + getitem_2137 = split_223[2] + getitem_2138 = split_223[3] + getitem_2139 = split_223[4] + getitem_2140 = split_223[5] + getitem_2141 = split_223[6] + getitem_2142 = split_223[7]; split_223 = None + cat_215 = torch.ops.aten.cat.default([getitem_2135, getitem_2136, getitem_2137, getitem_2138, getitem_2139, getitem_2140, getitem_2141, getitem_2142], 1); getitem_2135 = getitem_2136 = getitem_2137 = getitem_2138 = getitem_2139 = getitem_2140 = getitem_2141 = getitem_2142 = None + view_2827 = torch.ops.aten.view.default(cat_215, [16384, 4096]); cat_215 = None + permute_1029 = torch.ops.aten.permute.default(view_2827, [1, 0]) + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21); reduce_scatter_tensor_21 = None + add_41 = torch.ops.aten.add.Tensor(add_39, wait_tensor_138); wait_tensor_138 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16); primals_99 = None + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 32, '0'); convert_element_type_350 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32); add_41 = None + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_139) + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_352, 8, '1'); convert_element_type_352 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + split_51 = torch.ops.aten.split.Tensor(wait_tensor_140, 2); wait_tensor_140 = None + getitem_507 = split_51[0] + getitem_508 = split_51[1] + getitem_509 = split_51[2] + getitem_510 = split_51[3] + getitem_511 = split_51[4] + getitem_512 = split_51[5] + getitem_513 = split_51[6] + getitem_514 = split_51[7]; split_51 = None + cat_43 = torch.ops.aten.cat.default([getitem_507, getitem_508, getitem_509, getitem_510, getitem_511, getitem_512, getitem_513, getitem_514], 1); getitem_507 = getitem_508 = getitem_509 = getitem_510 = getitem_511 = getitem_512 = getitem_513 = getitem_514 = None + view_780 = torch.ops.aten.view.default(cat_43, [16384, 4096]); cat_43 = None + view_781 = torch.ops.aten.view.default(mm_74, [2, 8192, 1792]); mm_74 = None + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_781, torch.float32); view_781 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16); primals_101 = None + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 32, '0'); convert_element_type_358 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + mm_75 = torch.ops.aten.mm.default(view_780, permute_119) + view_788 = torch.ops.aten.view.default(mm_75, [2, 8192, 1792]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_788) + view_795 = torch.ops.aten.view.default(mul_87, [16384, 1792]); mul_87 = None + mm_521 = torch.ops.aten.mm.default(permute_1029, view_795); permute_1029 = view_795 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16); primals_102 = None + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 32, '0'); convert_element_type_361 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + permute_1031 = torch.ops.aten.permute.default(permute_120, [1, 0]); permute_120 = None + mm_522 = torch.ops.aten.mm.default(view_2827, permute_1031); view_2827 = permute_1031 = None + view_2828 = torch.ops.aten.view.default(mm_522, [2, 8192, 1792]); mm_522 = None + convert_element_type_2212 = torch.ops.prims.convert_element_type.default(mm_521, torch.float32); mm_521 = None + reduce_scatter_tensor_299 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2212, 'avg', 32, '0'); convert_element_type_2212 = None + wait_tensor_741 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_299); reduce_scatter_tensor_299 = None + mul_684 = torch.ops.aten.mul.Tensor(view_2828, convert_element_type_357); convert_element_type_357 = None + mul_685 = torch.ops.aten.mul.Tensor(view_2828, view_788); view_2828 = view_788 = None + view_2829 = torch.ops.aten.view.default(mul_684, [16384, 1792]); mul_684 = None + permute_1033 = torch.ops.aten.permute.default(view_2829, [1, 0]) + mm_523 = torch.ops.aten.mm.default(permute_1033, view_780); permute_1033 = None + permute_1035 = torch.ops.aten.permute.default(permute_119, [1, 0]); permute_119 = None + mm_524 = torch.ops.aten.mm.default(view_2829, permute_1035); view_2829 = permute_1035 = None + view_2830 = torch.ops.aten.view.default(mm_524, [2, 8192, 4096]); mm_524 = None + convert_element_type_2217 = torch.ops.prims.convert_element_type.default(mm_523, torch.float32); mm_523 = None + reduce_scatter_tensor_300 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2217, 'avg', 32, '0'); convert_element_type_2217 = None + wait_tensor_742 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_300); reduce_scatter_tensor_300 = None + convert_element_type_2218 = torch.ops.prims.convert_element_type.default(mul_685, torch.float32); mul_685 = None + neg_21 = torch.ops.aten.neg.default(convert_element_type_356) + exp_21 = torch.ops.aten.exp.default(neg_21); neg_21 = None + add_276 = torch.ops.aten.add.Tensor(exp_21, 1); exp_21 = None + reciprocal_21 = torch.ops.aten.reciprocal.default(add_276); add_276 = None + mul_686 = torch.ops.aten.mul.Tensor(reciprocal_21, 1); reciprocal_21 = None + mul_687 = torch.ops.aten.mul.Tensor(convert_element_type_2218, mul_686); convert_element_type_2218 = None + sub_65 = torch.ops.aten.sub.Tensor(1, mul_686); mul_686 = None + mul_688 = torch.ops.aten.mul.Tensor(convert_element_type_356, sub_65); convert_element_type_356 = sub_65 = None + add_277 = torch.ops.aten.add.Tensor(mul_688, 1); mul_688 = None + mul_689 = torch.ops.aten.mul.Tensor(mul_687, add_277); mul_687 = add_277 = None + convert_element_type_2220 = torch.ops.prims.convert_element_type.default(mul_689, torch.bfloat16); mul_689 = None + view_2831 = torch.ops.aten.view.default(convert_element_type_2220, [16384, 1792]); convert_element_type_2220 = None + permute_1037 = torch.ops.aten.permute.default(view_2831, [1, 0]) + mm_525 = torch.ops.aten.mm.default(permute_1037, view_780); permute_1037 = view_780 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16); primals_100 = None + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 32, '0'); convert_element_type_353 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_141, [1, 0]); wait_tensor_141 = None + permute_1039 = torch.ops.aten.permute.default(permute_118, [1, 0]); permute_118 = None + mm_526 = torch.ops.aten.mm.default(view_2831, permute_1039); view_2831 = permute_1039 = None + view_2832 = torch.ops.aten.view.default(mm_526, [2, 8192, 4096]); mm_526 = None + add_278 = torch.ops.aten.add.Tensor(view_2830, view_2832); view_2830 = view_2832 = None + convert_element_type_2225 = torch.ops.prims.convert_element_type.default(mm_525, torch.float32); mm_525 = None + reduce_scatter_tensor_301 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2225, 'avg', 32, '0'); convert_element_type_2225 = None + wait_tensor_743 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_301); reduce_scatter_tensor_301 = None + split_224 = torch.ops.aten.split.Tensor(add_278, 1024, 1); add_278 = None + getitem_2143 = split_224[0] + getitem_2144 = split_224[1] + getitem_2145 = split_224[2] + getitem_2146 = split_224[3] + getitem_2147 = split_224[4] + getitem_2148 = split_224[5] + getitem_2149 = split_224[6] + getitem_2150 = split_224[7]; split_224 = None + cat_216 = torch.ops.aten.cat.default([getitem_2143, getitem_2144, getitem_2145, getitem_2146, getitem_2147, getitem_2148, getitem_2149, getitem_2150]); getitem_2143 = getitem_2144 = getitem_2145 = getitem_2146 = getitem_2147 = getitem_2148 = getitem_2149 = getitem_2150 = None + reduce_scatter_tensor_302 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_216, 'sum', 8, '1'); cat_216 = None + wait_tensor_744 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_302); reduce_scatter_tensor_302 = None + convert_element_type_2226 = torch.ops.prims.convert_element_type.default(wait_tensor_744, torch.float32); wait_tensor_744 = None + convert_element_type_2228 = torch.ops.prims.convert_element_type.default(wait_tensor_139, torch.float32); wait_tensor_139 = None + mul_690 = torch.ops.aten.mul.Tensor(convert_element_type_2226, convert_element_type_2228); convert_element_type_2228 = None + mul_692 = torch.ops.aten.mul.Tensor(mul_84, mul_690) + sum_129 = torch.ops.aten.sum.dim_IntList(mul_692, [2], True); mul_692 = None + div_43 = torch.ops.aten.div.Tensor(mul_84, 4096) + mul_693 = torch.ops.aten.mul.Tensor(div_43, sum_129); div_43 = sum_129 = None + sub_66 = torch.ops.aten.sub.Tensor(mul_690, mul_693); mul_690 = mul_693 = None + mul_694 = torch.ops.aten.mul.Tensor(sub_66, rsqrt_21); sub_66 = rsqrt_21 = None + mul_695 = torch.ops.aten.mul.Tensor(convert_element_type_2226, mul_84); convert_element_type_2226 = mul_84 = None + sum_130 = torch.ops.aten.sum.dim_IntList(mul_695, [0, 1]); mul_695 = None + convert_element_type_2229 = torch.ops.prims.convert_element_type.default(mul_694, torch.bfloat16); mul_694 = None + convert_element_type_2230 = torch.ops.prims.convert_element_type.default(sum_130, torch.bfloat16); sum_130 = None + all_reduce_43 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2230, 'sum', '1'); convert_element_type_2230 = None + wait_tensor_745 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_43); all_reduce_43 = None + convert_element_type_2231 = torch.ops.prims.convert_element_type.default(wait_tensor_745, torch.float32); wait_tensor_745 = None + reduce_scatter_tensor_303 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2231, 'avg', 32, '0'); convert_element_type_2231 = None + wait_tensor_746 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_303); reduce_scatter_tensor_303 = None + add_279 = torch.ops.aten.add.Tensor(add_275, convert_element_type_2229); add_275 = convert_element_type_2229 = None + all_gather_into_tensor_399 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_279, 8, '1') + wait_tensor_747 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_399); all_gather_into_tensor_399 = None + split_225 = torch.ops.aten.split.Tensor(wait_tensor_747, 2); wait_tensor_747 = None + getitem_2151 = split_225[0] + getitem_2152 = split_225[1] + getitem_2153 = split_225[2] + getitem_2154 = split_225[3] + getitem_2155 = split_225[4] + getitem_2156 = split_225[5] + getitem_2157 = split_225[6] + getitem_2158 = split_225[7]; split_225 = None + cat_217 = torch.ops.aten.cat.default([getitem_2151, getitem_2152, getitem_2153, getitem_2154, getitem_2155, getitem_2156, getitem_2157, getitem_2158], 1); getitem_2151 = getitem_2152 = getitem_2153 = getitem_2154 = getitem_2155 = getitem_2156 = getitem_2157 = getitem_2158 = None + view_2833 = torch.ops.aten.view.default(cat_217, [16384, 4096]); cat_217 = None + permute_1041 = torch.ops.aten.permute.default(view_2833, [1, 0]) + permute_116 = torch.ops.aten.permute.default(getitem_490, [0, 2, 1, 3]) + view_762 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + view_768 = torch.ops.aten.view.default(view_762, [16384, 512]); view_762 = None + mm_527 = torch.ops.aten.mm.default(permute_1041, view_768); permute_1041 = view_768 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16); primals_98 = None + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 32, '0'); convert_element_type_347 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + permute_1043 = torch.ops.aten.permute.default(permute_117, [1, 0]); permute_117 = None + mm_528 = torch.ops.aten.mm.default(view_2833, permute_1043); view_2833 = permute_1043 = None + view_2834 = torch.ops.aten.view.default(mm_528, [2, 8192, 512]); mm_528 = None + convert_element_type_2236 = torch.ops.prims.convert_element_type.default(mm_527, torch.float32); mm_527 = None + reduce_scatter_tensor_304 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2236, 'avg', 32, '0'); convert_element_type_2236 = None + wait_tensor_748 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_304); reduce_scatter_tensor_304 = None + view_2835 = torch.ops.aten.view.default(view_2834, [2, 8192, 4, 128]); view_2834 = None + permute_1045 = torch.ops.aten.permute.default(view_2835, [0, 2, 1, 3]); view_2835 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16); primals_94 = None + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 32, '0'); convert_element_type_331 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32); add_39 = None + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_132) + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_333, 8, '1'); convert_element_type_333 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + split_49 = torch.ops.aten.split.Tensor(wait_tensor_133, 2); wait_tensor_133 = None + getitem_482 = split_49[0] + getitem_483 = split_49[1] + getitem_484 = split_49[2] + getitem_485 = split_49[3] + getitem_486 = split_49[4] + getitem_487 = split_49[5] + getitem_488 = split_49[6] + getitem_489 = split_49[7]; split_49 = None + cat_41 = torch.ops.aten.cat.default([getitem_482, getitem_483, getitem_484, getitem_485, getitem_486, getitem_487, getitem_488, getitem_489], 1); getitem_482 = getitem_483 = getitem_484 = getitem_485 = getitem_486 = getitem_487 = getitem_488 = getitem_489 = None + view_735 = torch.ops.aten.view.default(cat_41, [16384, 4096]); cat_41 = None + view_736 = torch.ops.aten.view.default(mm_70, [2, 8192, 512]); mm_70 = None + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16); primals_96 = None + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 32, '0'); convert_element_type_337 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + mm_71 = torch.ops.aten.mm.default(view_735, permute_111) + view_743 = torch.ops.aten.view.default(mm_71, [2, 8192, 128]); mm_71 = None + view_750 = torch.ops.aten.view.default(mm_72, [2, 8192, 128]); mm_72 = None + view_752 = torch.ops.aten.view.default(view_736, [2, 8192, -1, 128]); view_736 = None + view_753 = torch.ops.aten.view.default(view_743, [2, 8192, -1, 128]); view_743 = None + view_754 = torch.ops.aten.view.default(view_750, [2, 8192, -1, 128]); view_750 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_752, torch.float32); view_752 = None + view_755 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 4, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_755); view_755 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_753, torch.float32); view_753 = None + view_756 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 1, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_756); view_756 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_37); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_758 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 4, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_37); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_759 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 1, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_758, torch.bfloat16); view_758 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_759, torch.bfloat16); view_759 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 1, 4, 128]); unsqueeze_20 = None + view_760 = torch.ops.aten.view.default(expand_20, [2, 8192, 4, 128]); expand_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_754, 3); view_754 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 1, 4, 128]); unsqueeze_21 = None + view_761 = torch.ops.aten.view.default(expand_21, [2, 8192, 4, 128]); expand_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_760, [0, 2, 1, 3]); view_760 = None + permute_115 = torch.ops.aten.permute.default(view_761, [0, 2, 1, 3]); view_761 = None + _scaled_dot_product_cudnn_attention_backward_21 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1045, permute_113, permute_114, permute_115, getitem_490, getitem_491, getitem_496, getitem_497, None, None, None, 8192, 8192, 0.0, True); permute_1045 = permute_113 = permute_114 = permute_115 = getitem_490 = getitem_491 = getitem_496 = getitem_497 = None + getitem_2159 = _scaled_dot_product_cudnn_attention_backward_21[0] + getitem_2160 = _scaled_dot_product_cudnn_attention_backward_21[1] + getitem_2161 = _scaled_dot_product_cudnn_attention_backward_21[2]; _scaled_dot_product_cudnn_attention_backward_21 = None + permute_1046 = torch.ops.aten.permute.default(getitem_2161, [0, 2, 1, 3]); getitem_2161 = None + permute_1047 = torch.ops.aten.permute.default(getitem_2160, [0, 2, 1, 3]); getitem_2160 = None + permute_1048 = torch.ops.aten.permute.default(getitem_2159, [0, 2, 1, 3]); getitem_2159 = None + view_2836 = torch.ops.aten.view.default(permute_1046, [2, 8192, 1, 4, 128]); permute_1046 = None + sum_131 = torch.ops.aten.sum.dim_IntList(view_2836, [3], True); view_2836 = None + squeeze_42 = torch.ops.aten.squeeze.dim(sum_131, 3); sum_131 = None + view_2837 = torch.ops.aten.view.default(permute_1047, [2, 8192, 1, 4, 128]); permute_1047 = None + sum_132 = torch.ops.aten.sum.dim_IntList(view_2837, [3], True); view_2837 = None + squeeze_43 = torch.ops.aten.squeeze.dim(sum_132, 3); sum_132 = None + convert_element_type_2237 = torch.ops.prims.convert_element_type.default(squeeze_43, torch.float32); squeeze_43 = None + convert_element_type_2238 = torch.ops.prims.convert_element_type.default(permute_1048, torch.float32); permute_1048 = None + view_2838 = torch.ops.aten.view.default(convert_element_type_2237, [2, 8192, 1, 64, 2]); convert_element_type_2237 = None + view_as_complex_106 = torch.ops.aten.view_as_complex.default(view_2838); view_2838 = None + mul_696 = torch.ops.aten.mul.Tensor(view_as_complex_106, _conj); view_as_complex_106 = None + view_2839 = torch.ops.aten.view.default(convert_element_type_2238, [2, 8192, 4, 64, 2]); convert_element_type_2238 = None + view_as_complex_107 = torch.ops.aten.view_as_complex.default(view_2839); view_2839 = None + mul_697 = torch.ops.aten.mul.Tensor(view_as_complex_107, _conj); view_as_complex_107 = None + view_as_real_106 = torch.ops.aten.view_as_real.default(mul_696); mul_696 = None + view_2840 = torch.ops.aten.view.default(view_as_real_106, [2, 8192, 1, 128]); view_as_real_106 = None + convert_element_type_2239 = torch.ops.prims.convert_element_type.default(view_2840, torch.bfloat16); view_2840 = None + view_as_real_107 = torch.ops.aten.view_as_real.default(mul_697); mul_697 = None + view_2841 = torch.ops.aten.view.default(view_as_real_107, [2, 8192, 4, 128]); view_as_real_107 = None + convert_element_type_2240 = torch.ops.prims.convert_element_type.default(view_2841, torch.bfloat16); view_2841 = None + view_2842 = torch.ops.aten.view.default(squeeze_42, [2, 8192, 128]); squeeze_42 = None + view_2843 = torch.ops.aten.view.default(convert_element_type_2239, [2, 8192, 128]); convert_element_type_2239 = None + view_2844 = torch.ops.aten.view.default(convert_element_type_2240, [2, 8192, 512]); convert_element_type_2240 = None + view_2845 = torch.ops.aten.view.default(view_2842, [16384, 128]); view_2842 = None + permute_1049 = torch.ops.aten.permute.default(view_2845, [1, 0]) + mm_529 = torch.ops.aten.mm.default(permute_1049, view_735); permute_1049 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16); primals_97 = None + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 32, '0'); convert_element_type_340 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_136, [1, 0]); wait_tensor_136 = None + permute_1051 = torch.ops.aten.permute.default(permute_112, [1, 0]); permute_112 = None + mm_530 = torch.ops.aten.mm.default(view_2845, permute_1051); view_2845 = permute_1051 = None + view_2846 = torch.ops.aten.view.default(mm_530, [2, 8192, 4096]); mm_530 = None + convert_element_type_2245 = torch.ops.prims.convert_element_type.default(mm_529, torch.float32); mm_529 = None + reduce_scatter_tensor_305 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2245, 'avg', 32, '0'); convert_element_type_2245 = None + wait_tensor_749 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_305); reduce_scatter_tensor_305 = None + view_2847 = torch.ops.aten.view.default(view_2843, [16384, 128]); view_2843 = None + permute_1053 = torch.ops.aten.permute.default(view_2847, [1, 0]) + mm_531 = torch.ops.aten.mm.default(permute_1053, view_735); permute_1053 = None + permute_1055 = torch.ops.aten.permute.default(permute_111, [1, 0]); permute_111 = None + mm_532 = torch.ops.aten.mm.default(view_2847, permute_1055); view_2847 = permute_1055 = None + view_2848 = torch.ops.aten.view.default(mm_532, [2, 8192, 4096]); mm_532 = None + add_280 = torch.ops.aten.add.Tensor(view_2846, view_2848); view_2846 = view_2848 = None + convert_element_type_2250 = torch.ops.prims.convert_element_type.default(mm_531, torch.float32); mm_531 = None + reduce_scatter_tensor_306 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2250, 'avg', 32, '0'); convert_element_type_2250 = None + wait_tensor_750 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_306); reduce_scatter_tensor_306 = None + view_2849 = torch.ops.aten.view.default(view_2844, [16384, 512]); view_2844 = None + permute_1057 = torch.ops.aten.permute.default(view_2849, [1, 0]) + mm_533 = torch.ops.aten.mm.default(permute_1057, view_735); permute_1057 = view_735 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16); primals_95 = None + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 32, '0'); convert_element_type_334 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + permute_1059 = torch.ops.aten.permute.default(permute_110, [1, 0]); permute_110 = None + mm_534 = torch.ops.aten.mm.default(view_2849, permute_1059); view_2849 = permute_1059 = None + view_2850 = torch.ops.aten.view.default(mm_534, [2, 8192, 4096]); mm_534 = None + add_281 = torch.ops.aten.add.Tensor(add_280, view_2850); add_280 = view_2850 = None + convert_element_type_2255 = torch.ops.prims.convert_element_type.default(mm_533, torch.float32); mm_533 = None + reduce_scatter_tensor_307 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2255, 'avg', 32, '0'); convert_element_type_2255 = None + wait_tensor_751 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_307); reduce_scatter_tensor_307 = None + split_226 = torch.ops.aten.split.Tensor(add_281, 1024, 1); add_281 = None + getitem_2162 = split_226[0] + getitem_2163 = split_226[1] + getitem_2164 = split_226[2] + getitem_2165 = split_226[3] + getitem_2166 = split_226[4] + getitem_2167 = split_226[5] + getitem_2168 = split_226[6] + getitem_2169 = split_226[7]; split_226 = None + cat_218 = torch.ops.aten.cat.default([getitem_2162, getitem_2163, getitem_2164, getitem_2165, getitem_2166, getitem_2167, getitem_2168, getitem_2169]); getitem_2162 = getitem_2163 = getitem_2164 = getitem_2165 = getitem_2166 = getitem_2167 = getitem_2168 = getitem_2169 = None + reduce_scatter_tensor_308 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_218, 'sum', 8, '1'); cat_218 = None + wait_tensor_752 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_308); reduce_scatter_tensor_308 = None + convert_element_type_2256 = torch.ops.prims.convert_element_type.default(wait_tensor_752, torch.float32); wait_tensor_752 = None + convert_element_type_2258 = torch.ops.prims.convert_element_type.default(wait_tensor_132, torch.float32); wait_tensor_132 = None + mul_698 = torch.ops.aten.mul.Tensor(convert_element_type_2256, convert_element_type_2258); convert_element_type_2258 = None + mul_700 = torch.ops.aten.mul.Tensor(mul_80, mul_698) + sum_133 = torch.ops.aten.sum.dim_IntList(mul_700, [2], True); mul_700 = None + div_44 = torch.ops.aten.div.Tensor(mul_80, 4096) + mul_701 = torch.ops.aten.mul.Tensor(div_44, sum_133); div_44 = sum_133 = None + sub_67 = torch.ops.aten.sub.Tensor(mul_698, mul_701); mul_698 = mul_701 = None + mul_702 = torch.ops.aten.mul.Tensor(sub_67, rsqrt_20); sub_67 = rsqrt_20 = None + mul_703 = torch.ops.aten.mul.Tensor(convert_element_type_2256, mul_80); convert_element_type_2256 = mul_80 = None + sum_134 = torch.ops.aten.sum.dim_IntList(mul_703, [0, 1]); mul_703 = None + convert_element_type_2259 = torch.ops.prims.convert_element_type.default(mul_702, torch.bfloat16); mul_702 = None + convert_element_type_2260 = torch.ops.prims.convert_element_type.default(sum_134, torch.bfloat16); sum_134 = None + all_reduce_44 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2260, 'sum', '1'); convert_element_type_2260 = None + wait_tensor_753 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_44); all_reduce_44 = None + convert_element_type_2261 = torch.ops.prims.convert_element_type.default(wait_tensor_753, torch.float32); wait_tensor_753 = None + reduce_scatter_tensor_309 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2261, 'avg', 32, '0'); convert_element_type_2261 = None + wait_tensor_754 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_309); reduce_scatter_tensor_309 = None + add_282 = torch.ops.aten.add.Tensor(add_279, convert_element_type_2259); add_279 = convert_element_type_2259 = None + all_gather_into_tensor_400 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_282, 8, '1') + wait_tensor_755 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_400); all_gather_into_tensor_400 = None + split_227 = torch.ops.aten.split.Tensor(wait_tensor_755, 2); wait_tensor_755 = None + getitem_2170 = split_227[0] + getitem_2171 = split_227[1] + getitem_2172 = split_227[2] + getitem_2173 = split_227[3] + getitem_2174 = split_227[4] + getitem_2175 = split_227[5] + getitem_2176 = split_227[6] + getitem_2177 = split_227[7]; split_227 = None + cat_219 = torch.ops.aten.cat.default([getitem_2170, getitem_2171, getitem_2172, getitem_2173, getitem_2174, getitem_2175, getitem_2176, getitem_2177], 1); getitem_2170 = getitem_2171 = getitem_2172 = getitem_2173 = getitem_2174 = getitem_2175 = getitem_2176 = getitem_2177 = None + view_2851 = torch.ops.aten.view.default(cat_219, [16384, 4096]); cat_219 = None + permute_1061 = torch.ops.aten.permute.default(view_2851, [1, 0]) + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19); reduce_scatter_tensor_19 = None + add_37 = torch.ops.aten.add.Tensor(add_35, wait_tensor_125); wait_tensor_125 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16); primals_90 = None + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 32, '0'); convert_element_type_317 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32); add_37 = None + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_126) + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_319, 8, '1'); convert_element_type_319 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + split_47 = torch.ops.aten.split.Tensor(wait_tensor_127, 2); wait_tensor_127 = None + getitem_466 = split_47[0] + getitem_467 = split_47[1] + getitem_468 = split_47[2] + getitem_469 = split_47[3] + getitem_470 = split_47[4] + getitem_471 = split_47[5] + getitem_472 = split_47[6] + getitem_473 = split_47[7]; split_47 = None + cat_39 = torch.ops.aten.cat.default([getitem_466, getitem_467, getitem_468, getitem_469, getitem_470, getitem_471, getitem_472, getitem_473], 1); getitem_466 = getitem_467 = getitem_468 = getitem_469 = getitem_470 = getitem_471 = getitem_472 = getitem_473 = None + view_708 = torch.ops.aten.view.default(cat_39, [16384, 4096]); cat_39 = None + view_709 = torch.ops.aten.view.default(mm_67, [2, 8192, 1792]); mm_67 = None + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_709, torch.float32); view_709 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16); primals_92 = None + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 32, '0'); convert_element_type_325 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_68 = torch.ops.aten.mm.default(view_708, permute_108) + view_716 = torch.ops.aten.view.default(mm_68, [2, 8192, 1792]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_716) + view_723 = torch.ops.aten.view.default(mul_79, [16384, 1792]); mul_79 = None + mm_535 = torch.ops.aten.mm.default(permute_1061, view_723); permute_1061 = view_723 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16); primals_93 = None + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 32, '0'); convert_element_type_328 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + permute_1063 = torch.ops.aten.permute.default(permute_109, [1, 0]); permute_109 = None + mm_536 = torch.ops.aten.mm.default(view_2851, permute_1063); view_2851 = permute_1063 = None + view_2852 = torch.ops.aten.view.default(mm_536, [2, 8192, 1792]); mm_536 = None + convert_element_type_2266 = torch.ops.prims.convert_element_type.default(mm_535, torch.float32); mm_535 = None + reduce_scatter_tensor_310 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2266, 'avg', 32, '0'); convert_element_type_2266 = None + wait_tensor_756 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_310); reduce_scatter_tensor_310 = None + mul_704 = torch.ops.aten.mul.Tensor(view_2852, convert_element_type_324); convert_element_type_324 = None + mul_705 = torch.ops.aten.mul.Tensor(view_2852, view_716); view_2852 = view_716 = None + view_2853 = torch.ops.aten.view.default(mul_704, [16384, 1792]); mul_704 = None + permute_1065 = torch.ops.aten.permute.default(view_2853, [1, 0]) + mm_537 = torch.ops.aten.mm.default(permute_1065, view_708); permute_1065 = None + permute_1067 = torch.ops.aten.permute.default(permute_108, [1, 0]); permute_108 = None + mm_538 = torch.ops.aten.mm.default(view_2853, permute_1067); view_2853 = permute_1067 = None + view_2854 = torch.ops.aten.view.default(mm_538, [2, 8192, 4096]); mm_538 = None + convert_element_type_2271 = torch.ops.prims.convert_element_type.default(mm_537, torch.float32); mm_537 = None + reduce_scatter_tensor_311 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2271, 'avg', 32, '0'); convert_element_type_2271 = None + wait_tensor_757 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_311); reduce_scatter_tensor_311 = None + convert_element_type_2272 = torch.ops.prims.convert_element_type.default(mul_705, torch.float32); mul_705 = None + neg_22 = torch.ops.aten.neg.default(convert_element_type_323) + exp_22 = torch.ops.aten.exp.default(neg_22); neg_22 = None + add_283 = torch.ops.aten.add.Tensor(exp_22, 1); exp_22 = None + reciprocal_22 = torch.ops.aten.reciprocal.default(add_283); add_283 = None + mul_706 = torch.ops.aten.mul.Tensor(reciprocal_22, 1); reciprocal_22 = None + mul_707 = torch.ops.aten.mul.Tensor(convert_element_type_2272, mul_706); convert_element_type_2272 = None + sub_68 = torch.ops.aten.sub.Tensor(1, mul_706); mul_706 = None + mul_708 = torch.ops.aten.mul.Tensor(convert_element_type_323, sub_68); convert_element_type_323 = sub_68 = None + add_284 = torch.ops.aten.add.Tensor(mul_708, 1); mul_708 = None + mul_709 = torch.ops.aten.mul.Tensor(mul_707, add_284); mul_707 = add_284 = None + convert_element_type_2274 = torch.ops.prims.convert_element_type.default(mul_709, torch.bfloat16); mul_709 = None + view_2855 = torch.ops.aten.view.default(convert_element_type_2274, [16384, 1792]); convert_element_type_2274 = None + permute_1069 = torch.ops.aten.permute.default(view_2855, [1, 0]) + mm_539 = torch.ops.aten.mm.default(permute_1069, view_708); permute_1069 = view_708 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16); primals_91 = None + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 32, '0'); convert_element_type_320 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + permute_1071 = torch.ops.aten.permute.default(permute_107, [1, 0]); permute_107 = None + mm_540 = torch.ops.aten.mm.default(view_2855, permute_1071); view_2855 = permute_1071 = None + view_2856 = torch.ops.aten.view.default(mm_540, [2, 8192, 4096]); mm_540 = None + add_285 = torch.ops.aten.add.Tensor(view_2854, view_2856); view_2854 = view_2856 = None + convert_element_type_2279 = torch.ops.prims.convert_element_type.default(mm_539, torch.float32); mm_539 = None + reduce_scatter_tensor_312 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2279, 'avg', 32, '0'); convert_element_type_2279 = None + wait_tensor_758 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_312); reduce_scatter_tensor_312 = None + split_228 = torch.ops.aten.split.Tensor(add_285, 1024, 1); add_285 = None + getitem_2178 = split_228[0] + getitem_2179 = split_228[1] + getitem_2180 = split_228[2] + getitem_2181 = split_228[3] + getitem_2182 = split_228[4] + getitem_2183 = split_228[5] + getitem_2184 = split_228[6] + getitem_2185 = split_228[7]; split_228 = None + cat_220 = torch.ops.aten.cat.default([getitem_2178, getitem_2179, getitem_2180, getitem_2181, getitem_2182, getitem_2183, getitem_2184, getitem_2185]); getitem_2178 = getitem_2179 = getitem_2180 = getitem_2181 = getitem_2182 = getitem_2183 = getitem_2184 = getitem_2185 = None + reduce_scatter_tensor_313 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_220, 'sum', 8, '1'); cat_220 = None + wait_tensor_759 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_313); reduce_scatter_tensor_313 = None + convert_element_type_2280 = torch.ops.prims.convert_element_type.default(wait_tensor_759, torch.float32); wait_tensor_759 = None + convert_element_type_2282 = torch.ops.prims.convert_element_type.default(wait_tensor_126, torch.float32); wait_tensor_126 = None + mul_710 = torch.ops.aten.mul.Tensor(convert_element_type_2280, convert_element_type_2282); convert_element_type_2282 = None + mul_712 = torch.ops.aten.mul.Tensor(mul_76, mul_710) + sum_135 = torch.ops.aten.sum.dim_IntList(mul_712, [2], True); mul_712 = None + div_45 = torch.ops.aten.div.Tensor(mul_76, 4096) + mul_713 = torch.ops.aten.mul.Tensor(div_45, sum_135); div_45 = sum_135 = None + sub_69 = torch.ops.aten.sub.Tensor(mul_710, mul_713); mul_710 = mul_713 = None + mul_714 = torch.ops.aten.mul.Tensor(sub_69, rsqrt_19); sub_69 = rsqrt_19 = None + mul_715 = torch.ops.aten.mul.Tensor(convert_element_type_2280, mul_76); convert_element_type_2280 = mul_76 = None + sum_136 = torch.ops.aten.sum.dim_IntList(mul_715, [0, 1]); mul_715 = None + convert_element_type_2283 = torch.ops.prims.convert_element_type.default(mul_714, torch.bfloat16); mul_714 = None + convert_element_type_2284 = torch.ops.prims.convert_element_type.default(sum_136, torch.bfloat16); sum_136 = None + all_reduce_45 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2284, 'sum', '1'); convert_element_type_2284 = None + wait_tensor_760 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_45); all_reduce_45 = None + convert_element_type_2285 = torch.ops.prims.convert_element_type.default(wait_tensor_760, torch.float32); wait_tensor_760 = None + reduce_scatter_tensor_314 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2285, 'avg', 32, '0'); convert_element_type_2285 = None + wait_tensor_761 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_314); reduce_scatter_tensor_314 = None + add_286 = torch.ops.aten.add.Tensor(add_282, convert_element_type_2283); add_282 = convert_element_type_2283 = None + all_gather_into_tensor_401 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_286, 8, '1') + wait_tensor_762 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_401); all_gather_into_tensor_401 = None + split_229 = torch.ops.aten.split.Tensor(wait_tensor_762, 2); wait_tensor_762 = None + getitem_2186 = split_229[0] + getitem_2187 = split_229[1] + getitem_2188 = split_229[2] + getitem_2189 = split_229[3] + getitem_2190 = split_229[4] + getitem_2191 = split_229[5] + getitem_2192 = split_229[6] + getitem_2193 = split_229[7]; split_229 = None + cat_221 = torch.ops.aten.cat.default([getitem_2186, getitem_2187, getitem_2188, getitem_2189, getitem_2190, getitem_2191, getitem_2192, getitem_2193], 1); getitem_2186 = getitem_2187 = getitem_2188 = getitem_2189 = getitem_2190 = getitem_2191 = getitem_2192 = getitem_2193 = None + view_2857 = torch.ops.aten.view.default(cat_221, [16384, 4096]); cat_221 = None + permute_1073 = torch.ops.aten.permute.default(view_2857, [1, 0]) + permute_105 = torch.ops.aten.permute.default(getitem_449, [0, 2, 1, 3]) + view_690 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + view_696 = torch.ops.aten.view.default(view_690, [16384, 512]); view_690 = None + mm_541 = torch.ops.aten.mm.default(permute_1073, view_696); permute_1073 = view_696 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16); primals_89 = None + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 32, '0'); convert_element_type_314 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + permute_1075 = torch.ops.aten.permute.default(permute_106, [1, 0]); permute_106 = None + mm_542 = torch.ops.aten.mm.default(view_2857, permute_1075); view_2857 = permute_1075 = None + view_2858 = torch.ops.aten.view.default(mm_542, [2, 8192, 512]); mm_542 = None + convert_element_type_2290 = torch.ops.prims.convert_element_type.default(mm_541, torch.float32); mm_541 = None + reduce_scatter_tensor_315 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2290, 'avg', 32, '0'); convert_element_type_2290 = None + wait_tensor_763 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_315); reduce_scatter_tensor_315 = None + view_2859 = torch.ops.aten.view.default(view_2858, [2, 8192, 4, 128]); view_2858 = None + permute_1077 = torch.ops.aten.permute.default(view_2859, [0, 2, 1, 3]); view_2859 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16); primals_85 = None + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 32, '0'); convert_element_type_298 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32); add_35 = None + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_119) + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_300, 8, '1'); convert_element_type_300 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + split_45 = torch.ops.aten.split.Tensor(wait_tensor_120, 2); wait_tensor_120 = None + getitem_441 = split_45[0] + getitem_442 = split_45[1] + getitem_443 = split_45[2] + getitem_444 = split_45[3] + getitem_445 = split_45[4] + getitem_446 = split_45[5] + getitem_447 = split_45[6] + getitem_448 = split_45[7]; split_45 = None + cat_37 = torch.ops.aten.cat.default([getitem_441, getitem_442, getitem_443, getitem_444, getitem_445, getitem_446, getitem_447, getitem_448], 1); getitem_441 = getitem_442 = getitem_443 = getitem_444 = getitem_445 = getitem_446 = getitem_447 = getitem_448 = None + view_663 = torch.ops.aten.view.default(cat_37, [16384, 4096]); cat_37 = None + view_664 = torch.ops.aten.view.default(mm_63, [2, 8192, 512]); mm_63 = None + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16); primals_87 = None + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 32, '0'); convert_element_type_304 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + mm_64 = torch.ops.aten.mm.default(view_663, permute_100) + view_671 = torch.ops.aten.view.default(mm_64, [2, 8192, 128]); mm_64 = None + view_678 = torch.ops.aten.view.default(mm_65, [2, 8192, 128]); mm_65 = None + view_680 = torch.ops.aten.view.default(view_664, [2, 8192, -1, 128]); view_664 = None + view_681 = torch.ops.aten.view.default(view_671, [2, 8192, -1, 128]); view_671 = None + view_682 = torch.ops.aten.view.default(view_678, [2, 8192, -1, 128]); view_678 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_680, torch.float32); view_680 = None + view_683 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 4, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_683); view_683 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_681, torch.float32); view_681 = None + view_684 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 1, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_684); view_684 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_37); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_686 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 4, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_37); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_687 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 1, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_686, torch.bfloat16); view_686 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_687, torch.bfloat16); view_687 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 1, 4, 128]); unsqueeze_18 = None + view_688 = torch.ops.aten.view.default(expand_18, [2, 8192, 4, 128]); expand_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_682, 3); view_682 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 1, 4, 128]); unsqueeze_19 = None + view_689 = torch.ops.aten.view.default(expand_19, [2, 8192, 4, 128]); expand_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_688, [0, 2, 1, 3]); view_688 = None + permute_104 = torch.ops.aten.permute.default(view_689, [0, 2, 1, 3]); view_689 = None + _scaled_dot_product_cudnn_attention_backward_22 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1077, permute_102, permute_103, permute_104, getitem_449, getitem_450, getitem_455, getitem_456, None, None, None, 8192, 8192, 0.0, True); permute_1077 = permute_102 = permute_103 = permute_104 = getitem_449 = getitem_450 = getitem_455 = getitem_456 = None + getitem_2194 = _scaled_dot_product_cudnn_attention_backward_22[0] + getitem_2195 = _scaled_dot_product_cudnn_attention_backward_22[1] + getitem_2196 = _scaled_dot_product_cudnn_attention_backward_22[2]; _scaled_dot_product_cudnn_attention_backward_22 = None + permute_1078 = torch.ops.aten.permute.default(getitem_2196, [0, 2, 1, 3]); getitem_2196 = None + permute_1079 = torch.ops.aten.permute.default(getitem_2195, [0, 2, 1, 3]); getitem_2195 = None + permute_1080 = torch.ops.aten.permute.default(getitem_2194, [0, 2, 1, 3]); getitem_2194 = None + view_2860 = torch.ops.aten.view.default(permute_1078, [2, 8192, 1, 4, 128]); permute_1078 = None + sum_137 = torch.ops.aten.sum.dim_IntList(view_2860, [3], True); view_2860 = None + squeeze_44 = torch.ops.aten.squeeze.dim(sum_137, 3); sum_137 = None + view_2861 = torch.ops.aten.view.default(permute_1079, [2, 8192, 1, 4, 128]); permute_1079 = None + sum_138 = torch.ops.aten.sum.dim_IntList(view_2861, [3], True); view_2861 = None + squeeze_45 = torch.ops.aten.squeeze.dim(sum_138, 3); sum_138 = None + convert_element_type_2291 = torch.ops.prims.convert_element_type.default(squeeze_45, torch.float32); squeeze_45 = None + convert_element_type_2292 = torch.ops.prims.convert_element_type.default(permute_1080, torch.float32); permute_1080 = None + view_2862 = torch.ops.aten.view.default(convert_element_type_2291, [2, 8192, 1, 64, 2]); convert_element_type_2291 = None + view_as_complex_108 = torch.ops.aten.view_as_complex.default(view_2862); view_2862 = None + mul_716 = torch.ops.aten.mul.Tensor(view_as_complex_108, _conj); view_as_complex_108 = None + view_2863 = torch.ops.aten.view.default(convert_element_type_2292, [2, 8192, 4, 64, 2]); convert_element_type_2292 = None + view_as_complex_109 = torch.ops.aten.view_as_complex.default(view_2863); view_2863 = None + mul_717 = torch.ops.aten.mul.Tensor(view_as_complex_109, _conj); view_as_complex_109 = None + view_as_real_108 = torch.ops.aten.view_as_real.default(mul_716); mul_716 = None + view_2864 = torch.ops.aten.view.default(view_as_real_108, [2, 8192, 1, 128]); view_as_real_108 = None + convert_element_type_2293 = torch.ops.prims.convert_element_type.default(view_2864, torch.bfloat16); view_2864 = None + view_as_real_109 = torch.ops.aten.view_as_real.default(mul_717); mul_717 = None + view_2865 = torch.ops.aten.view.default(view_as_real_109, [2, 8192, 4, 128]); view_as_real_109 = None + convert_element_type_2294 = torch.ops.prims.convert_element_type.default(view_2865, torch.bfloat16); view_2865 = None + view_2866 = torch.ops.aten.view.default(squeeze_44, [2, 8192, 128]); squeeze_44 = None + view_2867 = torch.ops.aten.view.default(convert_element_type_2293, [2, 8192, 128]); convert_element_type_2293 = None + view_2868 = torch.ops.aten.view.default(convert_element_type_2294, [2, 8192, 512]); convert_element_type_2294 = None + view_2869 = torch.ops.aten.view.default(view_2866, [16384, 128]); view_2866 = None + permute_1081 = torch.ops.aten.permute.default(view_2869, [1, 0]) + mm_543 = torch.ops.aten.mm.default(permute_1081, view_663); permute_1081 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16); primals_88 = None + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 32, '0'); convert_element_type_307 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_123, [1, 0]); wait_tensor_123 = None + permute_1083 = torch.ops.aten.permute.default(permute_101, [1, 0]); permute_101 = None + mm_544 = torch.ops.aten.mm.default(view_2869, permute_1083); view_2869 = permute_1083 = None + view_2870 = torch.ops.aten.view.default(mm_544, [2, 8192, 4096]); mm_544 = None + convert_element_type_2299 = torch.ops.prims.convert_element_type.default(mm_543, torch.float32); mm_543 = None + reduce_scatter_tensor_316 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2299, 'avg', 32, '0'); convert_element_type_2299 = None + wait_tensor_764 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_316); reduce_scatter_tensor_316 = None + view_2871 = torch.ops.aten.view.default(view_2867, [16384, 128]); view_2867 = None + permute_1085 = torch.ops.aten.permute.default(view_2871, [1, 0]) + mm_545 = torch.ops.aten.mm.default(permute_1085, view_663); permute_1085 = None + permute_1087 = torch.ops.aten.permute.default(permute_100, [1, 0]); permute_100 = None + mm_546 = torch.ops.aten.mm.default(view_2871, permute_1087); view_2871 = permute_1087 = None + view_2872 = torch.ops.aten.view.default(mm_546, [2, 8192, 4096]); mm_546 = None + add_287 = torch.ops.aten.add.Tensor(view_2870, view_2872); view_2870 = view_2872 = None + convert_element_type_2304 = torch.ops.prims.convert_element_type.default(mm_545, torch.float32); mm_545 = None + reduce_scatter_tensor_317 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2304, 'avg', 32, '0'); convert_element_type_2304 = None + wait_tensor_765 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_317); reduce_scatter_tensor_317 = None + view_2873 = torch.ops.aten.view.default(view_2868, [16384, 512]); view_2868 = None + permute_1089 = torch.ops.aten.permute.default(view_2873, [1, 0]) + mm_547 = torch.ops.aten.mm.default(permute_1089, view_663); permute_1089 = view_663 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16); primals_86 = None + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 32, '0'); convert_element_type_301 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + permute_1091 = torch.ops.aten.permute.default(permute_99, [1, 0]); permute_99 = None + mm_548 = torch.ops.aten.mm.default(view_2873, permute_1091); view_2873 = permute_1091 = None + view_2874 = torch.ops.aten.view.default(mm_548, [2, 8192, 4096]); mm_548 = None + add_288 = torch.ops.aten.add.Tensor(add_287, view_2874); add_287 = view_2874 = None + convert_element_type_2309 = torch.ops.prims.convert_element_type.default(mm_547, torch.float32); mm_547 = None + reduce_scatter_tensor_318 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2309, 'avg', 32, '0'); convert_element_type_2309 = None + wait_tensor_766 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_318); reduce_scatter_tensor_318 = None + split_230 = torch.ops.aten.split.Tensor(add_288, 1024, 1); add_288 = None + getitem_2197 = split_230[0] + getitem_2198 = split_230[1] + getitem_2199 = split_230[2] + getitem_2200 = split_230[3] + getitem_2201 = split_230[4] + getitem_2202 = split_230[5] + getitem_2203 = split_230[6] + getitem_2204 = split_230[7]; split_230 = None + cat_222 = torch.ops.aten.cat.default([getitem_2197, getitem_2198, getitem_2199, getitem_2200, getitem_2201, getitem_2202, getitem_2203, getitem_2204]); getitem_2197 = getitem_2198 = getitem_2199 = getitem_2200 = getitem_2201 = getitem_2202 = getitem_2203 = getitem_2204 = None + reduce_scatter_tensor_319 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_222, 'sum', 8, '1'); cat_222 = None + wait_tensor_767 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_319); reduce_scatter_tensor_319 = None + convert_element_type_2310 = torch.ops.prims.convert_element_type.default(wait_tensor_767, torch.float32); wait_tensor_767 = None + convert_element_type_2312 = torch.ops.prims.convert_element_type.default(wait_tensor_119, torch.float32); wait_tensor_119 = None + mul_718 = torch.ops.aten.mul.Tensor(convert_element_type_2310, convert_element_type_2312); convert_element_type_2312 = None + mul_720 = torch.ops.aten.mul.Tensor(mul_72, mul_718) + sum_139 = torch.ops.aten.sum.dim_IntList(mul_720, [2], True); mul_720 = None + div_46 = torch.ops.aten.div.Tensor(mul_72, 4096) + mul_721 = torch.ops.aten.mul.Tensor(div_46, sum_139); div_46 = sum_139 = None + sub_70 = torch.ops.aten.sub.Tensor(mul_718, mul_721); mul_718 = mul_721 = None + mul_722 = torch.ops.aten.mul.Tensor(sub_70, rsqrt_18); sub_70 = rsqrt_18 = None + mul_723 = torch.ops.aten.mul.Tensor(convert_element_type_2310, mul_72); convert_element_type_2310 = mul_72 = None + sum_140 = torch.ops.aten.sum.dim_IntList(mul_723, [0, 1]); mul_723 = None + convert_element_type_2313 = torch.ops.prims.convert_element_type.default(mul_722, torch.bfloat16); mul_722 = None + convert_element_type_2314 = torch.ops.prims.convert_element_type.default(sum_140, torch.bfloat16); sum_140 = None + all_reduce_46 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2314, 'sum', '1'); convert_element_type_2314 = None + wait_tensor_768 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_46); all_reduce_46 = None + convert_element_type_2315 = torch.ops.prims.convert_element_type.default(wait_tensor_768, torch.float32); wait_tensor_768 = None + reduce_scatter_tensor_320 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2315, 'avg', 32, '0'); convert_element_type_2315 = None + wait_tensor_769 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_320); reduce_scatter_tensor_320 = None + add_289 = torch.ops.aten.add.Tensor(add_286, convert_element_type_2313); add_286 = convert_element_type_2313 = None + all_gather_into_tensor_402 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_289, 8, '1') + wait_tensor_770 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_402); all_gather_into_tensor_402 = None + split_231 = torch.ops.aten.split.Tensor(wait_tensor_770, 2); wait_tensor_770 = None + getitem_2205 = split_231[0] + getitem_2206 = split_231[1] + getitem_2207 = split_231[2] + getitem_2208 = split_231[3] + getitem_2209 = split_231[4] + getitem_2210 = split_231[5] + getitem_2211 = split_231[6] + getitem_2212 = split_231[7]; split_231 = None + cat_223 = torch.ops.aten.cat.default([getitem_2205, getitem_2206, getitem_2207, getitem_2208, getitem_2209, getitem_2210, getitem_2211, getitem_2212], 1); getitem_2205 = getitem_2206 = getitem_2207 = getitem_2208 = getitem_2209 = getitem_2210 = getitem_2211 = getitem_2212 = None + view_2875 = torch.ops.aten.view.default(cat_223, [16384, 4096]); cat_223 = None + permute_1093 = torch.ops.aten.permute.default(view_2875, [1, 0]) + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17); reduce_scatter_tensor_17 = None + add_33 = torch.ops.aten.add.Tensor(add_31, wait_tensor_112); wait_tensor_112 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16); primals_81 = None + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 32, '0'); convert_element_type_284 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32); add_33 = None + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_113) + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_286, 8, '1'); convert_element_type_286 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + split_43 = torch.ops.aten.split.Tensor(wait_tensor_114, 2); wait_tensor_114 = None + getitem_425 = split_43[0] + getitem_426 = split_43[1] + getitem_427 = split_43[2] + getitem_428 = split_43[3] + getitem_429 = split_43[4] + getitem_430 = split_43[5] + getitem_431 = split_43[6] + getitem_432 = split_43[7]; split_43 = None + cat_35 = torch.ops.aten.cat.default([getitem_425, getitem_426, getitem_427, getitem_428, getitem_429, getitem_430, getitem_431, getitem_432], 1); getitem_425 = getitem_426 = getitem_427 = getitem_428 = getitem_429 = getitem_430 = getitem_431 = getitem_432 = None + view_636 = torch.ops.aten.view.default(cat_35, [16384, 4096]); cat_35 = None + view_637 = torch.ops.aten.view.default(mm_60, [2, 8192, 1792]); mm_60 = None + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_637, torch.float32); view_637 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16); primals_83 = None + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 32, '0'); convert_element_type_292 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_61 = torch.ops.aten.mm.default(view_636, permute_97) + view_644 = torch.ops.aten.view.default(mm_61, [2, 8192, 1792]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_644) + view_651 = torch.ops.aten.view.default(mul_71, [16384, 1792]); mul_71 = None + mm_549 = torch.ops.aten.mm.default(permute_1093, view_651); permute_1093 = view_651 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16); primals_84 = None + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 32, '0'); convert_element_type_295 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + permute_1095 = torch.ops.aten.permute.default(permute_98, [1, 0]); permute_98 = None + mm_550 = torch.ops.aten.mm.default(view_2875, permute_1095); view_2875 = permute_1095 = None + view_2876 = torch.ops.aten.view.default(mm_550, [2, 8192, 1792]); mm_550 = None + convert_element_type_2320 = torch.ops.prims.convert_element_type.default(mm_549, torch.float32); mm_549 = None + reduce_scatter_tensor_321 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2320, 'avg', 32, '0'); convert_element_type_2320 = None + wait_tensor_771 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_321); reduce_scatter_tensor_321 = None + mul_724 = torch.ops.aten.mul.Tensor(view_2876, convert_element_type_291); convert_element_type_291 = None + mul_725 = torch.ops.aten.mul.Tensor(view_2876, view_644); view_2876 = view_644 = None + view_2877 = torch.ops.aten.view.default(mul_724, [16384, 1792]); mul_724 = None + permute_1097 = torch.ops.aten.permute.default(view_2877, [1, 0]) + mm_551 = torch.ops.aten.mm.default(permute_1097, view_636); permute_1097 = None + permute_1099 = torch.ops.aten.permute.default(permute_97, [1, 0]); permute_97 = None + mm_552 = torch.ops.aten.mm.default(view_2877, permute_1099); view_2877 = permute_1099 = None + view_2878 = torch.ops.aten.view.default(mm_552, [2, 8192, 4096]); mm_552 = None + convert_element_type_2325 = torch.ops.prims.convert_element_type.default(mm_551, torch.float32); mm_551 = None + reduce_scatter_tensor_322 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2325, 'avg', 32, '0'); convert_element_type_2325 = None + wait_tensor_772 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_322); reduce_scatter_tensor_322 = None + convert_element_type_2326 = torch.ops.prims.convert_element_type.default(mul_725, torch.float32); mul_725 = None + neg_23 = torch.ops.aten.neg.default(convert_element_type_290) + exp_23 = torch.ops.aten.exp.default(neg_23); neg_23 = None + add_290 = torch.ops.aten.add.Tensor(exp_23, 1); exp_23 = None + reciprocal_23 = torch.ops.aten.reciprocal.default(add_290); add_290 = None + mul_726 = torch.ops.aten.mul.Tensor(reciprocal_23, 1); reciprocal_23 = None + mul_727 = torch.ops.aten.mul.Tensor(convert_element_type_2326, mul_726); convert_element_type_2326 = None + sub_71 = torch.ops.aten.sub.Tensor(1, mul_726); mul_726 = None + mul_728 = torch.ops.aten.mul.Tensor(convert_element_type_290, sub_71); convert_element_type_290 = sub_71 = None + add_291 = torch.ops.aten.add.Tensor(mul_728, 1); mul_728 = None + mul_729 = torch.ops.aten.mul.Tensor(mul_727, add_291); mul_727 = add_291 = None + convert_element_type_2328 = torch.ops.prims.convert_element_type.default(mul_729, torch.bfloat16); mul_729 = None + view_2879 = torch.ops.aten.view.default(convert_element_type_2328, [16384, 1792]); convert_element_type_2328 = None + permute_1101 = torch.ops.aten.permute.default(view_2879, [1, 0]) + mm_553 = torch.ops.aten.mm.default(permute_1101, view_636); permute_1101 = view_636 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16); primals_82 = None + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 32, '0'); convert_element_type_287 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + permute_1103 = torch.ops.aten.permute.default(permute_96, [1, 0]); permute_96 = None + mm_554 = torch.ops.aten.mm.default(view_2879, permute_1103); view_2879 = permute_1103 = None + view_2880 = torch.ops.aten.view.default(mm_554, [2, 8192, 4096]); mm_554 = None + add_292 = torch.ops.aten.add.Tensor(view_2878, view_2880); view_2878 = view_2880 = None + convert_element_type_2333 = torch.ops.prims.convert_element_type.default(mm_553, torch.float32); mm_553 = None + reduce_scatter_tensor_323 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2333, 'avg', 32, '0'); convert_element_type_2333 = None + wait_tensor_773 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_323); reduce_scatter_tensor_323 = None + split_232 = torch.ops.aten.split.Tensor(add_292, 1024, 1); add_292 = None + getitem_2213 = split_232[0] + getitem_2214 = split_232[1] + getitem_2215 = split_232[2] + getitem_2216 = split_232[3] + getitem_2217 = split_232[4] + getitem_2218 = split_232[5] + getitem_2219 = split_232[6] + getitem_2220 = split_232[7]; split_232 = None + cat_224 = torch.ops.aten.cat.default([getitem_2213, getitem_2214, getitem_2215, getitem_2216, getitem_2217, getitem_2218, getitem_2219, getitem_2220]); getitem_2213 = getitem_2214 = getitem_2215 = getitem_2216 = getitem_2217 = getitem_2218 = getitem_2219 = getitem_2220 = None + reduce_scatter_tensor_324 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_224, 'sum', 8, '1'); cat_224 = None + wait_tensor_774 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_324); reduce_scatter_tensor_324 = None + convert_element_type_2334 = torch.ops.prims.convert_element_type.default(wait_tensor_774, torch.float32); wait_tensor_774 = None + convert_element_type_2336 = torch.ops.prims.convert_element_type.default(wait_tensor_113, torch.float32); wait_tensor_113 = None + mul_730 = torch.ops.aten.mul.Tensor(convert_element_type_2334, convert_element_type_2336); convert_element_type_2336 = None + mul_732 = torch.ops.aten.mul.Tensor(mul_68, mul_730) + sum_141 = torch.ops.aten.sum.dim_IntList(mul_732, [2], True); mul_732 = None + div_47 = torch.ops.aten.div.Tensor(mul_68, 4096) + mul_733 = torch.ops.aten.mul.Tensor(div_47, sum_141); div_47 = sum_141 = None + sub_72 = torch.ops.aten.sub.Tensor(mul_730, mul_733); mul_730 = mul_733 = None + mul_734 = torch.ops.aten.mul.Tensor(sub_72, rsqrt_17); sub_72 = rsqrt_17 = None + mul_735 = torch.ops.aten.mul.Tensor(convert_element_type_2334, mul_68); convert_element_type_2334 = mul_68 = None + sum_142 = torch.ops.aten.sum.dim_IntList(mul_735, [0, 1]); mul_735 = None + convert_element_type_2337 = torch.ops.prims.convert_element_type.default(mul_734, torch.bfloat16); mul_734 = None + convert_element_type_2338 = torch.ops.prims.convert_element_type.default(sum_142, torch.bfloat16); sum_142 = None + all_reduce_47 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2338, 'sum', '1'); convert_element_type_2338 = None + wait_tensor_775 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_47); all_reduce_47 = None + convert_element_type_2339 = torch.ops.prims.convert_element_type.default(wait_tensor_775, torch.float32); wait_tensor_775 = None + reduce_scatter_tensor_325 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2339, 'avg', 32, '0'); convert_element_type_2339 = None + wait_tensor_776 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_325); reduce_scatter_tensor_325 = None + add_293 = torch.ops.aten.add.Tensor(add_289, convert_element_type_2337); add_289 = convert_element_type_2337 = None + all_gather_into_tensor_403 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_293, 8, '1') + wait_tensor_777 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_403); all_gather_into_tensor_403 = None + split_233 = torch.ops.aten.split.Tensor(wait_tensor_777, 2); wait_tensor_777 = None + getitem_2221 = split_233[0] + getitem_2222 = split_233[1] + getitem_2223 = split_233[2] + getitem_2224 = split_233[3] + getitem_2225 = split_233[4] + getitem_2226 = split_233[5] + getitem_2227 = split_233[6] + getitem_2228 = split_233[7]; split_233 = None + cat_225 = torch.ops.aten.cat.default([getitem_2221, getitem_2222, getitem_2223, getitem_2224, getitem_2225, getitem_2226, getitem_2227, getitem_2228], 1); getitem_2221 = getitem_2222 = getitem_2223 = getitem_2224 = getitem_2225 = getitem_2226 = getitem_2227 = getitem_2228 = None + view_2881 = torch.ops.aten.view.default(cat_225, [16384, 4096]); cat_225 = None + permute_1105 = torch.ops.aten.permute.default(view_2881, [1, 0]) + permute_94 = torch.ops.aten.permute.default(getitem_408, [0, 2, 1, 3]) + view_618 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + view_624 = torch.ops.aten.view.default(view_618, [16384, 512]); view_618 = None + mm_555 = torch.ops.aten.mm.default(permute_1105, view_624); permute_1105 = view_624 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16); primals_80 = None + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 32, '0'); convert_element_type_281 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + permute_1107 = torch.ops.aten.permute.default(permute_95, [1, 0]); permute_95 = None + mm_556 = torch.ops.aten.mm.default(view_2881, permute_1107); view_2881 = permute_1107 = None + view_2882 = torch.ops.aten.view.default(mm_556, [2, 8192, 512]); mm_556 = None + convert_element_type_2344 = torch.ops.prims.convert_element_type.default(mm_555, torch.float32); mm_555 = None + reduce_scatter_tensor_326 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2344, 'avg', 32, '0'); convert_element_type_2344 = None + wait_tensor_778 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_326); reduce_scatter_tensor_326 = None + view_2883 = torch.ops.aten.view.default(view_2882, [2, 8192, 4, 128]); view_2882 = None + permute_1109 = torch.ops.aten.permute.default(view_2883, [0, 2, 1, 3]); view_2883 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16); primals_76 = None + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 32, '0'); convert_element_type_265 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32); add_31 = None + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_106) + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_267, 8, '1'); convert_element_type_267 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + split_41 = torch.ops.aten.split.Tensor(wait_tensor_107, 2); wait_tensor_107 = None + getitem_400 = split_41[0] + getitem_401 = split_41[1] + getitem_402 = split_41[2] + getitem_403 = split_41[3] + getitem_404 = split_41[4] + getitem_405 = split_41[5] + getitem_406 = split_41[6] + getitem_407 = split_41[7]; split_41 = None + cat_33 = torch.ops.aten.cat.default([getitem_400, getitem_401, getitem_402, getitem_403, getitem_404, getitem_405, getitem_406, getitem_407], 1); getitem_400 = getitem_401 = getitem_402 = getitem_403 = getitem_404 = getitem_405 = getitem_406 = getitem_407 = None + view_591 = torch.ops.aten.view.default(cat_33, [16384, 4096]); cat_33 = None + view_592 = torch.ops.aten.view.default(mm_56, [2, 8192, 512]); mm_56 = None + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16); primals_78 = None + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 32, '0'); convert_element_type_271 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_109, [1, 0]); wait_tensor_109 = None + mm_57 = torch.ops.aten.mm.default(view_591, permute_89) + view_599 = torch.ops.aten.view.default(mm_57, [2, 8192, 128]); mm_57 = None + view_606 = torch.ops.aten.view.default(mm_58, [2, 8192, 128]); mm_58 = None + view_608 = torch.ops.aten.view.default(view_592, [2, 8192, -1, 128]); view_592 = None + view_609 = torch.ops.aten.view.default(view_599, [2, 8192, -1, 128]); view_599 = None + view_610 = torch.ops.aten.view.default(view_606, [2, 8192, -1, 128]); view_606 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_608, torch.float32); view_608 = None + view_611 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 4, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_611); view_611 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_609, torch.float32); view_609 = None + view_612 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 1, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_612); view_612 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_37); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_614 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 4, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_37); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_615 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 1, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_614, torch.bfloat16); view_614 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_615, torch.bfloat16); view_615 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 1, 4, 128]); unsqueeze_16 = None + view_616 = torch.ops.aten.view.default(expand_16, [2, 8192, 4, 128]); expand_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_610, 3); view_610 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 1, 4, 128]); unsqueeze_17 = None + view_617 = torch.ops.aten.view.default(expand_17, [2, 8192, 4, 128]); expand_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_616, [0, 2, 1, 3]); view_616 = None + permute_93 = torch.ops.aten.permute.default(view_617, [0, 2, 1, 3]); view_617 = None + _scaled_dot_product_cudnn_attention_backward_23 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1109, permute_91, permute_92, permute_93, getitem_408, getitem_409, getitem_414, getitem_415, None, None, None, 8192, 8192, 0.0, True); permute_1109 = permute_91 = permute_92 = permute_93 = getitem_408 = getitem_409 = getitem_414 = getitem_415 = None + getitem_2229 = _scaled_dot_product_cudnn_attention_backward_23[0] + getitem_2230 = _scaled_dot_product_cudnn_attention_backward_23[1] + getitem_2231 = _scaled_dot_product_cudnn_attention_backward_23[2]; _scaled_dot_product_cudnn_attention_backward_23 = None + permute_1110 = torch.ops.aten.permute.default(getitem_2231, [0, 2, 1, 3]); getitem_2231 = None + permute_1111 = torch.ops.aten.permute.default(getitem_2230, [0, 2, 1, 3]); getitem_2230 = None + permute_1112 = torch.ops.aten.permute.default(getitem_2229, [0, 2, 1, 3]); getitem_2229 = None + view_2884 = torch.ops.aten.view.default(permute_1110, [2, 8192, 1, 4, 128]); permute_1110 = None + sum_143 = torch.ops.aten.sum.dim_IntList(view_2884, [3], True); view_2884 = None + squeeze_46 = torch.ops.aten.squeeze.dim(sum_143, 3); sum_143 = None + view_2885 = torch.ops.aten.view.default(permute_1111, [2, 8192, 1, 4, 128]); permute_1111 = None + sum_144 = torch.ops.aten.sum.dim_IntList(view_2885, [3], True); view_2885 = None + squeeze_47 = torch.ops.aten.squeeze.dim(sum_144, 3); sum_144 = None + convert_element_type_2345 = torch.ops.prims.convert_element_type.default(squeeze_47, torch.float32); squeeze_47 = None + convert_element_type_2346 = torch.ops.prims.convert_element_type.default(permute_1112, torch.float32); permute_1112 = None + view_2886 = torch.ops.aten.view.default(convert_element_type_2345, [2, 8192, 1, 64, 2]); convert_element_type_2345 = None + view_as_complex_110 = torch.ops.aten.view_as_complex.default(view_2886); view_2886 = None + mul_736 = torch.ops.aten.mul.Tensor(view_as_complex_110, _conj); view_as_complex_110 = None + view_2887 = torch.ops.aten.view.default(convert_element_type_2346, [2, 8192, 4, 64, 2]); convert_element_type_2346 = None + view_as_complex_111 = torch.ops.aten.view_as_complex.default(view_2887); view_2887 = None + mul_737 = torch.ops.aten.mul.Tensor(view_as_complex_111, _conj); view_as_complex_111 = None + view_as_real_110 = torch.ops.aten.view_as_real.default(mul_736); mul_736 = None + view_2888 = torch.ops.aten.view.default(view_as_real_110, [2, 8192, 1, 128]); view_as_real_110 = None + convert_element_type_2347 = torch.ops.prims.convert_element_type.default(view_2888, torch.bfloat16); view_2888 = None + view_as_real_111 = torch.ops.aten.view_as_real.default(mul_737); mul_737 = None + view_2889 = torch.ops.aten.view.default(view_as_real_111, [2, 8192, 4, 128]); view_as_real_111 = None + convert_element_type_2348 = torch.ops.prims.convert_element_type.default(view_2889, torch.bfloat16); view_2889 = None + view_2890 = torch.ops.aten.view.default(squeeze_46, [2, 8192, 128]); squeeze_46 = None + view_2891 = torch.ops.aten.view.default(convert_element_type_2347, [2, 8192, 128]); convert_element_type_2347 = None + view_2892 = torch.ops.aten.view.default(convert_element_type_2348, [2, 8192, 512]); convert_element_type_2348 = None + view_2893 = torch.ops.aten.view.default(view_2890, [16384, 128]); view_2890 = None + permute_1113 = torch.ops.aten.permute.default(view_2893, [1, 0]) + mm_557 = torch.ops.aten.mm.default(permute_1113, view_591); permute_1113 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16); primals_79 = None + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 32, '0'); convert_element_type_274 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + permute_1115 = torch.ops.aten.permute.default(permute_90, [1, 0]); permute_90 = None + mm_558 = torch.ops.aten.mm.default(view_2893, permute_1115); view_2893 = permute_1115 = None + view_2894 = torch.ops.aten.view.default(mm_558, [2, 8192, 4096]); mm_558 = None + convert_element_type_2353 = torch.ops.prims.convert_element_type.default(mm_557, torch.float32); mm_557 = None + reduce_scatter_tensor_327 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2353, 'avg', 32, '0'); convert_element_type_2353 = None + wait_tensor_779 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_327); reduce_scatter_tensor_327 = None + view_2895 = torch.ops.aten.view.default(view_2891, [16384, 128]); view_2891 = None + permute_1117 = torch.ops.aten.permute.default(view_2895, [1, 0]) + mm_559 = torch.ops.aten.mm.default(permute_1117, view_591); permute_1117 = None + permute_1119 = torch.ops.aten.permute.default(permute_89, [1, 0]); permute_89 = None + mm_560 = torch.ops.aten.mm.default(view_2895, permute_1119); view_2895 = permute_1119 = None + view_2896 = torch.ops.aten.view.default(mm_560, [2, 8192, 4096]); mm_560 = None + add_294 = torch.ops.aten.add.Tensor(view_2894, view_2896); view_2894 = view_2896 = None + convert_element_type_2358 = torch.ops.prims.convert_element_type.default(mm_559, torch.float32); mm_559 = None + reduce_scatter_tensor_328 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2358, 'avg', 32, '0'); convert_element_type_2358 = None + wait_tensor_780 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_328); reduce_scatter_tensor_328 = None + view_2897 = torch.ops.aten.view.default(view_2892, [16384, 512]); view_2892 = None + permute_1121 = torch.ops.aten.permute.default(view_2897, [1, 0]) + mm_561 = torch.ops.aten.mm.default(permute_1121, view_591); permute_1121 = view_591 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16); primals_77 = None + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 32, '0'); convert_element_type_268 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + permute_1123 = torch.ops.aten.permute.default(permute_88, [1, 0]); permute_88 = None + mm_562 = torch.ops.aten.mm.default(view_2897, permute_1123); view_2897 = permute_1123 = None + view_2898 = torch.ops.aten.view.default(mm_562, [2, 8192, 4096]); mm_562 = None + add_295 = torch.ops.aten.add.Tensor(add_294, view_2898); add_294 = view_2898 = None + convert_element_type_2363 = torch.ops.prims.convert_element_type.default(mm_561, torch.float32); mm_561 = None + reduce_scatter_tensor_329 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2363, 'avg', 32, '0'); convert_element_type_2363 = None + wait_tensor_781 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_329); reduce_scatter_tensor_329 = None + split_234 = torch.ops.aten.split.Tensor(add_295, 1024, 1); add_295 = None + getitem_2232 = split_234[0] + getitem_2233 = split_234[1] + getitem_2234 = split_234[2] + getitem_2235 = split_234[3] + getitem_2236 = split_234[4] + getitem_2237 = split_234[5] + getitem_2238 = split_234[6] + getitem_2239 = split_234[7]; split_234 = None + cat_226 = torch.ops.aten.cat.default([getitem_2232, getitem_2233, getitem_2234, getitem_2235, getitem_2236, getitem_2237, getitem_2238, getitem_2239]); getitem_2232 = getitem_2233 = getitem_2234 = getitem_2235 = getitem_2236 = getitem_2237 = getitem_2238 = getitem_2239 = None + reduce_scatter_tensor_330 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_226, 'sum', 8, '1'); cat_226 = None + wait_tensor_782 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_330); reduce_scatter_tensor_330 = None + convert_element_type_2364 = torch.ops.prims.convert_element_type.default(wait_tensor_782, torch.float32); wait_tensor_782 = None + convert_element_type_2366 = torch.ops.prims.convert_element_type.default(wait_tensor_106, torch.float32); wait_tensor_106 = None + mul_738 = torch.ops.aten.mul.Tensor(convert_element_type_2364, convert_element_type_2366); convert_element_type_2366 = None + mul_740 = torch.ops.aten.mul.Tensor(mul_64, mul_738) + sum_145 = torch.ops.aten.sum.dim_IntList(mul_740, [2], True); mul_740 = None + div_48 = torch.ops.aten.div.Tensor(mul_64, 4096) + mul_741 = torch.ops.aten.mul.Tensor(div_48, sum_145); div_48 = sum_145 = None + sub_73 = torch.ops.aten.sub.Tensor(mul_738, mul_741); mul_738 = mul_741 = None + mul_742 = torch.ops.aten.mul.Tensor(sub_73, rsqrt_16); sub_73 = rsqrt_16 = None + mul_743 = torch.ops.aten.mul.Tensor(convert_element_type_2364, mul_64); convert_element_type_2364 = mul_64 = None + sum_146 = torch.ops.aten.sum.dim_IntList(mul_743, [0, 1]); mul_743 = None + convert_element_type_2367 = torch.ops.prims.convert_element_type.default(mul_742, torch.bfloat16); mul_742 = None + convert_element_type_2368 = torch.ops.prims.convert_element_type.default(sum_146, torch.bfloat16); sum_146 = None + all_reduce_48 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2368, 'sum', '1'); convert_element_type_2368 = None + wait_tensor_783 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_48); all_reduce_48 = None + convert_element_type_2369 = torch.ops.prims.convert_element_type.default(wait_tensor_783, torch.float32); wait_tensor_783 = None + reduce_scatter_tensor_331 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2369, 'avg', 32, '0'); convert_element_type_2369 = None + wait_tensor_784 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_331); reduce_scatter_tensor_331 = None + add_296 = torch.ops.aten.add.Tensor(add_293, convert_element_type_2367); add_293 = convert_element_type_2367 = None + all_gather_into_tensor_404 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_296, 8, '1') + wait_tensor_785 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_404); all_gather_into_tensor_404 = None + split_235 = torch.ops.aten.split.Tensor(wait_tensor_785, 2); wait_tensor_785 = None + getitem_2240 = split_235[0] + getitem_2241 = split_235[1] + getitem_2242 = split_235[2] + getitem_2243 = split_235[3] + getitem_2244 = split_235[4] + getitem_2245 = split_235[5] + getitem_2246 = split_235[6] + getitem_2247 = split_235[7]; split_235 = None + cat_227 = torch.ops.aten.cat.default([getitem_2240, getitem_2241, getitem_2242, getitem_2243, getitem_2244, getitem_2245, getitem_2246, getitem_2247], 1); getitem_2240 = getitem_2241 = getitem_2242 = getitem_2243 = getitem_2244 = getitem_2245 = getitem_2246 = getitem_2247 = None + view_2899 = torch.ops.aten.view.default(cat_227, [16384, 4096]); cat_227 = None + permute_1125 = torch.ops.aten.permute.default(view_2899, [1, 0]) + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15); reduce_scatter_tensor_15 = None + add_29 = torch.ops.aten.add.Tensor(add_27, wait_tensor_99); wait_tensor_99 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16); primals_72 = None + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 32, '0'); convert_element_type_251 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32); add_29 = None + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_100) + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_253, 8, '1'); convert_element_type_253 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + split_39 = torch.ops.aten.split.Tensor(wait_tensor_101, 2); wait_tensor_101 = None + getitem_384 = split_39[0] + getitem_385 = split_39[1] + getitem_386 = split_39[2] + getitem_387 = split_39[3] + getitem_388 = split_39[4] + getitem_389 = split_39[5] + getitem_390 = split_39[6] + getitem_391 = split_39[7]; split_39 = None + cat_31 = torch.ops.aten.cat.default([getitem_384, getitem_385, getitem_386, getitem_387, getitem_388, getitem_389, getitem_390, getitem_391], 1); getitem_384 = getitem_385 = getitem_386 = getitem_387 = getitem_388 = getitem_389 = getitem_390 = getitem_391 = None + view_564 = torch.ops.aten.view.default(cat_31, [16384, 4096]); cat_31 = None + view_565 = torch.ops.aten.view.default(mm_53, [2, 8192, 1792]); mm_53 = None + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_565, torch.float32); view_565 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16); primals_74 = None + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 32, '0'); convert_element_type_259 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + mm_54 = torch.ops.aten.mm.default(view_564, permute_86) + view_572 = torch.ops.aten.view.default(mm_54, [2, 8192, 1792]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_572) + view_579 = torch.ops.aten.view.default(mul_63, [16384, 1792]); mul_63 = None + mm_563 = torch.ops.aten.mm.default(permute_1125, view_579); permute_1125 = view_579 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16); primals_75 = None + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 32, '0'); convert_element_type_262 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + permute_1127 = torch.ops.aten.permute.default(permute_87, [1, 0]); permute_87 = None + mm_564 = torch.ops.aten.mm.default(view_2899, permute_1127); view_2899 = permute_1127 = None + view_2900 = torch.ops.aten.view.default(mm_564, [2, 8192, 1792]); mm_564 = None + convert_element_type_2374 = torch.ops.prims.convert_element_type.default(mm_563, torch.float32); mm_563 = None + reduce_scatter_tensor_332 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2374, 'avg', 32, '0'); convert_element_type_2374 = None + wait_tensor_786 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_332); reduce_scatter_tensor_332 = None + mul_744 = torch.ops.aten.mul.Tensor(view_2900, convert_element_type_258); convert_element_type_258 = None + mul_745 = torch.ops.aten.mul.Tensor(view_2900, view_572); view_2900 = view_572 = None + view_2901 = torch.ops.aten.view.default(mul_744, [16384, 1792]); mul_744 = None + permute_1129 = torch.ops.aten.permute.default(view_2901, [1, 0]) + mm_565 = torch.ops.aten.mm.default(permute_1129, view_564); permute_1129 = None + permute_1131 = torch.ops.aten.permute.default(permute_86, [1, 0]); permute_86 = None + mm_566 = torch.ops.aten.mm.default(view_2901, permute_1131); view_2901 = permute_1131 = None + view_2902 = torch.ops.aten.view.default(mm_566, [2, 8192, 4096]); mm_566 = None + convert_element_type_2379 = torch.ops.prims.convert_element_type.default(mm_565, torch.float32); mm_565 = None + reduce_scatter_tensor_333 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2379, 'avg', 32, '0'); convert_element_type_2379 = None + wait_tensor_787 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_333); reduce_scatter_tensor_333 = None + convert_element_type_2380 = torch.ops.prims.convert_element_type.default(mul_745, torch.float32); mul_745 = None + neg_24 = torch.ops.aten.neg.default(convert_element_type_257) + exp_24 = torch.ops.aten.exp.default(neg_24); neg_24 = None + add_297 = torch.ops.aten.add.Tensor(exp_24, 1); exp_24 = None + reciprocal_24 = torch.ops.aten.reciprocal.default(add_297); add_297 = None + mul_746 = torch.ops.aten.mul.Tensor(reciprocal_24, 1); reciprocal_24 = None + mul_747 = torch.ops.aten.mul.Tensor(convert_element_type_2380, mul_746); convert_element_type_2380 = None + sub_74 = torch.ops.aten.sub.Tensor(1, mul_746); mul_746 = None + mul_748 = torch.ops.aten.mul.Tensor(convert_element_type_257, sub_74); convert_element_type_257 = sub_74 = None + add_298 = torch.ops.aten.add.Tensor(mul_748, 1); mul_748 = None + mul_749 = torch.ops.aten.mul.Tensor(mul_747, add_298); mul_747 = add_298 = None + convert_element_type_2382 = torch.ops.prims.convert_element_type.default(mul_749, torch.bfloat16); mul_749 = None + view_2903 = torch.ops.aten.view.default(convert_element_type_2382, [16384, 1792]); convert_element_type_2382 = None + permute_1133 = torch.ops.aten.permute.default(view_2903, [1, 0]) + mm_567 = torch.ops.aten.mm.default(permute_1133, view_564); permute_1133 = view_564 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16); primals_73 = None + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 32, '0'); convert_element_type_254 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + permute_1135 = torch.ops.aten.permute.default(permute_85, [1, 0]); permute_85 = None + mm_568 = torch.ops.aten.mm.default(view_2903, permute_1135); view_2903 = permute_1135 = None + view_2904 = torch.ops.aten.view.default(mm_568, [2, 8192, 4096]); mm_568 = None + add_299 = torch.ops.aten.add.Tensor(view_2902, view_2904); view_2902 = view_2904 = None + convert_element_type_2387 = torch.ops.prims.convert_element_type.default(mm_567, torch.float32); mm_567 = None + reduce_scatter_tensor_334 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2387, 'avg', 32, '0'); convert_element_type_2387 = None + wait_tensor_788 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_334); reduce_scatter_tensor_334 = None + split_236 = torch.ops.aten.split.Tensor(add_299, 1024, 1); add_299 = None + getitem_2248 = split_236[0] + getitem_2249 = split_236[1] + getitem_2250 = split_236[2] + getitem_2251 = split_236[3] + getitem_2252 = split_236[4] + getitem_2253 = split_236[5] + getitem_2254 = split_236[6] + getitem_2255 = split_236[7]; split_236 = None + cat_228 = torch.ops.aten.cat.default([getitem_2248, getitem_2249, getitem_2250, getitem_2251, getitem_2252, getitem_2253, getitem_2254, getitem_2255]); getitem_2248 = getitem_2249 = getitem_2250 = getitem_2251 = getitem_2252 = getitem_2253 = getitem_2254 = getitem_2255 = None + reduce_scatter_tensor_335 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_228, 'sum', 8, '1'); cat_228 = None + wait_tensor_789 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_335); reduce_scatter_tensor_335 = None + convert_element_type_2388 = torch.ops.prims.convert_element_type.default(wait_tensor_789, torch.float32); wait_tensor_789 = None + convert_element_type_2390 = torch.ops.prims.convert_element_type.default(wait_tensor_100, torch.float32); wait_tensor_100 = None + mul_750 = torch.ops.aten.mul.Tensor(convert_element_type_2388, convert_element_type_2390); convert_element_type_2390 = None + mul_752 = torch.ops.aten.mul.Tensor(mul_60, mul_750) + sum_147 = torch.ops.aten.sum.dim_IntList(mul_752, [2], True); mul_752 = None + div_49 = torch.ops.aten.div.Tensor(mul_60, 4096) + mul_753 = torch.ops.aten.mul.Tensor(div_49, sum_147); div_49 = sum_147 = None + sub_75 = torch.ops.aten.sub.Tensor(mul_750, mul_753); mul_750 = mul_753 = None + mul_754 = torch.ops.aten.mul.Tensor(sub_75, rsqrt_15); sub_75 = rsqrt_15 = None + mul_755 = torch.ops.aten.mul.Tensor(convert_element_type_2388, mul_60); convert_element_type_2388 = mul_60 = None + sum_148 = torch.ops.aten.sum.dim_IntList(mul_755, [0, 1]); mul_755 = None + convert_element_type_2391 = torch.ops.prims.convert_element_type.default(mul_754, torch.bfloat16); mul_754 = None + convert_element_type_2392 = torch.ops.prims.convert_element_type.default(sum_148, torch.bfloat16); sum_148 = None + all_reduce_49 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2392, 'sum', '1'); convert_element_type_2392 = None + wait_tensor_790 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_49); all_reduce_49 = None + convert_element_type_2393 = torch.ops.prims.convert_element_type.default(wait_tensor_790, torch.float32); wait_tensor_790 = None + reduce_scatter_tensor_336 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2393, 'avg', 32, '0'); convert_element_type_2393 = None + wait_tensor_791 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_336); reduce_scatter_tensor_336 = None + add_300 = torch.ops.aten.add.Tensor(add_296, convert_element_type_2391); add_296 = convert_element_type_2391 = None + all_gather_into_tensor_405 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_300, 8, '1') + wait_tensor_792 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_405); all_gather_into_tensor_405 = None + split_237 = torch.ops.aten.split.Tensor(wait_tensor_792, 2); wait_tensor_792 = None + getitem_2256 = split_237[0] + getitem_2257 = split_237[1] + getitem_2258 = split_237[2] + getitem_2259 = split_237[3] + getitem_2260 = split_237[4] + getitem_2261 = split_237[5] + getitem_2262 = split_237[6] + getitem_2263 = split_237[7]; split_237 = None + cat_229 = torch.ops.aten.cat.default([getitem_2256, getitem_2257, getitem_2258, getitem_2259, getitem_2260, getitem_2261, getitem_2262, getitem_2263], 1); getitem_2256 = getitem_2257 = getitem_2258 = getitem_2259 = getitem_2260 = getitem_2261 = getitem_2262 = getitem_2263 = None + view_2905 = torch.ops.aten.view.default(cat_229, [16384, 4096]); cat_229 = None + permute_1137 = torch.ops.aten.permute.default(view_2905, [1, 0]) + permute_83 = torch.ops.aten.permute.default(getitem_367, [0, 2, 1, 3]) + view_546 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + view_552 = torch.ops.aten.view.default(view_546, [16384, 512]); view_546 = None + mm_569 = torch.ops.aten.mm.default(permute_1137, view_552); permute_1137 = view_552 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16); primals_71 = None + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 32, '0'); convert_element_type_248 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + permute_1139 = torch.ops.aten.permute.default(permute_84, [1, 0]); permute_84 = None + mm_570 = torch.ops.aten.mm.default(view_2905, permute_1139); view_2905 = permute_1139 = None + view_2906 = torch.ops.aten.view.default(mm_570, [2, 8192, 512]); mm_570 = None + convert_element_type_2398 = torch.ops.prims.convert_element_type.default(mm_569, torch.float32); mm_569 = None + reduce_scatter_tensor_337 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2398, 'avg', 32, '0'); convert_element_type_2398 = None + wait_tensor_793 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_337); reduce_scatter_tensor_337 = None + view_2907 = torch.ops.aten.view.default(view_2906, [2, 8192, 4, 128]); view_2906 = None + permute_1141 = torch.ops.aten.permute.default(view_2907, [0, 2, 1, 3]); view_2907 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16); primals_67 = None + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 32, '0'); convert_element_type_232 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32); add_27 = None + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_93) + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_234, 8, '1'); convert_element_type_234 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + split_37 = torch.ops.aten.split.Tensor(wait_tensor_94, 2); wait_tensor_94 = None + getitem_359 = split_37[0] + getitem_360 = split_37[1] + getitem_361 = split_37[2] + getitem_362 = split_37[3] + getitem_363 = split_37[4] + getitem_364 = split_37[5] + getitem_365 = split_37[6] + getitem_366 = split_37[7]; split_37 = None + cat_29 = torch.ops.aten.cat.default([getitem_359, getitem_360, getitem_361, getitem_362, getitem_363, getitem_364, getitem_365, getitem_366], 1); getitem_359 = getitem_360 = getitem_361 = getitem_362 = getitem_363 = getitem_364 = getitem_365 = getitem_366 = None + view_519 = torch.ops.aten.view.default(cat_29, [16384, 4096]); cat_29 = None + view_520 = torch.ops.aten.view.default(mm_49, [2, 8192, 512]); mm_49 = None + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16); primals_69 = None + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 32, '0'); convert_element_type_238 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_96, [1, 0]); wait_tensor_96 = None + mm_50 = torch.ops.aten.mm.default(view_519, permute_78) + view_527 = torch.ops.aten.view.default(mm_50, [2, 8192, 128]); mm_50 = None + view_534 = torch.ops.aten.view.default(mm_51, [2, 8192, 128]); mm_51 = None + view_536 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + view_537 = torch.ops.aten.view.default(view_527, [2, 8192, -1, 128]); view_527 = None + view_538 = torch.ops.aten.view.default(view_534, [2, 8192, -1, 128]); view_534 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_536, torch.float32); view_536 = None + view_539 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 4, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_539); view_539 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_537, torch.float32); view_537 = None + view_540 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 1, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_540); view_540 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_37); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_542 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 4, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_37); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_543 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 1, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_542, torch.bfloat16); view_542 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_543, torch.bfloat16); view_543 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 1, 4, 128]); unsqueeze_14 = None + view_544 = torch.ops.aten.view.default(expand_14, [2, 8192, 4, 128]); expand_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_538, 3); view_538 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 1, 4, 128]); unsqueeze_15 = None + view_545 = torch.ops.aten.view.default(expand_15, [2, 8192, 4, 128]); expand_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_544, [0, 2, 1, 3]); view_544 = None + permute_82 = torch.ops.aten.permute.default(view_545, [0, 2, 1, 3]); view_545 = None + _scaled_dot_product_cudnn_attention_backward_24 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1141, permute_80, permute_81, permute_82, getitem_367, getitem_368, getitem_373, getitem_374, None, None, None, 8192, 8192, 0.0, True); permute_1141 = permute_80 = permute_81 = permute_82 = getitem_367 = getitem_368 = getitem_373 = getitem_374 = None + getitem_2264 = _scaled_dot_product_cudnn_attention_backward_24[0] + getitem_2265 = _scaled_dot_product_cudnn_attention_backward_24[1] + getitem_2266 = _scaled_dot_product_cudnn_attention_backward_24[2]; _scaled_dot_product_cudnn_attention_backward_24 = None + permute_1142 = torch.ops.aten.permute.default(getitem_2266, [0, 2, 1, 3]); getitem_2266 = None + permute_1143 = torch.ops.aten.permute.default(getitem_2265, [0, 2, 1, 3]); getitem_2265 = None + permute_1144 = torch.ops.aten.permute.default(getitem_2264, [0, 2, 1, 3]); getitem_2264 = None + view_2908 = torch.ops.aten.view.default(permute_1142, [2, 8192, 1, 4, 128]); permute_1142 = None + sum_149 = torch.ops.aten.sum.dim_IntList(view_2908, [3], True); view_2908 = None + squeeze_48 = torch.ops.aten.squeeze.dim(sum_149, 3); sum_149 = None + view_2909 = torch.ops.aten.view.default(permute_1143, [2, 8192, 1, 4, 128]); permute_1143 = None + sum_150 = torch.ops.aten.sum.dim_IntList(view_2909, [3], True); view_2909 = None + squeeze_49 = torch.ops.aten.squeeze.dim(sum_150, 3); sum_150 = None + convert_element_type_2399 = torch.ops.prims.convert_element_type.default(squeeze_49, torch.float32); squeeze_49 = None + convert_element_type_2400 = torch.ops.prims.convert_element_type.default(permute_1144, torch.float32); permute_1144 = None + view_2910 = torch.ops.aten.view.default(convert_element_type_2399, [2, 8192, 1, 64, 2]); convert_element_type_2399 = None + view_as_complex_112 = torch.ops.aten.view_as_complex.default(view_2910); view_2910 = None + mul_756 = torch.ops.aten.mul.Tensor(view_as_complex_112, _conj); view_as_complex_112 = None + view_2911 = torch.ops.aten.view.default(convert_element_type_2400, [2, 8192, 4, 64, 2]); convert_element_type_2400 = None + view_as_complex_113 = torch.ops.aten.view_as_complex.default(view_2911); view_2911 = None + mul_757 = torch.ops.aten.mul.Tensor(view_as_complex_113, _conj); view_as_complex_113 = None + view_as_real_112 = torch.ops.aten.view_as_real.default(mul_756); mul_756 = None + view_2912 = torch.ops.aten.view.default(view_as_real_112, [2, 8192, 1, 128]); view_as_real_112 = None + convert_element_type_2401 = torch.ops.prims.convert_element_type.default(view_2912, torch.bfloat16); view_2912 = None + view_as_real_113 = torch.ops.aten.view_as_real.default(mul_757); mul_757 = None + view_2913 = torch.ops.aten.view.default(view_as_real_113, [2, 8192, 4, 128]); view_as_real_113 = None + convert_element_type_2402 = torch.ops.prims.convert_element_type.default(view_2913, torch.bfloat16); view_2913 = None + view_2914 = torch.ops.aten.view.default(squeeze_48, [2, 8192, 128]); squeeze_48 = None + view_2915 = torch.ops.aten.view.default(convert_element_type_2401, [2, 8192, 128]); convert_element_type_2401 = None + view_2916 = torch.ops.aten.view.default(convert_element_type_2402, [2, 8192, 512]); convert_element_type_2402 = None + view_2917 = torch.ops.aten.view.default(view_2914, [16384, 128]); view_2914 = None + permute_1145 = torch.ops.aten.permute.default(view_2917, [1, 0]) + mm_571 = torch.ops.aten.mm.default(permute_1145, view_519); permute_1145 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16); primals_70 = None + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 32, '0'); convert_element_type_241 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + permute_1147 = torch.ops.aten.permute.default(permute_79, [1, 0]); permute_79 = None + mm_572 = torch.ops.aten.mm.default(view_2917, permute_1147); view_2917 = permute_1147 = None + view_2918 = torch.ops.aten.view.default(mm_572, [2, 8192, 4096]); mm_572 = None + convert_element_type_2407 = torch.ops.prims.convert_element_type.default(mm_571, torch.float32); mm_571 = None + reduce_scatter_tensor_338 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2407, 'avg', 32, '0'); convert_element_type_2407 = None + wait_tensor_794 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_338); reduce_scatter_tensor_338 = None + view_2919 = torch.ops.aten.view.default(view_2915, [16384, 128]); view_2915 = None + permute_1149 = torch.ops.aten.permute.default(view_2919, [1, 0]) + mm_573 = torch.ops.aten.mm.default(permute_1149, view_519); permute_1149 = None + permute_1151 = torch.ops.aten.permute.default(permute_78, [1, 0]); permute_78 = None + mm_574 = torch.ops.aten.mm.default(view_2919, permute_1151); view_2919 = permute_1151 = None + view_2920 = torch.ops.aten.view.default(mm_574, [2, 8192, 4096]); mm_574 = None + add_301 = torch.ops.aten.add.Tensor(view_2918, view_2920); view_2918 = view_2920 = None + convert_element_type_2412 = torch.ops.prims.convert_element_type.default(mm_573, torch.float32); mm_573 = None + reduce_scatter_tensor_339 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2412, 'avg', 32, '0'); convert_element_type_2412 = None + wait_tensor_795 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_339); reduce_scatter_tensor_339 = None + view_2921 = torch.ops.aten.view.default(view_2916, [16384, 512]); view_2916 = None + permute_1153 = torch.ops.aten.permute.default(view_2921, [1, 0]) + mm_575 = torch.ops.aten.mm.default(permute_1153, view_519); permute_1153 = view_519 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16); primals_68 = None + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 32, '0'); convert_element_type_235 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + permute_1155 = torch.ops.aten.permute.default(permute_77, [1, 0]); permute_77 = None + mm_576 = torch.ops.aten.mm.default(view_2921, permute_1155); view_2921 = permute_1155 = None + view_2922 = torch.ops.aten.view.default(mm_576, [2, 8192, 4096]); mm_576 = None + add_302 = torch.ops.aten.add.Tensor(add_301, view_2922); add_301 = view_2922 = None + convert_element_type_2417 = torch.ops.prims.convert_element_type.default(mm_575, torch.float32); mm_575 = None + reduce_scatter_tensor_340 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2417, 'avg', 32, '0'); convert_element_type_2417 = None + wait_tensor_796 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_340); reduce_scatter_tensor_340 = None + split_238 = torch.ops.aten.split.Tensor(add_302, 1024, 1); add_302 = None + getitem_2267 = split_238[0] + getitem_2268 = split_238[1] + getitem_2269 = split_238[2] + getitem_2270 = split_238[3] + getitem_2271 = split_238[4] + getitem_2272 = split_238[5] + getitem_2273 = split_238[6] + getitem_2274 = split_238[7]; split_238 = None + cat_230 = torch.ops.aten.cat.default([getitem_2267, getitem_2268, getitem_2269, getitem_2270, getitem_2271, getitem_2272, getitem_2273, getitem_2274]); getitem_2267 = getitem_2268 = getitem_2269 = getitem_2270 = getitem_2271 = getitem_2272 = getitem_2273 = getitem_2274 = None + reduce_scatter_tensor_341 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_230, 'sum', 8, '1'); cat_230 = None + wait_tensor_797 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_341); reduce_scatter_tensor_341 = None + convert_element_type_2418 = torch.ops.prims.convert_element_type.default(wait_tensor_797, torch.float32); wait_tensor_797 = None + convert_element_type_2420 = torch.ops.prims.convert_element_type.default(wait_tensor_93, torch.float32); wait_tensor_93 = None + mul_758 = torch.ops.aten.mul.Tensor(convert_element_type_2418, convert_element_type_2420); convert_element_type_2420 = None + mul_760 = torch.ops.aten.mul.Tensor(mul_56, mul_758) + sum_151 = torch.ops.aten.sum.dim_IntList(mul_760, [2], True); mul_760 = None + div_50 = torch.ops.aten.div.Tensor(mul_56, 4096) + mul_761 = torch.ops.aten.mul.Tensor(div_50, sum_151); div_50 = sum_151 = None + sub_76 = torch.ops.aten.sub.Tensor(mul_758, mul_761); mul_758 = mul_761 = None + mul_762 = torch.ops.aten.mul.Tensor(sub_76, rsqrt_14); sub_76 = rsqrt_14 = None + mul_763 = torch.ops.aten.mul.Tensor(convert_element_type_2418, mul_56); convert_element_type_2418 = mul_56 = None + sum_152 = torch.ops.aten.sum.dim_IntList(mul_763, [0, 1]); mul_763 = None + convert_element_type_2421 = torch.ops.prims.convert_element_type.default(mul_762, torch.bfloat16); mul_762 = None + convert_element_type_2422 = torch.ops.prims.convert_element_type.default(sum_152, torch.bfloat16); sum_152 = None + all_reduce_50 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2422, 'sum', '1'); convert_element_type_2422 = None + wait_tensor_798 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_50); all_reduce_50 = None + convert_element_type_2423 = torch.ops.prims.convert_element_type.default(wait_tensor_798, torch.float32); wait_tensor_798 = None + reduce_scatter_tensor_342 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2423, 'avg', 32, '0'); convert_element_type_2423 = None + wait_tensor_799 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_342); reduce_scatter_tensor_342 = None + add_303 = torch.ops.aten.add.Tensor(add_300, convert_element_type_2421); add_300 = convert_element_type_2421 = None + all_gather_into_tensor_406 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_303, 8, '1') + wait_tensor_800 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_406); all_gather_into_tensor_406 = None + split_239 = torch.ops.aten.split.Tensor(wait_tensor_800, 2); wait_tensor_800 = None + getitem_2275 = split_239[0] + getitem_2276 = split_239[1] + getitem_2277 = split_239[2] + getitem_2278 = split_239[3] + getitem_2279 = split_239[4] + getitem_2280 = split_239[5] + getitem_2281 = split_239[6] + getitem_2282 = split_239[7]; split_239 = None + cat_231 = torch.ops.aten.cat.default([getitem_2275, getitem_2276, getitem_2277, getitem_2278, getitem_2279, getitem_2280, getitem_2281, getitem_2282], 1); getitem_2275 = getitem_2276 = getitem_2277 = getitem_2278 = getitem_2279 = getitem_2280 = getitem_2281 = getitem_2282 = None + view_2923 = torch.ops.aten.view.default(cat_231, [16384, 4096]); cat_231 = None + permute_1157 = torch.ops.aten.permute.default(view_2923, [1, 0]) + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13); reduce_scatter_tensor_13 = None + add_25 = torch.ops.aten.add.Tensor(add_23, wait_tensor_86); wait_tensor_86 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16); primals_63 = None + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 32, '0'); convert_element_type_218 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32); add_25 = None + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_87) + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_220, 8, '1'); convert_element_type_220 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + split_35 = torch.ops.aten.split.Tensor(wait_tensor_88, 2); wait_tensor_88 = None + getitem_343 = split_35[0] + getitem_344 = split_35[1] + getitem_345 = split_35[2] + getitem_346 = split_35[3] + getitem_347 = split_35[4] + getitem_348 = split_35[5] + getitem_349 = split_35[6] + getitem_350 = split_35[7]; split_35 = None + cat_27 = torch.ops.aten.cat.default([getitem_343, getitem_344, getitem_345, getitem_346, getitem_347, getitem_348, getitem_349, getitem_350], 1); getitem_343 = getitem_344 = getitem_345 = getitem_346 = getitem_347 = getitem_348 = getitem_349 = getitem_350 = None + view_492 = torch.ops.aten.view.default(cat_27, [16384, 4096]); cat_27 = None + view_493 = torch.ops.aten.view.default(mm_46, [2, 8192, 1792]); mm_46 = None + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_493, torch.float32); view_493 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16); primals_65 = None + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 32, '0'); convert_element_type_226 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + mm_47 = torch.ops.aten.mm.default(view_492, permute_75) + view_500 = torch.ops.aten.view.default(mm_47, [2, 8192, 1792]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_500) + view_507 = torch.ops.aten.view.default(mul_55, [16384, 1792]); mul_55 = None + mm_577 = torch.ops.aten.mm.default(permute_1157, view_507); permute_1157 = view_507 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16); primals_66 = None + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 32, '0'); convert_element_type_229 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_91, [1, 0]); wait_tensor_91 = None + permute_1159 = torch.ops.aten.permute.default(permute_76, [1, 0]); permute_76 = None + mm_578 = torch.ops.aten.mm.default(view_2923, permute_1159); view_2923 = permute_1159 = None + view_2924 = torch.ops.aten.view.default(mm_578, [2, 8192, 1792]); mm_578 = None + convert_element_type_2428 = torch.ops.prims.convert_element_type.default(mm_577, torch.float32); mm_577 = None + reduce_scatter_tensor_343 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2428, 'avg', 32, '0'); convert_element_type_2428 = None + wait_tensor_801 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_343); reduce_scatter_tensor_343 = None + mul_764 = torch.ops.aten.mul.Tensor(view_2924, convert_element_type_225); convert_element_type_225 = None + mul_765 = torch.ops.aten.mul.Tensor(view_2924, view_500); view_2924 = view_500 = None + view_2925 = torch.ops.aten.view.default(mul_764, [16384, 1792]); mul_764 = None + permute_1161 = torch.ops.aten.permute.default(view_2925, [1, 0]) + mm_579 = torch.ops.aten.mm.default(permute_1161, view_492); permute_1161 = None + permute_1163 = torch.ops.aten.permute.default(permute_75, [1, 0]); permute_75 = None + mm_580 = torch.ops.aten.mm.default(view_2925, permute_1163); view_2925 = permute_1163 = None + view_2926 = torch.ops.aten.view.default(mm_580, [2, 8192, 4096]); mm_580 = None + convert_element_type_2433 = torch.ops.prims.convert_element_type.default(mm_579, torch.float32); mm_579 = None + reduce_scatter_tensor_344 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2433, 'avg', 32, '0'); convert_element_type_2433 = None + wait_tensor_802 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_344); reduce_scatter_tensor_344 = None + convert_element_type_2434 = torch.ops.prims.convert_element_type.default(mul_765, torch.float32); mul_765 = None + neg_25 = torch.ops.aten.neg.default(convert_element_type_224) + exp_25 = torch.ops.aten.exp.default(neg_25); neg_25 = None + add_304 = torch.ops.aten.add.Tensor(exp_25, 1); exp_25 = None + reciprocal_25 = torch.ops.aten.reciprocal.default(add_304); add_304 = None + mul_766 = torch.ops.aten.mul.Tensor(reciprocal_25, 1); reciprocal_25 = None + mul_767 = torch.ops.aten.mul.Tensor(convert_element_type_2434, mul_766); convert_element_type_2434 = None + sub_77 = torch.ops.aten.sub.Tensor(1, mul_766); mul_766 = None + mul_768 = torch.ops.aten.mul.Tensor(convert_element_type_224, sub_77); convert_element_type_224 = sub_77 = None + add_305 = torch.ops.aten.add.Tensor(mul_768, 1); mul_768 = None + mul_769 = torch.ops.aten.mul.Tensor(mul_767, add_305); mul_767 = add_305 = None + convert_element_type_2436 = torch.ops.prims.convert_element_type.default(mul_769, torch.bfloat16); mul_769 = None + view_2927 = torch.ops.aten.view.default(convert_element_type_2436, [16384, 1792]); convert_element_type_2436 = None + permute_1165 = torch.ops.aten.permute.default(view_2927, [1, 0]) + mm_581 = torch.ops.aten.mm.default(permute_1165, view_492); permute_1165 = view_492 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16); primals_64 = None + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 32, '0'); convert_element_type_221 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + permute_1167 = torch.ops.aten.permute.default(permute_74, [1, 0]); permute_74 = None + mm_582 = torch.ops.aten.mm.default(view_2927, permute_1167); view_2927 = permute_1167 = None + view_2928 = torch.ops.aten.view.default(mm_582, [2, 8192, 4096]); mm_582 = None + add_306 = torch.ops.aten.add.Tensor(view_2926, view_2928); view_2926 = view_2928 = None + convert_element_type_2441 = torch.ops.prims.convert_element_type.default(mm_581, torch.float32); mm_581 = None + reduce_scatter_tensor_345 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2441, 'avg', 32, '0'); convert_element_type_2441 = None + wait_tensor_803 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_345); reduce_scatter_tensor_345 = None + split_240 = torch.ops.aten.split.Tensor(add_306, 1024, 1); add_306 = None + getitem_2283 = split_240[0] + getitem_2284 = split_240[1] + getitem_2285 = split_240[2] + getitem_2286 = split_240[3] + getitem_2287 = split_240[4] + getitem_2288 = split_240[5] + getitem_2289 = split_240[6] + getitem_2290 = split_240[7]; split_240 = None + cat_232 = torch.ops.aten.cat.default([getitem_2283, getitem_2284, getitem_2285, getitem_2286, getitem_2287, getitem_2288, getitem_2289, getitem_2290]); getitem_2283 = getitem_2284 = getitem_2285 = getitem_2286 = getitem_2287 = getitem_2288 = getitem_2289 = getitem_2290 = None + reduce_scatter_tensor_346 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_232, 'sum', 8, '1'); cat_232 = None + wait_tensor_804 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_346); reduce_scatter_tensor_346 = None + convert_element_type_2442 = torch.ops.prims.convert_element_type.default(wait_tensor_804, torch.float32); wait_tensor_804 = None + convert_element_type_2444 = torch.ops.prims.convert_element_type.default(wait_tensor_87, torch.float32); wait_tensor_87 = None + mul_770 = torch.ops.aten.mul.Tensor(convert_element_type_2442, convert_element_type_2444); convert_element_type_2444 = None + mul_772 = torch.ops.aten.mul.Tensor(mul_52, mul_770) + sum_153 = torch.ops.aten.sum.dim_IntList(mul_772, [2], True); mul_772 = None + div_51 = torch.ops.aten.div.Tensor(mul_52, 4096) + mul_773 = torch.ops.aten.mul.Tensor(div_51, sum_153); div_51 = sum_153 = None + sub_78 = torch.ops.aten.sub.Tensor(mul_770, mul_773); mul_770 = mul_773 = None + mul_774 = torch.ops.aten.mul.Tensor(sub_78, rsqrt_13); sub_78 = rsqrt_13 = None + mul_775 = torch.ops.aten.mul.Tensor(convert_element_type_2442, mul_52); convert_element_type_2442 = mul_52 = None + sum_154 = torch.ops.aten.sum.dim_IntList(mul_775, [0, 1]); mul_775 = None + convert_element_type_2445 = torch.ops.prims.convert_element_type.default(mul_774, torch.bfloat16); mul_774 = None + convert_element_type_2446 = torch.ops.prims.convert_element_type.default(sum_154, torch.bfloat16); sum_154 = None + all_reduce_51 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2446, 'sum', '1'); convert_element_type_2446 = None + wait_tensor_805 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_51); all_reduce_51 = None + convert_element_type_2447 = torch.ops.prims.convert_element_type.default(wait_tensor_805, torch.float32); wait_tensor_805 = None + reduce_scatter_tensor_347 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2447, 'avg', 32, '0'); convert_element_type_2447 = None + wait_tensor_806 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_347); reduce_scatter_tensor_347 = None + add_307 = torch.ops.aten.add.Tensor(add_303, convert_element_type_2445); add_303 = convert_element_type_2445 = None + all_gather_into_tensor_407 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_307, 8, '1') + wait_tensor_807 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_407); all_gather_into_tensor_407 = None + split_241 = torch.ops.aten.split.Tensor(wait_tensor_807, 2); wait_tensor_807 = None + getitem_2291 = split_241[0] + getitem_2292 = split_241[1] + getitem_2293 = split_241[2] + getitem_2294 = split_241[3] + getitem_2295 = split_241[4] + getitem_2296 = split_241[5] + getitem_2297 = split_241[6] + getitem_2298 = split_241[7]; split_241 = None + cat_233 = torch.ops.aten.cat.default([getitem_2291, getitem_2292, getitem_2293, getitem_2294, getitem_2295, getitem_2296, getitem_2297, getitem_2298], 1); getitem_2291 = getitem_2292 = getitem_2293 = getitem_2294 = getitem_2295 = getitem_2296 = getitem_2297 = getitem_2298 = None + view_2929 = torch.ops.aten.view.default(cat_233, [16384, 4096]); cat_233 = None + permute_1169 = torch.ops.aten.permute.default(view_2929, [1, 0]) + permute_72 = torch.ops.aten.permute.default(getitem_326, [0, 2, 1, 3]) + view_474 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + view_480 = torch.ops.aten.view.default(view_474, [16384, 512]); view_474 = None + mm_583 = torch.ops.aten.mm.default(permute_1169, view_480); permute_1169 = view_480 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16); primals_62 = None + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 32, '0'); convert_element_type_215 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + permute_1171 = torch.ops.aten.permute.default(permute_73, [1, 0]); permute_73 = None + mm_584 = torch.ops.aten.mm.default(view_2929, permute_1171); view_2929 = permute_1171 = None + view_2930 = torch.ops.aten.view.default(mm_584, [2, 8192, 512]); mm_584 = None + convert_element_type_2452 = torch.ops.prims.convert_element_type.default(mm_583, torch.float32); mm_583 = None + reduce_scatter_tensor_348 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2452, 'avg', 32, '0'); convert_element_type_2452 = None + wait_tensor_808 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_348); reduce_scatter_tensor_348 = None + view_2931 = torch.ops.aten.view.default(view_2930, [2, 8192, 4, 128]); view_2930 = None + permute_1173 = torch.ops.aten.permute.default(view_2931, [0, 2, 1, 3]); view_2931 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16); primals_58 = None + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 32, '0'); convert_element_type_199 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32); add_23 = None + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_80) + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_201, 8, '1'); convert_element_type_201 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + split_33 = torch.ops.aten.split.Tensor(wait_tensor_81, 2); wait_tensor_81 = None + getitem_318 = split_33[0] + getitem_319 = split_33[1] + getitem_320 = split_33[2] + getitem_321 = split_33[3] + getitem_322 = split_33[4] + getitem_323 = split_33[5] + getitem_324 = split_33[6] + getitem_325 = split_33[7]; split_33 = None + cat_25 = torch.ops.aten.cat.default([getitem_318, getitem_319, getitem_320, getitem_321, getitem_322, getitem_323, getitem_324, getitem_325], 1); getitem_318 = getitem_319 = getitem_320 = getitem_321 = getitem_322 = getitem_323 = getitem_324 = getitem_325 = None + view_447 = torch.ops.aten.view.default(cat_25, [16384, 4096]); cat_25 = None + view_448 = torch.ops.aten.view.default(mm_42, [2, 8192, 512]); mm_42 = None + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16); primals_60 = None + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 32, '0'); convert_element_type_205 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + mm_43 = torch.ops.aten.mm.default(view_447, permute_67) + view_455 = torch.ops.aten.view.default(mm_43, [2, 8192, 128]); mm_43 = None + view_462 = torch.ops.aten.view.default(mm_44, [2, 8192, 128]); mm_44 = None + view_464 = torch.ops.aten.view.default(view_448, [2, 8192, -1, 128]); view_448 = None + view_465 = torch.ops.aten.view.default(view_455, [2, 8192, -1, 128]); view_455 = None + view_466 = torch.ops.aten.view.default(view_462, [2, 8192, -1, 128]); view_462 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_464, torch.float32); view_464 = None + view_467 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 4, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_467); view_467 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_465, torch.float32); view_465 = None + view_468 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 1, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_468); view_468 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_37); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_470 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 4, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_37); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_471 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 1, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_470, torch.bfloat16); view_470 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_471, torch.bfloat16); view_471 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 1, 4, 128]); unsqueeze_12 = None + view_472 = torch.ops.aten.view.default(expand_12, [2, 8192, 4, 128]); expand_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_466, 3); view_466 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 1, 4, 128]); unsqueeze_13 = None + view_473 = torch.ops.aten.view.default(expand_13, [2, 8192, 4, 128]); expand_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_472, [0, 2, 1, 3]); view_472 = None + permute_71 = torch.ops.aten.permute.default(view_473, [0, 2, 1, 3]); view_473 = None + _scaled_dot_product_cudnn_attention_backward_25 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1173, permute_69, permute_70, permute_71, getitem_326, getitem_327, getitem_332, getitem_333, None, None, None, 8192, 8192, 0.0, True); permute_1173 = permute_69 = permute_70 = permute_71 = getitem_326 = getitem_327 = getitem_332 = getitem_333 = None + getitem_2299 = _scaled_dot_product_cudnn_attention_backward_25[0] + getitem_2300 = _scaled_dot_product_cudnn_attention_backward_25[1] + getitem_2301 = _scaled_dot_product_cudnn_attention_backward_25[2]; _scaled_dot_product_cudnn_attention_backward_25 = None + permute_1174 = torch.ops.aten.permute.default(getitem_2301, [0, 2, 1, 3]); getitem_2301 = None + permute_1175 = torch.ops.aten.permute.default(getitem_2300, [0, 2, 1, 3]); getitem_2300 = None + permute_1176 = torch.ops.aten.permute.default(getitem_2299, [0, 2, 1, 3]); getitem_2299 = None + view_2932 = torch.ops.aten.view.default(permute_1174, [2, 8192, 1, 4, 128]); permute_1174 = None + sum_155 = torch.ops.aten.sum.dim_IntList(view_2932, [3], True); view_2932 = None + squeeze_50 = torch.ops.aten.squeeze.dim(sum_155, 3); sum_155 = None + view_2933 = torch.ops.aten.view.default(permute_1175, [2, 8192, 1, 4, 128]); permute_1175 = None + sum_156 = torch.ops.aten.sum.dim_IntList(view_2933, [3], True); view_2933 = None + squeeze_51 = torch.ops.aten.squeeze.dim(sum_156, 3); sum_156 = None + convert_element_type_2453 = torch.ops.prims.convert_element_type.default(squeeze_51, torch.float32); squeeze_51 = None + convert_element_type_2454 = torch.ops.prims.convert_element_type.default(permute_1176, torch.float32); permute_1176 = None + view_2934 = torch.ops.aten.view.default(convert_element_type_2453, [2, 8192, 1, 64, 2]); convert_element_type_2453 = None + view_as_complex_114 = torch.ops.aten.view_as_complex.default(view_2934); view_2934 = None + mul_776 = torch.ops.aten.mul.Tensor(view_as_complex_114, _conj); view_as_complex_114 = None + view_2935 = torch.ops.aten.view.default(convert_element_type_2454, [2, 8192, 4, 64, 2]); convert_element_type_2454 = None + view_as_complex_115 = torch.ops.aten.view_as_complex.default(view_2935); view_2935 = None + mul_777 = torch.ops.aten.mul.Tensor(view_as_complex_115, _conj); view_as_complex_115 = None + view_as_real_114 = torch.ops.aten.view_as_real.default(mul_776); mul_776 = None + view_2936 = torch.ops.aten.view.default(view_as_real_114, [2, 8192, 1, 128]); view_as_real_114 = None + convert_element_type_2455 = torch.ops.prims.convert_element_type.default(view_2936, torch.bfloat16); view_2936 = None + view_as_real_115 = torch.ops.aten.view_as_real.default(mul_777); mul_777 = None + view_2937 = torch.ops.aten.view.default(view_as_real_115, [2, 8192, 4, 128]); view_as_real_115 = None + convert_element_type_2456 = torch.ops.prims.convert_element_type.default(view_2937, torch.bfloat16); view_2937 = None + view_2938 = torch.ops.aten.view.default(squeeze_50, [2, 8192, 128]); squeeze_50 = None + view_2939 = torch.ops.aten.view.default(convert_element_type_2455, [2, 8192, 128]); convert_element_type_2455 = None + view_2940 = torch.ops.aten.view.default(convert_element_type_2456, [2, 8192, 512]); convert_element_type_2456 = None + view_2941 = torch.ops.aten.view.default(view_2938, [16384, 128]); view_2938 = None + permute_1177 = torch.ops.aten.permute.default(view_2941, [1, 0]) + mm_585 = torch.ops.aten.mm.default(permute_1177, view_447); permute_1177 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16); primals_61 = None + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 32, '0'); convert_element_type_208 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + permute_1179 = torch.ops.aten.permute.default(permute_68, [1, 0]); permute_68 = None + mm_586 = torch.ops.aten.mm.default(view_2941, permute_1179); view_2941 = permute_1179 = None + view_2942 = torch.ops.aten.view.default(mm_586, [2, 8192, 4096]); mm_586 = None + convert_element_type_2461 = torch.ops.prims.convert_element_type.default(mm_585, torch.float32); mm_585 = None + reduce_scatter_tensor_349 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2461, 'avg', 32, '0'); convert_element_type_2461 = None + wait_tensor_809 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_349); reduce_scatter_tensor_349 = None + view_2943 = torch.ops.aten.view.default(view_2939, [16384, 128]); view_2939 = None + permute_1181 = torch.ops.aten.permute.default(view_2943, [1, 0]) + mm_587 = torch.ops.aten.mm.default(permute_1181, view_447); permute_1181 = None + permute_1183 = torch.ops.aten.permute.default(permute_67, [1, 0]); permute_67 = None + mm_588 = torch.ops.aten.mm.default(view_2943, permute_1183); view_2943 = permute_1183 = None + view_2944 = torch.ops.aten.view.default(mm_588, [2, 8192, 4096]); mm_588 = None + add_308 = torch.ops.aten.add.Tensor(view_2942, view_2944); view_2942 = view_2944 = None + convert_element_type_2466 = torch.ops.prims.convert_element_type.default(mm_587, torch.float32); mm_587 = None + reduce_scatter_tensor_350 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2466, 'avg', 32, '0'); convert_element_type_2466 = None + wait_tensor_810 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_350); reduce_scatter_tensor_350 = None + view_2945 = torch.ops.aten.view.default(view_2940, [16384, 512]); view_2940 = None + permute_1185 = torch.ops.aten.permute.default(view_2945, [1, 0]) + mm_589 = torch.ops.aten.mm.default(permute_1185, view_447); permute_1185 = view_447 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16); primals_59 = None + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 32, '0'); convert_element_type_202 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_82, [1, 0]); wait_tensor_82 = None + permute_1187 = torch.ops.aten.permute.default(permute_66, [1, 0]); permute_66 = None + mm_590 = torch.ops.aten.mm.default(view_2945, permute_1187); view_2945 = permute_1187 = None + view_2946 = torch.ops.aten.view.default(mm_590, [2, 8192, 4096]); mm_590 = None + add_309 = torch.ops.aten.add.Tensor(add_308, view_2946); add_308 = view_2946 = None + convert_element_type_2471 = torch.ops.prims.convert_element_type.default(mm_589, torch.float32); mm_589 = None + reduce_scatter_tensor_351 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2471, 'avg', 32, '0'); convert_element_type_2471 = None + wait_tensor_811 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_351); reduce_scatter_tensor_351 = None + split_242 = torch.ops.aten.split.Tensor(add_309, 1024, 1); add_309 = None + getitem_2302 = split_242[0] + getitem_2303 = split_242[1] + getitem_2304 = split_242[2] + getitem_2305 = split_242[3] + getitem_2306 = split_242[4] + getitem_2307 = split_242[5] + getitem_2308 = split_242[6] + getitem_2309 = split_242[7]; split_242 = None + cat_234 = torch.ops.aten.cat.default([getitem_2302, getitem_2303, getitem_2304, getitem_2305, getitem_2306, getitem_2307, getitem_2308, getitem_2309]); getitem_2302 = getitem_2303 = getitem_2304 = getitem_2305 = getitem_2306 = getitem_2307 = getitem_2308 = getitem_2309 = None + reduce_scatter_tensor_352 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_234, 'sum', 8, '1'); cat_234 = None + wait_tensor_812 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_352); reduce_scatter_tensor_352 = None + convert_element_type_2472 = torch.ops.prims.convert_element_type.default(wait_tensor_812, torch.float32); wait_tensor_812 = None + convert_element_type_2474 = torch.ops.prims.convert_element_type.default(wait_tensor_80, torch.float32); wait_tensor_80 = None + mul_778 = torch.ops.aten.mul.Tensor(convert_element_type_2472, convert_element_type_2474); convert_element_type_2474 = None + mul_780 = torch.ops.aten.mul.Tensor(mul_48, mul_778) + sum_157 = torch.ops.aten.sum.dim_IntList(mul_780, [2], True); mul_780 = None + div_52 = torch.ops.aten.div.Tensor(mul_48, 4096) + mul_781 = torch.ops.aten.mul.Tensor(div_52, sum_157); div_52 = sum_157 = None + sub_79 = torch.ops.aten.sub.Tensor(mul_778, mul_781); mul_778 = mul_781 = None + mul_782 = torch.ops.aten.mul.Tensor(sub_79, rsqrt_12); sub_79 = rsqrt_12 = None + mul_783 = torch.ops.aten.mul.Tensor(convert_element_type_2472, mul_48); convert_element_type_2472 = mul_48 = None + sum_158 = torch.ops.aten.sum.dim_IntList(mul_783, [0, 1]); mul_783 = None + convert_element_type_2475 = torch.ops.prims.convert_element_type.default(mul_782, torch.bfloat16); mul_782 = None + convert_element_type_2476 = torch.ops.prims.convert_element_type.default(sum_158, torch.bfloat16); sum_158 = None + all_reduce_52 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2476, 'sum', '1'); convert_element_type_2476 = None + wait_tensor_813 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_52); all_reduce_52 = None + convert_element_type_2477 = torch.ops.prims.convert_element_type.default(wait_tensor_813, torch.float32); wait_tensor_813 = None + reduce_scatter_tensor_353 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2477, 'avg', 32, '0'); convert_element_type_2477 = None + wait_tensor_814 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_353); reduce_scatter_tensor_353 = None + add_310 = torch.ops.aten.add.Tensor(add_307, convert_element_type_2475); add_307 = convert_element_type_2475 = None + all_gather_into_tensor_408 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_310, 8, '1') + wait_tensor_815 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_408); all_gather_into_tensor_408 = None + split_243 = torch.ops.aten.split.Tensor(wait_tensor_815, 2); wait_tensor_815 = None + getitem_2310 = split_243[0] + getitem_2311 = split_243[1] + getitem_2312 = split_243[2] + getitem_2313 = split_243[3] + getitem_2314 = split_243[4] + getitem_2315 = split_243[5] + getitem_2316 = split_243[6] + getitem_2317 = split_243[7]; split_243 = None + cat_235 = torch.ops.aten.cat.default([getitem_2310, getitem_2311, getitem_2312, getitem_2313, getitem_2314, getitem_2315, getitem_2316, getitem_2317], 1); getitem_2310 = getitem_2311 = getitem_2312 = getitem_2313 = getitem_2314 = getitem_2315 = getitem_2316 = getitem_2317 = None + view_2947 = torch.ops.aten.view.default(cat_235, [16384, 4096]); cat_235 = None + permute_1189 = torch.ops.aten.permute.default(view_2947, [1, 0]) + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11); reduce_scatter_tensor_11 = None + add_21 = torch.ops.aten.add.Tensor(add_19, wait_tensor_73); wait_tensor_73 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16); primals_54 = None + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 32, '0'); convert_element_type_185 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32); add_21 = None + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_74) + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_187, 8, '1'); convert_element_type_187 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + split_31 = torch.ops.aten.split.Tensor(wait_tensor_75, 2); wait_tensor_75 = None + getitem_302 = split_31[0] + getitem_303 = split_31[1] + getitem_304 = split_31[2] + getitem_305 = split_31[3] + getitem_306 = split_31[4] + getitem_307 = split_31[5] + getitem_308 = split_31[6] + getitem_309 = split_31[7]; split_31 = None + cat_23 = torch.ops.aten.cat.default([getitem_302, getitem_303, getitem_304, getitem_305, getitem_306, getitem_307, getitem_308, getitem_309], 1); getitem_302 = getitem_303 = getitem_304 = getitem_305 = getitem_306 = getitem_307 = getitem_308 = getitem_309 = None + view_420 = torch.ops.aten.view.default(cat_23, [16384, 4096]); cat_23 = None + view_421 = torch.ops.aten.view.default(mm_39, [2, 8192, 1792]); mm_39 = None + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_421, torch.float32); view_421 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16); primals_56 = None + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 32, '0'); convert_element_type_193 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + mm_40 = torch.ops.aten.mm.default(view_420, permute_64) + view_428 = torch.ops.aten.view.default(mm_40, [2, 8192, 1792]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_428) + view_435 = torch.ops.aten.view.default(mul_47, [16384, 1792]); mul_47 = None + mm_591 = torch.ops.aten.mm.default(permute_1189, view_435); permute_1189 = view_435 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16); primals_57 = None + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 32, '0'); convert_element_type_196 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_78, [1, 0]); wait_tensor_78 = None + permute_1191 = torch.ops.aten.permute.default(permute_65, [1, 0]); permute_65 = None + mm_592 = torch.ops.aten.mm.default(view_2947, permute_1191); view_2947 = permute_1191 = None + view_2948 = torch.ops.aten.view.default(mm_592, [2, 8192, 1792]); mm_592 = None + convert_element_type_2482 = torch.ops.prims.convert_element_type.default(mm_591, torch.float32); mm_591 = None + reduce_scatter_tensor_354 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2482, 'avg', 32, '0'); convert_element_type_2482 = None + wait_tensor_816 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_354); reduce_scatter_tensor_354 = None + mul_784 = torch.ops.aten.mul.Tensor(view_2948, convert_element_type_192); convert_element_type_192 = None + mul_785 = torch.ops.aten.mul.Tensor(view_2948, view_428); view_2948 = view_428 = None + view_2949 = torch.ops.aten.view.default(mul_784, [16384, 1792]); mul_784 = None + permute_1193 = torch.ops.aten.permute.default(view_2949, [1, 0]) + mm_593 = torch.ops.aten.mm.default(permute_1193, view_420); permute_1193 = None + permute_1195 = torch.ops.aten.permute.default(permute_64, [1, 0]); permute_64 = None + mm_594 = torch.ops.aten.mm.default(view_2949, permute_1195); view_2949 = permute_1195 = None + view_2950 = torch.ops.aten.view.default(mm_594, [2, 8192, 4096]); mm_594 = None + convert_element_type_2487 = torch.ops.prims.convert_element_type.default(mm_593, torch.float32); mm_593 = None + reduce_scatter_tensor_355 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2487, 'avg', 32, '0'); convert_element_type_2487 = None + wait_tensor_817 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_355); reduce_scatter_tensor_355 = None + convert_element_type_2488 = torch.ops.prims.convert_element_type.default(mul_785, torch.float32); mul_785 = None + neg_26 = torch.ops.aten.neg.default(convert_element_type_191) + exp_26 = torch.ops.aten.exp.default(neg_26); neg_26 = None + add_311 = torch.ops.aten.add.Tensor(exp_26, 1); exp_26 = None + reciprocal_26 = torch.ops.aten.reciprocal.default(add_311); add_311 = None + mul_786 = torch.ops.aten.mul.Tensor(reciprocal_26, 1); reciprocal_26 = None + mul_787 = torch.ops.aten.mul.Tensor(convert_element_type_2488, mul_786); convert_element_type_2488 = None + sub_80 = torch.ops.aten.sub.Tensor(1, mul_786); mul_786 = None + mul_788 = torch.ops.aten.mul.Tensor(convert_element_type_191, sub_80); convert_element_type_191 = sub_80 = None + add_312 = torch.ops.aten.add.Tensor(mul_788, 1); mul_788 = None + mul_789 = torch.ops.aten.mul.Tensor(mul_787, add_312); mul_787 = add_312 = None + convert_element_type_2490 = torch.ops.prims.convert_element_type.default(mul_789, torch.bfloat16); mul_789 = None + view_2951 = torch.ops.aten.view.default(convert_element_type_2490, [16384, 1792]); convert_element_type_2490 = None + permute_1197 = torch.ops.aten.permute.default(view_2951, [1, 0]) + mm_595 = torch.ops.aten.mm.default(permute_1197, view_420); permute_1197 = view_420 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16); primals_55 = None + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 32, '0'); convert_element_type_188 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + permute_1199 = torch.ops.aten.permute.default(permute_63, [1, 0]); permute_63 = None + mm_596 = torch.ops.aten.mm.default(view_2951, permute_1199); view_2951 = permute_1199 = None + view_2952 = torch.ops.aten.view.default(mm_596, [2, 8192, 4096]); mm_596 = None + add_313 = torch.ops.aten.add.Tensor(view_2950, view_2952); view_2950 = view_2952 = None + convert_element_type_2495 = torch.ops.prims.convert_element_type.default(mm_595, torch.float32); mm_595 = None + reduce_scatter_tensor_356 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2495, 'avg', 32, '0'); convert_element_type_2495 = None + wait_tensor_818 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_356); reduce_scatter_tensor_356 = None + split_244 = torch.ops.aten.split.Tensor(add_313, 1024, 1); add_313 = None + getitem_2318 = split_244[0] + getitem_2319 = split_244[1] + getitem_2320 = split_244[2] + getitem_2321 = split_244[3] + getitem_2322 = split_244[4] + getitem_2323 = split_244[5] + getitem_2324 = split_244[6] + getitem_2325 = split_244[7]; split_244 = None + cat_236 = torch.ops.aten.cat.default([getitem_2318, getitem_2319, getitem_2320, getitem_2321, getitem_2322, getitem_2323, getitem_2324, getitem_2325]); getitem_2318 = getitem_2319 = getitem_2320 = getitem_2321 = getitem_2322 = getitem_2323 = getitem_2324 = getitem_2325 = None + reduce_scatter_tensor_357 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_236, 'sum', 8, '1'); cat_236 = None + wait_tensor_819 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_357); reduce_scatter_tensor_357 = None + convert_element_type_2496 = torch.ops.prims.convert_element_type.default(wait_tensor_819, torch.float32); wait_tensor_819 = None + convert_element_type_2498 = torch.ops.prims.convert_element_type.default(wait_tensor_74, torch.float32); wait_tensor_74 = None + mul_790 = torch.ops.aten.mul.Tensor(convert_element_type_2496, convert_element_type_2498); convert_element_type_2498 = None + mul_792 = torch.ops.aten.mul.Tensor(mul_44, mul_790) + sum_159 = torch.ops.aten.sum.dim_IntList(mul_792, [2], True); mul_792 = None + div_53 = torch.ops.aten.div.Tensor(mul_44, 4096) + mul_793 = torch.ops.aten.mul.Tensor(div_53, sum_159); div_53 = sum_159 = None + sub_81 = torch.ops.aten.sub.Tensor(mul_790, mul_793); mul_790 = mul_793 = None + mul_794 = torch.ops.aten.mul.Tensor(sub_81, rsqrt_11); sub_81 = rsqrt_11 = None + mul_795 = torch.ops.aten.mul.Tensor(convert_element_type_2496, mul_44); convert_element_type_2496 = mul_44 = None + sum_160 = torch.ops.aten.sum.dim_IntList(mul_795, [0, 1]); mul_795 = None + convert_element_type_2499 = torch.ops.prims.convert_element_type.default(mul_794, torch.bfloat16); mul_794 = None + convert_element_type_2500 = torch.ops.prims.convert_element_type.default(sum_160, torch.bfloat16); sum_160 = None + all_reduce_53 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2500, 'sum', '1'); convert_element_type_2500 = None + wait_tensor_820 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_53); all_reduce_53 = None + convert_element_type_2501 = torch.ops.prims.convert_element_type.default(wait_tensor_820, torch.float32); wait_tensor_820 = None + reduce_scatter_tensor_358 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2501, 'avg', 32, '0'); convert_element_type_2501 = None + wait_tensor_821 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_358); reduce_scatter_tensor_358 = None + add_314 = torch.ops.aten.add.Tensor(add_310, convert_element_type_2499); add_310 = convert_element_type_2499 = None + all_gather_into_tensor_409 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_314, 8, '1') + wait_tensor_822 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_409); all_gather_into_tensor_409 = None + split_245 = torch.ops.aten.split.Tensor(wait_tensor_822, 2); wait_tensor_822 = None + getitem_2326 = split_245[0] + getitem_2327 = split_245[1] + getitem_2328 = split_245[2] + getitem_2329 = split_245[3] + getitem_2330 = split_245[4] + getitem_2331 = split_245[5] + getitem_2332 = split_245[6] + getitem_2333 = split_245[7]; split_245 = None + cat_237 = torch.ops.aten.cat.default([getitem_2326, getitem_2327, getitem_2328, getitem_2329, getitem_2330, getitem_2331, getitem_2332, getitem_2333], 1); getitem_2326 = getitem_2327 = getitem_2328 = getitem_2329 = getitem_2330 = getitem_2331 = getitem_2332 = getitem_2333 = None + view_2953 = torch.ops.aten.view.default(cat_237, [16384, 4096]); cat_237 = None + permute_1201 = torch.ops.aten.permute.default(view_2953, [1, 0]) + permute_61 = torch.ops.aten.permute.default(getitem_285, [0, 2, 1, 3]) + view_402 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + view_408 = torch.ops.aten.view.default(view_402, [16384, 512]); view_402 = None + mm_597 = torch.ops.aten.mm.default(permute_1201, view_408); permute_1201 = view_408 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16); primals_53 = None + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 32, '0'); convert_element_type_182 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + permute_1203 = torch.ops.aten.permute.default(permute_62, [1, 0]); permute_62 = None + mm_598 = torch.ops.aten.mm.default(view_2953, permute_1203); view_2953 = permute_1203 = None + view_2954 = torch.ops.aten.view.default(mm_598, [2, 8192, 512]); mm_598 = None + convert_element_type_2506 = torch.ops.prims.convert_element_type.default(mm_597, torch.float32); mm_597 = None + reduce_scatter_tensor_359 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2506, 'avg', 32, '0'); convert_element_type_2506 = None + wait_tensor_823 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_359); reduce_scatter_tensor_359 = None + view_2955 = torch.ops.aten.view.default(view_2954, [2, 8192, 4, 128]); view_2954 = None + permute_1205 = torch.ops.aten.permute.default(view_2955, [0, 2, 1, 3]); view_2955 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16); primals_49 = None + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 32, '0'); convert_element_type_166 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32); add_19 = None + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_67) + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_168, 8, '1'); convert_element_type_168 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + split_29 = torch.ops.aten.split.Tensor(wait_tensor_68, 2); wait_tensor_68 = None + getitem_277 = split_29[0] + getitem_278 = split_29[1] + getitem_279 = split_29[2] + getitem_280 = split_29[3] + getitem_281 = split_29[4] + getitem_282 = split_29[5] + getitem_283 = split_29[6] + getitem_284 = split_29[7]; split_29 = None + cat_21 = torch.ops.aten.cat.default([getitem_277, getitem_278, getitem_279, getitem_280, getitem_281, getitem_282, getitem_283, getitem_284], 1); getitem_277 = getitem_278 = getitem_279 = getitem_280 = getitem_281 = getitem_282 = getitem_283 = getitem_284 = None + view_375 = torch.ops.aten.view.default(cat_21, [16384, 4096]); cat_21 = None + view_376 = torch.ops.aten.view.default(mm_35, [2, 8192, 512]); mm_35 = None + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16); primals_51 = None + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 32, '0'); convert_element_type_172 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + mm_36 = torch.ops.aten.mm.default(view_375, permute_56) + view_383 = torch.ops.aten.view.default(mm_36, [2, 8192, 128]); mm_36 = None + view_390 = torch.ops.aten.view.default(mm_37, [2, 8192, 128]); mm_37 = None + view_392 = torch.ops.aten.view.default(view_376, [2, 8192, -1, 128]); view_376 = None + view_393 = torch.ops.aten.view.default(view_383, [2, 8192, -1, 128]); view_383 = None + view_394 = torch.ops.aten.view.default(view_390, [2, 8192, -1, 128]); view_390 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_392, torch.float32); view_392 = None + view_395 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 4, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_395); view_395 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_393, torch.float32); view_393 = None + view_396 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 1, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_396); view_396 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_37); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_398 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 4, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_37); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_399 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 1, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_398, torch.bfloat16); view_398 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_399, torch.bfloat16); view_399 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 1, 4, 128]); unsqueeze_10 = None + view_400 = torch.ops.aten.view.default(expand_10, [2, 8192, 4, 128]); expand_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_394, 3); view_394 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 1, 4, 128]); unsqueeze_11 = None + view_401 = torch.ops.aten.view.default(expand_11, [2, 8192, 4, 128]); expand_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_400, [0, 2, 1, 3]); view_400 = None + permute_60 = torch.ops.aten.permute.default(view_401, [0, 2, 1, 3]); view_401 = None + _scaled_dot_product_cudnn_attention_backward_26 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1205, permute_58, permute_59, permute_60, getitem_285, getitem_286, getitem_291, getitem_292, None, None, None, 8192, 8192, 0.0, True); permute_1205 = permute_58 = permute_59 = permute_60 = getitem_285 = getitem_286 = getitem_291 = getitem_292 = None + getitem_2334 = _scaled_dot_product_cudnn_attention_backward_26[0] + getitem_2335 = _scaled_dot_product_cudnn_attention_backward_26[1] + getitem_2336 = _scaled_dot_product_cudnn_attention_backward_26[2]; _scaled_dot_product_cudnn_attention_backward_26 = None + permute_1206 = torch.ops.aten.permute.default(getitem_2336, [0, 2, 1, 3]); getitem_2336 = None + permute_1207 = torch.ops.aten.permute.default(getitem_2335, [0, 2, 1, 3]); getitem_2335 = None + permute_1208 = torch.ops.aten.permute.default(getitem_2334, [0, 2, 1, 3]); getitem_2334 = None + view_2956 = torch.ops.aten.view.default(permute_1206, [2, 8192, 1, 4, 128]); permute_1206 = None + sum_161 = torch.ops.aten.sum.dim_IntList(view_2956, [3], True); view_2956 = None + squeeze_52 = torch.ops.aten.squeeze.dim(sum_161, 3); sum_161 = None + view_2957 = torch.ops.aten.view.default(permute_1207, [2, 8192, 1, 4, 128]); permute_1207 = None + sum_162 = torch.ops.aten.sum.dim_IntList(view_2957, [3], True); view_2957 = None + squeeze_53 = torch.ops.aten.squeeze.dim(sum_162, 3); sum_162 = None + convert_element_type_2507 = torch.ops.prims.convert_element_type.default(squeeze_53, torch.float32); squeeze_53 = None + convert_element_type_2508 = torch.ops.prims.convert_element_type.default(permute_1208, torch.float32); permute_1208 = None + view_2958 = torch.ops.aten.view.default(convert_element_type_2507, [2, 8192, 1, 64, 2]); convert_element_type_2507 = None + view_as_complex_116 = torch.ops.aten.view_as_complex.default(view_2958); view_2958 = None + mul_796 = torch.ops.aten.mul.Tensor(view_as_complex_116, _conj); view_as_complex_116 = None + view_2959 = torch.ops.aten.view.default(convert_element_type_2508, [2, 8192, 4, 64, 2]); convert_element_type_2508 = None + view_as_complex_117 = torch.ops.aten.view_as_complex.default(view_2959); view_2959 = None + mul_797 = torch.ops.aten.mul.Tensor(view_as_complex_117, _conj); view_as_complex_117 = None + view_as_real_116 = torch.ops.aten.view_as_real.default(mul_796); mul_796 = None + view_2960 = torch.ops.aten.view.default(view_as_real_116, [2, 8192, 1, 128]); view_as_real_116 = None + convert_element_type_2509 = torch.ops.prims.convert_element_type.default(view_2960, torch.bfloat16); view_2960 = None + view_as_real_117 = torch.ops.aten.view_as_real.default(mul_797); mul_797 = None + view_2961 = torch.ops.aten.view.default(view_as_real_117, [2, 8192, 4, 128]); view_as_real_117 = None + convert_element_type_2510 = torch.ops.prims.convert_element_type.default(view_2961, torch.bfloat16); view_2961 = None + view_2962 = torch.ops.aten.view.default(squeeze_52, [2, 8192, 128]); squeeze_52 = None + view_2963 = torch.ops.aten.view.default(convert_element_type_2509, [2, 8192, 128]); convert_element_type_2509 = None + view_2964 = torch.ops.aten.view.default(convert_element_type_2510, [2, 8192, 512]); convert_element_type_2510 = None + view_2965 = torch.ops.aten.view.default(view_2962, [16384, 128]); view_2962 = None + permute_1209 = torch.ops.aten.permute.default(view_2965, [1, 0]) + mm_599 = torch.ops.aten.mm.default(permute_1209, view_375); permute_1209 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16); primals_52 = None + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 32, '0'); convert_element_type_175 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + permute_1211 = torch.ops.aten.permute.default(permute_57, [1, 0]); permute_57 = None + mm_600 = torch.ops.aten.mm.default(view_2965, permute_1211); view_2965 = permute_1211 = None + view_2966 = torch.ops.aten.view.default(mm_600, [2, 8192, 4096]); mm_600 = None + convert_element_type_2515 = torch.ops.prims.convert_element_type.default(mm_599, torch.float32); mm_599 = None + reduce_scatter_tensor_360 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2515, 'avg', 32, '0'); convert_element_type_2515 = None + wait_tensor_824 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_360); reduce_scatter_tensor_360 = None + view_2967 = torch.ops.aten.view.default(view_2963, [16384, 128]); view_2963 = None + permute_1213 = torch.ops.aten.permute.default(view_2967, [1, 0]) + mm_601 = torch.ops.aten.mm.default(permute_1213, view_375); permute_1213 = None + permute_1215 = torch.ops.aten.permute.default(permute_56, [1, 0]); permute_56 = None + mm_602 = torch.ops.aten.mm.default(view_2967, permute_1215); view_2967 = permute_1215 = None + view_2968 = torch.ops.aten.view.default(mm_602, [2, 8192, 4096]); mm_602 = None + add_315 = torch.ops.aten.add.Tensor(view_2966, view_2968); view_2966 = view_2968 = None + convert_element_type_2520 = torch.ops.prims.convert_element_type.default(mm_601, torch.float32); mm_601 = None + reduce_scatter_tensor_361 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2520, 'avg', 32, '0'); convert_element_type_2520 = None + wait_tensor_825 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_361); reduce_scatter_tensor_361 = None + view_2969 = torch.ops.aten.view.default(view_2964, [16384, 512]); view_2964 = None + permute_1217 = torch.ops.aten.permute.default(view_2969, [1, 0]) + mm_603 = torch.ops.aten.mm.default(permute_1217, view_375); permute_1217 = view_375 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16); primals_50 = None + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 32, '0'); convert_element_type_169 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_69, [1, 0]); wait_tensor_69 = None + permute_1219 = torch.ops.aten.permute.default(permute_55, [1, 0]); permute_55 = None + mm_604 = torch.ops.aten.mm.default(view_2969, permute_1219); view_2969 = permute_1219 = None + view_2970 = torch.ops.aten.view.default(mm_604, [2, 8192, 4096]); mm_604 = None + add_316 = torch.ops.aten.add.Tensor(add_315, view_2970); add_315 = view_2970 = None + convert_element_type_2525 = torch.ops.prims.convert_element_type.default(mm_603, torch.float32); mm_603 = None + reduce_scatter_tensor_362 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2525, 'avg', 32, '0'); convert_element_type_2525 = None + wait_tensor_826 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_362); reduce_scatter_tensor_362 = None + split_246 = torch.ops.aten.split.Tensor(add_316, 1024, 1); add_316 = None + getitem_2337 = split_246[0] + getitem_2338 = split_246[1] + getitem_2339 = split_246[2] + getitem_2340 = split_246[3] + getitem_2341 = split_246[4] + getitem_2342 = split_246[5] + getitem_2343 = split_246[6] + getitem_2344 = split_246[7]; split_246 = None + cat_238 = torch.ops.aten.cat.default([getitem_2337, getitem_2338, getitem_2339, getitem_2340, getitem_2341, getitem_2342, getitem_2343, getitem_2344]); getitem_2337 = getitem_2338 = getitem_2339 = getitem_2340 = getitem_2341 = getitem_2342 = getitem_2343 = getitem_2344 = None + reduce_scatter_tensor_363 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_238, 'sum', 8, '1'); cat_238 = None + wait_tensor_827 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_363); reduce_scatter_tensor_363 = None + convert_element_type_2526 = torch.ops.prims.convert_element_type.default(wait_tensor_827, torch.float32); wait_tensor_827 = None + convert_element_type_2528 = torch.ops.prims.convert_element_type.default(wait_tensor_67, torch.float32); wait_tensor_67 = None + mul_798 = torch.ops.aten.mul.Tensor(convert_element_type_2526, convert_element_type_2528); convert_element_type_2528 = None + mul_800 = torch.ops.aten.mul.Tensor(mul_40, mul_798) + sum_163 = torch.ops.aten.sum.dim_IntList(mul_800, [2], True); mul_800 = None + div_54 = torch.ops.aten.div.Tensor(mul_40, 4096) + mul_801 = torch.ops.aten.mul.Tensor(div_54, sum_163); div_54 = sum_163 = None + sub_82 = torch.ops.aten.sub.Tensor(mul_798, mul_801); mul_798 = mul_801 = None + mul_802 = torch.ops.aten.mul.Tensor(sub_82, rsqrt_10); sub_82 = rsqrt_10 = None + mul_803 = torch.ops.aten.mul.Tensor(convert_element_type_2526, mul_40); convert_element_type_2526 = mul_40 = None + sum_164 = torch.ops.aten.sum.dim_IntList(mul_803, [0, 1]); mul_803 = None + convert_element_type_2529 = torch.ops.prims.convert_element_type.default(mul_802, torch.bfloat16); mul_802 = None + convert_element_type_2530 = torch.ops.prims.convert_element_type.default(sum_164, torch.bfloat16); sum_164 = None + all_reduce_54 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2530, 'sum', '1'); convert_element_type_2530 = None + wait_tensor_828 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_54); all_reduce_54 = None + convert_element_type_2531 = torch.ops.prims.convert_element_type.default(wait_tensor_828, torch.float32); wait_tensor_828 = None + reduce_scatter_tensor_364 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2531, 'avg', 32, '0'); convert_element_type_2531 = None + wait_tensor_829 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_364); reduce_scatter_tensor_364 = None + add_317 = torch.ops.aten.add.Tensor(add_314, convert_element_type_2529); add_314 = convert_element_type_2529 = None + all_gather_into_tensor_410 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_317, 8, '1') + wait_tensor_830 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_410); all_gather_into_tensor_410 = None + split_247 = torch.ops.aten.split.Tensor(wait_tensor_830, 2); wait_tensor_830 = None + getitem_2345 = split_247[0] + getitem_2346 = split_247[1] + getitem_2347 = split_247[2] + getitem_2348 = split_247[3] + getitem_2349 = split_247[4] + getitem_2350 = split_247[5] + getitem_2351 = split_247[6] + getitem_2352 = split_247[7]; split_247 = None + cat_239 = torch.ops.aten.cat.default([getitem_2345, getitem_2346, getitem_2347, getitem_2348, getitem_2349, getitem_2350, getitem_2351, getitem_2352], 1); getitem_2345 = getitem_2346 = getitem_2347 = getitem_2348 = getitem_2349 = getitem_2350 = getitem_2351 = getitem_2352 = None + view_2971 = torch.ops.aten.view.default(cat_239, [16384, 4096]); cat_239 = None + permute_1221 = torch.ops.aten.permute.default(view_2971, [1, 0]) + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9); reduce_scatter_tensor_9 = None + add_17 = torch.ops.aten.add.Tensor(add_15, wait_tensor_60); wait_tensor_60 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16); primals_45 = None + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 32, '0'); convert_element_type_152 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32); add_17 = None + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_61) + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_154, 8, '1'); convert_element_type_154 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + split_27 = torch.ops.aten.split.Tensor(wait_tensor_62, 2); wait_tensor_62 = None + getitem_261 = split_27[0] + getitem_262 = split_27[1] + getitem_263 = split_27[2] + getitem_264 = split_27[3] + getitem_265 = split_27[4] + getitem_266 = split_27[5] + getitem_267 = split_27[6] + getitem_268 = split_27[7]; split_27 = None + cat_19 = torch.ops.aten.cat.default([getitem_261, getitem_262, getitem_263, getitem_264, getitem_265, getitem_266, getitem_267, getitem_268], 1); getitem_261 = getitem_262 = getitem_263 = getitem_264 = getitem_265 = getitem_266 = getitem_267 = getitem_268 = None + view_348 = torch.ops.aten.view.default(cat_19, [16384, 4096]); cat_19 = None + view_349 = torch.ops.aten.view.default(mm_32, [2, 8192, 1792]); mm_32 = None + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_349, torch.float32); view_349 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16); primals_47 = None + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 32, '0'); convert_element_type_160 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_64, [1, 0]); wait_tensor_64 = None + mm_33 = torch.ops.aten.mm.default(view_348, permute_53) + view_356 = torch.ops.aten.view.default(mm_33, [2, 8192, 1792]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_356) + view_363 = torch.ops.aten.view.default(mul_39, [16384, 1792]); mul_39 = None + mm_605 = torch.ops.aten.mm.default(permute_1221, view_363); permute_1221 = view_363 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16); primals_48 = None + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 32, '0'); convert_element_type_163 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + permute_1223 = torch.ops.aten.permute.default(permute_54, [1, 0]); permute_54 = None + mm_606 = torch.ops.aten.mm.default(view_2971, permute_1223); view_2971 = permute_1223 = None + view_2972 = torch.ops.aten.view.default(mm_606, [2, 8192, 1792]); mm_606 = None + convert_element_type_2536 = torch.ops.prims.convert_element_type.default(mm_605, torch.float32); mm_605 = None + reduce_scatter_tensor_365 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2536, 'avg', 32, '0'); convert_element_type_2536 = None + wait_tensor_831 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_365); reduce_scatter_tensor_365 = None + mul_804 = torch.ops.aten.mul.Tensor(view_2972, convert_element_type_159); convert_element_type_159 = None + mul_805 = torch.ops.aten.mul.Tensor(view_2972, view_356); view_2972 = view_356 = None + view_2973 = torch.ops.aten.view.default(mul_804, [16384, 1792]); mul_804 = None + permute_1225 = torch.ops.aten.permute.default(view_2973, [1, 0]) + mm_607 = torch.ops.aten.mm.default(permute_1225, view_348); permute_1225 = None + permute_1227 = torch.ops.aten.permute.default(permute_53, [1, 0]); permute_53 = None + mm_608 = torch.ops.aten.mm.default(view_2973, permute_1227); view_2973 = permute_1227 = None + view_2974 = torch.ops.aten.view.default(mm_608, [2, 8192, 4096]); mm_608 = None + convert_element_type_2541 = torch.ops.prims.convert_element_type.default(mm_607, torch.float32); mm_607 = None + reduce_scatter_tensor_366 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2541, 'avg', 32, '0'); convert_element_type_2541 = None + wait_tensor_832 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_366); reduce_scatter_tensor_366 = None + convert_element_type_2542 = torch.ops.prims.convert_element_type.default(mul_805, torch.float32); mul_805 = None + neg_27 = torch.ops.aten.neg.default(convert_element_type_158) + exp_27 = torch.ops.aten.exp.default(neg_27); neg_27 = None + add_318 = torch.ops.aten.add.Tensor(exp_27, 1); exp_27 = None + reciprocal_27 = torch.ops.aten.reciprocal.default(add_318); add_318 = None + mul_806 = torch.ops.aten.mul.Tensor(reciprocal_27, 1); reciprocal_27 = None + mul_807 = torch.ops.aten.mul.Tensor(convert_element_type_2542, mul_806); convert_element_type_2542 = None + sub_83 = torch.ops.aten.sub.Tensor(1, mul_806); mul_806 = None + mul_808 = torch.ops.aten.mul.Tensor(convert_element_type_158, sub_83); convert_element_type_158 = sub_83 = None + add_319 = torch.ops.aten.add.Tensor(mul_808, 1); mul_808 = None + mul_809 = torch.ops.aten.mul.Tensor(mul_807, add_319); mul_807 = add_319 = None + convert_element_type_2544 = torch.ops.prims.convert_element_type.default(mul_809, torch.bfloat16); mul_809 = None + view_2975 = torch.ops.aten.view.default(convert_element_type_2544, [16384, 1792]); convert_element_type_2544 = None + permute_1229 = torch.ops.aten.permute.default(view_2975, [1, 0]) + mm_609 = torch.ops.aten.mm.default(permute_1229, view_348); permute_1229 = view_348 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16); primals_46 = None + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 32, '0'); convert_element_type_155 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + permute_1231 = torch.ops.aten.permute.default(permute_52, [1, 0]); permute_52 = None + mm_610 = torch.ops.aten.mm.default(view_2975, permute_1231); view_2975 = permute_1231 = None + view_2976 = torch.ops.aten.view.default(mm_610, [2, 8192, 4096]); mm_610 = None + add_320 = torch.ops.aten.add.Tensor(view_2974, view_2976); view_2974 = view_2976 = None + convert_element_type_2549 = torch.ops.prims.convert_element_type.default(mm_609, torch.float32); mm_609 = None + reduce_scatter_tensor_367 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2549, 'avg', 32, '0'); convert_element_type_2549 = None + wait_tensor_833 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_367); reduce_scatter_tensor_367 = None + split_248 = torch.ops.aten.split.Tensor(add_320, 1024, 1); add_320 = None + getitem_2353 = split_248[0] + getitem_2354 = split_248[1] + getitem_2355 = split_248[2] + getitem_2356 = split_248[3] + getitem_2357 = split_248[4] + getitem_2358 = split_248[5] + getitem_2359 = split_248[6] + getitem_2360 = split_248[7]; split_248 = None + cat_240 = torch.ops.aten.cat.default([getitem_2353, getitem_2354, getitem_2355, getitem_2356, getitem_2357, getitem_2358, getitem_2359, getitem_2360]); getitem_2353 = getitem_2354 = getitem_2355 = getitem_2356 = getitem_2357 = getitem_2358 = getitem_2359 = getitem_2360 = None + reduce_scatter_tensor_368 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_240, 'sum', 8, '1'); cat_240 = None + wait_tensor_834 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_368); reduce_scatter_tensor_368 = None + convert_element_type_2550 = torch.ops.prims.convert_element_type.default(wait_tensor_834, torch.float32); wait_tensor_834 = None + convert_element_type_2552 = torch.ops.prims.convert_element_type.default(wait_tensor_61, torch.float32); wait_tensor_61 = None + mul_810 = torch.ops.aten.mul.Tensor(convert_element_type_2550, convert_element_type_2552); convert_element_type_2552 = None + mul_812 = torch.ops.aten.mul.Tensor(mul_36, mul_810) + sum_165 = torch.ops.aten.sum.dim_IntList(mul_812, [2], True); mul_812 = None + div_55 = torch.ops.aten.div.Tensor(mul_36, 4096) + mul_813 = torch.ops.aten.mul.Tensor(div_55, sum_165); div_55 = sum_165 = None + sub_84 = torch.ops.aten.sub.Tensor(mul_810, mul_813); mul_810 = mul_813 = None + mul_814 = torch.ops.aten.mul.Tensor(sub_84, rsqrt_9); sub_84 = rsqrt_9 = None + mul_815 = torch.ops.aten.mul.Tensor(convert_element_type_2550, mul_36); convert_element_type_2550 = mul_36 = None + sum_166 = torch.ops.aten.sum.dim_IntList(mul_815, [0, 1]); mul_815 = None + convert_element_type_2553 = torch.ops.prims.convert_element_type.default(mul_814, torch.bfloat16); mul_814 = None + convert_element_type_2554 = torch.ops.prims.convert_element_type.default(sum_166, torch.bfloat16); sum_166 = None + all_reduce_55 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2554, 'sum', '1'); convert_element_type_2554 = None + wait_tensor_835 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_55); all_reduce_55 = None + convert_element_type_2555 = torch.ops.prims.convert_element_type.default(wait_tensor_835, torch.float32); wait_tensor_835 = None + reduce_scatter_tensor_369 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2555, 'avg', 32, '0'); convert_element_type_2555 = None + wait_tensor_836 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_369); reduce_scatter_tensor_369 = None + add_321 = torch.ops.aten.add.Tensor(add_317, convert_element_type_2553); add_317 = convert_element_type_2553 = None + all_gather_into_tensor_411 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_321, 8, '1') + wait_tensor_837 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_411); all_gather_into_tensor_411 = None + split_249 = torch.ops.aten.split.Tensor(wait_tensor_837, 2); wait_tensor_837 = None + getitem_2361 = split_249[0] + getitem_2362 = split_249[1] + getitem_2363 = split_249[2] + getitem_2364 = split_249[3] + getitem_2365 = split_249[4] + getitem_2366 = split_249[5] + getitem_2367 = split_249[6] + getitem_2368 = split_249[7]; split_249 = None + cat_241 = torch.ops.aten.cat.default([getitem_2361, getitem_2362, getitem_2363, getitem_2364, getitem_2365, getitem_2366, getitem_2367, getitem_2368], 1); getitem_2361 = getitem_2362 = getitem_2363 = getitem_2364 = getitem_2365 = getitem_2366 = getitem_2367 = getitem_2368 = None + view_2977 = torch.ops.aten.view.default(cat_241, [16384, 4096]); cat_241 = None + permute_1233 = torch.ops.aten.permute.default(view_2977, [1, 0]) + permute_50 = torch.ops.aten.permute.default(getitem_244, [0, 2, 1, 3]) + view_330 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + view_336 = torch.ops.aten.view.default(view_330, [16384, 512]); view_330 = None + mm_611 = torch.ops.aten.mm.default(permute_1233, view_336); permute_1233 = view_336 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16); primals_44 = None + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 32, '0'); convert_element_type_149 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + permute_1235 = torch.ops.aten.permute.default(permute_51, [1, 0]); permute_51 = None + mm_612 = torch.ops.aten.mm.default(view_2977, permute_1235); view_2977 = permute_1235 = None + view_2978 = torch.ops.aten.view.default(mm_612, [2, 8192, 512]); mm_612 = None + convert_element_type_2560 = torch.ops.prims.convert_element_type.default(mm_611, torch.float32); mm_611 = None + reduce_scatter_tensor_370 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2560, 'avg', 32, '0'); convert_element_type_2560 = None + wait_tensor_838 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_370); reduce_scatter_tensor_370 = None + view_2979 = torch.ops.aten.view.default(view_2978, [2, 8192, 4, 128]); view_2978 = None + permute_1237 = torch.ops.aten.permute.default(view_2979, [0, 2, 1, 3]); view_2979 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16); primals_40 = None + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 32, '0'); convert_element_type_133 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32); add_15 = None + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_54) + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_135, 8, '1'); convert_element_type_135 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + split_25 = torch.ops.aten.split.Tensor(wait_tensor_55, 2); wait_tensor_55 = None + getitem_236 = split_25[0] + getitem_237 = split_25[1] + getitem_238 = split_25[2] + getitem_239 = split_25[3] + getitem_240 = split_25[4] + getitem_241 = split_25[5] + getitem_242 = split_25[6] + getitem_243 = split_25[7]; split_25 = None + cat_17 = torch.ops.aten.cat.default([getitem_236, getitem_237, getitem_238, getitem_239, getitem_240, getitem_241, getitem_242, getitem_243], 1); getitem_236 = getitem_237 = getitem_238 = getitem_239 = getitem_240 = getitem_241 = getitem_242 = getitem_243 = None + view_303 = torch.ops.aten.view.default(cat_17, [16384, 4096]); cat_17 = None + view_304 = torch.ops.aten.view.default(mm_28, [2, 8192, 512]); mm_28 = None + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16); primals_42 = None + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 32, '0'); convert_element_type_139 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_29 = torch.ops.aten.mm.default(view_303, permute_45) + view_311 = torch.ops.aten.view.default(mm_29, [2, 8192, 128]); mm_29 = None + view_318 = torch.ops.aten.view.default(mm_30, [2, 8192, 128]); mm_30 = None + view_320 = torch.ops.aten.view.default(view_304, [2, 8192, -1, 128]); view_304 = None + view_321 = torch.ops.aten.view.default(view_311, [2, 8192, -1, 128]); view_311 = None + view_322 = torch.ops.aten.view.default(view_318, [2, 8192, -1, 128]); view_318 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_320, torch.float32); view_320 = None + view_323 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 4, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_323); view_323 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_321, torch.float32); view_321 = None + view_324 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 1, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_324); view_324 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_37); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_326 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 4, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_37); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_327 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 1, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_326, torch.bfloat16); view_326 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_327, torch.bfloat16); view_327 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 1, 4, 128]); unsqueeze_8 = None + view_328 = torch.ops.aten.view.default(expand_8, [2, 8192, 4, 128]); expand_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_322, 3); view_322 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 1, 4, 128]); unsqueeze_9 = None + view_329 = torch.ops.aten.view.default(expand_9, [2, 8192, 4, 128]); expand_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_328, [0, 2, 1, 3]); view_328 = None + permute_49 = torch.ops.aten.permute.default(view_329, [0, 2, 1, 3]); view_329 = None + _scaled_dot_product_cudnn_attention_backward_27 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1237, permute_47, permute_48, permute_49, getitem_244, getitem_245, getitem_250, getitem_251, None, None, None, 8192, 8192, 0.0, True); permute_1237 = permute_47 = permute_48 = permute_49 = getitem_244 = getitem_245 = getitem_250 = getitem_251 = None + getitem_2369 = _scaled_dot_product_cudnn_attention_backward_27[0] + getitem_2370 = _scaled_dot_product_cudnn_attention_backward_27[1] + getitem_2371 = _scaled_dot_product_cudnn_attention_backward_27[2]; _scaled_dot_product_cudnn_attention_backward_27 = None + permute_1238 = torch.ops.aten.permute.default(getitem_2371, [0, 2, 1, 3]); getitem_2371 = None + permute_1239 = torch.ops.aten.permute.default(getitem_2370, [0, 2, 1, 3]); getitem_2370 = None + permute_1240 = torch.ops.aten.permute.default(getitem_2369, [0, 2, 1, 3]); getitem_2369 = None + view_2980 = torch.ops.aten.view.default(permute_1238, [2, 8192, 1, 4, 128]); permute_1238 = None + sum_167 = torch.ops.aten.sum.dim_IntList(view_2980, [3], True); view_2980 = None + squeeze_54 = torch.ops.aten.squeeze.dim(sum_167, 3); sum_167 = None + view_2981 = torch.ops.aten.view.default(permute_1239, [2, 8192, 1, 4, 128]); permute_1239 = None + sum_168 = torch.ops.aten.sum.dim_IntList(view_2981, [3], True); view_2981 = None + squeeze_55 = torch.ops.aten.squeeze.dim(sum_168, 3); sum_168 = None + convert_element_type_2561 = torch.ops.prims.convert_element_type.default(squeeze_55, torch.float32); squeeze_55 = None + convert_element_type_2562 = torch.ops.prims.convert_element_type.default(permute_1240, torch.float32); permute_1240 = None + view_2982 = torch.ops.aten.view.default(convert_element_type_2561, [2, 8192, 1, 64, 2]); convert_element_type_2561 = None + view_as_complex_118 = torch.ops.aten.view_as_complex.default(view_2982); view_2982 = None + mul_816 = torch.ops.aten.mul.Tensor(view_as_complex_118, _conj); view_as_complex_118 = None + view_2983 = torch.ops.aten.view.default(convert_element_type_2562, [2, 8192, 4, 64, 2]); convert_element_type_2562 = None + view_as_complex_119 = torch.ops.aten.view_as_complex.default(view_2983); view_2983 = None + mul_817 = torch.ops.aten.mul.Tensor(view_as_complex_119, _conj); view_as_complex_119 = None + view_as_real_118 = torch.ops.aten.view_as_real.default(mul_816); mul_816 = None + view_2984 = torch.ops.aten.view.default(view_as_real_118, [2, 8192, 1, 128]); view_as_real_118 = None + convert_element_type_2563 = torch.ops.prims.convert_element_type.default(view_2984, torch.bfloat16); view_2984 = None + view_as_real_119 = torch.ops.aten.view_as_real.default(mul_817); mul_817 = None + view_2985 = torch.ops.aten.view.default(view_as_real_119, [2, 8192, 4, 128]); view_as_real_119 = None + convert_element_type_2564 = torch.ops.prims.convert_element_type.default(view_2985, torch.bfloat16); view_2985 = None + view_2986 = torch.ops.aten.view.default(squeeze_54, [2, 8192, 128]); squeeze_54 = None + view_2987 = torch.ops.aten.view.default(convert_element_type_2563, [2, 8192, 128]); convert_element_type_2563 = None + view_2988 = torch.ops.aten.view.default(convert_element_type_2564, [2, 8192, 512]); convert_element_type_2564 = None + view_2989 = torch.ops.aten.view.default(view_2986, [16384, 128]); view_2986 = None + permute_1241 = torch.ops.aten.permute.default(view_2989, [1, 0]) + mm_613 = torch.ops.aten.mm.default(permute_1241, view_303); permute_1241 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16); primals_43 = None + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 32, '0'); convert_element_type_142 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + permute_1243 = torch.ops.aten.permute.default(permute_46, [1, 0]); permute_46 = None + mm_614 = torch.ops.aten.mm.default(view_2989, permute_1243); view_2989 = permute_1243 = None + view_2990 = torch.ops.aten.view.default(mm_614, [2, 8192, 4096]); mm_614 = None + convert_element_type_2569 = torch.ops.prims.convert_element_type.default(mm_613, torch.float32); mm_613 = None + reduce_scatter_tensor_371 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2569, 'avg', 32, '0'); convert_element_type_2569 = None + wait_tensor_839 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_371); reduce_scatter_tensor_371 = None + view_2991 = torch.ops.aten.view.default(view_2987, [16384, 128]); view_2987 = None + permute_1245 = torch.ops.aten.permute.default(view_2991, [1, 0]) + mm_615 = torch.ops.aten.mm.default(permute_1245, view_303); permute_1245 = None + permute_1247 = torch.ops.aten.permute.default(permute_45, [1, 0]); permute_45 = None + mm_616 = torch.ops.aten.mm.default(view_2991, permute_1247); view_2991 = permute_1247 = None + view_2992 = torch.ops.aten.view.default(mm_616, [2, 8192, 4096]); mm_616 = None + add_322 = torch.ops.aten.add.Tensor(view_2990, view_2992); view_2990 = view_2992 = None + convert_element_type_2574 = torch.ops.prims.convert_element_type.default(mm_615, torch.float32); mm_615 = None + reduce_scatter_tensor_372 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2574, 'avg', 32, '0'); convert_element_type_2574 = None + wait_tensor_840 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_372); reduce_scatter_tensor_372 = None + view_2993 = torch.ops.aten.view.default(view_2988, [16384, 512]); view_2988 = None + permute_1249 = torch.ops.aten.permute.default(view_2993, [1, 0]) + mm_617 = torch.ops.aten.mm.default(permute_1249, view_303); permute_1249 = view_303 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16); primals_41 = None + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 32, '0'); convert_element_type_136 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + permute_1251 = torch.ops.aten.permute.default(permute_44, [1, 0]); permute_44 = None + mm_618 = torch.ops.aten.mm.default(view_2993, permute_1251); view_2993 = permute_1251 = None + view_2994 = torch.ops.aten.view.default(mm_618, [2, 8192, 4096]); mm_618 = None + add_323 = torch.ops.aten.add.Tensor(add_322, view_2994); add_322 = view_2994 = None + convert_element_type_2579 = torch.ops.prims.convert_element_type.default(mm_617, torch.float32); mm_617 = None + reduce_scatter_tensor_373 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2579, 'avg', 32, '0'); convert_element_type_2579 = None + wait_tensor_841 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_373); reduce_scatter_tensor_373 = None + split_250 = torch.ops.aten.split.Tensor(add_323, 1024, 1); add_323 = None + getitem_2372 = split_250[0] + getitem_2373 = split_250[1] + getitem_2374 = split_250[2] + getitem_2375 = split_250[3] + getitem_2376 = split_250[4] + getitem_2377 = split_250[5] + getitem_2378 = split_250[6] + getitem_2379 = split_250[7]; split_250 = None + cat_242 = torch.ops.aten.cat.default([getitem_2372, getitem_2373, getitem_2374, getitem_2375, getitem_2376, getitem_2377, getitem_2378, getitem_2379]); getitem_2372 = getitem_2373 = getitem_2374 = getitem_2375 = getitem_2376 = getitem_2377 = getitem_2378 = getitem_2379 = None + reduce_scatter_tensor_374 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_242, 'sum', 8, '1'); cat_242 = None + wait_tensor_842 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_374); reduce_scatter_tensor_374 = None + convert_element_type_2580 = torch.ops.prims.convert_element_type.default(wait_tensor_842, torch.float32); wait_tensor_842 = None + convert_element_type_2582 = torch.ops.prims.convert_element_type.default(wait_tensor_54, torch.float32); wait_tensor_54 = None + mul_818 = torch.ops.aten.mul.Tensor(convert_element_type_2580, convert_element_type_2582); convert_element_type_2582 = None + mul_820 = torch.ops.aten.mul.Tensor(mul_32, mul_818) + sum_169 = torch.ops.aten.sum.dim_IntList(mul_820, [2], True); mul_820 = None + div_56 = torch.ops.aten.div.Tensor(mul_32, 4096) + mul_821 = torch.ops.aten.mul.Tensor(div_56, sum_169); div_56 = sum_169 = None + sub_85 = torch.ops.aten.sub.Tensor(mul_818, mul_821); mul_818 = mul_821 = None + mul_822 = torch.ops.aten.mul.Tensor(sub_85, rsqrt_8); sub_85 = rsqrt_8 = None + mul_823 = torch.ops.aten.mul.Tensor(convert_element_type_2580, mul_32); convert_element_type_2580 = mul_32 = None + sum_170 = torch.ops.aten.sum.dim_IntList(mul_823, [0, 1]); mul_823 = None + convert_element_type_2583 = torch.ops.prims.convert_element_type.default(mul_822, torch.bfloat16); mul_822 = None + convert_element_type_2584 = torch.ops.prims.convert_element_type.default(sum_170, torch.bfloat16); sum_170 = None + all_reduce_56 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2584, 'sum', '1'); convert_element_type_2584 = None + wait_tensor_843 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_56); all_reduce_56 = None + convert_element_type_2585 = torch.ops.prims.convert_element_type.default(wait_tensor_843, torch.float32); wait_tensor_843 = None + reduce_scatter_tensor_375 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2585, 'avg', 32, '0'); convert_element_type_2585 = None + wait_tensor_844 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_375); reduce_scatter_tensor_375 = None + add_324 = torch.ops.aten.add.Tensor(add_321, convert_element_type_2583); add_321 = convert_element_type_2583 = None + all_gather_into_tensor_412 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_324, 8, '1') + wait_tensor_845 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_412); all_gather_into_tensor_412 = None + split_251 = torch.ops.aten.split.Tensor(wait_tensor_845, 2); wait_tensor_845 = None + getitem_2380 = split_251[0] + getitem_2381 = split_251[1] + getitem_2382 = split_251[2] + getitem_2383 = split_251[3] + getitem_2384 = split_251[4] + getitem_2385 = split_251[5] + getitem_2386 = split_251[6] + getitem_2387 = split_251[7]; split_251 = None + cat_243 = torch.ops.aten.cat.default([getitem_2380, getitem_2381, getitem_2382, getitem_2383, getitem_2384, getitem_2385, getitem_2386, getitem_2387], 1); getitem_2380 = getitem_2381 = getitem_2382 = getitem_2383 = getitem_2384 = getitem_2385 = getitem_2386 = getitem_2387 = None + view_2995 = torch.ops.aten.view.default(cat_243, [16384, 4096]); cat_243 = None + permute_1253 = torch.ops.aten.permute.default(view_2995, [1, 0]) + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7); reduce_scatter_tensor_7 = None + add_13 = torch.ops.aten.add.Tensor(add_11, wait_tensor_47); wait_tensor_47 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16); primals_36 = None + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 32, '0'); convert_element_type_119 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32); add_13 = None + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_48) + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_121, 8, '1'); convert_element_type_121 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + split_23 = torch.ops.aten.split.Tensor(wait_tensor_49, 2); wait_tensor_49 = None + getitem_220 = split_23[0] + getitem_221 = split_23[1] + getitem_222 = split_23[2] + getitem_223 = split_23[3] + getitem_224 = split_23[4] + getitem_225 = split_23[5] + getitem_226 = split_23[6] + getitem_227 = split_23[7]; split_23 = None + cat_15 = torch.ops.aten.cat.default([getitem_220, getitem_221, getitem_222, getitem_223, getitem_224, getitem_225, getitem_226, getitem_227], 1); getitem_220 = getitem_221 = getitem_222 = getitem_223 = getitem_224 = getitem_225 = getitem_226 = getitem_227 = None + view_276 = torch.ops.aten.view.default(cat_15, [16384, 4096]); cat_15 = None + view_277 = torch.ops.aten.view.default(mm_25, [2, 8192, 1792]); mm_25 = None + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_277, torch.float32); view_277 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16); primals_38 = None + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 32, '0'); convert_element_type_127 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_51, [1, 0]); wait_tensor_51 = None + mm_26 = torch.ops.aten.mm.default(view_276, permute_42) + view_284 = torch.ops.aten.view.default(mm_26, [2, 8192, 1792]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_284) + view_291 = torch.ops.aten.view.default(mul_31, [16384, 1792]); mul_31 = None + mm_619 = torch.ops.aten.mm.default(permute_1253, view_291); permute_1253 = view_291 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16); primals_39 = None + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 32, '0'); convert_element_type_130 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + permute_1255 = torch.ops.aten.permute.default(permute_43, [1, 0]); permute_43 = None + mm_620 = torch.ops.aten.mm.default(view_2995, permute_1255); view_2995 = permute_1255 = None + view_2996 = torch.ops.aten.view.default(mm_620, [2, 8192, 1792]); mm_620 = None + convert_element_type_2590 = torch.ops.prims.convert_element_type.default(mm_619, torch.float32); mm_619 = None + reduce_scatter_tensor_376 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2590, 'avg', 32, '0'); convert_element_type_2590 = None + wait_tensor_846 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_376); reduce_scatter_tensor_376 = None + mul_824 = torch.ops.aten.mul.Tensor(view_2996, convert_element_type_126); convert_element_type_126 = None + mul_825 = torch.ops.aten.mul.Tensor(view_2996, view_284); view_2996 = view_284 = None + view_2997 = torch.ops.aten.view.default(mul_824, [16384, 1792]); mul_824 = None + permute_1257 = torch.ops.aten.permute.default(view_2997, [1, 0]) + mm_621 = torch.ops.aten.mm.default(permute_1257, view_276); permute_1257 = None + permute_1259 = torch.ops.aten.permute.default(permute_42, [1, 0]); permute_42 = None + mm_622 = torch.ops.aten.mm.default(view_2997, permute_1259); view_2997 = permute_1259 = None + view_2998 = torch.ops.aten.view.default(mm_622, [2, 8192, 4096]); mm_622 = None + convert_element_type_2595 = torch.ops.prims.convert_element_type.default(mm_621, torch.float32); mm_621 = None + reduce_scatter_tensor_377 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2595, 'avg', 32, '0'); convert_element_type_2595 = None + wait_tensor_847 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_377); reduce_scatter_tensor_377 = None + convert_element_type_2596 = torch.ops.prims.convert_element_type.default(mul_825, torch.float32); mul_825 = None + neg_28 = torch.ops.aten.neg.default(convert_element_type_125) + exp_28 = torch.ops.aten.exp.default(neg_28); neg_28 = None + add_325 = torch.ops.aten.add.Tensor(exp_28, 1); exp_28 = None + reciprocal_28 = torch.ops.aten.reciprocal.default(add_325); add_325 = None + mul_826 = torch.ops.aten.mul.Tensor(reciprocal_28, 1); reciprocal_28 = None + mul_827 = torch.ops.aten.mul.Tensor(convert_element_type_2596, mul_826); convert_element_type_2596 = None + sub_86 = torch.ops.aten.sub.Tensor(1, mul_826); mul_826 = None + mul_828 = torch.ops.aten.mul.Tensor(convert_element_type_125, sub_86); convert_element_type_125 = sub_86 = None + add_326 = torch.ops.aten.add.Tensor(mul_828, 1); mul_828 = None + mul_829 = torch.ops.aten.mul.Tensor(mul_827, add_326); mul_827 = add_326 = None + convert_element_type_2598 = torch.ops.prims.convert_element_type.default(mul_829, torch.bfloat16); mul_829 = None + view_2999 = torch.ops.aten.view.default(convert_element_type_2598, [16384, 1792]); convert_element_type_2598 = None + permute_1261 = torch.ops.aten.permute.default(view_2999, [1, 0]) + mm_623 = torch.ops.aten.mm.default(permute_1261, view_276); permute_1261 = view_276 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16); primals_37 = None + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 32, '0'); convert_element_type_122 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + permute_1263 = torch.ops.aten.permute.default(permute_41, [1, 0]); permute_41 = None + mm_624 = torch.ops.aten.mm.default(view_2999, permute_1263); view_2999 = permute_1263 = None + view_3000 = torch.ops.aten.view.default(mm_624, [2, 8192, 4096]); mm_624 = None + add_327 = torch.ops.aten.add.Tensor(view_2998, view_3000); view_2998 = view_3000 = None + convert_element_type_2603 = torch.ops.prims.convert_element_type.default(mm_623, torch.float32); mm_623 = None + reduce_scatter_tensor_378 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2603, 'avg', 32, '0'); convert_element_type_2603 = None + wait_tensor_848 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_378); reduce_scatter_tensor_378 = None + split_252 = torch.ops.aten.split.Tensor(add_327, 1024, 1); add_327 = None + getitem_2388 = split_252[0] + getitem_2389 = split_252[1] + getitem_2390 = split_252[2] + getitem_2391 = split_252[3] + getitem_2392 = split_252[4] + getitem_2393 = split_252[5] + getitem_2394 = split_252[6] + getitem_2395 = split_252[7]; split_252 = None + cat_244 = torch.ops.aten.cat.default([getitem_2388, getitem_2389, getitem_2390, getitem_2391, getitem_2392, getitem_2393, getitem_2394, getitem_2395]); getitem_2388 = getitem_2389 = getitem_2390 = getitem_2391 = getitem_2392 = getitem_2393 = getitem_2394 = getitem_2395 = None + reduce_scatter_tensor_379 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_244, 'sum', 8, '1'); cat_244 = None + wait_tensor_849 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_379); reduce_scatter_tensor_379 = None + convert_element_type_2604 = torch.ops.prims.convert_element_type.default(wait_tensor_849, torch.float32); wait_tensor_849 = None + convert_element_type_2606 = torch.ops.prims.convert_element_type.default(wait_tensor_48, torch.float32); wait_tensor_48 = None + mul_830 = torch.ops.aten.mul.Tensor(convert_element_type_2604, convert_element_type_2606); convert_element_type_2606 = None + mul_832 = torch.ops.aten.mul.Tensor(mul_28, mul_830) + sum_171 = torch.ops.aten.sum.dim_IntList(mul_832, [2], True); mul_832 = None + div_57 = torch.ops.aten.div.Tensor(mul_28, 4096) + mul_833 = torch.ops.aten.mul.Tensor(div_57, sum_171); div_57 = sum_171 = None + sub_87 = torch.ops.aten.sub.Tensor(mul_830, mul_833); mul_830 = mul_833 = None + mul_834 = torch.ops.aten.mul.Tensor(sub_87, rsqrt_7); sub_87 = rsqrt_7 = None + mul_835 = torch.ops.aten.mul.Tensor(convert_element_type_2604, mul_28); convert_element_type_2604 = mul_28 = None + sum_172 = torch.ops.aten.sum.dim_IntList(mul_835, [0, 1]); mul_835 = None + convert_element_type_2607 = torch.ops.prims.convert_element_type.default(mul_834, torch.bfloat16); mul_834 = None + convert_element_type_2608 = torch.ops.prims.convert_element_type.default(sum_172, torch.bfloat16); sum_172 = None + all_reduce_57 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2608, 'sum', '1'); convert_element_type_2608 = None + wait_tensor_850 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_57); all_reduce_57 = None + convert_element_type_2609 = torch.ops.prims.convert_element_type.default(wait_tensor_850, torch.float32); wait_tensor_850 = None + reduce_scatter_tensor_380 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2609, 'avg', 32, '0'); convert_element_type_2609 = None + wait_tensor_851 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_380); reduce_scatter_tensor_380 = None + add_328 = torch.ops.aten.add.Tensor(add_324, convert_element_type_2607); add_324 = convert_element_type_2607 = None + all_gather_into_tensor_413 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_328, 8, '1') + wait_tensor_852 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_413); all_gather_into_tensor_413 = None + split_253 = torch.ops.aten.split.Tensor(wait_tensor_852, 2); wait_tensor_852 = None + getitem_2396 = split_253[0] + getitem_2397 = split_253[1] + getitem_2398 = split_253[2] + getitem_2399 = split_253[3] + getitem_2400 = split_253[4] + getitem_2401 = split_253[5] + getitem_2402 = split_253[6] + getitem_2403 = split_253[7]; split_253 = None + cat_245 = torch.ops.aten.cat.default([getitem_2396, getitem_2397, getitem_2398, getitem_2399, getitem_2400, getitem_2401, getitem_2402, getitem_2403], 1); getitem_2396 = getitem_2397 = getitem_2398 = getitem_2399 = getitem_2400 = getitem_2401 = getitem_2402 = getitem_2403 = None + view_3001 = torch.ops.aten.view.default(cat_245, [16384, 4096]); cat_245 = None + permute_1265 = torch.ops.aten.permute.default(view_3001, [1, 0]) + permute_39 = torch.ops.aten.permute.default(getitem_203, [0, 2, 1, 3]) + view_258 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + view_264 = torch.ops.aten.view.default(view_258, [16384, 512]); view_258 = None + mm_625 = torch.ops.aten.mm.default(permute_1265, view_264); permute_1265 = view_264 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16); primals_35 = None + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 32, '0'); convert_element_type_116 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_46, [1, 0]); wait_tensor_46 = None + permute_1267 = torch.ops.aten.permute.default(permute_40, [1, 0]); permute_40 = None + mm_626 = torch.ops.aten.mm.default(view_3001, permute_1267); view_3001 = permute_1267 = None + view_3002 = torch.ops.aten.view.default(mm_626, [2, 8192, 512]); mm_626 = None + convert_element_type_2614 = torch.ops.prims.convert_element_type.default(mm_625, torch.float32); mm_625 = None + reduce_scatter_tensor_381 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2614, 'avg', 32, '0'); convert_element_type_2614 = None + wait_tensor_853 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_381); reduce_scatter_tensor_381 = None + view_3003 = torch.ops.aten.view.default(view_3002, [2, 8192, 4, 128]); view_3002 = None + permute_1269 = torch.ops.aten.permute.default(view_3003, [0, 2, 1, 3]); view_3003 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16); primals_31 = None + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 32, '0'); convert_element_type_100 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32); add_11 = None + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_41) + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_102, 8, '1'); convert_element_type_102 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + split_21 = torch.ops.aten.split.Tensor(wait_tensor_42, 2); wait_tensor_42 = None + getitem_195 = split_21[0] + getitem_196 = split_21[1] + getitem_197 = split_21[2] + getitem_198 = split_21[3] + getitem_199 = split_21[4] + getitem_200 = split_21[5] + getitem_201 = split_21[6] + getitem_202 = split_21[7]; split_21 = None + cat_13 = torch.ops.aten.cat.default([getitem_195, getitem_196, getitem_197, getitem_198, getitem_199, getitem_200, getitem_201, getitem_202], 1); getitem_195 = getitem_196 = getitem_197 = getitem_198 = getitem_199 = getitem_200 = getitem_201 = getitem_202 = None + view_231 = torch.ops.aten.view.default(cat_13, [16384, 4096]); cat_13 = None + view_232 = torch.ops.aten.view.default(mm_21, [2, 8192, 512]); mm_21 = None + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16); primals_33 = None + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 32, '0'); convert_element_type_106 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_22 = torch.ops.aten.mm.default(view_231, permute_34) + view_239 = torch.ops.aten.view.default(mm_22, [2, 8192, 128]); mm_22 = None + view_246 = torch.ops.aten.view.default(mm_23, [2, 8192, 128]); mm_23 = None + view_248 = torch.ops.aten.view.default(view_232, [2, 8192, -1, 128]); view_232 = None + view_249 = torch.ops.aten.view.default(view_239, [2, 8192, -1, 128]); view_239 = None + view_250 = torch.ops.aten.view.default(view_246, [2, 8192, -1, 128]); view_246 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_248, torch.float32); view_248 = None + view_251 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 4, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_251); view_251 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 1, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_37); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_254 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 4, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_37); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_255 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 1, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_254, torch.bfloat16); view_254 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 1, 4, 128]); unsqueeze_6 = None + view_256 = torch.ops.aten.view.default(expand_6, [2, 8192, 4, 128]); expand_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_250, 3); view_250 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 1, 4, 128]); unsqueeze_7 = None + view_257 = torch.ops.aten.view.default(expand_7, [2, 8192, 4, 128]); expand_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_256, [0, 2, 1, 3]); view_256 = None + permute_38 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + _scaled_dot_product_cudnn_attention_backward_28 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1269, permute_36, permute_37, permute_38, getitem_203, getitem_204, getitem_209, getitem_210, None, None, None, 8192, 8192, 0.0, True); permute_1269 = permute_36 = permute_37 = permute_38 = getitem_203 = getitem_204 = getitem_209 = getitem_210 = None + getitem_2404 = _scaled_dot_product_cudnn_attention_backward_28[0] + getitem_2405 = _scaled_dot_product_cudnn_attention_backward_28[1] + getitem_2406 = _scaled_dot_product_cudnn_attention_backward_28[2]; _scaled_dot_product_cudnn_attention_backward_28 = None + permute_1270 = torch.ops.aten.permute.default(getitem_2406, [0, 2, 1, 3]); getitem_2406 = None + permute_1271 = torch.ops.aten.permute.default(getitem_2405, [0, 2, 1, 3]); getitem_2405 = None + permute_1272 = torch.ops.aten.permute.default(getitem_2404, [0, 2, 1, 3]); getitem_2404 = None + view_3004 = torch.ops.aten.view.default(permute_1270, [2, 8192, 1, 4, 128]); permute_1270 = None + sum_173 = torch.ops.aten.sum.dim_IntList(view_3004, [3], True); view_3004 = None + squeeze_56 = torch.ops.aten.squeeze.dim(sum_173, 3); sum_173 = None + view_3005 = torch.ops.aten.view.default(permute_1271, [2, 8192, 1, 4, 128]); permute_1271 = None + sum_174 = torch.ops.aten.sum.dim_IntList(view_3005, [3], True); view_3005 = None + squeeze_57 = torch.ops.aten.squeeze.dim(sum_174, 3); sum_174 = None + convert_element_type_2615 = torch.ops.prims.convert_element_type.default(squeeze_57, torch.float32); squeeze_57 = None + convert_element_type_2616 = torch.ops.prims.convert_element_type.default(permute_1272, torch.float32); permute_1272 = None + view_3006 = torch.ops.aten.view.default(convert_element_type_2615, [2, 8192, 1, 64, 2]); convert_element_type_2615 = None + view_as_complex_120 = torch.ops.aten.view_as_complex.default(view_3006); view_3006 = None + mul_836 = torch.ops.aten.mul.Tensor(view_as_complex_120, _conj); view_as_complex_120 = None + view_3007 = torch.ops.aten.view.default(convert_element_type_2616, [2, 8192, 4, 64, 2]); convert_element_type_2616 = None + view_as_complex_121 = torch.ops.aten.view_as_complex.default(view_3007); view_3007 = None + mul_837 = torch.ops.aten.mul.Tensor(view_as_complex_121, _conj); view_as_complex_121 = None + view_as_real_120 = torch.ops.aten.view_as_real.default(mul_836); mul_836 = None + view_3008 = torch.ops.aten.view.default(view_as_real_120, [2, 8192, 1, 128]); view_as_real_120 = None + convert_element_type_2617 = torch.ops.prims.convert_element_type.default(view_3008, torch.bfloat16); view_3008 = None + view_as_real_121 = torch.ops.aten.view_as_real.default(mul_837); mul_837 = None + view_3009 = torch.ops.aten.view.default(view_as_real_121, [2, 8192, 4, 128]); view_as_real_121 = None + convert_element_type_2618 = torch.ops.prims.convert_element_type.default(view_3009, torch.bfloat16); view_3009 = None + view_3010 = torch.ops.aten.view.default(squeeze_56, [2, 8192, 128]); squeeze_56 = None + view_3011 = torch.ops.aten.view.default(convert_element_type_2617, [2, 8192, 128]); convert_element_type_2617 = None + view_3012 = torch.ops.aten.view.default(convert_element_type_2618, [2, 8192, 512]); convert_element_type_2618 = None + view_3013 = torch.ops.aten.view.default(view_3010, [16384, 128]); view_3010 = None + permute_1273 = torch.ops.aten.permute.default(view_3013, [1, 0]) + mm_627 = torch.ops.aten.mm.default(permute_1273, view_231); permute_1273 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16); primals_34 = None + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 32, '0'); convert_element_type_109 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + permute_1275 = torch.ops.aten.permute.default(permute_35, [1, 0]); permute_35 = None + mm_628 = torch.ops.aten.mm.default(view_3013, permute_1275); view_3013 = permute_1275 = None + view_3014 = torch.ops.aten.view.default(mm_628, [2, 8192, 4096]); mm_628 = None + convert_element_type_2623 = torch.ops.prims.convert_element_type.default(mm_627, torch.float32); mm_627 = None + reduce_scatter_tensor_382 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2623, 'avg', 32, '0'); convert_element_type_2623 = None + wait_tensor_854 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_382); reduce_scatter_tensor_382 = None + view_3015 = torch.ops.aten.view.default(view_3011, [16384, 128]); view_3011 = None + permute_1277 = torch.ops.aten.permute.default(view_3015, [1, 0]) + mm_629 = torch.ops.aten.mm.default(permute_1277, view_231); permute_1277 = None + permute_1279 = torch.ops.aten.permute.default(permute_34, [1, 0]); permute_34 = None + mm_630 = torch.ops.aten.mm.default(view_3015, permute_1279); view_3015 = permute_1279 = None + view_3016 = torch.ops.aten.view.default(mm_630, [2, 8192, 4096]); mm_630 = None + add_329 = torch.ops.aten.add.Tensor(view_3014, view_3016); view_3014 = view_3016 = None + convert_element_type_2628 = torch.ops.prims.convert_element_type.default(mm_629, torch.float32); mm_629 = None + reduce_scatter_tensor_383 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2628, 'avg', 32, '0'); convert_element_type_2628 = None + wait_tensor_855 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_383); reduce_scatter_tensor_383 = None + view_3017 = torch.ops.aten.view.default(view_3012, [16384, 512]); view_3012 = None + permute_1281 = torch.ops.aten.permute.default(view_3017, [1, 0]) + mm_631 = torch.ops.aten.mm.default(permute_1281, view_231); permute_1281 = view_231 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16); primals_32 = None + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 32, '0'); convert_element_type_103 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + permute_1283 = torch.ops.aten.permute.default(permute_33, [1, 0]); permute_33 = None + mm_632 = torch.ops.aten.mm.default(view_3017, permute_1283); view_3017 = permute_1283 = None + view_3018 = torch.ops.aten.view.default(mm_632, [2, 8192, 4096]); mm_632 = None + add_330 = torch.ops.aten.add.Tensor(add_329, view_3018); add_329 = view_3018 = None + convert_element_type_2633 = torch.ops.prims.convert_element_type.default(mm_631, torch.float32); mm_631 = None + reduce_scatter_tensor_384 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2633, 'avg', 32, '0'); convert_element_type_2633 = None + wait_tensor_856 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_384); reduce_scatter_tensor_384 = None + split_254 = torch.ops.aten.split.Tensor(add_330, 1024, 1); add_330 = None + getitem_2407 = split_254[0] + getitem_2408 = split_254[1] + getitem_2409 = split_254[2] + getitem_2410 = split_254[3] + getitem_2411 = split_254[4] + getitem_2412 = split_254[5] + getitem_2413 = split_254[6] + getitem_2414 = split_254[7]; split_254 = None + cat_246 = torch.ops.aten.cat.default([getitem_2407, getitem_2408, getitem_2409, getitem_2410, getitem_2411, getitem_2412, getitem_2413, getitem_2414]); getitem_2407 = getitem_2408 = getitem_2409 = getitem_2410 = getitem_2411 = getitem_2412 = getitem_2413 = getitem_2414 = None + reduce_scatter_tensor_385 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_246, 'sum', 8, '1'); cat_246 = None + wait_tensor_857 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_385); reduce_scatter_tensor_385 = None + convert_element_type_2634 = torch.ops.prims.convert_element_type.default(wait_tensor_857, torch.float32); wait_tensor_857 = None + convert_element_type_2636 = torch.ops.prims.convert_element_type.default(wait_tensor_41, torch.float32); wait_tensor_41 = None + mul_838 = torch.ops.aten.mul.Tensor(convert_element_type_2634, convert_element_type_2636); convert_element_type_2636 = None + mul_840 = torch.ops.aten.mul.Tensor(mul_24, mul_838) + sum_175 = torch.ops.aten.sum.dim_IntList(mul_840, [2], True); mul_840 = None + div_58 = torch.ops.aten.div.Tensor(mul_24, 4096) + mul_841 = torch.ops.aten.mul.Tensor(div_58, sum_175); div_58 = sum_175 = None + sub_88 = torch.ops.aten.sub.Tensor(mul_838, mul_841); mul_838 = mul_841 = None + mul_842 = torch.ops.aten.mul.Tensor(sub_88, rsqrt_6); sub_88 = rsqrt_6 = None + mul_843 = torch.ops.aten.mul.Tensor(convert_element_type_2634, mul_24); convert_element_type_2634 = mul_24 = None + sum_176 = torch.ops.aten.sum.dim_IntList(mul_843, [0, 1]); mul_843 = None + convert_element_type_2637 = torch.ops.prims.convert_element_type.default(mul_842, torch.bfloat16); mul_842 = None + convert_element_type_2638 = torch.ops.prims.convert_element_type.default(sum_176, torch.bfloat16); sum_176 = None + all_reduce_58 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2638, 'sum', '1'); convert_element_type_2638 = None + wait_tensor_858 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_58); all_reduce_58 = None + convert_element_type_2639 = torch.ops.prims.convert_element_type.default(wait_tensor_858, torch.float32); wait_tensor_858 = None + reduce_scatter_tensor_386 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2639, 'avg', 32, '0'); convert_element_type_2639 = None + wait_tensor_859 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_386); reduce_scatter_tensor_386 = None + add_331 = torch.ops.aten.add.Tensor(add_328, convert_element_type_2637); add_328 = convert_element_type_2637 = None + all_gather_into_tensor_414 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_331, 8, '1') + wait_tensor_860 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_414); all_gather_into_tensor_414 = None + split_255 = torch.ops.aten.split.Tensor(wait_tensor_860, 2); wait_tensor_860 = None + getitem_2415 = split_255[0] + getitem_2416 = split_255[1] + getitem_2417 = split_255[2] + getitem_2418 = split_255[3] + getitem_2419 = split_255[4] + getitem_2420 = split_255[5] + getitem_2421 = split_255[6] + getitem_2422 = split_255[7]; split_255 = None + cat_247 = torch.ops.aten.cat.default([getitem_2415, getitem_2416, getitem_2417, getitem_2418, getitem_2419, getitem_2420, getitem_2421, getitem_2422], 1); getitem_2415 = getitem_2416 = getitem_2417 = getitem_2418 = getitem_2419 = getitem_2420 = getitem_2421 = getitem_2422 = None + view_3019 = torch.ops.aten.view.default(cat_247, [16384, 4096]); cat_247 = None + permute_1285 = torch.ops.aten.permute.default(view_3019, [1, 0]) + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5); reduce_scatter_tensor_5 = None + add_9 = torch.ops.aten.add.Tensor(add_7, wait_tensor_34); wait_tensor_34 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16); primals_27 = None + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 32, '0'); convert_element_type_86 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32); add_9 = None + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_35) + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_88, 8, '1'); convert_element_type_88 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + split_19 = torch.ops.aten.split.Tensor(wait_tensor_36, 2); wait_tensor_36 = None + getitem_179 = split_19[0] + getitem_180 = split_19[1] + getitem_181 = split_19[2] + getitem_182 = split_19[3] + getitem_183 = split_19[4] + getitem_184 = split_19[5] + getitem_185 = split_19[6] + getitem_186 = split_19[7]; split_19 = None + cat_11 = torch.ops.aten.cat.default([getitem_179, getitem_180, getitem_181, getitem_182, getitem_183, getitem_184, getitem_185, getitem_186], 1); getitem_179 = getitem_180 = getitem_181 = getitem_182 = getitem_183 = getitem_184 = getitem_185 = getitem_186 = None + view_204 = torch.ops.aten.view.default(cat_11, [16384, 4096]); cat_11 = None + view_205 = torch.ops.aten.view.default(mm_18, [2, 8192, 1792]); mm_18 = None + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_205, torch.float32); view_205 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16); primals_29 = None + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 32, '0'); convert_element_type_94 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + mm_19 = torch.ops.aten.mm.default(view_204, permute_31) + view_212 = torch.ops.aten.view.default(mm_19, [2, 8192, 1792]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_212) + view_219 = torch.ops.aten.view.default(mul_23, [16384, 1792]); mul_23 = None + mm_633 = torch.ops.aten.mm.default(permute_1285, view_219); permute_1285 = view_219 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16); primals_30 = None + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 32, '0'); convert_element_type_97 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + permute_1287 = torch.ops.aten.permute.default(permute_32, [1, 0]); permute_32 = None + mm_634 = torch.ops.aten.mm.default(view_3019, permute_1287); view_3019 = permute_1287 = None + view_3020 = torch.ops.aten.view.default(mm_634, [2, 8192, 1792]); mm_634 = None + convert_element_type_2644 = torch.ops.prims.convert_element_type.default(mm_633, torch.float32); mm_633 = None + reduce_scatter_tensor_387 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2644, 'avg', 32, '0'); convert_element_type_2644 = None + wait_tensor_861 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_387); reduce_scatter_tensor_387 = None + mul_844 = torch.ops.aten.mul.Tensor(view_3020, convert_element_type_93); convert_element_type_93 = None + mul_845 = torch.ops.aten.mul.Tensor(view_3020, view_212); view_3020 = view_212 = None + view_3021 = torch.ops.aten.view.default(mul_844, [16384, 1792]); mul_844 = None + permute_1289 = torch.ops.aten.permute.default(view_3021, [1, 0]) + mm_635 = torch.ops.aten.mm.default(permute_1289, view_204); permute_1289 = None + permute_1291 = torch.ops.aten.permute.default(permute_31, [1, 0]); permute_31 = None + mm_636 = torch.ops.aten.mm.default(view_3021, permute_1291); view_3021 = permute_1291 = None + view_3022 = torch.ops.aten.view.default(mm_636, [2, 8192, 4096]); mm_636 = None + convert_element_type_2649 = torch.ops.prims.convert_element_type.default(mm_635, torch.float32); mm_635 = None + reduce_scatter_tensor_388 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2649, 'avg', 32, '0'); convert_element_type_2649 = None + wait_tensor_862 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_388); reduce_scatter_tensor_388 = None + convert_element_type_2650 = torch.ops.prims.convert_element_type.default(mul_845, torch.float32); mul_845 = None + neg_29 = torch.ops.aten.neg.default(convert_element_type_92) + exp_29 = torch.ops.aten.exp.default(neg_29); neg_29 = None + add_332 = torch.ops.aten.add.Tensor(exp_29, 1); exp_29 = None + reciprocal_29 = torch.ops.aten.reciprocal.default(add_332); add_332 = None + mul_846 = torch.ops.aten.mul.Tensor(reciprocal_29, 1); reciprocal_29 = None + mul_847 = torch.ops.aten.mul.Tensor(convert_element_type_2650, mul_846); convert_element_type_2650 = None + sub_89 = torch.ops.aten.sub.Tensor(1, mul_846); mul_846 = None + mul_848 = torch.ops.aten.mul.Tensor(convert_element_type_92, sub_89); convert_element_type_92 = sub_89 = None + add_333 = torch.ops.aten.add.Tensor(mul_848, 1); mul_848 = None + mul_849 = torch.ops.aten.mul.Tensor(mul_847, add_333); mul_847 = add_333 = None + convert_element_type_2652 = torch.ops.prims.convert_element_type.default(mul_849, torch.bfloat16); mul_849 = None + view_3023 = torch.ops.aten.view.default(convert_element_type_2652, [16384, 1792]); convert_element_type_2652 = None + permute_1293 = torch.ops.aten.permute.default(view_3023, [1, 0]) + mm_637 = torch.ops.aten.mm.default(permute_1293, view_204); permute_1293 = view_204 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16); primals_28 = None + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 32, '0'); convert_element_type_89 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_37, [1, 0]); wait_tensor_37 = None + permute_1295 = torch.ops.aten.permute.default(permute_30, [1, 0]); permute_30 = None + mm_638 = torch.ops.aten.mm.default(view_3023, permute_1295); view_3023 = permute_1295 = None + view_3024 = torch.ops.aten.view.default(mm_638, [2, 8192, 4096]); mm_638 = None + add_334 = torch.ops.aten.add.Tensor(view_3022, view_3024); view_3022 = view_3024 = None + convert_element_type_2657 = torch.ops.prims.convert_element_type.default(mm_637, torch.float32); mm_637 = None + reduce_scatter_tensor_389 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2657, 'avg', 32, '0'); convert_element_type_2657 = None + wait_tensor_863 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_389); reduce_scatter_tensor_389 = None + split_256 = torch.ops.aten.split.Tensor(add_334, 1024, 1); add_334 = None + getitem_2423 = split_256[0] + getitem_2424 = split_256[1] + getitem_2425 = split_256[2] + getitem_2426 = split_256[3] + getitem_2427 = split_256[4] + getitem_2428 = split_256[5] + getitem_2429 = split_256[6] + getitem_2430 = split_256[7]; split_256 = None + cat_248 = torch.ops.aten.cat.default([getitem_2423, getitem_2424, getitem_2425, getitem_2426, getitem_2427, getitem_2428, getitem_2429, getitem_2430]); getitem_2423 = getitem_2424 = getitem_2425 = getitem_2426 = getitem_2427 = getitem_2428 = getitem_2429 = getitem_2430 = None + reduce_scatter_tensor_390 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_248, 'sum', 8, '1'); cat_248 = None + wait_tensor_864 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_390); reduce_scatter_tensor_390 = None + convert_element_type_2658 = torch.ops.prims.convert_element_type.default(wait_tensor_864, torch.float32); wait_tensor_864 = None + convert_element_type_2660 = torch.ops.prims.convert_element_type.default(wait_tensor_35, torch.float32); wait_tensor_35 = None + mul_850 = torch.ops.aten.mul.Tensor(convert_element_type_2658, convert_element_type_2660); convert_element_type_2660 = None + mul_852 = torch.ops.aten.mul.Tensor(mul_20, mul_850) + sum_177 = torch.ops.aten.sum.dim_IntList(mul_852, [2], True); mul_852 = None + div_59 = torch.ops.aten.div.Tensor(mul_20, 4096) + mul_853 = torch.ops.aten.mul.Tensor(div_59, sum_177); div_59 = sum_177 = None + sub_90 = torch.ops.aten.sub.Tensor(mul_850, mul_853); mul_850 = mul_853 = None + mul_854 = torch.ops.aten.mul.Tensor(sub_90, rsqrt_5); sub_90 = rsqrt_5 = None + mul_855 = torch.ops.aten.mul.Tensor(convert_element_type_2658, mul_20); convert_element_type_2658 = mul_20 = None + sum_178 = torch.ops.aten.sum.dim_IntList(mul_855, [0, 1]); mul_855 = None + convert_element_type_2661 = torch.ops.prims.convert_element_type.default(mul_854, torch.bfloat16); mul_854 = None + convert_element_type_2662 = torch.ops.prims.convert_element_type.default(sum_178, torch.bfloat16); sum_178 = None + all_reduce_59 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2662, 'sum', '1'); convert_element_type_2662 = None + wait_tensor_865 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_59); all_reduce_59 = None + convert_element_type_2663 = torch.ops.prims.convert_element_type.default(wait_tensor_865, torch.float32); wait_tensor_865 = None + reduce_scatter_tensor_391 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2663, 'avg', 32, '0'); convert_element_type_2663 = None + wait_tensor_866 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_391); reduce_scatter_tensor_391 = None + add_335 = torch.ops.aten.add.Tensor(add_331, convert_element_type_2661); add_331 = convert_element_type_2661 = None + all_gather_into_tensor_415 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_335, 8, '1') + wait_tensor_867 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_415); all_gather_into_tensor_415 = None + split_257 = torch.ops.aten.split.Tensor(wait_tensor_867, 2); wait_tensor_867 = None + getitem_2431 = split_257[0] + getitem_2432 = split_257[1] + getitem_2433 = split_257[2] + getitem_2434 = split_257[3] + getitem_2435 = split_257[4] + getitem_2436 = split_257[5] + getitem_2437 = split_257[6] + getitem_2438 = split_257[7]; split_257 = None + cat_249 = torch.ops.aten.cat.default([getitem_2431, getitem_2432, getitem_2433, getitem_2434, getitem_2435, getitem_2436, getitem_2437, getitem_2438], 1); getitem_2431 = getitem_2432 = getitem_2433 = getitem_2434 = getitem_2435 = getitem_2436 = getitem_2437 = getitem_2438 = None + view_3025 = torch.ops.aten.view.default(cat_249, [16384, 4096]); cat_249 = None + permute_1297 = torch.ops.aten.permute.default(view_3025, [1, 0]) + permute_28 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_186 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + view_192 = torch.ops.aten.view.default(view_186, [16384, 512]); view_186 = None + mm_639 = torch.ops.aten.mm.default(permute_1297, view_192); permute_1297 = view_192 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16); primals_26 = None + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 32, '0'); convert_element_type_83 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_33, [1, 0]); wait_tensor_33 = None + permute_1299 = torch.ops.aten.permute.default(permute_29, [1, 0]); permute_29 = None + mm_640 = torch.ops.aten.mm.default(view_3025, permute_1299); view_3025 = permute_1299 = None + view_3026 = torch.ops.aten.view.default(mm_640, [2, 8192, 512]); mm_640 = None + convert_element_type_2668 = torch.ops.prims.convert_element_type.default(mm_639, torch.float32); mm_639 = None + reduce_scatter_tensor_392 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2668, 'avg', 32, '0'); convert_element_type_2668 = None + wait_tensor_868 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_392); reduce_scatter_tensor_392 = None + view_3027 = torch.ops.aten.view.default(view_3026, [2, 8192, 4, 128]); view_3026 = None + permute_1301 = torch.ops.aten.permute.default(view_3027, [0, 2, 1, 3]); view_3027 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16); primals_22 = None + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 32, '0'); convert_element_type_67 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32); add_7 = None + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_28) + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_69, 8, '1'); convert_element_type_69 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + split_17 = torch.ops.aten.split.Tensor(wait_tensor_29, 2); wait_tensor_29 = None + getitem_154 = split_17[0] + getitem_155 = split_17[1] + getitem_156 = split_17[2] + getitem_157 = split_17[3] + getitem_158 = split_17[4] + getitem_159 = split_17[5] + getitem_160 = split_17[6] + getitem_161 = split_17[7]; split_17 = None + cat_9 = torch.ops.aten.cat.default([getitem_154, getitem_155, getitem_156, getitem_157, getitem_158, getitem_159, getitem_160, getitem_161], 1); getitem_154 = getitem_155 = getitem_156 = getitem_157 = getitem_158 = getitem_159 = getitem_160 = getitem_161 = None + view_159 = torch.ops.aten.view.default(cat_9, [16384, 4096]); cat_9 = None + view_160 = torch.ops.aten.view.default(mm_14, [2, 8192, 512]); mm_14 = None + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16); primals_24 = None + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 32, '0'); convert_element_type_73 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_15 = torch.ops.aten.mm.default(view_159, permute_23) + view_167 = torch.ops.aten.view.default(mm_15, [2, 8192, 128]); mm_15 = None + view_174 = torch.ops.aten.view.default(mm_16, [2, 8192, 128]); mm_16 = None + view_176 = torch.ops.aten.view.default(view_160, [2, 8192, -1, 128]); view_160 = None + view_177 = torch.ops.aten.view.default(view_167, [2, 8192, -1, 128]); view_167 = None + view_178 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_176, torch.float32); view_176 = None + view_179 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 4, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_179); view_179 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_177, torch.float32); view_177 = None + view_180 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 1, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_180); view_180 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_37); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_182 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 4, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_37); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_183 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 1, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_182, torch.bfloat16); view_182 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_183, torch.bfloat16); view_183 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 1, 4, 128]); unsqueeze_4 = None + view_184 = torch.ops.aten.view.default(expand_4, [2, 8192, 4, 128]); expand_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_178, 3); view_178 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 1, 4, 128]); unsqueeze_5 = None + view_185 = torch.ops.aten.view.default(expand_5, [2, 8192, 4, 128]); expand_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_184, [0, 2, 1, 3]); view_184 = None + permute_27 = torch.ops.aten.permute.default(view_185, [0, 2, 1, 3]); view_185 = None + _scaled_dot_product_cudnn_attention_backward_29 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1301, permute_25, permute_26, permute_27, getitem_162, getitem_163, getitem_168, getitem_169, None, None, None, 8192, 8192, 0.0, True); permute_1301 = permute_25 = permute_26 = permute_27 = getitem_162 = getitem_163 = getitem_168 = getitem_169 = None + getitem_2439 = _scaled_dot_product_cudnn_attention_backward_29[0] + getitem_2440 = _scaled_dot_product_cudnn_attention_backward_29[1] + getitem_2441 = _scaled_dot_product_cudnn_attention_backward_29[2]; _scaled_dot_product_cudnn_attention_backward_29 = None + permute_1302 = torch.ops.aten.permute.default(getitem_2441, [0, 2, 1, 3]); getitem_2441 = None + permute_1303 = torch.ops.aten.permute.default(getitem_2440, [0, 2, 1, 3]); getitem_2440 = None + permute_1304 = torch.ops.aten.permute.default(getitem_2439, [0, 2, 1, 3]); getitem_2439 = None + view_3028 = torch.ops.aten.view.default(permute_1302, [2, 8192, 1, 4, 128]); permute_1302 = None + sum_179 = torch.ops.aten.sum.dim_IntList(view_3028, [3], True); view_3028 = None + squeeze_58 = torch.ops.aten.squeeze.dim(sum_179, 3); sum_179 = None + view_3029 = torch.ops.aten.view.default(permute_1303, [2, 8192, 1, 4, 128]); permute_1303 = None + sum_180 = torch.ops.aten.sum.dim_IntList(view_3029, [3], True); view_3029 = None + squeeze_59 = torch.ops.aten.squeeze.dim(sum_180, 3); sum_180 = None + convert_element_type_2669 = torch.ops.prims.convert_element_type.default(squeeze_59, torch.float32); squeeze_59 = None + convert_element_type_2670 = torch.ops.prims.convert_element_type.default(permute_1304, torch.float32); permute_1304 = None + view_3030 = torch.ops.aten.view.default(convert_element_type_2669, [2, 8192, 1, 64, 2]); convert_element_type_2669 = None + view_as_complex_122 = torch.ops.aten.view_as_complex.default(view_3030); view_3030 = None + mul_856 = torch.ops.aten.mul.Tensor(view_as_complex_122, _conj); view_as_complex_122 = None + view_3031 = torch.ops.aten.view.default(convert_element_type_2670, [2, 8192, 4, 64, 2]); convert_element_type_2670 = None + view_as_complex_123 = torch.ops.aten.view_as_complex.default(view_3031); view_3031 = None + mul_857 = torch.ops.aten.mul.Tensor(view_as_complex_123, _conj); view_as_complex_123 = None + view_as_real_122 = torch.ops.aten.view_as_real.default(mul_856); mul_856 = None + view_3032 = torch.ops.aten.view.default(view_as_real_122, [2, 8192, 1, 128]); view_as_real_122 = None + convert_element_type_2671 = torch.ops.prims.convert_element_type.default(view_3032, torch.bfloat16); view_3032 = None + view_as_real_123 = torch.ops.aten.view_as_real.default(mul_857); mul_857 = None + view_3033 = torch.ops.aten.view.default(view_as_real_123, [2, 8192, 4, 128]); view_as_real_123 = None + convert_element_type_2672 = torch.ops.prims.convert_element_type.default(view_3033, torch.bfloat16); view_3033 = None + view_3034 = torch.ops.aten.view.default(squeeze_58, [2, 8192, 128]); squeeze_58 = None + view_3035 = torch.ops.aten.view.default(convert_element_type_2671, [2, 8192, 128]); convert_element_type_2671 = None + view_3036 = torch.ops.aten.view.default(convert_element_type_2672, [2, 8192, 512]); convert_element_type_2672 = None + view_3037 = torch.ops.aten.view.default(view_3034, [16384, 128]); view_3034 = None + permute_1305 = torch.ops.aten.permute.default(view_3037, [1, 0]) + mm_641 = torch.ops.aten.mm.default(permute_1305, view_159); permute_1305 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16); primals_25 = None + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 32, '0'); convert_element_type_76 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + permute_1307 = torch.ops.aten.permute.default(permute_24, [1, 0]); permute_24 = None + mm_642 = torch.ops.aten.mm.default(view_3037, permute_1307); view_3037 = permute_1307 = None + view_3038 = torch.ops.aten.view.default(mm_642, [2, 8192, 4096]); mm_642 = None + convert_element_type_2677 = torch.ops.prims.convert_element_type.default(mm_641, torch.float32); mm_641 = None + reduce_scatter_tensor_393 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2677, 'avg', 32, '0'); convert_element_type_2677 = None + wait_tensor_869 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_393); reduce_scatter_tensor_393 = None + view_3039 = torch.ops.aten.view.default(view_3035, [16384, 128]); view_3035 = None + permute_1309 = torch.ops.aten.permute.default(view_3039, [1, 0]) + mm_643 = torch.ops.aten.mm.default(permute_1309, view_159); permute_1309 = None + permute_1311 = torch.ops.aten.permute.default(permute_23, [1, 0]); permute_23 = None + mm_644 = torch.ops.aten.mm.default(view_3039, permute_1311); view_3039 = permute_1311 = None + view_3040 = torch.ops.aten.view.default(mm_644, [2, 8192, 4096]); mm_644 = None + add_336 = torch.ops.aten.add.Tensor(view_3038, view_3040); view_3038 = view_3040 = None + convert_element_type_2682 = torch.ops.prims.convert_element_type.default(mm_643, torch.float32); mm_643 = None + reduce_scatter_tensor_394 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2682, 'avg', 32, '0'); convert_element_type_2682 = None + wait_tensor_870 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_394); reduce_scatter_tensor_394 = None + view_3041 = torch.ops.aten.view.default(view_3036, [16384, 512]); view_3036 = None + permute_1313 = torch.ops.aten.permute.default(view_3041, [1, 0]) + mm_645 = torch.ops.aten.mm.default(permute_1313, view_159); permute_1313 = view_159 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16); primals_23 = None + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 32, '0'); convert_element_type_70 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + permute_1315 = torch.ops.aten.permute.default(permute_22, [1, 0]); permute_22 = None + mm_646 = torch.ops.aten.mm.default(view_3041, permute_1315); view_3041 = permute_1315 = None + view_3042 = torch.ops.aten.view.default(mm_646, [2, 8192, 4096]); mm_646 = None + add_337 = torch.ops.aten.add.Tensor(add_336, view_3042); add_336 = view_3042 = None + convert_element_type_2687 = torch.ops.prims.convert_element_type.default(mm_645, torch.float32); mm_645 = None + reduce_scatter_tensor_395 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2687, 'avg', 32, '0'); convert_element_type_2687 = None + wait_tensor_871 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_395); reduce_scatter_tensor_395 = None + split_258 = torch.ops.aten.split.Tensor(add_337, 1024, 1); add_337 = None + getitem_2442 = split_258[0] + getitem_2443 = split_258[1] + getitem_2444 = split_258[2] + getitem_2445 = split_258[3] + getitem_2446 = split_258[4] + getitem_2447 = split_258[5] + getitem_2448 = split_258[6] + getitem_2449 = split_258[7]; split_258 = None + cat_250 = torch.ops.aten.cat.default([getitem_2442, getitem_2443, getitem_2444, getitem_2445, getitem_2446, getitem_2447, getitem_2448, getitem_2449]); getitem_2442 = getitem_2443 = getitem_2444 = getitem_2445 = getitem_2446 = getitem_2447 = getitem_2448 = getitem_2449 = None + reduce_scatter_tensor_396 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_250, 'sum', 8, '1'); cat_250 = None + wait_tensor_872 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_396); reduce_scatter_tensor_396 = None + convert_element_type_2688 = torch.ops.prims.convert_element_type.default(wait_tensor_872, torch.float32); wait_tensor_872 = None + convert_element_type_2690 = torch.ops.prims.convert_element_type.default(wait_tensor_28, torch.float32); wait_tensor_28 = None + mul_858 = torch.ops.aten.mul.Tensor(convert_element_type_2688, convert_element_type_2690); convert_element_type_2690 = None + mul_860 = torch.ops.aten.mul.Tensor(mul_16, mul_858) + sum_181 = torch.ops.aten.sum.dim_IntList(mul_860, [2], True); mul_860 = None + div_60 = torch.ops.aten.div.Tensor(mul_16, 4096) + mul_861 = torch.ops.aten.mul.Tensor(div_60, sum_181); div_60 = sum_181 = None + sub_91 = torch.ops.aten.sub.Tensor(mul_858, mul_861); mul_858 = mul_861 = None + mul_862 = torch.ops.aten.mul.Tensor(sub_91, rsqrt_4); sub_91 = rsqrt_4 = None + mul_863 = torch.ops.aten.mul.Tensor(convert_element_type_2688, mul_16); convert_element_type_2688 = mul_16 = None + sum_182 = torch.ops.aten.sum.dim_IntList(mul_863, [0, 1]); mul_863 = None + convert_element_type_2691 = torch.ops.prims.convert_element_type.default(mul_862, torch.bfloat16); mul_862 = None + convert_element_type_2692 = torch.ops.prims.convert_element_type.default(sum_182, torch.bfloat16); sum_182 = None + all_reduce_60 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2692, 'sum', '1'); convert_element_type_2692 = None + wait_tensor_873 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_60); all_reduce_60 = None + convert_element_type_2693 = torch.ops.prims.convert_element_type.default(wait_tensor_873, torch.float32); wait_tensor_873 = None + reduce_scatter_tensor_397 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2693, 'avg', 32, '0'); convert_element_type_2693 = None + wait_tensor_874 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_397); reduce_scatter_tensor_397 = None + add_338 = torch.ops.aten.add.Tensor(add_335, convert_element_type_2691); add_335 = convert_element_type_2691 = None + all_gather_into_tensor_416 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_338, 8, '1') + wait_tensor_875 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_416); all_gather_into_tensor_416 = None + split_259 = torch.ops.aten.split.Tensor(wait_tensor_875, 2); wait_tensor_875 = None + getitem_2450 = split_259[0] + getitem_2451 = split_259[1] + getitem_2452 = split_259[2] + getitem_2453 = split_259[3] + getitem_2454 = split_259[4] + getitem_2455 = split_259[5] + getitem_2456 = split_259[6] + getitem_2457 = split_259[7]; split_259 = None + cat_251 = torch.ops.aten.cat.default([getitem_2450, getitem_2451, getitem_2452, getitem_2453, getitem_2454, getitem_2455, getitem_2456, getitem_2457], 1); getitem_2450 = getitem_2451 = getitem_2452 = getitem_2453 = getitem_2454 = getitem_2455 = getitem_2456 = getitem_2457 = None + view_3043 = torch.ops.aten.view.default(cat_251, [16384, 4096]); cat_251 = None + permute_1317 = torch.ops.aten.permute.default(view_3043, [1, 0]) + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3); reduce_scatter_tensor_3 = None + add_5 = torch.ops.aten.add.Tensor(add_3, wait_tensor_21); wait_tensor_21 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16); primals_18 = None + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 32, '0'); convert_element_type_53 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32); add_5 = None + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_22) + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_55, 8, '1'); convert_element_type_55 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + split_15 = torch.ops.aten.split.Tensor(wait_tensor_23, 2); wait_tensor_23 = None + getitem_138 = split_15[0] + getitem_139 = split_15[1] + getitem_140 = split_15[2] + getitem_141 = split_15[3] + getitem_142 = split_15[4] + getitem_143 = split_15[5] + getitem_144 = split_15[6] + getitem_145 = split_15[7]; split_15 = None + cat_7 = torch.ops.aten.cat.default([getitem_138, getitem_139, getitem_140, getitem_141, getitem_142, getitem_143, getitem_144, getitem_145], 1); getitem_138 = getitem_139 = getitem_140 = getitem_141 = getitem_142 = getitem_143 = getitem_144 = getitem_145 = None + view_132 = torch.ops.aten.view.default(cat_7, [16384, 4096]); cat_7 = None + view_133 = torch.ops.aten.view.default(mm_11, [2, 8192, 1792]); mm_11 = None + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_133, torch.float32); view_133 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16); primals_20 = None + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 32, '0'); convert_element_type_61 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + mm_12 = torch.ops.aten.mm.default(view_132, permute_20) + view_140 = torch.ops.aten.view.default(mm_12, [2, 8192, 1792]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_140) + view_147 = torch.ops.aten.view.default(mul_15, [16384, 1792]); mul_15 = None + mm_647 = torch.ops.aten.mm.default(permute_1317, view_147); permute_1317 = view_147 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16); primals_21 = None + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 32, '0'); convert_element_type_64 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + permute_1319 = torch.ops.aten.permute.default(permute_21, [1, 0]); permute_21 = None + mm_648 = torch.ops.aten.mm.default(view_3043, permute_1319); view_3043 = permute_1319 = None + view_3044 = torch.ops.aten.view.default(mm_648, [2, 8192, 1792]); mm_648 = None + convert_element_type_2698 = torch.ops.prims.convert_element_type.default(mm_647, torch.float32); mm_647 = None + reduce_scatter_tensor_398 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2698, 'avg', 32, '0'); convert_element_type_2698 = None + wait_tensor_876 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_398); reduce_scatter_tensor_398 = None + mul_864 = torch.ops.aten.mul.Tensor(view_3044, convert_element_type_60); convert_element_type_60 = None + mul_865 = torch.ops.aten.mul.Tensor(view_3044, view_140); view_3044 = view_140 = None + view_3045 = torch.ops.aten.view.default(mul_864, [16384, 1792]); mul_864 = None + permute_1321 = torch.ops.aten.permute.default(view_3045, [1, 0]) + mm_649 = torch.ops.aten.mm.default(permute_1321, view_132); permute_1321 = None + permute_1323 = torch.ops.aten.permute.default(permute_20, [1, 0]); permute_20 = None + mm_650 = torch.ops.aten.mm.default(view_3045, permute_1323); view_3045 = permute_1323 = None + view_3046 = torch.ops.aten.view.default(mm_650, [2, 8192, 4096]); mm_650 = None + convert_element_type_2703 = torch.ops.prims.convert_element_type.default(mm_649, torch.float32); mm_649 = None + reduce_scatter_tensor_399 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2703, 'avg', 32, '0'); convert_element_type_2703 = None + wait_tensor_877 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_399); reduce_scatter_tensor_399 = None + convert_element_type_2704 = torch.ops.prims.convert_element_type.default(mul_865, torch.float32); mul_865 = None + neg_30 = torch.ops.aten.neg.default(convert_element_type_59) + exp_30 = torch.ops.aten.exp.default(neg_30); neg_30 = None + add_339 = torch.ops.aten.add.Tensor(exp_30, 1); exp_30 = None + reciprocal_30 = torch.ops.aten.reciprocal.default(add_339); add_339 = None + mul_866 = torch.ops.aten.mul.Tensor(reciprocal_30, 1); reciprocal_30 = None + mul_867 = torch.ops.aten.mul.Tensor(convert_element_type_2704, mul_866); convert_element_type_2704 = None + sub_92 = torch.ops.aten.sub.Tensor(1, mul_866); mul_866 = None + mul_868 = torch.ops.aten.mul.Tensor(convert_element_type_59, sub_92); convert_element_type_59 = sub_92 = None + add_340 = torch.ops.aten.add.Tensor(mul_868, 1); mul_868 = None + mul_869 = torch.ops.aten.mul.Tensor(mul_867, add_340); mul_867 = add_340 = None + convert_element_type_2706 = torch.ops.prims.convert_element_type.default(mul_869, torch.bfloat16); mul_869 = None + view_3047 = torch.ops.aten.view.default(convert_element_type_2706, [16384, 1792]); convert_element_type_2706 = None + permute_1325 = torch.ops.aten.permute.default(view_3047, [1, 0]) + mm_651 = torch.ops.aten.mm.default(permute_1325, view_132); permute_1325 = view_132 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16); primals_19 = None + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 32, '0'); convert_element_type_56 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_24, [1, 0]); wait_tensor_24 = None + permute_1327 = torch.ops.aten.permute.default(permute_19, [1, 0]); permute_19 = None + mm_652 = torch.ops.aten.mm.default(view_3047, permute_1327); view_3047 = permute_1327 = None + view_3048 = torch.ops.aten.view.default(mm_652, [2, 8192, 4096]); mm_652 = None + add_341 = torch.ops.aten.add.Tensor(view_3046, view_3048); view_3046 = view_3048 = None + convert_element_type_2711 = torch.ops.prims.convert_element_type.default(mm_651, torch.float32); mm_651 = None + reduce_scatter_tensor_400 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2711, 'avg', 32, '0'); convert_element_type_2711 = None + wait_tensor_878 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_400); reduce_scatter_tensor_400 = None + split_260 = torch.ops.aten.split.Tensor(add_341, 1024, 1); add_341 = None + getitem_2458 = split_260[0] + getitem_2459 = split_260[1] + getitem_2460 = split_260[2] + getitem_2461 = split_260[3] + getitem_2462 = split_260[4] + getitem_2463 = split_260[5] + getitem_2464 = split_260[6] + getitem_2465 = split_260[7]; split_260 = None + cat_252 = torch.ops.aten.cat.default([getitem_2458, getitem_2459, getitem_2460, getitem_2461, getitem_2462, getitem_2463, getitem_2464, getitem_2465]); getitem_2458 = getitem_2459 = getitem_2460 = getitem_2461 = getitem_2462 = getitem_2463 = getitem_2464 = getitem_2465 = None + reduce_scatter_tensor_401 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_252, 'sum', 8, '1'); cat_252 = None + wait_tensor_879 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_401); reduce_scatter_tensor_401 = None + convert_element_type_2712 = torch.ops.prims.convert_element_type.default(wait_tensor_879, torch.float32); wait_tensor_879 = None + convert_element_type_2714 = torch.ops.prims.convert_element_type.default(wait_tensor_22, torch.float32); wait_tensor_22 = None + mul_870 = torch.ops.aten.mul.Tensor(convert_element_type_2712, convert_element_type_2714); convert_element_type_2714 = None + mul_872 = torch.ops.aten.mul.Tensor(mul_12, mul_870) + sum_183 = torch.ops.aten.sum.dim_IntList(mul_872, [2], True); mul_872 = None + div_61 = torch.ops.aten.div.Tensor(mul_12, 4096) + mul_873 = torch.ops.aten.mul.Tensor(div_61, sum_183); div_61 = sum_183 = None + sub_93 = torch.ops.aten.sub.Tensor(mul_870, mul_873); mul_870 = mul_873 = None + mul_874 = torch.ops.aten.mul.Tensor(sub_93, rsqrt_3); sub_93 = rsqrt_3 = None + mul_875 = torch.ops.aten.mul.Tensor(convert_element_type_2712, mul_12); convert_element_type_2712 = mul_12 = None + sum_184 = torch.ops.aten.sum.dim_IntList(mul_875, [0, 1]); mul_875 = None + convert_element_type_2715 = torch.ops.prims.convert_element_type.default(mul_874, torch.bfloat16); mul_874 = None + convert_element_type_2716 = torch.ops.prims.convert_element_type.default(sum_184, torch.bfloat16); sum_184 = None + all_reduce_61 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2716, 'sum', '1'); convert_element_type_2716 = None + wait_tensor_880 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_61); all_reduce_61 = None + convert_element_type_2717 = torch.ops.prims.convert_element_type.default(wait_tensor_880, torch.float32); wait_tensor_880 = None + reduce_scatter_tensor_402 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2717, 'avg', 32, '0'); convert_element_type_2717 = None + wait_tensor_881 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_402); reduce_scatter_tensor_402 = None + add_342 = torch.ops.aten.add.Tensor(add_338, convert_element_type_2715); add_338 = convert_element_type_2715 = None + all_gather_into_tensor_417 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_342, 8, '1') + wait_tensor_882 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_417); all_gather_into_tensor_417 = None + split_261 = torch.ops.aten.split.Tensor(wait_tensor_882, 2); wait_tensor_882 = None + getitem_2466 = split_261[0] + getitem_2467 = split_261[1] + getitem_2468 = split_261[2] + getitem_2469 = split_261[3] + getitem_2470 = split_261[4] + getitem_2471 = split_261[5] + getitem_2472 = split_261[6] + getitem_2473 = split_261[7]; split_261 = None + cat_253 = torch.ops.aten.cat.default([getitem_2466, getitem_2467, getitem_2468, getitem_2469, getitem_2470, getitem_2471, getitem_2472, getitem_2473], 1); getitem_2466 = getitem_2467 = getitem_2468 = getitem_2469 = getitem_2470 = getitem_2471 = getitem_2472 = getitem_2473 = None + view_3049 = torch.ops.aten.view.default(cat_253, [16384, 4096]); cat_253 = None + permute_1329 = torch.ops.aten.permute.default(view_3049, [1, 0]) + permute_17 = torch.ops.aten.permute.default(getitem_121, [0, 2, 1, 3]) + view_114 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + view_120 = torch.ops.aten.view.default(view_114, [16384, 512]); view_114 = None + mm_653 = torch.ops.aten.mm.default(permute_1329, view_120); permute_1329 = view_120 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16); primals_17 = None + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 32, '0'); convert_element_type_50 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + permute_1331 = torch.ops.aten.permute.default(permute_18, [1, 0]); permute_18 = None + mm_654 = torch.ops.aten.mm.default(view_3049, permute_1331); view_3049 = permute_1331 = None + view_3050 = torch.ops.aten.view.default(mm_654, [2, 8192, 512]); mm_654 = None + convert_element_type_2722 = torch.ops.prims.convert_element_type.default(mm_653, torch.float32); mm_653 = None + reduce_scatter_tensor_403 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2722, 'avg', 32, '0'); convert_element_type_2722 = None + wait_tensor_883 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_403); reduce_scatter_tensor_403 = None + view_3051 = torch.ops.aten.view.default(view_3050, [2, 8192, 4, 128]); view_3050 = None + permute_1333 = torch.ops.aten.permute.default(view_3051, [0, 2, 1, 3]); view_3051 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16); primals_13 = None + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 32, '0'); convert_element_type_34 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32); add_3 = None + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_15) + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_36, 8, '1'); convert_element_type_36 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + split_13 = torch.ops.aten.split.Tensor(wait_tensor_16, 2); wait_tensor_16 = None + getitem_113 = split_13[0] + getitem_114 = split_13[1] + getitem_115 = split_13[2] + getitem_116 = split_13[3] + getitem_117 = split_13[4] + getitem_118 = split_13[5] + getitem_119 = split_13[6] + getitem_120 = split_13[7]; split_13 = None + cat_5 = torch.ops.aten.cat.default([getitem_113, getitem_114, getitem_115, getitem_116, getitem_117, getitem_118, getitem_119, getitem_120], 1); getitem_113 = getitem_114 = getitem_115 = getitem_116 = getitem_117 = getitem_118 = getitem_119 = getitem_120 = None + view_87 = torch.ops.aten.view.default(cat_5, [16384, 4096]); cat_5 = None + view_88 = torch.ops.aten.view.default(mm_7, [2, 8192, 512]); mm_7 = None + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16); primals_15 = None + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 32, '0'); convert_element_type_40 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + mm_8 = torch.ops.aten.mm.default(view_87, permute_12) + view_95 = torch.ops.aten.view.default(mm_8, [2, 8192, 128]); mm_8 = None + view_102 = torch.ops.aten.view.default(mm_9, [2, 8192, 128]); mm_9 = None + view_104 = torch.ops.aten.view.default(view_88, [2, 8192, -1, 128]); view_88 = None + view_105 = torch.ops.aten.view.default(view_95, [2, 8192, -1, 128]); view_95 = None + view_106 = torch.ops.aten.view.default(view_102, [2, 8192, -1, 128]); view_102 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_104, torch.float32); view_104 = None + view_107 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 4, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_107); view_107 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_105, torch.float32); view_105 = None + view_108 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 1, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_108); view_108 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_37); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_110 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 4, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_37); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_111 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 1, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_110, torch.bfloat16); view_110 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_111, torch.bfloat16); view_111 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 1, 4, 128]); unsqueeze_2 = None + view_112 = torch.ops.aten.view.default(expand_2, [2, 8192, 4, 128]); expand_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_106, 3); view_106 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 1, 4, 128]); unsqueeze_3 = None + view_113 = torch.ops.aten.view.default(expand_3, [2, 8192, 4, 128]); expand_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_112, [0, 2, 1, 3]); view_112 = None + permute_16 = torch.ops.aten.permute.default(view_113, [0, 2, 1, 3]); view_113 = None + _scaled_dot_product_cudnn_attention_backward_30 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1333, permute_14, permute_15, permute_16, getitem_121, getitem_122, getitem_127, getitem_128, None, None, None, 8192, 8192, 0.0, True); permute_1333 = permute_14 = permute_15 = permute_16 = getitem_121 = getitem_122 = getitem_127 = getitem_128 = None + getitem_2474 = _scaled_dot_product_cudnn_attention_backward_30[0] + getitem_2475 = _scaled_dot_product_cudnn_attention_backward_30[1] + getitem_2476 = _scaled_dot_product_cudnn_attention_backward_30[2]; _scaled_dot_product_cudnn_attention_backward_30 = None + permute_1334 = torch.ops.aten.permute.default(getitem_2476, [0, 2, 1, 3]); getitem_2476 = None + permute_1335 = torch.ops.aten.permute.default(getitem_2475, [0, 2, 1, 3]); getitem_2475 = None + permute_1336 = torch.ops.aten.permute.default(getitem_2474, [0, 2, 1, 3]); getitem_2474 = None + view_3052 = torch.ops.aten.view.default(permute_1334, [2, 8192, 1, 4, 128]); permute_1334 = None + sum_185 = torch.ops.aten.sum.dim_IntList(view_3052, [3], True); view_3052 = None + squeeze_60 = torch.ops.aten.squeeze.dim(sum_185, 3); sum_185 = None + view_3053 = torch.ops.aten.view.default(permute_1335, [2, 8192, 1, 4, 128]); permute_1335 = None + sum_186 = torch.ops.aten.sum.dim_IntList(view_3053, [3], True); view_3053 = None + squeeze_61 = torch.ops.aten.squeeze.dim(sum_186, 3); sum_186 = None + convert_element_type_2723 = torch.ops.prims.convert_element_type.default(squeeze_61, torch.float32); squeeze_61 = None + convert_element_type_2724 = torch.ops.prims.convert_element_type.default(permute_1336, torch.float32); permute_1336 = None + view_3054 = torch.ops.aten.view.default(convert_element_type_2723, [2, 8192, 1, 64, 2]); convert_element_type_2723 = None + view_as_complex_124 = torch.ops.aten.view_as_complex.default(view_3054); view_3054 = None + mul_876 = torch.ops.aten.mul.Tensor(view_as_complex_124, _conj); view_as_complex_124 = None + view_3055 = torch.ops.aten.view.default(convert_element_type_2724, [2, 8192, 4, 64, 2]); convert_element_type_2724 = None + view_as_complex_125 = torch.ops.aten.view_as_complex.default(view_3055); view_3055 = None + mul_877 = torch.ops.aten.mul.Tensor(view_as_complex_125, _conj); view_as_complex_125 = None + view_as_real_124 = torch.ops.aten.view_as_real.default(mul_876); mul_876 = None + view_3056 = torch.ops.aten.view.default(view_as_real_124, [2, 8192, 1, 128]); view_as_real_124 = None + convert_element_type_2725 = torch.ops.prims.convert_element_type.default(view_3056, torch.bfloat16); view_3056 = None + view_as_real_125 = torch.ops.aten.view_as_real.default(mul_877); mul_877 = None + view_3057 = torch.ops.aten.view.default(view_as_real_125, [2, 8192, 4, 128]); view_as_real_125 = None + convert_element_type_2726 = torch.ops.prims.convert_element_type.default(view_3057, torch.bfloat16); view_3057 = None + view_3058 = torch.ops.aten.view.default(squeeze_60, [2, 8192, 128]); squeeze_60 = None + view_3059 = torch.ops.aten.view.default(convert_element_type_2725, [2, 8192, 128]); convert_element_type_2725 = None + view_3060 = torch.ops.aten.view.default(convert_element_type_2726, [2, 8192, 512]); convert_element_type_2726 = None + view_3061 = torch.ops.aten.view.default(view_3058, [16384, 128]); view_3058 = None + permute_1337 = torch.ops.aten.permute.default(view_3061, [1, 0]) + mm_655 = torch.ops.aten.mm.default(permute_1337, view_87); permute_1337 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16); primals_16 = None + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 32, '0'); convert_element_type_43 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_19, [1, 0]); wait_tensor_19 = None + permute_1339 = torch.ops.aten.permute.default(permute_13, [1, 0]); permute_13 = None + mm_656 = torch.ops.aten.mm.default(view_3061, permute_1339); view_3061 = permute_1339 = None + view_3062 = torch.ops.aten.view.default(mm_656, [2, 8192, 4096]); mm_656 = None + convert_element_type_2731 = torch.ops.prims.convert_element_type.default(mm_655, torch.float32); mm_655 = None + reduce_scatter_tensor_404 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2731, 'avg', 32, '0'); convert_element_type_2731 = None + wait_tensor_884 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_404); reduce_scatter_tensor_404 = None + view_3063 = torch.ops.aten.view.default(view_3059, [16384, 128]); view_3059 = None + permute_1341 = torch.ops.aten.permute.default(view_3063, [1, 0]) + mm_657 = torch.ops.aten.mm.default(permute_1341, view_87); permute_1341 = None + permute_1343 = torch.ops.aten.permute.default(permute_12, [1, 0]); permute_12 = None + mm_658 = torch.ops.aten.mm.default(view_3063, permute_1343); view_3063 = permute_1343 = None + view_3064 = torch.ops.aten.view.default(mm_658, [2, 8192, 4096]); mm_658 = None + add_343 = torch.ops.aten.add.Tensor(view_3062, view_3064); view_3062 = view_3064 = None + convert_element_type_2736 = torch.ops.prims.convert_element_type.default(mm_657, torch.float32); mm_657 = None + reduce_scatter_tensor_405 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2736, 'avg', 32, '0'); convert_element_type_2736 = None + wait_tensor_885 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_405); reduce_scatter_tensor_405 = None + view_3065 = torch.ops.aten.view.default(view_3060, [16384, 512]); view_3060 = None + permute_1345 = torch.ops.aten.permute.default(view_3065, [1, 0]) + mm_659 = torch.ops.aten.mm.default(permute_1345, view_87); permute_1345 = view_87 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16); primals_14 = None + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 32, '0'); convert_element_type_37 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + permute_1347 = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None + mm_660 = torch.ops.aten.mm.default(view_3065, permute_1347); view_3065 = permute_1347 = None + view_3066 = torch.ops.aten.view.default(mm_660, [2, 8192, 4096]); mm_660 = None + add_344 = torch.ops.aten.add.Tensor(add_343, view_3066); add_343 = view_3066 = None + convert_element_type_2741 = torch.ops.prims.convert_element_type.default(mm_659, torch.float32); mm_659 = None + reduce_scatter_tensor_406 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2741, 'avg', 32, '0'); convert_element_type_2741 = None + wait_tensor_886 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_406); reduce_scatter_tensor_406 = None + split_262 = torch.ops.aten.split.Tensor(add_344, 1024, 1); add_344 = None + getitem_2477 = split_262[0] + getitem_2478 = split_262[1] + getitem_2479 = split_262[2] + getitem_2480 = split_262[3] + getitem_2481 = split_262[4] + getitem_2482 = split_262[5] + getitem_2483 = split_262[6] + getitem_2484 = split_262[7]; split_262 = None + cat_254 = torch.ops.aten.cat.default([getitem_2477, getitem_2478, getitem_2479, getitem_2480, getitem_2481, getitem_2482, getitem_2483, getitem_2484]); getitem_2477 = getitem_2478 = getitem_2479 = getitem_2480 = getitem_2481 = getitem_2482 = getitem_2483 = getitem_2484 = None + reduce_scatter_tensor_407 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_254, 'sum', 8, '1'); cat_254 = None + wait_tensor_887 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_407); reduce_scatter_tensor_407 = None + convert_element_type_2742 = torch.ops.prims.convert_element_type.default(wait_tensor_887, torch.float32); wait_tensor_887 = None + convert_element_type_2744 = torch.ops.prims.convert_element_type.default(wait_tensor_15, torch.float32); wait_tensor_15 = None + mul_878 = torch.ops.aten.mul.Tensor(convert_element_type_2742, convert_element_type_2744); convert_element_type_2744 = None + mul_880 = torch.ops.aten.mul.Tensor(mul_8, mul_878) + sum_187 = torch.ops.aten.sum.dim_IntList(mul_880, [2], True); mul_880 = None + div_62 = torch.ops.aten.div.Tensor(mul_8, 4096) + mul_881 = torch.ops.aten.mul.Tensor(div_62, sum_187); div_62 = sum_187 = None + sub_94 = torch.ops.aten.sub.Tensor(mul_878, mul_881); mul_878 = mul_881 = None + mul_882 = torch.ops.aten.mul.Tensor(sub_94, rsqrt_2); sub_94 = rsqrt_2 = None + mul_883 = torch.ops.aten.mul.Tensor(convert_element_type_2742, mul_8); convert_element_type_2742 = mul_8 = None + sum_188 = torch.ops.aten.sum.dim_IntList(mul_883, [0, 1]); mul_883 = None + convert_element_type_2745 = torch.ops.prims.convert_element_type.default(mul_882, torch.bfloat16); mul_882 = None + convert_element_type_2746 = torch.ops.prims.convert_element_type.default(sum_188, torch.bfloat16); sum_188 = None + all_reduce_62 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2746, 'sum', '1'); convert_element_type_2746 = None + wait_tensor_888 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_62); all_reduce_62 = None + convert_element_type_2747 = torch.ops.prims.convert_element_type.default(wait_tensor_888, torch.float32); wait_tensor_888 = None + reduce_scatter_tensor_408 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2747, 'avg', 32, '0'); convert_element_type_2747 = None + wait_tensor_889 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_408); reduce_scatter_tensor_408 = None + add_345 = torch.ops.aten.add.Tensor(add_342, convert_element_type_2745); add_342 = convert_element_type_2745 = None + all_gather_into_tensor_418 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_345, 8, '1') + wait_tensor_890 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_418); all_gather_into_tensor_418 = None + split_263 = torch.ops.aten.split.Tensor(wait_tensor_890, 2); wait_tensor_890 = None + getitem_2485 = split_263[0] + getitem_2486 = split_263[1] + getitem_2487 = split_263[2] + getitem_2488 = split_263[3] + getitem_2489 = split_263[4] + getitem_2490 = split_263[5] + getitem_2491 = split_263[6] + getitem_2492 = split_263[7]; split_263 = None + cat_255 = torch.ops.aten.cat.default([getitem_2485, getitem_2486, getitem_2487, getitem_2488, getitem_2489, getitem_2490, getitem_2491, getitem_2492], 1); getitem_2485 = getitem_2486 = getitem_2487 = getitem_2488 = getitem_2489 = getitem_2490 = getitem_2491 = getitem_2492 = None + view_3067 = torch.ops.aten.view.default(cat_255, [16384, 4096]); cat_255 = None + permute_1349 = torch.ops.aten.permute.default(view_3067, [1, 0]) + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None + add_1 = torch.ops.aten.add.Tensor(wait_tensor_1, wait_tensor_8); wait_tensor_8 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16); primals_9 = None + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 32, '0'); convert_element_type_20 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32); add_1 = None + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_9) + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_22, 8, '1'); convert_element_type_22 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + split_11 = torch.ops.aten.split.Tensor(wait_tensor_10, 2); wait_tensor_10 = None + getitem_97 = split_11[0] + getitem_98 = split_11[1] + getitem_99 = split_11[2] + getitem_100 = split_11[3] + getitem_101 = split_11[4] + getitem_102 = split_11[5] + getitem_103 = split_11[6] + getitem_104 = split_11[7]; split_11 = None + cat_3 = torch.ops.aten.cat.default([getitem_97, getitem_98, getitem_99, getitem_100, getitem_101, getitem_102, getitem_103, getitem_104], 1); getitem_97 = getitem_98 = getitem_99 = getitem_100 = getitem_101 = getitem_102 = getitem_103 = getitem_104 = None + view_60 = torch.ops.aten.view.default(cat_3, [16384, 4096]); cat_3 = None + view_61 = torch.ops.aten.view.default(mm_4, [2, 8192, 1792]); mm_4 = None + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_61, torch.float32); view_61 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16); primals_11 = None + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 32, '0'); convert_element_type_28 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_5 = torch.ops.aten.mm.default(view_60, permute_9) + view_68 = torch.ops.aten.view.default(mm_5, [2, 8192, 1792]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_68) + view_75 = torch.ops.aten.view.default(mul_7, [16384, 1792]); mul_7 = None + mm_661 = torch.ops.aten.mm.default(permute_1349, view_75); permute_1349 = view_75 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16); primals_12 = None + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 32, '0'); convert_element_type_31 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + permute_1351 = torch.ops.aten.permute.default(permute_10, [1, 0]); permute_10 = None + mm_662 = torch.ops.aten.mm.default(view_3067, permute_1351); view_3067 = permute_1351 = None + view_3068 = torch.ops.aten.view.default(mm_662, [2, 8192, 1792]); mm_662 = None + convert_element_type_2752 = torch.ops.prims.convert_element_type.default(mm_661, torch.float32); mm_661 = None + reduce_scatter_tensor_409 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2752, 'avg', 32, '0'); convert_element_type_2752 = None + wait_tensor_891 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_409); reduce_scatter_tensor_409 = None + mul_884 = torch.ops.aten.mul.Tensor(view_3068, convert_element_type_27); convert_element_type_27 = None + mul_885 = torch.ops.aten.mul.Tensor(view_3068, view_68); view_3068 = view_68 = None + view_3069 = torch.ops.aten.view.default(mul_884, [16384, 1792]); mul_884 = None + permute_1353 = torch.ops.aten.permute.default(view_3069, [1, 0]) + mm_663 = torch.ops.aten.mm.default(permute_1353, view_60); permute_1353 = None + permute_1355 = torch.ops.aten.permute.default(permute_9, [1, 0]); permute_9 = None + mm_664 = torch.ops.aten.mm.default(view_3069, permute_1355); view_3069 = permute_1355 = None + view_3070 = torch.ops.aten.view.default(mm_664, [2, 8192, 4096]); mm_664 = None + convert_element_type_2757 = torch.ops.prims.convert_element_type.default(mm_663, torch.float32); mm_663 = None + reduce_scatter_tensor_410 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2757, 'avg', 32, '0'); convert_element_type_2757 = None + wait_tensor_892 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_410); reduce_scatter_tensor_410 = None + convert_element_type_2758 = torch.ops.prims.convert_element_type.default(mul_885, torch.float32); mul_885 = None + neg_31 = torch.ops.aten.neg.default(convert_element_type_26) + exp_31 = torch.ops.aten.exp.default(neg_31); neg_31 = None + add_346 = torch.ops.aten.add.Tensor(exp_31, 1); exp_31 = None + reciprocal_31 = torch.ops.aten.reciprocal.default(add_346); add_346 = None + mul_886 = torch.ops.aten.mul.Tensor(reciprocal_31, 1); reciprocal_31 = None + mul_887 = torch.ops.aten.mul.Tensor(convert_element_type_2758, mul_886); convert_element_type_2758 = None + sub_95 = torch.ops.aten.sub.Tensor(1, mul_886); mul_886 = None + mul_888 = torch.ops.aten.mul.Tensor(convert_element_type_26, sub_95); convert_element_type_26 = sub_95 = None + add_347 = torch.ops.aten.add.Tensor(mul_888, 1); mul_888 = None + mul_889 = torch.ops.aten.mul.Tensor(mul_887, add_347); mul_887 = add_347 = None + convert_element_type_2760 = torch.ops.prims.convert_element_type.default(mul_889, torch.bfloat16); mul_889 = None + view_3071 = torch.ops.aten.view.default(convert_element_type_2760, [16384, 1792]); convert_element_type_2760 = None + permute_1357 = torch.ops.aten.permute.default(view_3071, [1, 0]) + mm_665 = torch.ops.aten.mm.default(permute_1357, view_60); permute_1357 = view_60 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16); primals_10 = None + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 32, '0'); convert_element_type_23 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + permute_1359 = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None + mm_666 = torch.ops.aten.mm.default(view_3071, permute_1359); view_3071 = permute_1359 = None + view_3072 = torch.ops.aten.view.default(mm_666, [2, 8192, 4096]); mm_666 = None + add_348 = torch.ops.aten.add.Tensor(view_3070, view_3072); view_3070 = view_3072 = None + convert_element_type_2765 = torch.ops.prims.convert_element_type.default(mm_665, torch.float32); mm_665 = None + reduce_scatter_tensor_411 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2765, 'avg', 32, '0'); convert_element_type_2765 = None + wait_tensor_893 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_411); reduce_scatter_tensor_411 = None + split_264 = torch.ops.aten.split.Tensor(add_348, 1024, 1); add_348 = None + getitem_2493 = split_264[0] + getitem_2494 = split_264[1] + getitem_2495 = split_264[2] + getitem_2496 = split_264[3] + getitem_2497 = split_264[4] + getitem_2498 = split_264[5] + getitem_2499 = split_264[6] + getitem_2500 = split_264[7]; split_264 = None + cat_256 = torch.ops.aten.cat.default([getitem_2493, getitem_2494, getitem_2495, getitem_2496, getitem_2497, getitem_2498, getitem_2499, getitem_2500]); getitem_2493 = getitem_2494 = getitem_2495 = getitem_2496 = getitem_2497 = getitem_2498 = getitem_2499 = getitem_2500 = None + reduce_scatter_tensor_412 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_256, 'sum', 8, '1'); cat_256 = None + wait_tensor_894 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_412); reduce_scatter_tensor_412 = None + convert_element_type_2766 = torch.ops.prims.convert_element_type.default(wait_tensor_894, torch.float32); wait_tensor_894 = None + convert_element_type_2768 = torch.ops.prims.convert_element_type.default(wait_tensor_9, torch.float32); wait_tensor_9 = None + mul_890 = torch.ops.aten.mul.Tensor(convert_element_type_2766, convert_element_type_2768); convert_element_type_2768 = None + mul_892 = torch.ops.aten.mul.Tensor(mul_4, mul_890) + sum_189 = torch.ops.aten.sum.dim_IntList(mul_892, [2], True); mul_892 = None + div_63 = torch.ops.aten.div.Tensor(mul_4, 4096) + mul_893 = torch.ops.aten.mul.Tensor(div_63, sum_189); div_63 = sum_189 = None + sub_96 = torch.ops.aten.sub.Tensor(mul_890, mul_893); mul_890 = mul_893 = None + mul_894 = torch.ops.aten.mul.Tensor(sub_96, rsqrt_1); sub_96 = rsqrt_1 = None + mul_895 = torch.ops.aten.mul.Tensor(convert_element_type_2766, mul_4); convert_element_type_2766 = mul_4 = None + sum_190 = torch.ops.aten.sum.dim_IntList(mul_895, [0, 1]); mul_895 = None + convert_element_type_2769 = torch.ops.prims.convert_element_type.default(mul_894, torch.bfloat16); mul_894 = None + convert_element_type_2770 = torch.ops.prims.convert_element_type.default(sum_190, torch.bfloat16); sum_190 = None + all_reduce_63 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2770, 'sum', '1'); convert_element_type_2770 = None + wait_tensor_895 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_63); all_reduce_63 = None + convert_element_type_2771 = torch.ops.prims.convert_element_type.default(wait_tensor_895, torch.float32); wait_tensor_895 = None + reduce_scatter_tensor_413 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2771, 'avg', 32, '0'); convert_element_type_2771 = None + wait_tensor_896 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_413); reduce_scatter_tensor_413 = None + add_349 = torch.ops.aten.add.Tensor(add_345, convert_element_type_2769); add_345 = convert_element_type_2769 = None + all_gather_into_tensor_419 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_349, 8, '1') + wait_tensor_897 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_419); all_gather_into_tensor_419 = None + split_265 = torch.ops.aten.split.Tensor(wait_tensor_897, 2); wait_tensor_897 = None + getitem_2501 = split_265[0] + getitem_2502 = split_265[1] + getitem_2503 = split_265[2] + getitem_2504 = split_265[3] + getitem_2505 = split_265[4] + getitem_2506 = split_265[5] + getitem_2507 = split_265[6] + getitem_2508 = split_265[7]; split_265 = None + cat_257 = torch.ops.aten.cat.default([getitem_2501, getitem_2502, getitem_2503, getitem_2504, getitem_2505, getitem_2506, getitem_2507, getitem_2508], 1); getitem_2501 = getitem_2502 = getitem_2503 = getitem_2504 = getitem_2505 = getitem_2506 = getitem_2507 = getitem_2508 = None + view_3073 = torch.ops.aten.view.default(cat_257, [16384, 4096]); cat_257 = None + permute_1361 = torch.ops.aten.permute.default(view_3073, [1, 0]) + permute_6 = torch.ops.aten.permute.default(getitem_80, [0, 2, 1, 3]) + view_42 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + view_48 = torch.ops.aten.view.default(view_42, [16384, 512]); view_42 = None + mm_667 = torch.ops.aten.mm.default(permute_1361, view_48); permute_1361 = view_48 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16); primals_8 = None + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 32, '0'); convert_element_type_17 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + permute_1363 = torch.ops.aten.permute.default(permute_7, [1, 0]); permute_7 = None + mm_668 = torch.ops.aten.mm.default(view_3073, permute_1363); view_3073 = permute_1363 = None + view_3074 = torch.ops.aten.view.default(mm_668, [2, 8192, 512]); mm_668 = None + convert_element_type_2776 = torch.ops.prims.convert_element_type.default(mm_667, torch.float32); mm_667 = None + reduce_scatter_tensor_414 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2776, 'avg', 32, '0'); convert_element_type_2776 = None + wait_tensor_898 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_414); reduce_scatter_tensor_414 = None + view_3075 = torch.ops.aten.view.default(view_3074, [2, 8192, 4, 128]); view_3074 = None + permute_1365 = torch.ops.aten.permute.default(view_3075, [0, 2, 1, 3]); view_3075 = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16); primals_4 = None + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 32, '0'); convert_element_type_1 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32); wait_tensor_1 = None + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_2) + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_3, 8, '1'); convert_element_type_3 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + split_9 = torch.ops.aten.split.Tensor(wait_tensor_3, 2); wait_tensor_3 = None + getitem_72 = split_9[0] + getitem_73 = split_9[1] + getitem_74 = split_9[2] + getitem_75 = split_9[3] + getitem_76 = split_9[4] + getitem_77 = split_9[5] + getitem_78 = split_9[6] + getitem_79 = split_9[7]; split_9 = None + cat_1 = torch.ops.aten.cat.default([getitem_72, getitem_73, getitem_74, getitem_75, getitem_76, getitem_77, getitem_78, getitem_79], 1); getitem_72 = getitem_73 = getitem_74 = getitem_75 = getitem_76 = getitem_77 = getitem_78 = getitem_79 = None + view_15 = torch.ops.aten.view.default(cat_1, [16384, 4096]); cat_1 = None + view_16 = torch.ops.aten.view.default(mm, [2, 8192, 512]); mm = None + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16); primals_6 = None + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 32, '0'); convert_element_type_7 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + mm_1 = torch.ops.aten.mm.default(view_15, permute_1) + view_23 = torch.ops.aten.view.default(mm_1, [2, 8192, 128]); mm_1 = None + view_30 = torch.ops.aten.view.default(mm_2, [2, 8192, 128]); mm_2 = None + view_32 = torch.ops.aten.view.default(view_16, [2, 8192, -1, 128]); view_16 = None + view_33 = torch.ops.aten.view.default(view_23, [2, 8192, -1, 128]); view_23 = None + view_34 = torch.ops.aten.view.default(view_30, [2, 8192, -1, 128]); view_30 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_32, torch.float32); view_32 = None + view_35 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 4, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_35); view_35 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_33, torch.float32); view_33 = None + view_36 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 1, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_36); view_36 = None + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_37); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_38 = torch.ops.aten.view.default(view_as_real, [2, 8192, 4, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_37); view_as_complex_1 = view_37 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_39 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 1, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_38, torch.bfloat16); view_38 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_39, torch.bfloat16); view_39 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 1, 4, 128]); unsqueeze = None + view_40 = torch.ops.aten.view.default(expand, [2, 8192, 4, 128]); expand = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_34, 3); view_34 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 1, 4, 128]); unsqueeze_1 = None + view_41 = torch.ops.aten.view.default(expand_1, [2, 8192, 4, 128]); expand_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_40, [0, 2, 1, 3]); view_40 = None + permute_5 = torch.ops.aten.permute.default(view_41, [0, 2, 1, 3]); view_41 = None + _scaled_dot_product_cudnn_attention_backward_31 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1365, permute_3, permute_4, permute_5, getitem_80, getitem_81, getitem_86, getitem_87, None, None, None, 8192, 8192, 0.0, True); permute_1365 = permute_3 = permute_4 = permute_5 = getitem_80 = getitem_81 = getitem_86 = getitem_87 = None + getitem_2509 = _scaled_dot_product_cudnn_attention_backward_31[0] + getitem_2510 = _scaled_dot_product_cudnn_attention_backward_31[1] + getitem_2511 = _scaled_dot_product_cudnn_attention_backward_31[2]; _scaled_dot_product_cudnn_attention_backward_31 = None + permute_1366 = torch.ops.aten.permute.default(getitem_2511, [0, 2, 1, 3]); getitem_2511 = None + permute_1367 = torch.ops.aten.permute.default(getitem_2510, [0, 2, 1, 3]); getitem_2510 = None + permute_1368 = torch.ops.aten.permute.default(getitem_2509, [0, 2, 1, 3]); getitem_2509 = None + view_3076 = torch.ops.aten.view.default(permute_1366, [2, 8192, 1, 4, 128]); permute_1366 = None + sum_191 = torch.ops.aten.sum.dim_IntList(view_3076, [3], True); view_3076 = None + squeeze_62 = torch.ops.aten.squeeze.dim(sum_191, 3); sum_191 = None + view_3077 = torch.ops.aten.view.default(permute_1367, [2, 8192, 1, 4, 128]); permute_1367 = None + sum_192 = torch.ops.aten.sum.dim_IntList(view_3077, [3], True); view_3077 = None + squeeze_63 = torch.ops.aten.squeeze.dim(sum_192, 3); sum_192 = None + convert_element_type_2777 = torch.ops.prims.convert_element_type.default(squeeze_63, torch.float32); squeeze_63 = None + convert_element_type_2778 = torch.ops.prims.convert_element_type.default(permute_1368, torch.float32); permute_1368 = None + view_3078 = torch.ops.aten.view.default(convert_element_type_2777, [2, 8192, 1, 64, 2]); convert_element_type_2777 = None + view_as_complex_126 = torch.ops.aten.view_as_complex.default(view_3078); view_3078 = None + mul_896 = torch.ops.aten.mul.Tensor(view_as_complex_126, _conj); view_as_complex_126 = None + view_3079 = torch.ops.aten.view.default(convert_element_type_2778, [2, 8192, 4, 64, 2]); convert_element_type_2778 = None + view_as_complex_127 = torch.ops.aten.view_as_complex.default(view_3079); view_3079 = None + mul_897 = torch.ops.aten.mul.Tensor(view_as_complex_127, _conj); view_as_complex_127 = _conj = None + view_as_real_126 = torch.ops.aten.view_as_real.default(mul_896); mul_896 = None + view_3080 = torch.ops.aten.view.default(view_as_real_126, [2, 8192, 1, 128]); view_as_real_126 = None + convert_element_type_2779 = torch.ops.prims.convert_element_type.default(view_3080, torch.bfloat16); view_3080 = None + view_as_real_127 = torch.ops.aten.view_as_real.default(mul_897); mul_897 = None + view_3081 = torch.ops.aten.view.default(view_as_real_127, [2, 8192, 4, 128]); view_as_real_127 = None + convert_element_type_2780 = torch.ops.prims.convert_element_type.default(view_3081, torch.bfloat16); view_3081 = None + view_3082 = torch.ops.aten.view.default(squeeze_62, [2, 8192, 128]); squeeze_62 = None + view_3083 = torch.ops.aten.view.default(convert_element_type_2779, [2, 8192, 128]); convert_element_type_2779 = None + view_3084 = torch.ops.aten.view.default(convert_element_type_2780, [2, 8192, 512]); convert_element_type_2780 = None + view_3085 = torch.ops.aten.view.default(view_3082, [16384, 128]); view_3082 = None + permute_1369 = torch.ops.aten.permute.default(view_3085, [1, 0]) + mm_669 = torch.ops.aten.mm.default(permute_1369, view_15); permute_1369 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16); primals_7 = None + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 32, '0'); convert_element_type_10 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_6, [1, 0]); wait_tensor_6 = None + permute_1371 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + mm_670 = torch.ops.aten.mm.default(view_3085, permute_1371); view_3085 = permute_1371 = None + view_3086 = torch.ops.aten.view.default(mm_670, [2, 8192, 4096]); mm_670 = None + convert_element_type_2785 = torch.ops.prims.convert_element_type.default(mm_669, torch.float32); mm_669 = None + reduce_scatter_tensor_415 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2785, 'avg', 32, '0'); convert_element_type_2785 = None + wait_tensor_899 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_415); reduce_scatter_tensor_415 = None + view_3087 = torch.ops.aten.view.default(view_3083, [16384, 128]); view_3083 = None + permute_1373 = torch.ops.aten.permute.default(view_3087, [1, 0]) + mm_671 = torch.ops.aten.mm.default(permute_1373, view_15); permute_1373 = None + permute_1375 = torch.ops.aten.permute.default(permute_1, [1, 0]); permute_1 = None + mm_672 = torch.ops.aten.mm.default(view_3087, permute_1375); view_3087 = permute_1375 = None + view_3088 = torch.ops.aten.view.default(mm_672, [2, 8192, 4096]); mm_672 = None + add_350 = torch.ops.aten.add.Tensor(view_3086, view_3088); view_3086 = view_3088 = None + convert_element_type_2790 = torch.ops.prims.convert_element_type.default(mm_671, torch.float32); mm_671 = None + reduce_scatter_tensor_416 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2790, 'avg', 32, '0'); convert_element_type_2790 = None + wait_tensor_900 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_416); reduce_scatter_tensor_416 = None + view_3089 = torch.ops.aten.view.default(view_3084, [16384, 512]); view_3084 = None + permute_1377 = torch.ops.aten.permute.default(view_3089, [1, 0]) + mm_673 = torch.ops.aten.mm.default(permute_1377, view_15); permute_1377 = view_15 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16); primals_5 = None + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 32, '0'); convert_element_type_4 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + permute_1379 = torch.ops.aten.permute.default(permute, [1, 0]); permute = None + mm_674 = torch.ops.aten.mm.default(view_3089, permute_1379); view_3089 = permute_1379 = None + view_3090 = torch.ops.aten.view.default(mm_674, [2, 8192, 4096]); mm_674 = None + add_351 = torch.ops.aten.add.Tensor(add_350, view_3090); add_350 = view_3090 = None + convert_element_type_2795 = torch.ops.prims.convert_element_type.default(mm_673, torch.float32); mm_673 = None + reduce_scatter_tensor_417 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2795, 'avg', 32, '0'); convert_element_type_2795 = None + wait_tensor_901 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_417); reduce_scatter_tensor_417 = None + split_266 = torch.ops.aten.split.Tensor(add_351, 1024, 1); add_351 = None + getitem_2512 = split_266[0] + getitem_2513 = split_266[1] + getitem_2514 = split_266[2] + getitem_2515 = split_266[3] + getitem_2516 = split_266[4] + getitem_2517 = split_266[5] + getitem_2518 = split_266[6] + getitem_2519 = split_266[7]; split_266 = None + cat_258 = torch.ops.aten.cat.default([getitem_2512, getitem_2513, getitem_2514, getitem_2515, getitem_2516, getitem_2517, getitem_2518, getitem_2519]); getitem_2512 = getitem_2513 = getitem_2514 = getitem_2515 = getitem_2516 = getitem_2517 = getitem_2518 = getitem_2519 = None + reduce_scatter_tensor_418 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_258, 'sum', 8, '1'); cat_258 = None + wait_tensor_902 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_418); reduce_scatter_tensor_418 = None + convert_element_type_2796 = torch.ops.prims.convert_element_type.default(wait_tensor_902, torch.float32); wait_tensor_902 = None + convert_element_type_2798 = torch.ops.prims.convert_element_type.default(wait_tensor_2, torch.float32); wait_tensor_2 = None + mul_898 = torch.ops.aten.mul.Tensor(convert_element_type_2796, convert_element_type_2798); convert_element_type_2798 = None + mul_900 = torch.ops.aten.mul.Tensor(mul, mul_898) + sum_193 = torch.ops.aten.sum.dim_IntList(mul_900, [2], True); mul_900 = None + div_64 = torch.ops.aten.div.Tensor(mul, 4096) + mul_901 = torch.ops.aten.mul.Tensor(div_64, sum_193); div_64 = sum_193 = None + sub_97 = torch.ops.aten.sub.Tensor(mul_898, mul_901); mul_898 = mul_901 = None + mul_902 = torch.ops.aten.mul.Tensor(sub_97, rsqrt); sub_97 = rsqrt = None + mul_903 = torch.ops.aten.mul.Tensor(convert_element_type_2796, mul); convert_element_type_2796 = mul = None + sum_194 = torch.ops.aten.sum.dim_IntList(mul_903, [0, 1]); mul_903 = None + convert_element_type_2799 = torch.ops.prims.convert_element_type.default(mul_902, torch.bfloat16); mul_902 = None + convert_element_type_2800 = torch.ops.prims.convert_element_type.default(sum_194, torch.bfloat16); sum_194 = None + all_reduce_64 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2800, 'sum', '1'); convert_element_type_2800 = None + wait_tensor_903 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_64); all_reduce_64 = None + convert_element_type_2801 = torch.ops.prims.convert_element_type.default(wait_tensor_903, torch.float32); wait_tensor_903 = None + reduce_scatter_tensor_419 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2801, 'avg', 32, '0'); convert_element_type_2801 = None + wait_tensor_904 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_419); reduce_scatter_tensor_419 = None + add_352 = torch.ops.aten.add.Tensor(add_349, convert_element_type_2799); add_349 = convert_element_type_2799 = None + all_gather_into_tensor_420 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_352, 8, '1'); add_352 = None + wait_tensor_905 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_420); all_gather_into_tensor_420 = None + split_267 = torch.ops.aten.split.Tensor(wait_tensor_905, 2); wait_tensor_905 = None + getitem_2520 = split_267[0] + getitem_2521 = split_267[1] + getitem_2522 = split_267[2] + getitem_2523 = split_267[3] + getitem_2524 = split_267[4] + getitem_2525 = split_267[5] + getitem_2526 = split_267[6] + getitem_2527 = split_267[7]; split_267 = None + cat_259 = torch.ops.aten.cat.default([getitem_2520, getitem_2521, getitem_2522, getitem_2523, getitem_2524, getitem_2525, getitem_2526, getitem_2527], 1); getitem_2520 = getitem_2521 = getitem_2522 = getitem_2523 = getitem_2524 = getitem_2525 = getitem_2526 = getitem_2527 = None + convert_element_type_2802 = torch.ops.prims.convert_element_type.default(cat_259, torch.float32); cat_259 = None + eq = torch.ops.aten.eq.Scalar(primals_1, -1) + unsqueeze_64 = torch.ops.aten.unsqueeze.default(eq, -1); eq = None + full_default_2 = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + where = torch.ops.aten.where.self(unsqueeze_64, full_default_2, convert_element_type_2802); unsqueeze_64 = full_default_2 = convert_element_type_2802 = None + full_default_3 = torch.ops.aten.full.default([128256, 4096], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_2 = torch.ops.aten.index_put.default(full_default_3, [primals_1], where, True); full_default_3 = primals_1 = where = None + convert_element_type_2803 = torch.ops.prims.convert_element_type.default(index_put_2, torch.bfloat16); index_put_2 = None + split_268 = torch.ops.aten.split.Tensor(convert_element_type_2803, 16032); convert_element_type_2803 = None + getitem_2528 = split_268[0]; split_268 = None + convert_element_type_2804 = torch.ops.prims.convert_element_type.default(getitem_2528, torch.float32); getitem_2528 = None + reduce_scatter_tensor_420 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2804, 'avg', 32, '0'); convert_element_type_2804 = None + wait_tensor_906 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_420); reduce_scatter_tensor_420 = None + return (None, wait_tensor_906, None, wait_tensor_904, wait_tensor_901, wait_tensor_900, wait_tensor_899, wait_tensor_898, wait_tensor_896, wait_tensor_893, wait_tensor_892, wait_tensor_891, wait_tensor_889, wait_tensor_886, wait_tensor_885, wait_tensor_884, wait_tensor_883, wait_tensor_881, wait_tensor_878, wait_tensor_877, wait_tensor_876, wait_tensor_874, wait_tensor_871, wait_tensor_870, wait_tensor_869, wait_tensor_868, wait_tensor_866, wait_tensor_863, wait_tensor_862, wait_tensor_861, wait_tensor_859, wait_tensor_856, wait_tensor_855, wait_tensor_854, wait_tensor_853, wait_tensor_851, wait_tensor_848, wait_tensor_847, wait_tensor_846, wait_tensor_844, wait_tensor_841, wait_tensor_840, wait_tensor_839, wait_tensor_838, wait_tensor_836, wait_tensor_833, wait_tensor_832, wait_tensor_831, wait_tensor_829, wait_tensor_826, wait_tensor_825, wait_tensor_824, wait_tensor_823, wait_tensor_821, wait_tensor_818, wait_tensor_817, wait_tensor_816, wait_tensor_814, wait_tensor_811, wait_tensor_810, wait_tensor_809, wait_tensor_808, wait_tensor_806, wait_tensor_803, wait_tensor_802, wait_tensor_801, wait_tensor_799, wait_tensor_796, wait_tensor_795, wait_tensor_794, wait_tensor_793, wait_tensor_791, wait_tensor_788, wait_tensor_787, wait_tensor_786, wait_tensor_784, wait_tensor_781, wait_tensor_780, wait_tensor_779, wait_tensor_778, wait_tensor_776, wait_tensor_773, wait_tensor_772, wait_tensor_771, wait_tensor_769, wait_tensor_766, wait_tensor_765, wait_tensor_764, wait_tensor_763, wait_tensor_761, wait_tensor_758, wait_tensor_757, wait_tensor_756, wait_tensor_754, wait_tensor_751, wait_tensor_750, wait_tensor_749, wait_tensor_748, wait_tensor_746, wait_tensor_743, wait_tensor_742, wait_tensor_741, wait_tensor_739, wait_tensor_736, wait_tensor_735, wait_tensor_734, wait_tensor_733, wait_tensor_731, wait_tensor_728, wait_tensor_727, wait_tensor_726, wait_tensor_724, wait_tensor_721, wait_tensor_720, wait_tensor_719, wait_tensor_718, wait_tensor_716, wait_tensor_713, wait_tensor_712, wait_tensor_711, wait_tensor_709, wait_tensor_706, wait_tensor_705, wait_tensor_704, wait_tensor_703, wait_tensor_701, wait_tensor_698, wait_tensor_697, wait_tensor_696, wait_tensor_694, wait_tensor_691, wait_tensor_690, wait_tensor_689, wait_tensor_688, wait_tensor_686, wait_tensor_683, wait_tensor_682, wait_tensor_681, wait_tensor_679, wait_tensor_676, wait_tensor_675, wait_tensor_674, wait_tensor_673, wait_tensor_671, wait_tensor_668, wait_tensor_667, wait_tensor_666, wait_tensor_664, wait_tensor_661, wait_tensor_660, wait_tensor_659, wait_tensor_658, wait_tensor_656, wait_tensor_653, wait_tensor_652, wait_tensor_651, wait_tensor_649, wait_tensor_646, wait_tensor_645, wait_tensor_644, wait_tensor_643, wait_tensor_641, wait_tensor_638, wait_tensor_637, wait_tensor_636, wait_tensor_634, wait_tensor_631, wait_tensor_630, wait_tensor_629, wait_tensor_628, wait_tensor_626, wait_tensor_623, wait_tensor_622, wait_tensor_621, wait_tensor_619, wait_tensor_616, wait_tensor_615, wait_tensor_614, wait_tensor_613, wait_tensor_611, wait_tensor_608, wait_tensor_607, wait_tensor_606, wait_tensor_604, wait_tensor_601, wait_tensor_600, wait_tensor_599, wait_tensor_598, wait_tensor_596, wait_tensor_593, wait_tensor_592, wait_tensor_591, wait_tensor_589, wait_tensor_586, wait_tensor_585, wait_tensor_584, wait_tensor_583, wait_tensor_581, wait_tensor_578, wait_tensor_577, wait_tensor_576, wait_tensor_574, wait_tensor_571, wait_tensor_570, wait_tensor_569, wait_tensor_568, wait_tensor_566, wait_tensor_563, wait_tensor_562, wait_tensor_561, wait_tensor_559, wait_tensor_556, wait_tensor_555, wait_tensor_554, wait_tensor_553, wait_tensor_551, wait_tensor_548, wait_tensor_547, wait_tensor_546, wait_tensor_544, wait_tensor_541, wait_tensor_540, wait_tensor_539, wait_tensor_538, wait_tensor_536, wait_tensor_533, wait_tensor_532, wait_tensor_531, wait_tensor_529, wait_tensor_526, wait_tensor_525, wait_tensor_524, wait_tensor_523, wait_tensor_521, wait_tensor_518, wait_tensor_517, wait_tensor_516, wait_tensor_514, wait_tensor_511, wait_tensor_510, wait_tensor_509, wait_tensor_508, wait_tensor_506, wait_tensor_503, wait_tensor_502, wait_tensor_501, wait_tensor_499, wait_tensor_496, wait_tensor_495, wait_tensor_494, wait_tensor_493, wait_tensor_491, wait_tensor_488, wait_tensor_487, wait_tensor_486, wait_tensor_484, wait_tensor_481, wait_tensor_480, wait_tensor_479, wait_tensor_478, wait_tensor_476, wait_tensor_473, wait_tensor_472, wait_tensor_471, wait_tensor_469, wait_tensor_466, wait_tensor_465, wait_tensor_464, wait_tensor_463, wait_tensor_461, wait_tensor_458, wait_tensor_457, wait_tensor_456, wait_tensor_454, wait_tensor_451, wait_tensor_450, wait_tensor_449, wait_tensor_448, wait_tensor_446, wait_tensor_443, wait_tensor_442, wait_tensor_441, wait_tensor_439, wait_tensor_436, wait_tensor_435, wait_tensor_434, wait_tensor_433, wait_tensor_431, wait_tensor_428, wait_tensor_427, wait_tensor_426, wait_tensor_424, wait_tensor_421) + +def load_args(reader): + buf0 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf0, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_1 + buf1 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf1, (501, 4096), is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf3, (128,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf4, (16, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf5, (4, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf6, (4, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf7, (128, 512), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf8, (128,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf9, (56, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf10, (56, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf11, (128, 1792), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf12, (128,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf13, (16, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf14, (4, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf15, (4, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf16, (128, 512), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf17, (128,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf18, (56, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf19, (56, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf20, (128, 1792), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf21, (128,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf22, (16, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf23, (4, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf24, (4, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf25, (128, 512), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf26, (128,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf27, (56, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf28, (56, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf29, (128, 1792), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf30, (128,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf31, (16, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf32, (4, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf33, (4, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf34, (128, 512), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf35, (128,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf36, (56, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf37, (56, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf38, (128, 1792), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf39, (128,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf40, (16, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf41, (4, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf42, (4, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf43, (128, 512), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf44, (128,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf45, (56, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf46, (56, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf47, (128, 1792), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf48, (128,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf49, (16, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf50, (4, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf51, (4, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf52, (128, 512), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf53, (128,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf54, (56, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf55, (56, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf56, (128, 1792), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf57, (128,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf58, (16, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf59, (4, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf60, (4, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf61, (128, 512), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf62, (128,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf63, (56, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf64, (56, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf65, (128, 1792), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf66, (128,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf67, (16, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf68, (4, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf69, (4, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf70, (128, 512), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf71, (128,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf72, (56, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf73, (56, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf74, (128, 1792), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf75, (128,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf76, (16, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf77, (4, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf78, (4, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf79, (128, 512), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf80, (128,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf81, (56, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf82, (56, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf83, (128, 1792), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf84, (128,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf85, (16, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf86, (4, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf87, (4, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf88, (128, 512), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf89, (128,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf90, (56, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf91, (56, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf92, (128, 1792), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf93, (128,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf94, (16, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf95, (4, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf96, (4, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf97, (128, 512), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf98, (128,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf99, (56, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf100, (56, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf101, (128, 1792), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf102, (128,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf103, (16, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf104, (4, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf105, (4, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf106, (128, 512), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf107, (128,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf108, (56, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf109, (56, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf110, (128, 1792), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf111, (128,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf112, (16, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf113, (4, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf114, (4, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf115, (128, 512), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf116, (128,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf117, (56, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf118, (56, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf119, (128, 1792), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf120, (128,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf121, (16, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf122, (4, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf123, (4, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf124, (128, 512), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf125, (128,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf126, (56, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf127, (56, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf128, (128, 1792), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf129, (128,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf130, (16, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf131, (4, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf132, (4, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf133, (128, 512), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf134, (128,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf135, (56, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf136, (56, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf137, (128, 1792), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf138, (128,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf139, (16, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf140, (4, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf141, (4, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf142, (128, 512), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf143, (128,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf144, (56, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf145, (56, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf146, (128, 1792), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf147, (128,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf148, (16, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf149, (4, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf150, (4, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf151, (128, 512), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf152, (128,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf153, (56, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf154, (56, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf155, (128, 1792), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf156, (128,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf157, (16, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf158, (4, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf159, (4, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf160, (128, 512), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf161, (128,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf162, (56, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf163, (56, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf164, (128, 1792), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf165, (128,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf166, (16, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf167, (4, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf168, (4, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf169, (128, 512), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf170, (128,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf171, (56, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf172, (56, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf173, (128, 1792), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf174, (128,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf175, (16, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf176, (4, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf177, (4, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf178, (128, 512), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf179, (128,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf180, (56, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf181, (56, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf182, (128, 1792), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf183, (128,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf184, (16, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf185, (4, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf186, (4, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf187, (128, 512), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf188, (128,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf189, (56, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf190, (56, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf191, (128, 1792), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf192, (128,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf193, (16, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf194, (4, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf195, (4, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf196, (128, 512), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf197, (128,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf198, (56, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf199, (56, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf200, (128, 1792), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf201, (128,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf202, (16, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf203, (4, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf204, (4, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf205, (128, 512), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf206, (128,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf207, (56, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf208, (56, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf209, (128, 1792), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf210, (128,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf211, (16, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf212, (4, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf213, (4, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf214, (128, 512), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf215, (128,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf216, (56, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf217, (56, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf218, (128, 1792), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf219, (128,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf220, (16, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf221, (4, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf222, (4, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf223, (128, 512), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf224, (128,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf225, (56, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf226, (56, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf227, (128, 1792), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf228, (128,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf229, (16, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf230, (4, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf231, (4, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf232, (128, 512), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf233, (128,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf234, (56, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf235, (56, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf236, (128, 1792), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf237, (128,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf238, (16, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf239, (4, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf240, (4, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf241, (128, 512), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf242, (128,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf243, (56, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf244, (56, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf245, (128, 1792), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf246, (128,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf247, (16, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf248, (4, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf249, (4, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf250, (128, 512), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf251, (128,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf252, (56, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf253, (56, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf254, (128, 1792), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf255, (128,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf256, (16, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf257, (4, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf258, (4, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf259, (128, 512), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf260, (128,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf261, (56, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf262, (56, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf263, (128, 1792), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf264, (128,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf265, (16, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf266, (4, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf267, (4, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf268, (128, 512), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf269, (128,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf270, (56, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf271, (56, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf272, (128, 1792), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf273, (128,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf274, (16, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf275, (4, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf276, (4, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf277, (128, 512), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf278, (128,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf279, (56, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf280, (56, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf281, (128, 1792), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf282, (128,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf283, (16, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf284, (4, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf285, (4, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf286, (128, 512), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf287, (128,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf288, (56, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf289, (56, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf290, (128, 1792), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf291, (128,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf292, (501, 4096), is_leaf=True) # primals_293 + buf293 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf293, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # wait_tensor_1 + buf294 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf294, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm + buf295 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf295, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_2 + buf296 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf296, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_80 + buf297 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf297, (2, 4, 8192, 1), is_leaf=True) # getitem_81 + buf298 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf298, (), dtype=torch.int64, is_leaf=True) # getitem_86 + buf299 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf299, (), dtype=torch.int64, is_leaf=True) # getitem_87 + buf300 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf300, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_1 + buf301 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf301, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_4 + buf302 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf302, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_3 + buf303 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf303, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_7 + buf304 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf304, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_9 + buf305 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf305, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_121 + buf306 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf306, (2, 4, 8192, 1), is_leaf=True) # getitem_122 + buf307 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf307, (), dtype=torch.int64, is_leaf=True) # getitem_127 + buf308 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf308, (), dtype=torch.int64, is_leaf=True) # getitem_128 + buf309 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf309, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_3 + buf310 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf310, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_11 + buf311 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf311, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_7 + buf312 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf312, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_14 + buf313 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf313, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_16 + buf314 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf314, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_162 + buf315 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf315, (2, 4, 8192, 1), is_leaf=True) # getitem_163 + buf316 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf316, (), dtype=torch.int64, is_leaf=True) # getitem_168 + buf317 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf317, (), dtype=torch.int64, is_leaf=True) # getitem_169 + buf318 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf318, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_5 + buf319 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf319, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_18 + buf320 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf320, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_11 + buf321 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf321, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_21 + buf322 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf322, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_23 + buf323 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf323, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_203 + buf324 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf324, (2, 4, 8192, 1), is_leaf=True) # getitem_204 + buf325 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf325, (), dtype=torch.int64, is_leaf=True) # getitem_209 + buf326 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf326, (), dtype=torch.int64, is_leaf=True) # getitem_210 + buf327 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf327, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_7 + buf328 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf328, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_25 + buf329 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf329, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_15 + buf330 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf330, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_28 + buf331 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf331, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_30 + buf332 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf332, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_244 + buf333 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf333, (2, 4, 8192, 1), is_leaf=True) # getitem_245 + buf334 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf334, (), dtype=torch.int64, is_leaf=True) # getitem_250 + buf335 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf335, (), dtype=torch.int64, is_leaf=True) # getitem_251 + buf336 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf336, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_9 + buf337 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf337, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_32 + buf338 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf338, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_19 + buf339 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf339, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_35 + buf340 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf340, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_37 + buf341 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf341, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_285 + buf342 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf342, (2, 4, 8192, 1), is_leaf=True) # getitem_286 + buf343 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf343, (), dtype=torch.int64, is_leaf=True) # getitem_291 + buf344 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf344, (), dtype=torch.int64, is_leaf=True) # getitem_292 + buf345 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf345, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_11 + buf346 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf346, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_39 + buf347 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf347, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_23 + buf348 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf348, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_42 + buf349 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf349, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_44 + buf350 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf350, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_326 + buf351 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf351, (2, 4, 8192, 1), is_leaf=True) # getitem_327 + buf352 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf352, (), dtype=torch.int64, is_leaf=True) # getitem_332 + buf353 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf353, (), dtype=torch.int64, is_leaf=True) # getitem_333 + buf354 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf354, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_13 + buf355 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf355, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_46 + buf356 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf356, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_27 + buf357 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf357, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_49 + buf358 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf358, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_51 + buf359 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf359, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_367 + buf360 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf360, (2, 4, 8192, 1), is_leaf=True) # getitem_368 + buf361 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf361, (), dtype=torch.int64, is_leaf=True) # getitem_373 + buf362 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf362, (), dtype=torch.int64, is_leaf=True) # getitem_374 + buf363 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf363, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_15 + buf364 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf364, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_53 + buf365 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf365, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_31 + buf366 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf366, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_56 + buf367 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf367, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_58 + buf368 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf368, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_408 + buf369 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf369, (2, 4, 8192, 1), is_leaf=True) # getitem_409 + buf370 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf370, (), dtype=torch.int64, is_leaf=True) # getitem_414 + buf371 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf371, (), dtype=torch.int64, is_leaf=True) # getitem_415 + buf372 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf372, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_17 + buf373 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf373, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_60 + buf374 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf374, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_35 + buf375 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf375, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_63 + buf376 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf376, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_65 + buf377 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf377, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_449 + buf378 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf378, (2, 4, 8192, 1), is_leaf=True) # getitem_450 + buf379 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf379, (), dtype=torch.int64, is_leaf=True) # getitem_455 + buf380 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf380, (), dtype=torch.int64, is_leaf=True) # getitem_456 + buf381 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf381, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_19 + buf382 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf382, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_67 + buf383 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf383, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_39 + buf384 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf384, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_70 + buf385 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf385, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_72 + buf386 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf386, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_490 + buf387 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf387, (2, 4, 8192, 1), is_leaf=True) # getitem_491 + buf388 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf388, (), dtype=torch.int64, is_leaf=True) # getitem_496 + buf389 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf389, (), dtype=torch.int64, is_leaf=True) # getitem_497 + buf390 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf390, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_21 + buf391 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf391, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_74 + buf392 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf392, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_43 + buf393 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf393, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_77 + buf394 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf394, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_79 + buf395 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf395, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_531 + buf396 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf396, (2, 4, 8192, 1), is_leaf=True) # getitem_532 + buf397 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf397, (), dtype=torch.int64, is_leaf=True) # getitem_537 + buf398 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf398, (), dtype=torch.int64, is_leaf=True) # getitem_538 + buf399 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf399, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_23 + buf400 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf400, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_81 + buf401 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf401, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_47 + buf402 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf402, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_84 + buf403 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf403, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_86 + buf404 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf404, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_572 + buf405 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf405, (2, 4, 8192, 1), is_leaf=True) # getitem_573 + buf406 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf406, (), dtype=torch.int64, is_leaf=True) # getitem_578 + buf407 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf407, (), dtype=torch.int64, is_leaf=True) # getitem_579 + buf408 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf408, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_25 + buf409 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf409, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_88 + buf410 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf410, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_51 + buf411 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf411, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_91 + buf412 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf412, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_93 + buf413 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf413, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_613 + buf414 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf414, (2, 4, 8192, 1), is_leaf=True) # getitem_614 + buf415 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf415, (), dtype=torch.int64, is_leaf=True) # getitem_619 + buf416 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf416, (), dtype=torch.int64, is_leaf=True) # getitem_620 + buf417 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf417, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_27 + buf418 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf418, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_95 + buf419 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf419, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_55 + buf420 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf420, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_98 + buf421 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf421, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_100 + buf422 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf422, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_654 + buf423 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf423, (2, 4, 8192, 1), is_leaf=True) # getitem_655 + buf424 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf424, (), dtype=torch.int64, is_leaf=True) # getitem_660 + buf425 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf425, (), dtype=torch.int64, is_leaf=True) # getitem_661 + buf426 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf426, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_29 + buf427 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf427, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_102 + buf428 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf428, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_59 + buf429 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf429, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_105 + buf430 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf430, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_107 + buf431 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf431, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_695 + buf432 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf432, (2, 4, 8192, 1), is_leaf=True) # getitem_696 + buf433 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf433, (), dtype=torch.int64, is_leaf=True) # getitem_701 + buf434 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf434, (), dtype=torch.int64, is_leaf=True) # getitem_702 + buf435 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf435, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_31 + buf436 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf436, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_109 + buf437 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf437, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_63 + buf438 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf438, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_112 + buf439 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf439, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_114 + buf440 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf440, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_736 + buf441 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf441, (2, 4, 8192, 1), is_leaf=True) # getitem_737 + buf442 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf442, (), dtype=torch.int64, is_leaf=True) # getitem_742 + buf443 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf443, (), dtype=torch.int64, is_leaf=True) # getitem_743 + buf444 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf444, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_33 + buf445 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf445, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_116 + buf446 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf446, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_67 + buf447 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf447, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_119 + buf448 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf448, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_121 + buf449 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf449, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_777 + buf450 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf450, (2, 4, 8192, 1), is_leaf=True) # getitem_778 + buf451 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf451, (), dtype=torch.int64, is_leaf=True) # getitem_783 + buf452 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf452, (), dtype=torch.int64, is_leaf=True) # getitem_784 + buf453 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf453, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_35 + buf454 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf454, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_123 + buf455 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf455, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_71 + buf456 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf456, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_126 + buf457 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf457, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_128 + buf458 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf458, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_818 + buf459 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf459, (2, 4, 8192, 1), is_leaf=True) # getitem_819 + buf460 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf460, (), dtype=torch.int64, is_leaf=True) # getitem_824 + buf461 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf461, (), dtype=torch.int64, is_leaf=True) # getitem_825 + buf462 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf462, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_37 + buf463 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf463, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_130 + buf464 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf464, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_75 + buf465 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf465, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_133 + buf466 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf466, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_135 + buf467 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf467, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_859 + buf468 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf468, (2, 4, 8192, 1), is_leaf=True) # getitem_860 + buf469 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf469, (), dtype=torch.int64, is_leaf=True) # getitem_865 + buf470 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf470, (), dtype=torch.int64, is_leaf=True) # getitem_866 + buf471 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf471, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_39 + buf472 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf472, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_137 + buf473 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf473, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_79 + buf474 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf474, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_140 + buf475 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf475, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_142 + buf476 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf476, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_900 + buf477 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf477, (2, 4, 8192, 1), is_leaf=True) # getitem_901 + buf478 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf478, (), dtype=torch.int64, is_leaf=True) # getitem_906 + buf479 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf479, (), dtype=torch.int64, is_leaf=True) # getitem_907 + buf480 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf480, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_41 + buf481 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf481, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_144 + buf482 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf482, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_83 + buf483 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf483, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_147 + buf484 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf484, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_149 + buf485 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf485, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_941 + buf486 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf486, (2, 4, 8192, 1), is_leaf=True) # getitem_942 + buf487 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf487, (), dtype=torch.int64, is_leaf=True) # getitem_947 + buf488 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf488, (), dtype=torch.int64, is_leaf=True) # getitem_948 + buf489 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf489, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_43 + buf490 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf490, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_151 + buf491 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf491, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_87 + buf492 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf492, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_154 + buf493 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf493, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_156 + buf494 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf494, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_982 + buf495 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf495, (2, 4, 8192, 1), is_leaf=True) # getitem_983 + buf496 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf496, (), dtype=torch.int64, is_leaf=True) # getitem_988 + buf497 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf497, (), dtype=torch.int64, is_leaf=True) # getitem_989 + buf498 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf498, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_45 + buf499 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf499, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_158 + buf500 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf500, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_91 + buf501 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf501, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_161 + buf502 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf502, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_163 + buf503 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf503, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1023 + buf504 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf504, (2, 4, 8192, 1), is_leaf=True) # getitem_1024 + buf505 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf505, (), dtype=torch.int64, is_leaf=True) # getitem_1029 + buf506 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf506, (), dtype=torch.int64, is_leaf=True) # getitem_1030 + buf507 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf507, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_47 + buf508 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf508, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_165 + buf509 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf509, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_95 + buf510 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf510, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_168 + buf511 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf511, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_170 + buf512 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf512, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1064 + buf513 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf513, (2, 4, 8192, 1), is_leaf=True) # getitem_1065 + buf514 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf514, (), dtype=torch.int64, is_leaf=True) # getitem_1070 + buf515 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf515, (), dtype=torch.int64, is_leaf=True) # getitem_1071 + buf516 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf516, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_49 + buf517 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf517, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_172 + buf518 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf518, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_99 + buf519 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf519, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_175 + buf520 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf520, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_177 + buf521 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf521, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1105 + buf522 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf522, (2, 4, 8192, 1), is_leaf=True) # getitem_1106 + buf523 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf523, (), dtype=torch.int64, is_leaf=True) # getitem_1111 + buf524 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf524, (), dtype=torch.int64, is_leaf=True) # getitem_1112 + buf525 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf525, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_51 + buf526 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf526, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_179 + buf527 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf527, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_103 + buf528 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf528, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_182 + buf529 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf529, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_184 + buf530 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf530, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1146 + buf531 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf531, (2, 4, 8192, 1), is_leaf=True) # getitem_1147 + buf532 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf532, (), dtype=torch.int64, is_leaf=True) # getitem_1152 + buf533 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf533, (), dtype=torch.int64, is_leaf=True) # getitem_1153 + buf534 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf534, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_53 + buf535 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf535, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_186 + buf536 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf536, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_107 + buf537 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf537, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_189 + buf538 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf538, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_191 + buf539 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf539, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1187 + buf540 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf540, (2, 4, 8192, 1), is_leaf=True) # getitem_1188 + buf541 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf541, (), dtype=torch.int64, is_leaf=True) # getitem_1193 + buf542 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf542, (), dtype=torch.int64, is_leaf=True) # getitem_1194 + buf543 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf543, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_55 + buf544 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf544, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_193 + buf545 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf545, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_111 + buf546 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf546, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_196 + buf547 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf547, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_198 + buf548 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf548, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1228 + buf549 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf549, (2, 4, 8192, 1), is_leaf=True) # getitem_1229 + buf550 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf550, (), dtype=torch.int64, is_leaf=True) # getitem_1234 + buf551 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf551, (), dtype=torch.int64, is_leaf=True) # getitem_1235 + buf552 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf552, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_57 + buf553 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf553, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_200 + buf554 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf554, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_115 + buf555 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf555, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_203 + buf556 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf556, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_205 + buf557 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf557, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1269 + buf558 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf558, (2, 4, 8192, 1), is_leaf=True) # getitem_1270 + buf559 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf559, (), dtype=torch.int64, is_leaf=True) # getitem_1275 + buf560 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf560, (), dtype=torch.int64, is_leaf=True) # getitem_1276 + buf561 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf561, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_59 + buf562 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf562, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_207 + buf563 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf563, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_119 + buf564 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf564, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_210 + buf565 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf565, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_212 + buf566 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf566, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1310 + buf567 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf567, (2, 4, 8192, 1), is_leaf=True) # getitem_1311 + buf568 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf568, (), dtype=torch.int64, is_leaf=True) # getitem_1316 + buf569 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf569, (), dtype=torch.int64, is_leaf=True) # getitem_1317 + buf570 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf570, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_61 + buf571 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf571, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_214 + buf572 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf572, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_123 + buf573 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf573, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_217 + buf574 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf574, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_219 + buf575 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf575, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1351 + buf576 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf576, (2, 4, 8192, 1), is_leaf=True) # getitem_1352 + buf577 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf577, (), dtype=torch.int64, is_leaf=True) # getitem_1357 + buf578 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf578, (), dtype=torch.int64, is_leaf=True) # getitem_1358 + buf579 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf579, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_63 + buf580 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf580, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_221 + buf581 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf581, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_64 + buf582 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf582, (2, 1024, 1), is_leaf=True) # rsqrt_64 + buf583 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf583, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # view_2319 + buf584 = reader.storage(None, 525336576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf584, (2, 8192, 16032), dtype=torch.bfloat16, is_leaf=True) # tangents_1 +load_args._version = 0 + +def get_mesh_sizes(): + return 32, 8 + +def get_colls_estimations_file(): + return "colls32_8.table" + +def get_pg_names(): + return "0", "1" diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_1d_32layers.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_1d_32layers.py new file mode 100644 index 00000000..754c23f8 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_1d_32layers.py @@ -0,0 +1,8953 @@ +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, embedding, mm, mm_2, getitem, getitem_1, getitem_6, getitem_7, mm_4, add_3, mm_7, mm_9, getitem_9, getitem_10, getitem_15, getitem_16, mm_11, add_7, mm_14, mm_16, getitem_18, getitem_19, getitem_24, getitem_25, mm_18, add_11, mm_21, mm_23, getitem_27, getitem_28, getitem_33, getitem_34, mm_25, add_15, mm_28, mm_30, getitem_36, getitem_37, getitem_42, getitem_43, mm_32, add_19, mm_35, mm_37, getitem_45, getitem_46, getitem_51, getitem_52, mm_39, add_23, mm_42, mm_44, getitem_54, getitem_55, getitem_60, getitem_61, mm_46, add_27, mm_49, mm_51, getitem_63, getitem_64, getitem_69, getitem_70, mm_53, add_31, mm_56, mm_58, getitem_72, getitem_73, getitem_78, getitem_79, mm_60, add_35, mm_63, mm_65, getitem_81, getitem_82, getitem_87, getitem_88, mm_67, add_39, mm_70, mm_72, getitem_90, getitem_91, getitem_96, getitem_97, mm_74, add_43, mm_77, mm_79, getitem_99, getitem_100, getitem_105, getitem_106, mm_81, add_47, mm_84, mm_86, getitem_108, getitem_109, getitem_114, getitem_115, mm_88, add_51, mm_91, mm_93, getitem_117, getitem_118, getitem_123, getitem_124, mm_95, add_55, mm_98, mm_100, getitem_126, getitem_127, getitem_132, getitem_133, mm_102, add_59, mm_105, mm_107, getitem_135, getitem_136, getitem_141, getitem_142, mm_109, add_63, mm_112, mm_114, getitem_144, getitem_145, getitem_150, getitem_151, mm_116, add_67, mm_119, mm_121, getitem_153, getitem_154, getitem_159, getitem_160, mm_123, add_71, mm_126, mm_128, getitem_162, getitem_163, getitem_168, getitem_169, mm_130, add_75, mm_133, mm_135, getitem_171, getitem_172, getitem_177, getitem_178, mm_137, add_79, mm_140, mm_142, getitem_180, getitem_181, getitem_186, getitem_187, mm_144, add_83, mm_147, mm_149, getitem_189, getitem_190, getitem_195, getitem_196, mm_151, add_87, mm_154, mm_156, getitem_198, getitem_199, getitem_204, getitem_205, mm_158, add_91, mm_161, mm_163, getitem_207, getitem_208, getitem_213, getitem_214, mm_165, add_95, mm_168, mm_170, getitem_216, getitem_217, getitem_222, getitem_223, mm_172, add_99, mm_175, mm_177, getitem_225, getitem_226, getitem_231, getitem_232, mm_179, add_103, mm_182, mm_184, getitem_234, getitem_235, getitem_240, getitem_241, mm_186, add_107, mm_189, mm_191, getitem_243, getitem_244, getitem_249, getitem_250, mm_193, add_111, mm_196, mm_198, getitem_252, getitem_253, getitem_258, getitem_259, mm_200, add_115, mm_203, mm_205, getitem_261, getitem_262, getitem_267, getitem_268, mm_207, add_119, mm_210, mm_212, getitem_270, getitem_271, getitem_276, getitem_277, mm_214, add_123, mm_217, mm_219, getitem_279, getitem_280, getitem_285, getitem_286, mm_221, mm_223, rsqrt_64, view_1091, tangents_1): + view_1093 = torch.ops.aten.view.default(tangents_1, [16384, 128256]); tangents_1 = None + permute_353 = torch.ops.aten.permute.default(view_1093, [1, 0]) + mm_225 = torch.ops.aten.mm.default(permute_353, view_1091); permute_353 = view_1091 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16); primals_293 = None + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 64, '0'); convert_element_type_1060 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + permute_355 = torch.ops.aten.permute.default(permute_352, [1, 0]); permute_352 = None + mm_226 = torch.ops.aten.mm.default(view_1093, permute_355); view_1093 = permute_355 = None + view_1094 = torch.ops.aten.view.default(mm_226, [2, 8192, 4096]); mm_226 = None + convert_element_type_1067 = torch.ops.prims.convert_element_type.default(mm_225, torch.float32); mm_225 = None + reduce_scatter_tensor = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1067, 'avg', 64, '0'); convert_element_type_1067 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None + convert_element_type_1068 = torch.ops.prims.convert_element_type.default(view_1094, torch.float32); view_1094 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16); primals_292 = None + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 64, '0'); convert_element_type_1057 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + convert_element_type_1070 = torch.ops.prims.convert_element_type.default(wait_tensor_289, torch.float32); wait_tensor_289 = None + mul_258 = torch.ops.aten.mul.Tensor(convert_element_type_1068, convert_element_type_1070); convert_element_type_1070 = None + permute_347 = torch.ops.aten.permute.default(getitem_279, [0, 2, 1, 3]) + view_1075 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16); primals_287 = None + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 64, '0'); convert_element_type_1040 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + view_1077 = torch.ops.aten.view.default(view_1075, [16384, 4096]); view_1075 = None + mm_220 = torch.ops.aten.mm.default(view_1077, permute_348) + view_1078 = torch.ops.aten.view.default(mm_220, [2, 8192, 4096]); mm_220 = None + add_125 = torch.ops.aten.add.Tensor(add_123, view_1078); view_1078 = None + view_1088 = torch.ops.aten.view.default(mm_223, [2, 8192, 4096]); mm_223 = None + add_127 = torch.ops.aten.add.Tensor(add_125, view_1088); view_1088 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_260 = torch.ops.aten.mul.Tensor(mul_256, mul_258) + sum_1 = torch.ops.aten.sum.dim_IntList(mul_260, [2], True); mul_260 = None + div = torch.ops.aten.div.Tensor(mul_256, 4096) + mul_261 = torch.ops.aten.mul.Tensor(div, sum_1); div = sum_1 = None + sub = torch.ops.aten.sub.Tensor(mul_258, mul_261); mul_258 = mul_261 = None + mul_262 = torch.ops.aten.mul.Tensor(sub, rsqrt_64); sub = rsqrt_64 = None + mul_263 = torch.ops.aten.mul.Tensor(convert_element_type_1068, mul_256); convert_element_type_1068 = mul_256 = None + sum_2 = torch.ops.aten.sum.dim_IntList(mul_263, [0, 1]); mul_263 = None + convert_element_type_1071 = torch.ops.prims.convert_element_type.default(mul_262, torch.bfloat16); mul_262 = None + convert_element_type_default_65 = torch.ops.prims.convert_element_type.default(sum_2, torch.float32); sum_2 = None + reduce_scatter_tensor_1 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_65, 'avg', 64, '0'); convert_element_type_default_65 = None + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None + view_1095 = torch.ops.aten.view.default(convert_element_type_1071, [16384, 4096]) + permute_357 = torch.ops.aten.permute.default(view_1095, [1, 0]) + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16); primals_288 = None + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 64, '0'); convert_element_type_1043 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32); add_125 = None + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_285) + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + view_1081 = torch.ops.aten.view.default(convert_element_type_1045, [16384, 4096]); convert_element_type_1045 = None + view_1082 = torch.ops.aten.view.default(mm_221, [2, 8192, 14336]); mm_221 = None + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_1082, torch.float32); view_1082 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16); primals_290 = None + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 64, '0'); convert_element_type_1051 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_287, [1, 0]); wait_tensor_287 = None + mm_222 = torch.ops.aten.mm.default(view_1081, permute_350) + view_1085 = torch.ops.aten.view.default(mm_222, [2, 8192, 14336]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_1085) + view_1087 = torch.ops.aten.view.default(mul_255, [16384, 14336]); mul_255 = None + mm_227 = torch.ops.aten.mm.default(permute_357, view_1087); permute_357 = view_1087 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16); primals_291 = None + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 64, '0'); convert_element_type_1054 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_288, [1, 0]); wait_tensor_288 = None + permute_359 = torch.ops.aten.permute.default(permute_351, [1, 0]); permute_351 = None + mm_228 = torch.ops.aten.mm.default(view_1095, permute_359); view_1095 = permute_359 = None + view_1096 = torch.ops.aten.view.default(mm_228, [2, 8192, 14336]); mm_228 = None + convert_element_type_1078 = torch.ops.prims.convert_element_type.default(mm_227, torch.float32); mm_227 = None + reduce_scatter_tensor_2 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1078, 'avg', 64, '0'); convert_element_type_1078 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None + mul_264 = torch.ops.aten.mul.Tensor(view_1096, convert_element_type_1050); convert_element_type_1050 = None + mul_265 = torch.ops.aten.mul.Tensor(view_1096, view_1085); view_1096 = view_1085 = None + view_1097 = torch.ops.aten.view.default(mul_264, [16384, 14336]); mul_264 = None + permute_361 = torch.ops.aten.permute.default(view_1097, [1, 0]) + mm_229 = torch.ops.aten.mm.default(permute_361, view_1081); permute_361 = None + permute_363 = torch.ops.aten.permute.default(permute_350, [1, 0]); permute_350 = None + mm_230 = torch.ops.aten.mm.default(view_1097, permute_363); view_1097 = permute_363 = None + view_1098 = torch.ops.aten.view.default(mm_230, [2, 8192, 4096]); mm_230 = None + convert_element_type_1083 = torch.ops.prims.convert_element_type.default(mm_229, torch.float32); mm_229 = None + reduce_scatter_tensor_3 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1083, 'avg', 64, '0'); convert_element_type_1083 = None + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3); reduce_scatter_tensor_3 = None + convert_element_type_1084 = torch.ops.prims.convert_element_type.default(mul_265, torch.float32); mul_265 = None + neg = torch.ops.aten.neg.default(convert_element_type_1049) + exp = torch.ops.aten.exp.default(neg); neg = None + add_129 = torch.ops.aten.add.Tensor(exp, 1); exp = None + reciprocal = torch.ops.aten.reciprocal.default(add_129); add_129 = None + mul_266 = torch.ops.aten.mul.Tensor(reciprocal, 1); reciprocal = None + mul_267 = torch.ops.aten.mul.Tensor(convert_element_type_1084, mul_266); convert_element_type_1084 = None + sub_1 = torch.ops.aten.sub.Tensor(1, mul_266); mul_266 = None + mul_268 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sub_1); convert_element_type_1049 = sub_1 = None + add_130 = torch.ops.aten.add.Tensor(mul_268, 1); mul_268 = None + mul_269 = torch.ops.aten.mul.Tensor(mul_267, add_130); mul_267 = add_130 = None + convert_element_type_1086 = torch.ops.prims.convert_element_type.default(mul_269, torch.bfloat16); mul_269 = None + view_1099 = torch.ops.aten.view.default(convert_element_type_1086, [16384, 14336]); convert_element_type_1086 = None + permute_365 = torch.ops.aten.permute.default(view_1099, [1, 0]) + mm_231 = torch.ops.aten.mm.default(permute_365, view_1081); permute_365 = view_1081 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16); primals_289 = None + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 64, '0'); convert_element_type_1046 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + permute_367 = torch.ops.aten.permute.default(permute_349, [1, 0]); permute_349 = None + mm_232 = torch.ops.aten.mm.default(view_1099, permute_367); view_1099 = permute_367 = None + view_1100 = torch.ops.aten.view.default(mm_232, [2, 8192, 4096]); mm_232 = None + add_131 = torch.ops.aten.add.Tensor(view_1098, view_1100); view_1098 = view_1100 = None + convert_element_type_1091 = torch.ops.prims.convert_element_type.default(mm_231, torch.float32); mm_231 = None + reduce_scatter_tensor_4 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1091, 'avg', 64, '0'); convert_element_type_1091 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_4); reduce_scatter_tensor_4 = None + convert_element_type_1092 = torch.ops.prims.convert_element_type.default(add_131, torch.float32); add_131 = None + convert_element_type_1094 = torch.ops.prims.convert_element_type.default(wait_tensor_285, torch.float32); wait_tensor_285 = None + mul_270 = torch.ops.aten.mul.Tensor(convert_element_type_1092, convert_element_type_1094); convert_element_type_1094 = None + mul_272 = torch.ops.aten.mul.Tensor(mul_252, mul_270) + sum_3 = torch.ops.aten.sum.dim_IntList(mul_272, [2], True); mul_272 = None + div_1 = torch.ops.aten.div.Tensor(mul_252, 4096) + mul_273 = torch.ops.aten.mul.Tensor(div_1, sum_3); div_1 = sum_3 = None + sub_2 = torch.ops.aten.sub.Tensor(mul_270, mul_273); mul_270 = mul_273 = None + mul_274 = torch.ops.aten.mul.Tensor(sub_2, rsqrt_63); sub_2 = rsqrt_63 = None + mul_275 = torch.ops.aten.mul.Tensor(convert_element_type_1092, mul_252); convert_element_type_1092 = mul_252 = None + sum_4 = torch.ops.aten.sum.dim_IntList(mul_275, [0, 1]); mul_275 = None + convert_element_type_1095 = torch.ops.prims.convert_element_type.default(mul_274, torch.bfloat16); mul_274 = None + add_132 = torch.ops.aten.add.Tensor(convert_element_type_1071, convert_element_type_1095); convert_element_type_1071 = convert_element_type_1095 = None + convert_element_type_default_64 = torch.ops.prims.convert_element_type.default(sum_4, torch.float32); sum_4 = None + reduce_scatter_tensor_5 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_64, 'avg', 64, '0'); convert_element_type_default_64 = None + wait_tensor_296 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5); reduce_scatter_tensor_5 = None + view_1101 = torch.ops.aten.view.default(add_132, [16384, 4096]) + permute_369 = torch.ops.aten.permute.default(view_1101, [1, 0]) + mm_233 = torch.ops.aten.mm.default(permute_369, view_1077); permute_369 = view_1077 = None + permute_371 = torch.ops.aten.permute.default(permute_348, [1, 0]); permute_348 = None + mm_234 = torch.ops.aten.mm.default(view_1101, permute_371); view_1101 = permute_371 = None + view_1102 = torch.ops.aten.view.default(mm_234, [2, 8192, 4096]); mm_234 = None + convert_element_type_1102 = torch.ops.prims.convert_element_type.default(mm_233, torch.float32); mm_233 = None + reduce_scatter_tensor_6 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1102, 'avg', 64, '0'); convert_element_type_1102 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_6); reduce_scatter_tensor_6 = None + view_1103 = torch.ops.aten.view.default(view_1102, [2, 8192, 32, 128]); view_1102 = None + permute_373 = torch.ops.aten.permute.default(view_1103, [0, 2, 1, 3]); view_1103 = None + view_16 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]); primals_3 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16); primals_283 = None + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 64, '0'); convert_element_type_1024 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32); add_123 = None + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_280) + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + view_1057 = torch.ops.aten.view.default(convert_element_type_1026, [16384, 4096]); convert_element_type_1026 = None + view_1058 = torch.ops.aten.view.default(mm_217, [2, 8192, 4096]); mm_217 = None + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16); primals_285 = None + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 64, '0'); convert_element_type_1030 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_282, [1, 0]); wait_tensor_282 = None + mm_218 = torch.ops.aten.mm.default(view_1057, permute_342) + view_1061 = torch.ops.aten.view.default(mm_218, [2, 8192, 1024]); mm_218 = None + view_1064 = torch.ops.aten.view.default(mm_219, [2, 8192, 1024]); mm_219 = None + view_1065 = torch.ops.aten.view.default(view_1058, [2, 8192, -1, 128]); view_1058 = None + view_1066 = torch.ops.aten.view.default(view_1061, [2, 8192, -1, 128]); view_1061 = None + view_1067 = torch.ops.aten.view.default(view_1064, [2, 8192, -1, 128]); view_1064 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_1065, torch.float32); view_1065 = None + view_1068 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 32, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_1068); view_1068 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_1066, torch.float32); view_1066 = None + view_1069 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 8, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_1069); view_1069 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_16); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_1071 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 32, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_16); view_as_complex_63 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_1072 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 8, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_1071, torch.bfloat16); view_1071 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_1072, torch.bfloat16); view_1072 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 8, 4, 128]); unsqueeze_62 = None + clone_62 = torch.ops.aten.clone.default(expand_62, memory_format = torch.contiguous_format); expand_62 = None + view_1073 = torch.ops.aten.view.default(clone_62, [2, 8192, 32, 128]); clone_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_1067, 3); view_1067 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 8, 4, 128]); unsqueeze_63 = None + clone_63 = torch.ops.aten.clone.default(expand_63, memory_format = torch.contiguous_format); expand_63 = None + view_1074 = torch.ops.aten.view.default(clone_63, [2, 8192, 32, 128]); clone_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_1073, [0, 2, 1, 3]); view_1073 = None + permute_346 = torch.ops.aten.permute.default(view_1074, [0, 2, 1, 3]); view_1074 = None + _scaled_dot_product_cudnn_attention_backward = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_373, permute_344, permute_345, permute_346, getitem_279, getitem_280, getitem_285, getitem_286, None, None, None, 8192, 8192, 0.0, True); permute_373 = permute_344 = permute_345 = permute_346 = getitem_279 = getitem_280 = getitem_285 = getitem_286 = None + getitem_288 = _scaled_dot_product_cudnn_attention_backward[0] + getitem_289 = _scaled_dot_product_cudnn_attention_backward[1] + getitem_290 = _scaled_dot_product_cudnn_attention_backward[2]; _scaled_dot_product_cudnn_attention_backward = None + permute_374 = torch.ops.aten.permute.default(getitem_290, [0, 2, 1, 3]); getitem_290 = None + permute_375 = torch.ops.aten.permute.default(getitem_289, [0, 2, 1, 3]); getitem_289 = None + permute_376 = torch.ops.aten.permute.default(getitem_288, [0, 2, 1, 3]); getitem_288 = None + view_1104 = torch.ops.aten.view.default(permute_374, [2, 8192, 8, 4, 128]); permute_374 = None + sum_5 = torch.ops.aten.sum.dim_IntList(view_1104, [3], True); view_1104 = None + squeeze = torch.ops.aten.squeeze.dim(sum_5, 3); sum_5 = None + view_1105 = torch.ops.aten.view.default(permute_375, [2, 8192, 8, 4, 128]); permute_375 = None + sum_6 = torch.ops.aten.sum.dim_IntList(view_1105, [3], True); view_1105 = None + squeeze_1 = torch.ops.aten.squeeze.dim(sum_6, 3); sum_6 = None + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(squeeze_1, torch.float32); squeeze_1 = None + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(permute_376, torch.float32); permute_376 = None + view_1106 = torch.ops.aten.view.default(convert_element_type_1103, [2, 8192, 8, 64, 2]); convert_element_type_1103 = None + view_as_complex_64 = torch.ops.aten.view_as_complex.default(view_1106); view_1106 = None + _conj = torch.ops.aten._conj.default(view_16) + mul_276 = torch.ops.aten.mul.Tensor(view_as_complex_64, _conj); view_as_complex_64 = None + view_1107 = torch.ops.aten.view.default(convert_element_type_1104, [2, 8192, 32, 64, 2]); convert_element_type_1104 = None + view_as_complex_65 = torch.ops.aten.view_as_complex.default(view_1107); view_1107 = None + mul_277 = torch.ops.aten.mul.Tensor(view_as_complex_65, _conj); view_as_complex_65 = None + view_as_real_64 = torch.ops.aten.view_as_real.default(mul_276); mul_276 = None + view_1108 = torch.ops.aten.view.default(view_as_real_64, [2, 8192, 8, 128]); view_as_real_64 = None + convert_element_type_1105 = torch.ops.prims.convert_element_type.default(view_1108, torch.bfloat16); view_1108 = None + view_as_real_65 = torch.ops.aten.view_as_real.default(mul_277); mul_277 = None + view_1109 = torch.ops.aten.view.default(view_as_real_65, [2, 8192, 32, 128]); view_as_real_65 = None + convert_element_type_1106 = torch.ops.prims.convert_element_type.default(view_1109, torch.bfloat16); view_1109 = None + view_1110 = torch.ops.aten.view.default(squeeze, [2, 8192, 1024]); squeeze = None + view_1111 = torch.ops.aten.view.default(convert_element_type_1105, [2, 8192, 1024]); convert_element_type_1105 = None + view_1112 = torch.ops.aten.view.default(convert_element_type_1106, [2, 8192, 4096]); convert_element_type_1106 = None + view_1113 = torch.ops.aten.view.default(view_1110, [16384, 1024]); view_1110 = None + permute_377 = torch.ops.aten.permute.default(view_1113, [1, 0]) + mm_235 = torch.ops.aten.mm.default(permute_377, view_1057); permute_377 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16); primals_286 = None + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 64, '0'); convert_element_type_1033 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_283, [1, 0]); wait_tensor_283 = None + permute_379 = torch.ops.aten.permute.default(permute_343, [1, 0]); permute_343 = None + mm_236 = torch.ops.aten.mm.default(view_1113, permute_379); view_1113 = permute_379 = None + view_1114 = torch.ops.aten.view.default(mm_236, [2, 8192, 4096]); mm_236 = None + convert_element_type_1111 = torch.ops.prims.convert_element_type.default(mm_235, torch.float32); mm_235 = None + reduce_scatter_tensor_7 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1111, 'avg', 64, '0'); convert_element_type_1111 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7); reduce_scatter_tensor_7 = None + view_1115 = torch.ops.aten.view.default(view_1111, [16384, 1024]); view_1111 = None + permute_381 = torch.ops.aten.permute.default(view_1115, [1, 0]) + mm_237 = torch.ops.aten.mm.default(permute_381, view_1057); permute_381 = None + permute_383 = torch.ops.aten.permute.default(permute_342, [1, 0]); permute_342 = None + mm_238 = torch.ops.aten.mm.default(view_1115, permute_383); view_1115 = permute_383 = None + view_1116 = torch.ops.aten.view.default(mm_238, [2, 8192, 4096]); mm_238 = None + add_133 = torch.ops.aten.add.Tensor(view_1114, view_1116); view_1114 = view_1116 = None + convert_element_type_1116 = torch.ops.prims.convert_element_type.default(mm_237, torch.float32); mm_237 = None + reduce_scatter_tensor_8 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1116, 'avg', 64, '0'); convert_element_type_1116 = None + wait_tensor_299 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_8); reduce_scatter_tensor_8 = None + view_1117 = torch.ops.aten.view.default(view_1112, [16384, 4096]); view_1112 = None + permute_385 = torch.ops.aten.permute.default(view_1117, [1, 0]) + mm_239 = torch.ops.aten.mm.default(permute_385, view_1057); permute_385 = view_1057 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16); primals_284 = None + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 64, '0'); convert_element_type_1027 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_281, [1, 0]); wait_tensor_281 = None + permute_387 = torch.ops.aten.permute.default(permute_341, [1, 0]); permute_341 = None + mm_240 = torch.ops.aten.mm.default(view_1117, permute_387); view_1117 = permute_387 = None + view_1118 = torch.ops.aten.view.default(mm_240, [2, 8192, 4096]); mm_240 = None + add_134 = torch.ops.aten.add.Tensor(add_133, view_1118); add_133 = view_1118 = None + convert_element_type_1121 = torch.ops.prims.convert_element_type.default(mm_239, torch.float32); mm_239 = None + reduce_scatter_tensor_9 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1121, 'avg', 64, '0'); convert_element_type_1121 = None + wait_tensor_300 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9); reduce_scatter_tensor_9 = None + convert_element_type_1122 = torch.ops.prims.convert_element_type.default(add_134, torch.float32); add_134 = None + convert_element_type_1124 = torch.ops.prims.convert_element_type.default(wait_tensor_280, torch.float32); wait_tensor_280 = None + mul_278 = torch.ops.aten.mul.Tensor(convert_element_type_1122, convert_element_type_1124); convert_element_type_1124 = None + mul_280 = torch.ops.aten.mul.Tensor(mul_248, mul_278) + sum_7 = torch.ops.aten.sum.dim_IntList(mul_280, [2], True); mul_280 = None + div_2 = torch.ops.aten.div.Tensor(mul_248, 4096) + mul_281 = torch.ops.aten.mul.Tensor(div_2, sum_7); div_2 = sum_7 = None + sub_3 = torch.ops.aten.sub.Tensor(mul_278, mul_281); mul_278 = mul_281 = None + mul_282 = torch.ops.aten.mul.Tensor(sub_3, rsqrt_62); sub_3 = rsqrt_62 = None + mul_283 = torch.ops.aten.mul.Tensor(convert_element_type_1122, mul_248); convert_element_type_1122 = mul_248 = None + sum_8 = torch.ops.aten.sum.dim_IntList(mul_283, [0, 1]); mul_283 = None + convert_element_type_1125 = torch.ops.prims.convert_element_type.default(mul_282, torch.bfloat16); mul_282 = None + add_135 = torch.ops.aten.add.Tensor(add_132, convert_element_type_1125); add_132 = convert_element_type_1125 = None + convert_element_type_default_63 = torch.ops.prims.convert_element_type.default(sum_8, torch.float32); sum_8 = None + reduce_scatter_tensor_10 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_63, 'avg', 64, '0'); convert_element_type_default_63 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_10); reduce_scatter_tensor_10 = None + view_1119 = torch.ops.aten.view.default(add_135, [16384, 4096]) + permute_389 = torch.ops.aten.permute.default(view_1119, [1, 0]) + permute_336 = torch.ops.aten.permute.default(getitem_270, [0, 2, 1, 3]) + view_1041 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16); primals_278 = None + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 64, '0'); convert_element_type_1007 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_275, [1, 0]); wait_tensor_275 = None + view_1043 = torch.ops.aten.view.default(view_1041, [16384, 4096]); view_1041 = None + mm_213 = torch.ops.aten.mm.default(view_1043, permute_337) + view_1044 = torch.ops.aten.view.default(mm_213, [2, 8192, 4096]); mm_213 = None + add_121 = torch.ops.aten.add.Tensor(add_119, view_1044); view_1044 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16); primals_279 = None + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 64, '0'); convert_element_type_1010 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32); add_121 = None + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_276) + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + view_1047 = torch.ops.aten.view.default(convert_element_type_1012, [16384, 4096]); convert_element_type_1012 = None + view_1048 = torch.ops.aten.view.default(mm_214, [2, 8192, 14336]); mm_214 = None + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_1048, torch.float32); view_1048 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16); primals_281 = None + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 64, '0'); convert_element_type_1018 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_215 = torch.ops.aten.mm.default(view_1047, permute_339) + view_1051 = torch.ops.aten.view.default(mm_215, [2, 8192, 14336]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_1051) + view_1053 = torch.ops.aten.view.default(mul_247, [16384, 14336]); mul_247 = None + mm_241 = torch.ops.aten.mm.default(permute_389, view_1053); permute_389 = view_1053 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16); primals_282 = None + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 64, '0'); convert_element_type_1021 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + permute_391 = torch.ops.aten.permute.default(permute_340, [1, 0]); permute_340 = None + mm_242 = torch.ops.aten.mm.default(view_1119, permute_391); view_1119 = permute_391 = None + view_1120 = torch.ops.aten.view.default(mm_242, [2, 8192, 14336]); mm_242 = None + convert_element_type_1132 = torch.ops.prims.convert_element_type.default(mm_241, torch.float32); mm_241 = None + reduce_scatter_tensor_11 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1132, 'avg', 64, '0'); convert_element_type_1132 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11); reduce_scatter_tensor_11 = None + mul_284 = torch.ops.aten.mul.Tensor(view_1120, convert_element_type_1017); convert_element_type_1017 = None + mul_285 = torch.ops.aten.mul.Tensor(view_1120, view_1051); view_1120 = view_1051 = None + view_1121 = torch.ops.aten.view.default(mul_284, [16384, 14336]); mul_284 = None + permute_393 = torch.ops.aten.permute.default(view_1121, [1, 0]) + mm_243 = torch.ops.aten.mm.default(permute_393, view_1047); permute_393 = None + permute_395 = torch.ops.aten.permute.default(permute_339, [1, 0]); permute_339 = None + mm_244 = torch.ops.aten.mm.default(view_1121, permute_395); view_1121 = permute_395 = None + view_1122 = torch.ops.aten.view.default(mm_244, [2, 8192, 4096]); mm_244 = None + convert_element_type_1137 = torch.ops.prims.convert_element_type.default(mm_243, torch.float32); mm_243 = None + reduce_scatter_tensor_12 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1137, 'avg', 64, '0'); convert_element_type_1137 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_12); reduce_scatter_tensor_12 = None + convert_element_type_1138 = torch.ops.prims.convert_element_type.default(mul_285, torch.float32); mul_285 = None + neg_1 = torch.ops.aten.neg.default(convert_element_type_1016) + exp_1 = torch.ops.aten.exp.default(neg_1); neg_1 = None + add_136 = torch.ops.aten.add.Tensor(exp_1, 1); exp_1 = None + reciprocal_1 = torch.ops.aten.reciprocal.default(add_136); add_136 = None + mul_286 = torch.ops.aten.mul.Tensor(reciprocal_1, 1); reciprocal_1 = None + mul_287 = torch.ops.aten.mul.Tensor(convert_element_type_1138, mul_286); convert_element_type_1138 = None + sub_4 = torch.ops.aten.sub.Tensor(1, mul_286); mul_286 = None + mul_288 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sub_4); convert_element_type_1016 = sub_4 = None + add_137 = torch.ops.aten.add.Tensor(mul_288, 1); mul_288 = None + mul_289 = torch.ops.aten.mul.Tensor(mul_287, add_137); mul_287 = add_137 = None + convert_element_type_1140 = torch.ops.prims.convert_element_type.default(mul_289, torch.bfloat16); mul_289 = None + view_1123 = torch.ops.aten.view.default(convert_element_type_1140, [16384, 14336]); convert_element_type_1140 = None + permute_397 = torch.ops.aten.permute.default(view_1123, [1, 0]) + mm_245 = torch.ops.aten.mm.default(permute_397, view_1047); permute_397 = view_1047 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16); primals_280 = None + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 64, '0'); convert_element_type_1013 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + permute_399 = torch.ops.aten.permute.default(permute_338, [1, 0]); permute_338 = None + mm_246 = torch.ops.aten.mm.default(view_1123, permute_399); view_1123 = permute_399 = None + view_1124 = torch.ops.aten.view.default(mm_246, [2, 8192, 4096]); mm_246 = None + add_138 = torch.ops.aten.add.Tensor(view_1122, view_1124); view_1122 = view_1124 = None + convert_element_type_1145 = torch.ops.prims.convert_element_type.default(mm_245, torch.float32); mm_245 = None + reduce_scatter_tensor_13 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1145, 'avg', 64, '0'); convert_element_type_1145 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13); reduce_scatter_tensor_13 = None + convert_element_type_1146 = torch.ops.prims.convert_element_type.default(add_138, torch.float32); add_138 = None + convert_element_type_1148 = torch.ops.prims.convert_element_type.default(wait_tensor_276, torch.float32); wait_tensor_276 = None + mul_290 = torch.ops.aten.mul.Tensor(convert_element_type_1146, convert_element_type_1148); convert_element_type_1148 = None + mul_292 = torch.ops.aten.mul.Tensor(mul_244, mul_290) + sum_9 = torch.ops.aten.sum.dim_IntList(mul_292, [2], True); mul_292 = None + div_3 = torch.ops.aten.div.Tensor(mul_244, 4096) + mul_293 = torch.ops.aten.mul.Tensor(div_3, sum_9); div_3 = sum_9 = None + sub_5 = torch.ops.aten.sub.Tensor(mul_290, mul_293); mul_290 = mul_293 = None + mul_294 = torch.ops.aten.mul.Tensor(sub_5, rsqrt_61); sub_5 = rsqrt_61 = None + mul_295 = torch.ops.aten.mul.Tensor(convert_element_type_1146, mul_244); convert_element_type_1146 = mul_244 = None + sum_10 = torch.ops.aten.sum.dim_IntList(mul_295, [0, 1]); mul_295 = None + convert_element_type_1149 = torch.ops.prims.convert_element_type.default(mul_294, torch.bfloat16); mul_294 = None + add_139 = torch.ops.aten.add.Tensor(add_135, convert_element_type_1149); add_135 = convert_element_type_1149 = None + convert_element_type_default_62 = torch.ops.prims.convert_element_type.default(sum_10, torch.float32); sum_10 = None + reduce_scatter_tensor_14 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_62, 'avg', 64, '0'); convert_element_type_default_62 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_14); reduce_scatter_tensor_14 = None + view_1125 = torch.ops.aten.view.default(add_139, [16384, 4096]) + permute_401 = torch.ops.aten.permute.default(view_1125, [1, 0]) + mm_247 = torch.ops.aten.mm.default(permute_401, view_1043); permute_401 = view_1043 = None + permute_403 = torch.ops.aten.permute.default(permute_337, [1, 0]); permute_337 = None + mm_248 = torch.ops.aten.mm.default(view_1125, permute_403); view_1125 = permute_403 = None + view_1126 = torch.ops.aten.view.default(mm_248, [2, 8192, 4096]); mm_248 = None + convert_element_type_1156 = torch.ops.prims.convert_element_type.default(mm_247, torch.float32); mm_247 = None + reduce_scatter_tensor_15 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1156, 'avg', 64, '0'); convert_element_type_1156 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15); reduce_scatter_tensor_15 = None + view_1127 = torch.ops.aten.view.default(view_1126, [2, 8192, 32, 128]); view_1126 = None + permute_405 = torch.ops.aten.permute.default(view_1127, [0, 2, 1, 3]); view_1127 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16); primals_274 = None + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 64, '0'); convert_element_type_991 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32); add_119 = None + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_271) + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + view_1023 = torch.ops.aten.view.default(convert_element_type_993, [16384, 4096]); convert_element_type_993 = None + view_1024 = torch.ops.aten.view.default(mm_210, [2, 8192, 4096]); mm_210 = None + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16); primals_276 = None + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 64, '0'); convert_element_type_997 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + mm_211 = torch.ops.aten.mm.default(view_1023, permute_331) + view_1027 = torch.ops.aten.view.default(mm_211, [2, 8192, 1024]); mm_211 = None + view_1030 = torch.ops.aten.view.default(mm_212, [2, 8192, 1024]); mm_212 = None + view_1031 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1032 = torch.ops.aten.view.default(view_1027, [2, 8192, -1, 128]); view_1027 = None + view_1033 = torch.ops.aten.view.default(view_1030, [2, 8192, -1, 128]); view_1030 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_1031, torch.float32); view_1031 = None + view_1034 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 32, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_1034); view_1034 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_1032, torch.float32); view_1032 = None + view_1035 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 8, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_1035); view_1035 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_16); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_1037 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 32, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_16); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_1038 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 8, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_1037, torch.bfloat16); view_1037 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_1038, torch.bfloat16); view_1038 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 8, 4, 128]); unsqueeze_60 = None + clone_60 = torch.ops.aten.clone.default(expand_60, memory_format = torch.contiguous_format); expand_60 = None + view_1039 = torch.ops.aten.view.default(clone_60, [2, 8192, 32, 128]); clone_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_1033, 3); view_1033 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 8, 4, 128]); unsqueeze_61 = None + clone_61 = torch.ops.aten.clone.default(expand_61, memory_format = torch.contiguous_format); expand_61 = None + view_1040 = torch.ops.aten.view.default(clone_61, [2, 8192, 32, 128]); clone_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_1039, [0, 2, 1, 3]); view_1039 = None + permute_335 = torch.ops.aten.permute.default(view_1040, [0, 2, 1, 3]); view_1040 = None + _scaled_dot_product_cudnn_attention_backward_1 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_405, permute_333, permute_334, permute_335, getitem_270, getitem_271, getitem_276, getitem_277, None, None, None, 8192, 8192, 0.0, True); permute_405 = permute_333 = permute_334 = permute_335 = getitem_270 = getitem_271 = getitem_276 = getitem_277 = None + getitem_291 = _scaled_dot_product_cudnn_attention_backward_1[0] + getitem_292 = _scaled_dot_product_cudnn_attention_backward_1[1] + getitem_293 = _scaled_dot_product_cudnn_attention_backward_1[2]; _scaled_dot_product_cudnn_attention_backward_1 = None + permute_406 = torch.ops.aten.permute.default(getitem_293, [0, 2, 1, 3]); getitem_293 = None + permute_407 = torch.ops.aten.permute.default(getitem_292, [0, 2, 1, 3]); getitem_292 = None + permute_408 = torch.ops.aten.permute.default(getitem_291, [0, 2, 1, 3]); getitem_291 = None + view_1128 = torch.ops.aten.view.default(permute_406, [2, 8192, 8, 4, 128]); permute_406 = None + sum_11 = torch.ops.aten.sum.dim_IntList(view_1128, [3], True); view_1128 = None + squeeze_2 = torch.ops.aten.squeeze.dim(sum_11, 3); sum_11 = None + view_1129 = torch.ops.aten.view.default(permute_407, [2, 8192, 8, 4, 128]); permute_407 = None + sum_12 = torch.ops.aten.sum.dim_IntList(view_1129, [3], True); view_1129 = None + squeeze_3 = torch.ops.aten.squeeze.dim(sum_12, 3); sum_12 = None + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(squeeze_3, torch.float32); squeeze_3 = None + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(permute_408, torch.float32); permute_408 = None + view_1130 = torch.ops.aten.view.default(convert_element_type_1157, [2, 8192, 8, 64, 2]); convert_element_type_1157 = None + view_as_complex_66 = torch.ops.aten.view_as_complex.default(view_1130); view_1130 = None + mul_296 = torch.ops.aten.mul.Tensor(view_as_complex_66, _conj); view_as_complex_66 = None + view_1131 = torch.ops.aten.view.default(convert_element_type_1158, [2, 8192, 32, 64, 2]); convert_element_type_1158 = None + view_as_complex_67 = torch.ops.aten.view_as_complex.default(view_1131); view_1131 = None + mul_297 = torch.ops.aten.mul.Tensor(view_as_complex_67, _conj); view_as_complex_67 = None + view_as_real_66 = torch.ops.aten.view_as_real.default(mul_296); mul_296 = None + view_1132 = torch.ops.aten.view.default(view_as_real_66, [2, 8192, 8, 128]); view_as_real_66 = None + convert_element_type_1159 = torch.ops.prims.convert_element_type.default(view_1132, torch.bfloat16); view_1132 = None + view_as_real_67 = torch.ops.aten.view_as_real.default(mul_297); mul_297 = None + view_1133 = torch.ops.aten.view.default(view_as_real_67, [2, 8192, 32, 128]); view_as_real_67 = None + convert_element_type_1160 = torch.ops.prims.convert_element_type.default(view_1133, torch.bfloat16); view_1133 = None + view_1134 = torch.ops.aten.view.default(squeeze_2, [2, 8192, 1024]); squeeze_2 = None + view_1135 = torch.ops.aten.view.default(convert_element_type_1159, [2, 8192, 1024]); convert_element_type_1159 = None + view_1136 = torch.ops.aten.view.default(convert_element_type_1160, [2, 8192, 4096]); convert_element_type_1160 = None + view_1137 = torch.ops.aten.view.default(view_1134, [16384, 1024]); view_1134 = None + permute_409 = torch.ops.aten.permute.default(view_1137, [1, 0]) + mm_249 = torch.ops.aten.mm.default(permute_409, view_1023); permute_409 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16); primals_277 = None + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 64, '0'); convert_element_type_1000 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_274, [1, 0]); wait_tensor_274 = None + permute_411 = torch.ops.aten.permute.default(permute_332, [1, 0]); permute_332 = None + mm_250 = torch.ops.aten.mm.default(view_1137, permute_411); view_1137 = permute_411 = None + view_1138 = torch.ops.aten.view.default(mm_250, [2, 8192, 4096]); mm_250 = None + convert_element_type_1165 = torch.ops.prims.convert_element_type.default(mm_249, torch.float32); mm_249 = None + reduce_scatter_tensor_16 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1165, 'avg', 64, '0'); convert_element_type_1165 = None + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_16); reduce_scatter_tensor_16 = None + view_1139 = torch.ops.aten.view.default(view_1135, [16384, 1024]); view_1135 = None + permute_413 = torch.ops.aten.permute.default(view_1139, [1, 0]) + mm_251 = torch.ops.aten.mm.default(permute_413, view_1023); permute_413 = None + permute_415 = torch.ops.aten.permute.default(permute_331, [1, 0]); permute_331 = None + mm_252 = torch.ops.aten.mm.default(view_1139, permute_415); view_1139 = permute_415 = None + view_1140 = torch.ops.aten.view.default(mm_252, [2, 8192, 4096]); mm_252 = None + add_140 = torch.ops.aten.add.Tensor(view_1138, view_1140); view_1138 = view_1140 = None + convert_element_type_1170 = torch.ops.prims.convert_element_type.default(mm_251, torch.float32); mm_251 = None + reduce_scatter_tensor_17 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1170, 'avg', 64, '0'); convert_element_type_1170 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17); reduce_scatter_tensor_17 = None + view_1141 = torch.ops.aten.view.default(view_1136, [16384, 4096]); view_1136 = None + permute_417 = torch.ops.aten.permute.default(view_1141, [1, 0]) + mm_253 = torch.ops.aten.mm.default(permute_417, view_1023); permute_417 = view_1023 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16); primals_275 = None + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 64, '0'); convert_element_type_994 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + permute_419 = torch.ops.aten.permute.default(permute_330, [1, 0]); permute_330 = None + mm_254 = torch.ops.aten.mm.default(view_1141, permute_419); view_1141 = permute_419 = None + view_1142 = torch.ops.aten.view.default(mm_254, [2, 8192, 4096]); mm_254 = None + add_141 = torch.ops.aten.add.Tensor(add_140, view_1142); add_140 = view_1142 = None + convert_element_type_1175 = torch.ops.prims.convert_element_type.default(mm_253, torch.float32); mm_253 = None + reduce_scatter_tensor_18 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1175, 'avg', 64, '0'); convert_element_type_1175 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_18); reduce_scatter_tensor_18 = None + convert_element_type_1176 = torch.ops.prims.convert_element_type.default(add_141, torch.float32); add_141 = None + convert_element_type_1178 = torch.ops.prims.convert_element_type.default(wait_tensor_271, torch.float32); wait_tensor_271 = None + mul_298 = torch.ops.aten.mul.Tensor(convert_element_type_1176, convert_element_type_1178); convert_element_type_1178 = None + mul_300 = torch.ops.aten.mul.Tensor(mul_240, mul_298) + sum_13 = torch.ops.aten.sum.dim_IntList(mul_300, [2], True); mul_300 = None + div_4 = torch.ops.aten.div.Tensor(mul_240, 4096) + mul_301 = torch.ops.aten.mul.Tensor(div_4, sum_13); div_4 = sum_13 = None + sub_6 = torch.ops.aten.sub.Tensor(mul_298, mul_301); mul_298 = mul_301 = None + mul_302 = torch.ops.aten.mul.Tensor(sub_6, rsqrt_60); sub_6 = rsqrt_60 = None + mul_303 = torch.ops.aten.mul.Tensor(convert_element_type_1176, mul_240); convert_element_type_1176 = mul_240 = None + sum_14 = torch.ops.aten.sum.dim_IntList(mul_303, [0, 1]); mul_303 = None + convert_element_type_1179 = torch.ops.prims.convert_element_type.default(mul_302, torch.bfloat16); mul_302 = None + add_142 = torch.ops.aten.add.Tensor(add_139, convert_element_type_1179); add_139 = convert_element_type_1179 = None + convert_element_type_default_61 = torch.ops.prims.convert_element_type.default(sum_14, torch.float32); sum_14 = None + reduce_scatter_tensor_19 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_61, 'avg', 64, '0'); convert_element_type_default_61 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19); reduce_scatter_tensor_19 = None + view_1143 = torch.ops.aten.view.default(add_142, [16384, 4096]) + permute_421 = torch.ops.aten.permute.default(view_1143, [1, 0]) + permute_325 = torch.ops.aten.permute.default(getitem_261, [0, 2, 1, 3]) + view_1007 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16); primals_269 = None + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 64, '0'); convert_element_type_974 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + view_1009 = torch.ops.aten.view.default(view_1007, [16384, 4096]); view_1007 = None + mm_206 = torch.ops.aten.mm.default(view_1009, permute_326) + view_1010 = torch.ops.aten.view.default(mm_206, [2, 8192, 4096]); mm_206 = None + add_117 = torch.ops.aten.add.Tensor(add_115, view_1010); view_1010 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16); primals_270 = None + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 64, '0'); convert_element_type_977 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32); add_117 = None + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_267) + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + view_1013 = torch.ops.aten.view.default(convert_element_type_979, [16384, 4096]); convert_element_type_979 = None + view_1014 = torch.ops.aten.view.default(mm_207, [2, 8192, 14336]); mm_207 = None + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_1014, torch.float32); view_1014 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16); primals_272 = None + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 64, '0'); convert_element_type_985 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_269, [1, 0]); wait_tensor_269 = None + mm_208 = torch.ops.aten.mm.default(view_1013, permute_328) + view_1017 = torch.ops.aten.view.default(mm_208, [2, 8192, 14336]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_1017) + view_1019 = torch.ops.aten.view.default(mul_239, [16384, 14336]); mul_239 = None + mm_255 = torch.ops.aten.mm.default(permute_421, view_1019); permute_421 = view_1019 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16); primals_273 = None + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 64, '0'); convert_element_type_988 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_270, [1, 0]); wait_tensor_270 = None + permute_423 = torch.ops.aten.permute.default(permute_329, [1, 0]); permute_329 = None + mm_256 = torch.ops.aten.mm.default(view_1143, permute_423); view_1143 = permute_423 = None + view_1144 = torch.ops.aten.view.default(mm_256, [2, 8192, 14336]); mm_256 = None + convert_element_type_1186 = torch.ops.prims.convert_element_type.default(mm_255, torch.float32); mm_255 = None + reduce_scatter_tensor_20 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1186, 'avg', 64, '0'); convert_element_type_1186 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_20); reduce_scatter_tensor_20 = None + mul_304 = torch.ops.aten.mul.Tensor(view_1144, convert_element_type_984); convert_element_type_984 = None + mul_305 = torch.ops.aten.mul.Tensor(view_1144, view_1017); view_1144 = view_1017 = None + view_1145 = torch.ops.aten.view.default(mul_304, [16384, 14336]); mul_304 = None + permute_425 = torch.ops.aten.permute.default(view_1145, [1, 0]) + mm_257 = torch.ops.aten.mm.default(permute_425, view_1013); permute_425 = None + permute_427 = torch.ops.aten.permute.default(permute_328, [1, 0]); permute_328 = None + mm_258 = torch.ops.aten.mm.default(view_1145, permute_427); view_1145 = permute_427 = None + view_1146 = torch.ops.aten.view.default(mm_258, [2, 8192, 4096]); mm_258 = None + convert_element_type_1191 = torch.ops.prims.convert_element_type.default(mm_257, torch.float32); mm_257 = None + reduce_scatter_tensor_21 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1191, 'avg', 64, '0'); convert_element_type_1191 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21); reduce_scatter_tensor_21 = None + convert_element_type_1192 = torch.ops.prims.convert_element_type.default(mul_305, torch.float32); mul_305 = None + neg_2 = torch.ops.aten.neg.default(convert_element_type_983) + exp_2 = torch.ops.aten.exp.default(neg_2); neg_2 = None + add_143 = torch.ops.aten.add.Tensor(exp_2, 1); exp_2 = None + reciprocal_2 = torch.ops.aten.reciprocal.default(add_143); add_143 = None + mul_306 = torch.ops.aten.mul.Tensor(reciprocal_2, 1); reciprocal_2 = None + mul_307 = torch.ops.aten.mul.Tensor(convert_element_type_1192, mul_306); convert_element_type_1192 = None + sub_7 = torch.ops.aten.sub.Tensor(1, mul_306); mul_306 = None + mul_308 = torch.ops.aten.mul.Tensor(convert_element_type_983, sub_7); convert_element_type_983 = sub_7 = None + add_144 = torch.ops.aten.add.Tensor(mul_308, 1); mul_308 = None + mul_309 = torch.ops.aten.mul.Tensor(mul_307, add_144); mul_307 = add_144 = None + convert_element_type_1194 = torch.ops.prims.convert_element_type.default(mul_309, torch.bfloat16); mul_309 = None + view_1147 = torch.ops.aten.view.default(convert_element_type_1194, [16384, 14336]); convert_element_type_1194 = None + permute_429 = torch.ops.aten.permute.default(view_1147, [1, 0]) + mm_259 = torch.ops.aten.mm.default(permute_429, view_1013); permute_429 = view_1013 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16); primals_271 = None + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 64, '0'); convert_element_type_980 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_268, [1, 0]); wait_tensor_268 = None + permute_431 = torch.ops.aten.permute.default(permute_327, [1, 0]); permute_327 = None + mm_260 = torch.ops.aten.mm.default(view_1147, permute_431); view_1147 = permute_431 = None + view_1148 = torch.ops.aten.view.default(mm_260, [2, 8192, 4096]); mm_260 = None + add_145 = torch.ops.aten.add.Tensor(view_1146, view_1148); view_1146 = view_1148 = None + convert_element_type_1199 = torch.ops.prims.convert_element_type.default(mm_259, torch.float32); mm_259 = None + reduce_scatter_tensor_22 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1199, 'avg', 64, '0'); convert_element_type_1199 = None + wait_tensor_313 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_22); reduce_scatter_tensor_22 = None + convert_element_type_1200 = torch.ops.prims.convert_element_type.default(add_145, torch.float32); add_145 = None + convert_element_type_1202 = torch.ops.prims.convert_element_type.default(wait_tensor_267, torch.float32); wait_tensor_267 = None + mul_310 = torch.ops.aten.mul.Tensor(convert_element_type_1200, convert_element_type_1202); convert_element_type_1202 = None + mul_312 = torch.ops.aten.mul.Tensor(mul_236, mul_310) + sum_15 = torch.ops.aten.sum.dim_IntList(mul_312, [2], True); mul_312 = None + div_5 = torch.ops.aten.div.Tensor(mul_236, 4096) + mul_313 = torch.ops.aten.mul.Tensor(div_5, sum_15); div_5 = sum_15 = None + sub_8 = torch.ops.aten.sub.Tensor(mul_310, mul_313); mul_310 = mul_313 = None + mul_314 = torch.ops.aten.mul.Tensor(sub_8, rsqrt_59); sub_8 = rsqrt_59 = None + mul_315 = torch.ops.aten.mul.Tensor(convert_element_type_1200, mul_236); convert_element_type_1200 = mul_236 = None + sum_16 = torch.ops.aten.sum.dim_IntList(mul_315, [0, 1]); mul_315 = None + convert_element_type_1203 = torch.ops.prims.convert_element_type.default(mul_314, torch.bfloat16); mul_314 = None + add_146 = torch.ops.aten.add.Tensor(add_142, convert_element_type_1203); add_142 = convert_element_type_1203 = None + convert_element_type_default_60 = torch.ops.prims.convert_element_type.default(sum_16, torch.float32); sum_16 = None + reduce_scatter_tensor_23 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_60, 'avg', 64, '0'); convert_element_type_default_60 = None + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23); reduce_scatter_tensor_23 = None + view_1149 = torch.ops.aten.view.default(add_146, [16384, 4096]) + permute_433 = torch.ops.aten.permute.default(view_1149, [1, 0]) + mm_261 = torch.ops.aten.mm.default(permute_433, view_1009); permute_433 = view_1009 = None + permute_435 = torch.ops.aten.permute.default(permute_326, [1, 0]); permute_326 = None + mm_262 = torch.ops.aten.mm.default(view_1149, permute_435); view_1149 = permute_435 = None + view_1150 = torch.ops.aten.view.default(mm_262, [2, 8192, 4096]); mm_262 = None + convert_element_type_1210 = torch.ops.prims.convert_element_type.default(mm_261, torch.float32); mm_261 = None + reduce_scatter_tensor_24 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1210, 'avg', 64, '0'); convert_element_type_1210 = None + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_24); reduce_scatter_tensor_24 = None + view_1151 = torch.ops.aten.view.default(view_1150, [2, 8192, 32, 128]); view_1150 = None + permute_437 = torch.ops.aten.permute.default(view_1151, [0, 2, 1, 3]); view_1151 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16); primals_265 = None + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 64, '0'); convert_element_type_958 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32); add_115 = None + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_262) + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + view_989 = torch.ops.aten.view.default(convert_element_type_960, [16384, 4096]); convert_element_type_960 = None + view_990 = torch.ops.aten.view.default(mm_203, [2, 8192, 4096]); mm_203 = None + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16); primals_267 = None + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 64, '0'); convert_element_type_964 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + mm_204 = torch.ops.aten.mm.default(view_989, permute_320) + view_993 = torch.ops.aten.view.default(mm_204, [2, 8192, 1024]); mm_204 = None + view_996 = torch.ops.aten.view.default(mm_205, [2, 8192, 1024]); mm_205 = None + view_997 = torch.ops.aten.view.default(view_990, [2, 8192, -1, 128]); view_990 = None + view_998 = torch.ops.aten.view.default(view_993, [2, 8192, -1, 128]); view_993 = None + view_999 = torch.ops.aten.view.default(view_996, [2, 8192, -1, 128]); view_996 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + view_1000 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 32, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_1000); view_1000 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_998, torch.float32); view_998 = None + view_1001 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 8, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_1001); view_1001 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_16); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_1003 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 32, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_16); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_1004 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 8, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_1003, torch.bfloat16); view_1003 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_1004, torch.bfloat16); view_1004 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 8, 4, 128]); unsqueeze_58 = None + clone_58 = torch.ops.aten.clone.default(expand_58, memory_format = torch.contiguous_format); expand_58 = None + view_1005 = torch.ops.aten.view.default(clone_58, [2, 8192, 32, 128]); clone_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_999, 3); view_999 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 8, 4, 128]); unsqueeze_59 = None + clone_59 = torch.ops.aten.clone.default(expand_59, memory_format = torch.contiguous_format); expand_59 = None + view_1006 = torch.ops.aten.view.default(clone_59, [2, 8192, 32, 128]); clone_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_1005, [0, 2, 1, 3]); view_1005 = None + permute_324 = torch.ops.aten.permute.default(view_1006, [0, 2, 1, 3]); view_1006 = None + _scaled_dot_product_cudnn_attention_backward_2 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_437, permute_322, permute_323, permute_324, getitem_261, getitem_262, getitem_267, getitem_268, None, None, None, 8192, 8192, 0.0, True); permute_437 = permute_322 = permute_323 = permute_324 = getitem_261 = getitem_262 = getitem_267 = getitem_268 = None + getitem_294 = _scaled_dot_product_cudnn_attention_backward_2[0] + getitem_295 = _scaled_dot_product_cudnn_attention_backward_2[1] + getitem_296 = _scaled_dot_product_cudnn_attention_backward_2[2]; _scaled_dot_product_cudnn_attention_backward_2 = None + permute_438 = torch.ops.aten.permute.default(getitem_296, [0, 2, 1, 3]); getitem_296 = None + permute_439 = torch.ops.aten.permute.default(getitem_295, [0, 2, 1, 3]); getitem_295 = None + permute_440 = torch.ops.aten.permute.default(getitem_294, [0, 2, 1, 3]); getitem_294 = None + view_1152 = torch.ops.aten.view.default(permute_438, [2, 8192, 8, 4, 128]); permute_438 = None + sum_17 = torch.ops.aten.sum.dim_IntList(view_1152, [3], True); view_1152 = None + squeeze_4 = torch.ops.aten.squeeze.dim(sum_17, 3); sum_17 = None + view_1153 = torch.ops.aten.view.default(permute_439, [2, 8192, 8, 4, 128]); permute_439 = None + sum_18 = torch.ops.aten.sum.dim_IntList(view_1153, [3], True); view_1153 = None + squeeze_5 = torch.ops.aten.squeeze.dim(sum_18, 3); sum_18 = None + convert_element_type_1211 = torch.ops.prims.convert_element_type.default(squeeze_5, torch.float32); squeeze_5 = None + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(permute_440, torch.float32); permute_440 = None + view_1154 = torch.ops.aten.view.default(convert_element_type_1211, [2, 8192, 8, 64, 2]); convert_element_type_1211 = None + view_as_complex_68 = torch.ops.aten.view_as_complex.default(view_1154); view_1154 = None + mul_316 = torch.ops.aten.mul.Tensor(view_as_complex_68, _conj); view_as_complex_68 = None + view_1155 = torch.ops.aten.view.default(convert_element_type_1212, [2, 8192, 32, 64, 2]); convert_element_type_1212 = None + view_as_complex_69 = torch.ops.aten.view_as_complex.default(view_1155); view_1155 = None + mul_317 = torch.ops.aten.mul.Tensor(view_as_complex_69, _conj); view_as_complex_69 = None + view_as_real_68 = torch.ops.aten.view_as_real.default(mul_316); mul_316 = None + view_1156 = torch.ops.aten.view.default(view_as_real_68, [2, 8192, 8, 128]); view_as_real_68 = None + convert_element_type_1213 = torch.ops.prims.convert_element_type.default(view_1156, torch.bfloat16); view_1156 = None + view_as_real_69 = torch.ops.aten.view_as_real.default(mul_317); mul_317 = None + view_1157 = torch.ops.aten.view.default(view_as_real_69, [2, 8192, 32, 128]); view_as_real_69 = None + convert_element_type_1214 = torch.ops.prims.convert_element_type.default(view_1157, torch.bfloat16); view_1157 = None + view_1158 = torch.ops.aten.view.default(squeeze_4, [2, 8192, 1024]); squeeze_4 = None + view_1159 = torch.ops.aten.view.default(convert_element_type_1213, [2, 8192, 1024]); convert_element_type_1213 = None + view_1160 = torch.ops.aten.view.default(convert_element_type_1214, [2, 8192, 4096]); convert_element_type_1214 = None + view_1161 = torch.ops.aten.view.default(view_1158, [16384, 1024]); view_1158 = None + permute_441 = torch.ops.aten.permute.default(view_1161, [1, 0]) + mm_263 = torch.ops.aten.mm.default(permute_441, view_989); permute_441 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16); primals_268 = None + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 64, '0'); convert_element_type_967 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + permute_443 = torch.ops.aten.permute.default(permute_321, [1, 0]); permute_321 = None + mm_264 = torch.ops.aten.mm.default(view_1161, permute_443); view_1161 = permute_443 = None + view_1162 = torch.ops.aten.view.default(mm_264, [2, 8192, 4096]); mm_264 = None + convert_element_type_1219 = torch.ops.prims.convert_element_type.default(mm_263, torch.float32); mm_263 = None + reduce_scatter_tensor_25 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1219, 'avg', 64, '0'); convert_element_type_1219 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25); reduce_scatter_tensor_25 = None + view_1163 = torch.ops.aten.view.default(view_1159, [16384, 1024]); view_1159 = None + permute_445 = torch.ops.aten.permute.default(view_1163, [1, 0]) + mm_265 = torch.ops.aten.mm.default(permute_445, view_989); permute_445 = None + permute_447 = torch.ops.aten.permute.default(permute_320, [1, 0]); permute_320 = None + mm_266 = torch.ops.aten.mm.default(view_1163, permute_447); view_1163 = permute_447 = None + view_1164 = torch.ops.aten.view.default(mm_266, [2, 8192, 4096]); mm_266 = None + add_147 = torch.ops.aten.add.Tensor(view_1162, view_1164); view_1162 = view_1164 = None + convert_element_type_1224 = torch.ops.prims.convert_element_type.default(mm_265, torch.float32); mm_265 = None + reduce_scatter_tensor_26 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1224, 'avg', 64, '0'); convert_element_type_1224 = None + wait_tensor_317 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_26); reduce_scatter_tensor_26 = None + view_1165 = torch.ops.aten.view.default(view_1160, [16384, 4096]); view_1160 = None + permute_449 = torch.ops.aten.permute.default(view_1165, [1, 0]) + mm_267 = torch.ops.aten.mm.default(permute_449, view_989); permute_449 = view_989 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16); primals_266 = None + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 64, '0'); convert_element_type_961 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_263, [1, 0]); wait_tensor_263 = None + permute_451 = torch.ops.aten.permute.default(permute_319, [1, 0]); permute_319 = None + mm_268 = torch.ops.aten.mm.default(view_1165, permute_451); view_1165 = permute_451 = None + view_1166 = torch.ops.aten.view.default(mm_268, [2, 8192, 4096]); mm_268 = None + add_148 = torch.ops.aten.add.Tensor(add_147, view_1166); add_147 = view_1166 = None + convert_element_type_1229 = torch.ops.prims.convert_element_type.default(mm_267, torch.float32); mm_267 = None + reduce_scatter_tensor_27 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1229, 'avg', 64, '0'); convert_element_type_1229 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27); reduce_scatter_tensor_27 = None + convert_element_type_1230 = torch.ops.prims.convert_element_type.default(add_148, torch.float32); add_148 = None + convert_element_type_1232 = torch.ops.prims.convert_element_type.default(wait_tensor_262, torch.float32); wait_tensor_262 = None + mul_318 = torch.ops.aten.mul.Tensor(convert_element_type_1230, convert_element_type_1232); convert_element_type_1232 = None + mul_320 = torch.ops.aten.mul.Tensor(mul_232, mul_318) + sum_19 = torch.ops.aten.sum.dim_IntList(mul_320, [2], True); mul_320 = None + div_6 = torch.ops.aten.div.Tensor(mul_232, 4096) + mul_321 = torch.ops.aten.mul.Tensor(div_6, sum_19); div_6 = sum_19 = None + sub_9 = torch.ops.aten.sub.Tensor(mul_318, mul_321); mul_318 = mul_321 = None + mul_322 = torch.ops.aten.mul.Tensor(sub_9, rsqrt_58); sub_9 = rsqrt_58 = None + mul_323 = torch.ops.aten.mul.Tensor(convert_element_type_1230, mul_232); convert_element_type_1230 = mul_232 = None + sum_20 = torch.ops.aten.sum.dim_IntList(mul_323, [0, 1]); mul_323 = None + convert_element_type_1233 = torch.ops.prims.convert_element_type.default(mul_322, torch.bfloat16); mul_322 = None + add_149 = torch.ops.aten.add.Tensor(add_146, convert_element_type_1233); add_146 = convert_element_type_1233 = None + convert_element_type_default_59 = torch.ops.prims.convert_element_type.default(sum_20, torch.float32); sum_20 = None + reduce_scatter_tensor_28 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_59, 'avg', 64, '0'); convert_element_type_default_59 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_28); reduce_scatter_tensor_28 = None + view_1167 = torch.ops.aten.view.default(add_149, [16384, 4096]) + permute_453 = torch.ops.aten.permute.default(view_1167, [1, 0]) + permute_314 = torch.ops.aten.permute.default(getitem_252, [0, 2, 1, 3]) + view_973 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16); primals_260 = None + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 64, '0'); convert_element_type_941 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_257, [1, 0]); wait_tensor_257 = None + view_975 = torch.ops.aten.view.default(view_973, [16384, 4096]); view_973 = None + mm_199 = torch.ops.aten.mm.default(view_975, permute_315) + view_976 = torch.ops.aten.view.default(mm_199, [2, 8192, 4096]); mm_199 = None + add_113 = torch.ops.aten.add.Tensor(add_111, view_976); view_976 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16); primals_261 = None + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 64, '0'); convert_element_type_944 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32); add_113 = None + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_258) + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + view_979 = torch.ops.aten.view.default(convert_element_type_946, [16384, 4096]); convert_element_type_946 = None + view_980 = torch.ops.aten.view.default(mm_200, [2, 8192, 14336]); mm_200 = None + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_980, torch.float32); view_980 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16); primals_263 = None + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 64, '0'); convert_element_type_952 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + mm_201 = torch.ops.aten.mm.default(view_979, permute_317) + view_983 = torch.ops.aten.view.default(mm_201, [2, 8192, 14336]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_983) + view_985 = torch.ops.aten.view.default(mul_231, [16384, 14336]); mul_231 = None + mm_269 = torch.ops.aten.mm.default(permute_453, view_985); permute_453 = view_985 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16); primals_264 = None + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 64, '0'); convert_element_type_955 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_261, [1, 0]); wait_tensor_261 = None + permute_455 = torch.ops.aten.permute.default(permute_318, [1, 0]); permute_318 = None + mm_270 = torch.ops.aten.mm.default(view_1167, permute_455); view_1167 = permute_455 = None + view_1168 = torch.ops.aten.view.default(mm_270, [2, 8192, 14336]); mm_270 = None + convert_element_type_1240 = torch.ops.prims.convert_element_type.default(mm_269, torch.float32); mm_269 = None + reduce_scatter_tensor_29 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1240, 'avg', 64, '0'); convert_element_type_1240 = None + wait_tensor_320 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29); reduce_scatter_tensor_29 = None + mul_324 = torch.ops.aten.mul.Tensor(view_1168, convert_element_type_951); convert_element_type_951 = None + mul_325 = torch.ops.aten.mul.Tensor(view_1168, view_983); view_1168 = view_983 = None + view_1169 = torch.ops.aten.view.default(mul_324, [16384, 14336]); mul_324 = None + permute_457 = torch.ops.aten.permute.default(view_1169, [1, 0]) + mm_271 = torch.ops.aten.mm.default(permute_457, view_979); permute_457 = None + permute_459 = torch.ops.aten.permute.default(permute_317, [1, 0]); permute_317 = None + mm_272 = torch.ops.aten.mm.default(view_1169, permute_459); view_1169 = permute_459 = None + view_1170 = torch.ops.aten.view.default(mm_272, [2, 8192, 4096]); mm_272 = None + convert_element_type_1245 = torch.ops.prims.convert_element_type.default(mm_271, torch.float32); mm_271 = None + reduce_scatter_tensor_30 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1245, 'avg', 64, '0'); convert_element_type_1245 = None + wait_tensor_321 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_30); reduce_scatter_tensor_30 = None + convert_element_type_1246 = torch.ops.prims.convert_element_type.default(mul_325, torch.float32); mul_325 = None + neg_3 = torch.ops.aten.neg.default(convert_element_type_950) + exp_3 = torch.ops.aten.exp.default(neg_3); neg_3 = None + add_150 = torch.ops.aten.add.Tensor(exp_3, 1); exp_3 = None + reciprocal_3 = torch.ops.aten.reciprocal.default(add_150); add_150 = None + mul_326 = torch.ops.aten.mul.Tensor(reciprocal_3, 1); reciprocal_3 = None + mul_327 = torch.ops.aten.mul.Tensor(convert_element_type_1246, mul_326); convert_element_type_1246 = None + sub_10 = torch.ops.aten.sub.Tensor(1, mul_326); mul_326 = None + mul_328 = torch.ops.aten.mul.Tensor(convert_element_type_950, sub_10); convert_element_type_950 = sub_10 = None + add_151 = torch.ops.aten.add.Tensor(mul_328, 1); mul_328 = None + mul_329 = torch.ops.aten.mul.Tensor(mul_327, add_151); mul_327 = add_151 = None + convert_element_type_1248 = torch.ops.prims.convert_element_type.default(mul_329, torch.bfloat16); mul_329 = None + view_1171 = torch.ops.aten.view.default(convert_element_type_1248, [16384, 14336]); convert_element_type_1248 = None + permute_461 = torch.ops.aten.permute.default(view_1171, [1, 0]) + mm_273 = torch.ops.aten.mm.default(permute_461, view_979); permute_461 = view_979 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16); primals_262 = None + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 64, '0'); convert_element_type_947 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + permute_463 = torch.ops.aten.permute.default(permute_316, [1, 0]); permute_316 = None + mm_274 = torch.ops.aten.mm.default(view_1171, permute_463); view_1171 = permute_463 = None + view_1172 = torch.ops.aten.view.default(mm_274, [2, 8192, 4096]); mm_274 = None + add_152 = torch.ops.aten.add.Tensor(view_1170, view_1172); view_1170 = view_1172 = None + convert_element_type_1253 = torch.ops.prims.convert_element_type.default(mm_273, torch.float32); mm_273 = None + reduce_scatter_tensor_31 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1253, 'avg', 64, '0'); convert_element_type_1253 = None + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31); reduce_scatter_tensor_31 = None + convert_element_type_1254 = torch.ops.prims.convert_element_type.default(add_152, torch.float32); add_152 = None + convert_element_type_1256 = torch.ops.prims.convert_element_type.default(wait_tensor_258, torch.float32); wait_tensor_258 = None + mul_330 = torch.ops.aten.mul.Tensor(convert_element_type_1254, convert_element_type_1256); convert_element_type_1256 = None + mul_332 = torch.ops.aten.mul.Tensor(mul_228, mul_330) + sum_21 = torch.ops.aten.sum.dim_IntList(mul_332, [2], True); mul_332 = None + div_7 = torch.ops.aten.div.Tensor(mul_228, 4096) + mul_333 = torch.ops.aten.mul.Tensor(div_7, sum_21); div_7 = sum_21 = None + sub_11 = torch.ops.aten.sub.Tensor(mul_330, mul_333); mul_330 = mul_333 = None + mul_334 = torch.ops.aten.mul.Tensor(sub_11, rsqrt_57); sub_11 = rsqrt_57 = None + mul_335 = torch.ops.aten.mul.Tensor(convert_element_type_1254, mul_228); convert_element_type_1254 = mul_228 = None + sum_22 = torch.ops.aten.sum.dim_IntList(mul_335, [0, 1]); mul_335 = None + convert_element_type_1257 = torch.ops.prims.convert_element_type.default(mul_334, torch.bfloat16); mul_334 = None + add_153 = torch.ops.aten.add.Tensor(add_149, convert_element_type_1257); add_149 = convert_element_type_1257 = None + convert_element_type_default_58 = torch.ops.prims.convert_element_type.default(sum_22, torch.float32); sum_22 = None + reduce_scatter_tensor_32 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_58, 'avg', 64, '0'); convert_element_type_default_58 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_32); reduce_scatter_tensor_32 = None + view_1173 = torch.ops.aten.view.default(add_153, [16384, 4096]) + permute_465 = torch.ops.aten.permute.default(view_1173, [1, 0]) + mm_275 = torch.ops.aten.mm.default(permute_465, view_975); permute_465 = view_975 = None + permute_467 = torch.ops.aten.permute.default(permute_315, [1, 0]); permute_315 = None + mm_276 = torch.ops.aten.mm.default(view_1173, permute_467); view_1173 = permute_467 = None + view_1174 = torch.ops.aten.view.default(mm_276, [2, 8192, 4096]); mm_276 = None + convert_element_type_1264 = torch.ops.prims.convert_element_type.default(mm_275, torch.float32); mm_275 = None + reduce_scatter_tensor_33 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1264, 'avg', 64, '0'); convert_element_type_1264 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33); reduce_scatter_tensor_33 = None + view_1175 = torch.ops.aten.view.default(view_1174, [2, 8192, 32, 128]); view_1174 = None + permute_469 = torch.ops.aten.permute.default(view_1175, [0, 2, 1, 3]); view_1175 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16); primals_256 = None + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 64, '0'); convert_element_type_925 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32); add_111 = None + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_253) + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + view_955 = torch.ops.aten.view.default(convert_element_type_927, [16384, 4096]); convert_element_type_927 = None + view_956 = torch.ops.aten.view.default(mm_196, [2, 8192, 4096]); mm_196 = None + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16); primals_258 = None + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 64, '0'); convert_element_type_931 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_255, [1, 0]); wait_tensor_255 = None + mm_197 = torch.ops.aten.mm.default(view_955, permute_309) + view_959 = torch.ops.aten.view.default(mm_197, [2, 8192, 1024]); mm_197 = None + view_962 = torch.ops.aten.view.default(mm_198, [2, 8192, 1024]); mm_198 = None + view_963 = torch.ops.aten.view.default(view_956, [2, 8192, -1, 128]); view_956 = None + view_964 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_965 = torch.ops.aten.view.default(view_962, [2, 8192, -1, 128]); view_962 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_963, torch.float32); view_963 = None + view_966 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 32, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_966); view_966 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_964, torch.float32); view_964 = None + view_967 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 8, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_967); view_967 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_16); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_969 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 32, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_16); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_970 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 8, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_969, torch.bfloat16); view_969 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_970, torch.bfloat16); view_970 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 8, 4, 128]); unsqueeze_56 = None + clone_56 = torch.ops.aten.clone.default(expand_56, memory_format = torch.contiguous_format); expand_56 = None + view_971 = torch.ops.aten.view.default(clone_56, [2, 8192, 32, 128]); clone_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_965, 3); view_965 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 8, 4, 128]); unsqueeze_57 = None + clone_57 = torch.ops.aten.clone.default(expand_57, memory_format = torch.contiguous_format); expand_57 = None + view_972 = torch.ops.aten.view.default(clone_57, [2, 8192, 32, 128]); clone_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_971, [0, 2, 1, 3]); view_971 = None + permute_313 = torch.ops.aten.permute.default(view_972, [0, 2, 1, 3]); view_972 = None + _scaled_dot_product_cudnn_attention_backward_3 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_469, permute_311, permute_312, permute_313, getitem_252, getitem_253, getitem_258, getitem_259, None, None, None, 8192, 8192, 0.0, True); permute_469 = permute_311 = permute_312 = permute_313 = getitem_252 = getitem_253 = getitem_258 = getitem_259 = None + getitem_297 = _scaled_dot_product_cudnn_attention_backward_3[0] + getitem_298 = _scaled_dot_product_cudnn_attention_backward_3[1] + getitem_299 = _scaled_dot_product_cudnn_attention_backward_3[2]; _scaled_dot_product_cudnn_attention_backward_3 = None + permute_470 = torch.ops.aten.permute.default(getitem_299, [0, 2, 1, 3]); getitem_299 = None + permute_471 = torch.ops.aten.permute.default(getitem_298, [0, 2, 1, 3]); getitem_298 = None + permute_472 = torch.ops.aten.permute.default(getitem_297, [0, 2, 1, 3]); getitem_297 = None + view_1176 = torch.ops.aten.view.default(permute_470, [2, 8192, 8, 4, 128]); permute_470 = None + sum_23 = torch.ops.aten.sum.dim_IntList(view_1176, [3], True); view_1176 = None + squeeze_6 = torch.ops.aten.squeeze.dim(sum_23, 3); sum_23 = None + view_1177 = torch.ops.aten.view.default(permute_471, [2, 8192, 8, 4, 128]); permute_471 = None + sum_24 = torch.ops.aten.sum.dim_IntList(view_1177, [3], True); view_1177 = None + squeeze_7 = torch.ops.aten.squeeze.dim(sum_24, 3); sum_24 = None + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(squeeze_7, torch.float32); squeeze_7 = None + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(permute_472, torch.float32); permute_472 = None + view_1178 = torch.ops.aten.view.default(convert_element_type_1265, [2, 8192, 8, 64, 2]); convert_element_type_1265 = None + view_as_complex_70 = torch.ops.aten.view_as_complex.default(view_1178); view_1178 = None + mul_336 = torch.ops.aten.mul.Tensor(view_as_complex_70, _conj); view_as_complex_70 = None + view_1179 = torch.ops.aten.view.default(convert_element_type_1266, [2, 8192, 32, 64, 2]); convert_element_type_1266 = None + view_as_complex_71 = torch.ops.aten.view_as_complex.default(view_1179); view_1179 = None + mul_337 = torch.ops.aten.mul.Tensor(view_as_complex_71, _conj); view_as_complex_71 = None + view_as_real_70 = torch.ops.aten.view_as_real.default(mul_336); mul_336 = None + view_1180 = torch.ops.aten.view.default(view_as_real_70, [2, 8192, 8, 128]); view_as_real_70 = None + convert_element_type_1267 = torch.ops.prims.convert_element_type.default(view_1180, torch.bfloat16); view_1180 = None + view_as_real_71 = torch.ops.aten.view_as_real.default(mul_337); mul_337 = None + view_1181 = torch.ops.aten.view.default(view_as_real_71, [2, 8192, 32, 128]); view_as_real_71 = None + convert_element_type_1268 = torch.ops.prims.convert_element_type.default(view_1181, torch.bfloat16); view_1181 = None + view_1182 = torch.ops.aten.view.default(squeeze_6, [2, 8192, 1024]); squeeze_6 = None + view_1183 = torch.ops.aten.view.default(convert_element_type_1267, [2, 8192, 1024]); convert_element_type_1267 = None + view_1184 = torch.ops.aten.view.default(convert_element_type_1268, [2, 8192, 4096]); convert_element_type_1268 = None + view_1185 = torch.ops.aten.view.default(view_1182, [16384, 1024]); view_1182 = None + permute_473 = torch.ops.aten.permute.default(view_1185, [1, 0]) + mm_277 = torch.ops.aten.mm.default(permute_473, view_955); permute_473 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16); primals_259 = None + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 64, '0'); convert_element_type_934 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_256, [1, 0]); wait_tensor_256 = None + permute_475 = torch.ops.aten.permute.default(permute_310, [1, 0]); permute_310 = None + mm_278 = torch.ops.aten.mm.default(view_1185, permute_475); view_1185 = permute_475 = None + view_1186 = torch.ops.aten.view.default(mm_278, [2, 8192, 4096]); mm_278 = None + convert_element_type_1273 = torch.ops.prims.convert_element_type.default(mm_277, torch.float32); mm_277 = None + reduce_scatter_tensor_34 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1273, 'avg', 64, '0'); convert_element_type_1273 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_34); reduce_scatter_tensor_34 = None + view_1187 = torch.ops.aten.view.default(view_1183, [16384, 1024]); view_1183 = None + permute_477 = torch.ops.aten.permute.default(view_1187, [1, 0]) + mm_279 = torch.ops.aten.mm.default(permute_477, view_955); permute_477 = None + permute_479 = torch.ops.aten.permute.default(permute_309, [1, 0]); permute_309 = None + mm_280 = torch.ops.aten.mm.default(view_1187, permute_479); view_1187 = permute_479 = None + view_1188 = torch.ops.aten.view.default(mm_280, [2, 8192, 4096]); mm_280 = None + add_154 = torch.ops.aten.add.Tensor(view_1186, view_1188); view_1186 = view_1188 = None + convert_element_type_1278 = torch.ops.prims.convert_element_type.default(mm_279, torch.float32); mm_279 = None + reduce_scatter_tensor_35 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1278, 'avg', 64, '0'); convert_element_type_1278 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35); reduce_scatter_tensor_35 = None + view_1189 = torch.ops.aten.view.default(view_1184, [16384, 4096]); view_1184 = None + permute_481 = torch.ops.aten.permute.default(view_1189, [1, 0]) + mm_281 = torch.ops.aten.mm.default(permute_481, view_955); permute_481 = view_955 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16); primals_257 = None + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 64, '0'); convert_element_type_928 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + permute_483 = torch.ops.aten.permute.default(permute_308, [1, 0]); permute_308 = None + mm_282 = torch.ops.aten.mm.default(view_1189, permute_483); view_1189 = permute_483 = None + view_1190 = torch.ops.aten.view.default(mm_282, [2, 8192, 4096]); mm_282 = None + add_155 = torch.ops.aten.add.Tensor(add_154, view_1190); add_154 = view_1190 = None + convert_element_type_1283 = torch.ops.prims.convert_element_type.default(mm_281, torch.float32); mm_281 = None + reduce_scatter_tensor_36 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1283, 'avg', 64, '0'); convert_element_type_1283 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_36); reduce_scatter_tensor_36 = None + convert_element_type_1284 = torch.ops.prims.convert_element_type.default(add_155, torch.float32); add_155 = None + convert_element_type_1286 = torch.ops.prims.convert_element_type.default(wait_tensor_253, torch.float32); wait_tensor_253 = None + mul_338 = torch.ops.aten.mul.Tensor(convert_element_type_1284, convert_element_type_1286); convert_element_type_1286 = None + mul_340 = torch.ops.aten.mul.Tensor(mul_224, mul_338) + sum_25 = torch.ops.aten.sum.dim_IntList(mul_340, [2], True); mul_340 = None + div_8 = torch.ops.aten.div.Tensor(mul_224, 4096) + mul_341 = torch.ops.aten.mul.Tensor(div_8, sum_25); div_8 = sum_25 = None + sub_12 = torch.ops.aten.sub.Tensor(mul_338, mul_341); mul_338 = mul_341 = None + mul_342 = torch.ops.aten.mul.Tensor(sub_12, rsqrt_56); sub_12 = rsqrt_56 = None + mul_343 = torch.ops.aten.mul.Tensor(convert_element_type_1284, mul_224); convert_element_type_1284 = mul_224 = None + sum_26 = torch.ops.aten.sum.dim_IntList(mul_343, [0, 1]); mul_343 = None + convert_element_type_1287 = torch.ops.prims.convert_element_type.default(mul_342, torch.bfloat16); mul_342 = None + add_156 = torch.ops.aten.add.Tensor(add_153, convert_element_type_1287); add_153 = convert_element_type_1287 = None + convert_element_type_default_57 = torch.ops.prims.convert_element_type.default(sum_26, torch.float32); sum_26 = None + reduce_scatter_tensor_37 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_57, 'avg', 64, '0'); convert_element_type_default_57 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37); reduce_scatter_tensor_37 = None + view_1191 = torch.ops.aten.view.default(add_156, [16384, 4096]) + permute_485 = torch.ops.aten.permute.default(view_1191, [1, 0]) + permute_303 = torch.ops.aten.permute.default(getitem_243, [0, 2, 1, 3]) + view_939 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16); primals_251 = None + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 64, '0'); convert_element_type_908 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_248, [1, 0]); wait_tensor_248 = None + view_941 = torch.ops.aten.view.default(view_939, [16384, 4096]); view_939 = None + mm_192 = torch.ops.aten.mm.default(view_941, permute_304) + view_942 = torch.ops.aten.view.default(mm_192, [2, 8192, 4096]); mm_192 = None + add_109 = torch.ops.aten.add.Tensor(add_107, view_942); view_942 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16); primals_252 = None + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 64, '0'); convert_element_type_911 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32); add_109 = None + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_249) + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + view_945 = torch.ops.aten.view.default(convert_element_type_913, [16384, 4096]); convert_element_type_913 = None + view_946 = torch.ops.aten.view.default(mm_193, [2, 8192, 14336]); mm_193 = None + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_946, torch.float32); view_946 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16); primals_254 = None + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 64, '0'); convert_element_type_919 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + mm_194 = torch.ops.aten.mm.default(view_945, permute_306) + view_949 = torch.ops.aten.view.default(mm_194, [2, 8192, 14336]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_949) + view_951 = torch.ops.aten.view.default(mul_223, [16384, 14336]); mul_223 = None + mm_283 = torch.ops.aten.mm.default(permute_485, view_951); permute_485 = view_951 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16); primals_255 = None + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 64, '0'); convert_element_type_922 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + permute_487 = torch.ops.aten.permute.default(permute_307, [1, 0]); permute_307 = None + mm_284 = torch.ops.aten.mm.default(view_1191, permute_487); view_1191 = permute_487 = None + view_1192 = torch.ops.aten.view.default(mm_284, [2, 8192, 14336]); mm_284 = None + convert_element_type_1294 = torch.ops.prims.convert_element_type.default(mm_283, torch.float32); mm_283 = None + reduce_scatter_tensor_38 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1294, 'avg', 64, '0'); convert_element_type_1294 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_38); reduce_scatter_tensor_38 = None + mul_344 = torch.ops.aten.mul.Tensor(view_1192, convert_element_type_918); convert_element_type_918 = None + mul_345 = torch.ops.aten.mul.Tensor(view_1192, view_949); view_1192 = view_949 = None + view_1193 = torch.ops.aten.view.default(mul_344, [16384, 14336]); mul_344 = None + permute_489 = torch.ops.aten.permute.default(view_1193, [1, 0]) + mm_285 = torch.ops.aten.mm.default(permute_489, view_945); permute_489 = None + permute_491 = torch.ops.aten.permute.default(permute_306, [1, 0]); permute_306 = None + mm_286 = torch.ops.aten.mm.default(view_1193, permute_491); view_1193 = permute_491 = None + view_1194 = torch.ops.aten.view.default(mm_286, [2, 8192, 4096]); mm_286 = None + convert_element_type_1299 = torch.ops.prims.convert_element_type.default(mm_285, torch.float32); mm_285 = None + reduce_scatter_tensor_39 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1299, 'avg', 64, '0'); convert_element_type_1299 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39); reduce_scatter_tensor_39 = None + convert_element_type_1300 = torch.ops.prims.convert_element_type.default(mul_345, torch.float32); mul_345 = None + neg_4 = torch.ops.aten.neg.default(convert_element_type_917) + exp_4 = torch.ops.aten.exp.default(neg_4); neg_4 = None + add_157 = torch.ops.aten.add.Tensor(exp_4, 1); exp_4 = None + reciprocal_4 = torch.ops.aten.reciprocal.default(add_157); add_157 = None + mul_346 = torch.ops.aten.mul.Tensor(reciprocal_4, 1); reciprocal_4 = None + mul_347 = torch.ops.aten.mul.Tensor(convert_element_type_1300, mul_346); convert_element_type_1300 = None + sub_13 = torch.ops.aten.sub.Tensor(1, mul_346); mul_346 = None + mul_348 = torch.ops.aten.mul.Tensor(convert_element_type_917, sub_13); convert_element_type_917 = sub_13 = None + add_158 = torch.ops.aten.add.Tensor(mul_348, 1); mul_348 = None + mul_349 = torch.ops.aten.mul.Tensor(mul_347, add_158); mul_347 = add_158 = None + convert_element_type_1302 = torch.ops.prims.convert_element_type.default(mul_349, torch.bfloat16); mul_349 = None + view_1195 = torch.ops.aten.view.default(convert_element_type_1302, [16384, 14336]); convert_element_type_1302 = None + permute_493 = torch.ops.aten.permute.default(view_1195, [1, 0]) + mm_287 = torch.ops.aten.mm.default(permute_493, view_945); permute_493 = view_945 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16); primals_253 = None + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 64, '0'); convert_element_type_914 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_250, [1, 0]); wait_tensor_250 = None + permute_495 = torch.ops.aten.permute.default(permute_305, [1, 0]); permute_305 = None + mm_288 = torch.ops.aten.mm.default(view_1195, permute_495); view_1195 = permute_495 = None + view_1196 = torch.ops.aten.view.default(mm_288, [2, 8192, 4096]); mm_288 = None + add_159 = torch.ops.aten.add.Tensor(view_1194, view_1196); view_1194 = view_1196 = None + convert_element_type_1307 = torch.ops.prims.convert_element_type.default(mm_287, torch.float32); mm_287 = None + reduce_scatter_tensor_40 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1307, 'avg', 64, '0'); convert_element_type_1307 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_40); reduce_scatter_tensor_40 = None + convert_element_type_1308 = torch.ops.prims.convert_element_type.default(add_159, torch.float32); add_159 = None + convert_element_type_1310 = torch.ops.prims.convert_element_type.default(wait_tensor_249, torch.float32); wait_tensor_249 = None + mul_350 = torch.ops.aten.mul.Tensor(convert_element_type_1308, convert_element_type_1310); convert_element_type_1310 = None + mul_352 = torch.ops.aten.mul.Tensor(mul_220, mul_350) + sum_27 = torch.ops.aten.sum.dim_IntList(mul_352, [2], True); mul_352 = None + div_9 = torch.ops.aten.div.Tensor(mul_220, 4096) + mul_353 = torch.ops.aten.mul.Tensor(div_9, sum_27); div_9 = sum_27 = None + sub_14 = torch.ops.aten.sub.Tensor(mul_350, mul_353); mul_350 = mul_353 = None + mul_354 = torch.ops.aten.mul.Tensor(sub_14, rsqrt_55); sub_14 = rsqrt_55 = None + mul_355 = torch.ops.aten.mul.Tensor(convert_element_type_1308, mul_220); convert_element_type_1308 = mul_220 = None + sum_28 = torch.ops.aten.sum.dim_IntList(mul_355, [0, 1]); mul_355 = None + convert_element_type_1311 = torch.ops.prims.convert_element_type.default(mul_354, torch.bfloat16); mul_354 = None + add_160 = torch.ops.aten.add.Tensor(add_156, convert_element_type_1311); add_156 = convert_element_type_1311 = None + convert_element_type_default_56 = torch.ops.prims.convert_element_type.default(sum_28, torch.float32); sum_28 = None + reduce_scatter_tensor_41 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_56, 'avg', 64, '0'); convert_element_type_default_56 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41); reduce_scatter_tensor_41 = None + view_1197 = torch.ops.aten.view.default(add_160, [16384, 4096]) + permute_497 = torch.ops.aten.permute.default(view_1197, [1, 0]) + mm_289 = torch.ops.aten.mm.default(permute_497, view_941); permute_497 = view_941 = None + permute_499 = torch.ops.aten.permute.default(permute_304, [1, 0]); permute_304 = None + mm_290 = torch.ops.aten.mm.default(view_1197, permute_499); view_1197 = permute_499 = None + view_1198 = torch.ops.aten.view.default(mm_290, [2, 8192, 4096]); mm_290 = None + convert_element_type_1318 = torch.ops.prims.convert_element_type.default(mm_289, torch.float32); mm_289 = None + reduce_scatter_tensor_42 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1318, 'avg', 64, '0'); convert_element_type_1318 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_42); reduce_scatter_tensor_42 = None + view_1199 = torch.ops.aten.view.default(view_1198, [2, 8192, 32, 128]); view_1198 = None + permute_501 = torch.ops.aten.permute.default(view_1199, [0, 2, 1, 3]); view_1199 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16); primals_247 = None + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 64, '0'); convert_element_type_892 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32); add_107 = None + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_244) + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + view_921 = torch.ops.aten.view.default(convert_element_type_894, [16384, 4096]); convert_element_type_894 = None + view_922 = torch.ops.aten.view.default(mm_189, [2, 8192, 4096]); mm_189 = None + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16); primals_249 = None + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 64, '0'); convert_element_type_898 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_190 = torch.ops.aten.mm.default(view_921, permute_298) + view_925 = torch.ops.aten.view.default(mm_190, [2, 8192, 1024]); mm_190 = None + view_928 = torch.ops.aten.view.default(mm_191, [2, 8192, 1024]); mm_191 = None + view_929 = torch.ops.aten.view.default(view_922, [2, 8192, -1, 128]); view_922 = None + view_930 = torch.ops.aten.view.default(view_925, [2, 8192, -1, 128]); view_925 = None + view_931 = torch.ops.aten.view.default(view_928, [2, 8192, -1, 128]); view_928 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_929, torch.float32); view_929 = None + view_932 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 32, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_932); view_932 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_930, torch.float32); view_930 = None + view_933 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 8, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_933); view_933 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_16); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_935 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 32, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_16); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_936 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 8, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_935, torch.bfloat16); view_935 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_936, torch.bfloat16); view_936 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 8, 4, 128]); unsqueeze_54 = None + clone_54 = torch.ops.aten.clone.default(expand_54, memory_format = torch.contiguous_format); expand_54 = None + view_937 = torch.ops.aten.view.default(clone_54, [2, 8192, 32, 128]); clone_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_931, 3); view_931 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 8, 4, 128]); unsqueeze_55 = None + clone_55 = torch.ops.aten.clone.default(expand_55, memory_format = torch.contiguous_format); expand_55 = None + view_938 = torch.ops.aten.view.default(clone_55, [2, 8192, 32, 128]); clone_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_937, [0, 2, 1, 3]); view_937 = None + permute_302 = torch.ops.aten.permute.default(view_938, [0, 2, 1, 3]); view_938 = None + _scaled_dot_product_cudnn_attention_backward_4 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_501, permute_300, permute_301, permute_302, getitem_243, getitem_244, getitem_249, getitem_250, None, None, None, 8192, 8192, 0.0, True); permute_501 = permute_300 = permute_301 = permute_302 = getitem_243 = getitem_244 = getitem_249 = getitem_250 = None + getitem_300 = _scaled_dot_product_cudnn_attention_backward_4[0] + getitem_301 = _scaled_dot_product_cudnn_attention_backward_4[1] + getitem_302 = _scaled_dot_product_cudnn_attention_backward_4[2]; _scaled_dot_product_cudnn_attention_backward_4 = None + permute_502 = torch.ops.aten.permute.default(getitem_302, [0, 2, 1, 3]); getitem_302 = None + permute_503 = torch.ops.aten.permute.default(getitem_301, [0, 2, 1, 3]); getitem_301 = None + permute_504 = torch.ops.aten.permute.default(getitem_300, [0, 2, 1, 3]); getitem_300 = None + view_1200 = torch.ops.aten.view.default(permute_502, [2, 8192, 8, 4, 128]); permute_502 = None + sum_29 = torch.ops.aten.sum.dim_IntList(view_1200, [3], True); view_1200 = None + squeeze_8 = torch.ops.aten.squeeze.dim(sum_29, 3); sum_29 = None + view_1201 = torch.ops.aten.view.default(permute_503, [2, 8192, 8, 4, 128]); permute_503 = None + sum_30 = torch.ops.aten.sum.dim_IntList(view_1201, [3], True); view_1201 = None + squeeze_9 = torch.ops.aten.squeeze.dim(sum_30, 3); sum_30 = None + convert_element_type_1319 = torch.ops.prims.convert_element_type.default(squeeze_9, torch.float32); squeeze_9 = None + convert_element_type_1320 = torch.ops.prims.convert_element_type.default(permute_504, torch.float32); permute_504 = None + view_1202 = torch.ops.aten.view.default(convert_element_type_1319, [2, 8192, 8, 64, 2]); convert_element_type_1319 = None + view_as_complex_72 = torch.ops.aten.view_as_complex.default(view_1202); view_1202 = None + mul_356 = torch.ops.aten.mul.Tensor(view_as_complex_72, _conj); view_as_complex_72 = None + view_1203 = torch.ops.aten.view.default(convert_element_type_1320, [2, 8192, 32, 64, 2]); convert_element_type_1320 = None + view_as_complex_73 = torch.ops.aten.view_as_complex.default(view_1203); view_1203 = None + mul_357 = torch.ops.aten.mul.Tensor(view_as_complex_73, _conj); view_as_complex_73 = None + view_as_real_72 = torch.ops.aten.view_as_real.default(mul_356); mul_356 = None + view_1204 = torch.ops.aten.view.default(view_as_real_72, [2, 8192, 8, 128]); view_as_real_72 = None + convert_element_type_1321 = torch.ops.prims.convert_element_type.default(view_1204, torch.bfloat16); view_1204 = None + view_as_real_73 = torch.ops.aten.view_as_real.default(mul_357); mul_357 = None + view_1205 = torch.ops.aten.view.default(view_as_real_73, [2, 8192, 32, 128]); view_as_real_73 = None + convert_element_type_1322 = torch.ops.prims.convert_element_type.default(view_1205, torch.bfloat16); view_1205 = None + view_1206 = torch.ops.aten.view.default(squeeze_8, [2, 8192, 1024]); squeeze_8 = None + view_1207 = torch.ops.aten.view.default(convert_element_type_1321, [2, 8192, 1024]); convert_element_type_1321 = None + view_1208 = torch.ops.aten.view.default(convert_element_type_1322, [2, 8192, 4096]); convert_element_type_1322 = None + view_1209 = torch.ops.aten.view.default(view_1206, [16384, 1024]); view_1206 = None + permute_505 = torch.ops.aten.permute.default(view_1209, [1, 0]) + mm_291 = torch.ops.aten.mm.default(permute_505, view_921); permute_505 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16); primals_250 = None + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 64, '0'); convert_element_type_901 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + permute_507 = torch.ops.aten.permute.default(permute_299, [1, 0]); permute_299 = None + mm_292 = torch.ops.aten.mm.default(view_1209, permute_507); view_1209 = permute_507 = None + view_1210 = torch.ops.aten.view.default(mm_292, [2, 8192, 4096]); mm_292 = None + convert_element_type_1327 = torch.ops.prims.convert_element_type.default(mm_291, torch.float32); mm_291 = None + reduce_scatter_tensor_43 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1327, 'avg', 64, '0'); convert_element_type_1327 = None + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43); reduce_scatter_tensor_43 = None + view_1211 = torch.ops.aten.view.default(view_1207, [16384, 1024]); view_1207 = None + permute_509 = torch.ops.aten.permute.default(view_1211, [1, 0]) + mm_293 = torch.ops.aten.mm.default(permute_509, view_921); permute_509 = None + permute_511 = torch.ops.aten.permute.default(permute_298, [1, 0]); permute_298 = None + mm_294 = torch.ops.aten.mm.default(view_1211, permute_511); view_1211 = permute_511 = None + view_1212 = torch.ops.aten.view.default(mm_294, [2, 8192, 4096]); mm_294 = None + add_161 = torch.ops.aten.add.Tensor(view_1210, view_1212); view_1210 = view_1212 = None + convert_element_type_1332 = torch.ops.prims.convert_element_type.default(mm_293, torch.float32); mm_293 = None + reduce_scatter_tensor_44 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1332, 'avg', 64, '0'); convert_element_type_1332 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_44); reduce_scatter_tensor_44 = None + view_1213 = torch.ops.aten.view.default(view_1208, [16384, 4096]); view_1208 = None + permute_513 = torch.ops.aten.permute.default(view_1213, [1, 0]) + mm_295 = torch.ops.aten.mm.default(permute_513, view_921); permute_513 = view_921 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16); primals_248 = None + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 64, '0'); convert_element_type_895 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + permute_515 = torch.ops.aten.permute.default(permute_297, [1, 0]); permute_297 = None + mm_296 = torch.ops.aten.mm.default(view_1213, permute_515); view_1213 = permute_515 = None + view_1214 = torch.ops.aten.view.default(mm_296, [2, 8192, 4096]); mm_296 = None + add_162 = torch.ops.aten.add.Tensor(add_161, view_1214); add_161 = view_1214 = None + convert_element_type_1337 = torch.ops.prims.convert_element_type.default(mm_295, torch.float32); mm_295 = None + reduce_scatter_tensor_45 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1337, 'avg', 64, '0'); convert_element_type_1337 = None + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45); reduce_scatter_tensor_45 = None + convert_element_type_1338 = torch.ops.prims.convert_element_type.default(add_162, torch.float32); add_162 = None + convert_element_type_1340 = torch.ops.prims.convert_element_type.default(wait_tensor_244, torch.float32); wait_tensor_244 = None + mul_358 = torch.ops.aten.mul.Tensor(convert_element_type_1338, convert_element_type_1340); convert_element_type_1340 = None + mul_360 = torch.ops.aten.mul.Tensor(mul_216, mul_358) + sum_31 = torch.ops.aten.sum.dim_IntList(mul_360, [2], True); mul_360 = None + div_10 = torch.ops.aten.div.Tensor(mul_216, 4096) + mul_361 = torch.ops.aten.mul.Tensor(div_10, sum_31); div_10 = sum_31 = None + sub_15 = torch.ops.aten.sub.Tensor(mul_358, mul_361); mul_358 = mul_361 = None + mul_362 = torch.ops.aten.mul.Tensor(sub_15, rsqrt_54); sub_15 = rsqrt_54 = None + mul_363 = torch.ops.aten.mul.Tensor(convert_element_type_1338, mul_216); convert_element_type_1338 = mul_216 = None + sum_32 = torch.ops.aten.sum.dim_IntList(mul_363, [0, 1]); mul_363 = None + convert_element_type_1341 = torch.ops.prims.convert_element_type.default(mul_362, torch.bfloat16); mul_362 = None + add_163 = torch.ops.aten.add.Tensor(add_160, convert_element_type_1341); add_160 = convert_element_type_1341 = None + convert_element_type_default_55 = torch.ops.prims.convert_element_type.default(sum_32, torch.float32); sum_32 = None + reduce_scatter_tensor_46 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_55, 'avg', 64, '0'); convert_element_type_default_55 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_46); reduce_scatter_tensor_46 = None + view_1215 = torch.ops.aten.view.default(add_163, [16384, 4096]) + permute_517 = torch.ops.aten.permute.default(view_1215, [1, 0]) + permute_292 = torch.ops.aten.permute.default(getitem_234, [0, 2, 1, 3]) + view_905 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16); primals_242 = None + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 64, '0'); convert_element_type_875 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + view_907 = torch.ops.aten.view.default(view_905, [16384, 4096]); view_905 = None + mm_185 = torch.ops.aten.mm.default(view_907, permute_293) + view_908 = torch.ops.aten.view.default(mm_185, [2, 8192, 4096]); mm_185 = None + add_105 = torch.ops.aten.add.Tensor(add_103, view_908); view_908 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16); primals_243 = None + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 64, '0'); convert_element_type_878 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32); add_105 = None + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_240) + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + view_911 = torch.ops.aten.view.default(convert_element_type_880, [16384, 4096]); convert_element_type_880 = None + view_912 = torch.ops.aten.view.default(mm_186, [2, 8192, 14336]); mm_186 = None + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_912, torch.float32); view_912 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16); primals_245 = None + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 64, '0'); convert_element_type_886 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_242, [1, 0]); wait_tensor_242 = None + mm_187 = torch.ops.aten.mm.default(view_911, permute_295) + view_915 = torch.ops.aten.view.default(mm_187, [2, 8192, 14336]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_915) + view_917 = torch.ops.aten.view.default(mul_215, [16384, 14336]); mul_215 = None + mm_297 = torch.ops.aten.mm.default(permute_517, view_917); permute_517 = view_917 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16); primals_246 = None + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 64, '0'); convert_element_type_889 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_243, [1, 0]); wait_tensor_243 = None + permute_519 = torch.ops.aten.permute.default(permute_296, [1, 0]); permute_296 = None + mm_298 = torch.ops.aten.mm.default(view_1215, permute_519); view_1215 = permute_519 = None + view_1216 = torch.ops.aten.view.default(mm_298, [2, 8192, 14336]); mm_298 = None + convert_element_type_1348 = torch.ops.prims.convert_element_type.default(mm_297, torch.float32); mm_297 = None + reduce_scatter_tensor_47 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1348, 'avg', 64, '0'); convert_element_type_1348 = None + wait_tensor_338 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47); reduce_scatter_tensor_47 = None + mul_364 = torch.ops.aten.mul.Tensor(view_1216, convert_element_type_885); convert_element_type_885 = None + mul_365 = torch.ops.aten.mul.Tensor(view_1216, view_915); view_1216 = view_915 = None + view_1217 = torch.ops.aten.view.default(mul_364, [16384, 14336]); mul_364 = None + permute_521 = torch.ops.aten.permute.default(view_1217, [1, 0]) + mm_299 = torch.ops.aten.mm.default(permute_521, view_911); permute_521 = None + permute_523 = torch.ops.aten.permute.default(permute_295, [1, 0]); permute_295 = None + mm_300 = torch.ops.aten.mm.default(view_1217, permute_523); view_1217 = permute_523 = None + view_1218 = torch.ops.aten.view.default(mm_300, [2, 8192, 4096]); mm_300 = None + convert_element_type_1353 = torch.ops.prims.convert_element_type.default(mm_299, torch.float32); mm_299 = None + reduce_scatter_tensor_48 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1353, 'avg', 64, '0'); convert_element_type_1353 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_48); reduce_scatter_tensor_48 = None + convert_element_type_1354 = torch.ops.prims.convert_element_type.default(mul_365, torch.float32); mul_365 = None + neg_5 = torch.ops.aten.neg.default(convert_element_type_884) + exp_5 = torch.ops.aten.exp.default(neg_5); neg_5 = None + add_164 = torch.ops.aten.add.Tensor(exp_5, 1); exp_5 = None + reciprocal_5 = torch.ops.aten.reciprocal.default(add_164); add_164 = None + mul_366 = torch.ops.aten.mul.Tensor(reciprocal_5, 1); reciprocal_5 = None + mul_367 = torch.ops.aten.mul.Tensor(convert_element_type_1354, mul_366); convert_element_type_1354 = None + sub_16 = torch.ops.aten.sub.Tensor(1, mul_366); mul_366 = None + mul_368 = torch.ops.aten.mul.Tensor(convert_element_type_884, sub_16); convert_element_type_884 = sub_16 = None + add_165 = torch.ops.aten.add.Tensor(mul_368, 1); mul_368 = None + mul_369 = torch.ops.aten.mul.Tensor(mul_367, add_165); mul_367 = add_165 = None + convert_element_type_1356 = torch.ops.prims.convert_element_type.default(mul_369, torch.bfloat16); mul_369 = None + view_1219 = torch.ops.aten.view.default(convert_element_type_1356, [16384, 14336]); convert_element_type_1356 = None + permute_525 = torch.ops.aten.permute.default(view_1219, [1, 0]) + mm_301 = torch.ops.aten.mm.default(permute_525, view_911); permute_525 = view_911 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16); primals_244 = None + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 64, '0'); convert_element_type_881 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + permute_527 = torch.ops.aten.permute.default(permute_294, [1, 0]); permute_294 = None + mm_302 = torch.ops.aten.mm.default(view_1219, permute_527); view_1219 = permute_527 = None + view_1220 = torch.ops.aten.view.default(mm_302, [2, 8192, 4096]); mm_302 = None + add_166 = torch.ops.aten.add.Tensor(view_1218, view_1220); view_1218 = view_1220 = None + convert_element_type_1361 = torch.ops.prims.convert_element_type.default(mm_301, torch.float32); mm_301 = None + reduce_scatter_tensor_49 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1361, 'avg', 64, '0'); convert_element_type_1361 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49); reduce_scatter_tensor_49 = None + convert_element_type_1362 = torch.ops.prims.convert_element_type.default(add_166, torch.float32); add_166 = None + convert_element_type_1364 = torch.ops.prims.convert_element_type.default(wait_tensor_240, torch.float32); wait_tensor_240 = None + mul_370 = torch.ops.aten.mul.Tensor(convert_element_type_1362, convert_element_type_1364); convert_element_type_1364 = None + mul_372 = torch.ops.aten.mul.Tensor(mul_212, mul_370) + sum_33 = torch.ops.aten.sum.dim_IntList(mul_372, [2], True); mul_372 = None + div_11 = torch.ops.aten.div.Tensor(mul_212, 4096) + mul_373 = torch.ops.aten.mul.Tensor(div_11, sum_33); div_11 = sum_33 = None + sub_17 = torch.ops.aten.sub.Tensor(mul_370, mul_373); mul_370 = mul_373 = None + mul_374 = torch.ops.aten.mul.Tensor(sub_17, rsqrt_53); sub_17 = rsqrt_53 = None + mul_375 = torch.ops.aten.mul.Tensor(convert_element_type_1362, mul_212); convert_element_type_1362 = mul_212 = None + sum_34 = torch.ops.aten.sum.dim_IntList(mul_375, [0, 1]); mul_375 = None + convert_element_type_1365 = torch.ops.prims.convert_element_type.default(mul_374, torch.bfloat16); mul_374 = None + add_167 = torch.ops.aten.add.Tensor(add_163, convert_element_type_1365); add_163 = convert_element_type_1365 = None + convert_element_type_default_54 = torch.ops.prims.convert_element_type.default(sum_34, torch.float32); sum_34 = None + reduce_scatter_tensor_50 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_54, 'avg', 64, '0'); convert_element_type_default_54 = None + wait_tensor_341 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_50); reduce_scatter_tensor_50 = None + view_1221 = torch.ops.aten.view.default(add_167, [16384, 4096]) + permute_529 = torch.ops.aten.permute.default(view_1221, [1, 0]) + mm_303 = torch.ops.aten.mm.default(permute_529, view_907); permute_529 = view_907 = None + permute_531 = torch.ops.aten.permute.default(permute_293, [1, 0]); permute_293 = None + mm_304 = torch.ops.aten.mm.default(view_1221, permute_531); view_1221 = permute_531 = None + view_1222 = torch.ops.aten.view.default(mm_304, [2, 8192, 4096]); mm_304 = None + convert_element_type_1372 = torch.ops.prims.convert_element_type.default(mm_303, torch.float32); mm_303 = None + reduce_scatter_tensor_51 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1372, 'avg', 64, '0'); convert_element_type_1372 = None + wait_tensor_342 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51); reduce_scatter_tensor_51 = None + view_1223 = torch.ops.aten.view.default(view_1222, [2, 8192, 32, 128]); view_1222 = None + permute_533 = torch.ops.aten.permute.default(view_1223, [0, 2, 1, 3]); view_1223 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16); primals_238 = None + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 64, '0'); convert_element_type_859 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32); add_103 = None + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_235) + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + view_887 = torch.ops.aten.view.default(convert_element_type_861, [16384, 4096]); convert_element_type_861 = None + view_888 = torch.ops.aten.view.default(mm_182, [2, 8192, 4096]); mm_182 = None + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16); primals_240 = None + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 64, '0'); convert_element_type_865 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_237, [1, 0]); wait_tensor_237 = None + mm_183 = torch.ops.aten.mm.default(view_887, permute_287) + view_891 = torch.ops.aten.view.default(mm_183, [2, 8192, 1024]); mm_183 = None + view_894 = torch.ops.aten.view.default(mm_184, [2, 8192, 1024]); mm_184 = None + view_895 = torch.ops.aten.view.default(view_888, [2, 8192, -1, 128]); view_888 = None + view_896 = torch.ops.aten.view.default(view_891, [2, 8192, -1, 128]); view_891 = None + view_897 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_895, torch.float32); view_895 = None + view_898 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 32, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_898); view_898 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 8, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_16); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_901 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 32, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_16); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_902 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 8, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_901, torch.bfloat16); view_901 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 8, 4, 128]); unsqueeze_52 = None + clone_52 = torch.ops.aten.clone.default(expand_52, memory_format = torch.contiguous_format); expand_52 = None + view_903 = torch.ops.aten.view.default(clone_52, [2, 8192, 32, 128]); clone_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_897, 3); view_897 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 8, 4, 128]); unsqueeze_53 = None + clone_53 = torch.ops.aten.clone.default(expand_53, memory_format = torch.contiguous_format); expand_53 = None + view_904 = torch.ops.aten.view.default(clone_53, [2, 8192, 32, 128]); clone_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_903, [0, 2, 1, 3]); view_903 = None + permute_291 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + _scaled_dot_product_cudnn_attention_backward_5 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_533, permute_289, permute_290, permute_291, getitem_234, getitem_235, getitem_240, getitem_241, None, None, None, 8192, 8192, 0.0, True); permute_533 = permute_289 = permute_290 = permute_291 = getitem_234 = getitem_235 = getitem_240 = getitem_241 = None + getitem_303 = _scaled_dot_product_cudnn_attention_backward_5[0] + getitem_304 = _scaled_dot_product_cudnn_attention_backward_5[1] + getitem_305 = _scaled_dot_product_cudnn_attention_backward_5[2]; _scaled_dot_product_cudnn_attention_backward_5 = None + permute_534 = torch.ops.aten.permute.default(getitem_305, [0, 2, 1, 3]); getitem_305 = None + permute_535 = torch.ops.aten.permute.default(getitem_304, [0, 2, 1, 3]); getitem_304 = None + permute_536 = torch.ops.aten.permute.default(getitem_303, [0, 2, 1, 3]); getitem_303 = None + view_1224 = torch.ops.aten.view.default(permute_534, [2, 8192, 8, 4, 128]); permute_534 = None + sum_35 = torch.ops.aten.sum.dim_IntList(view_1224, [3], True); view_1224 = None + squeeze_10 = torch.ops.aten.squeeze.dim(sum_35, 3); sum_35 = None + view_1225 = torch.ops.aten.view.default(permute_535, [2, 8192, 8, 4, 128]); permute_535 = None + sum_36 = torch.ops.aten.sum.dim_IntList(view_1225, [3], True); view_1225 = None + squeeze_11 = torch.ops.aten.squeeze.dim(sum_36, 3); sum_36 = None + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(squeeze_11, torch.float32); squeeze_11 = None + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(permute_536, torch.float32); permute_536 = None + view_1226 = torch.ops.aten.view.default(convert_element_type_1373, [2, 8192, 8, 64, 2]); convert_element_type_1373 = None + view_as_complex_74 = torch.ops.aten.view_as_complex.default(view_1226); view_1226 = None + mul_376 = torch.ops.aten.mul.Tensor(view_as_complex_74, _conj); view_as_complex_74 = None + view_1227 = torch.ops.aten.view.default(convert_element_type_1374, [2, 8192, 32, 64, 2]); convert_element_type_1374 = None + view_as_complex_75 = torch.ops.aten.view_as_complex.default(view_1227); view_1227 = None + mul_377 = torch.ops.aten.mul.Tensor(view_as_complex_75, _conj); view_as_complex_75 = None + view_as_real_74 = torch.ops.aten.view_as_real.default(mul_376); mul_376 = None + view_1228 = torch.ops.aten.view.default(view_as_real_74, [2, 8192, 8, 128]); view_as_real_74 = None + convert_element_type_1375 = torch.ops.prims.convert_element_type.default(view_1228, torch.bfloat16); view_1228 = None + view_as_real_75 = torch.ops.aten.view_as_real.default(mul_377); mul_377 = None + view_1229 = torch.ops.aten.view.default(view_as_real_75, [2, 8192, 32, 128]); view_as_real_75 = None + convert_element_type_1376 = torch.ops.prims.convert_element_type.default(view_1229, torch.bfloat16); view_1229 = None + view_1230 = torch.ops.aten.view.default(squeeze_10, [2, 8192, 1024]); squeeze_10 = None + view_1231 = torch.ops.aten.view.default(convert_element_type_1375, [2, 8192, 1024]); convert_element_type_1375 = None + view_1232 = torch.ops.aten.view.default(convert_element_type_1376, [2, 8192, 4096]); convert_element_type_1376 = None + view_1233 = torch.ops.aten.view.default(view_1230, [16384, 1024]); view_1230 = None + permute_537 = torch.ops.aten.permute.default(view_1233, [1, 0]) + mm_305 = torch.ops.aten.mm.default(permute_537, view_887); permute_537 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16); primals_241 = None + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 64, '0'); convert_element_type_868 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + permute_539 = torch.ops.aten.permute.default(permute_288, [1, 0]); permute_288 = None + mm_306 = torch.ops.aten.mm.default(view_1233, permute_539); view_1233 = permute_539 = None + view_1234 = torch.ops.aten.view.default(mm_306, [2, 8192, 4096]); mm_306 = None + convert_element_type_1381 = torch.ops.prims.convert_element_type.default(mm_305, torch.float32); mm_305 = None + reduce_scatter_tensor_52 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1381, 'avg', 64, '0'); convert_element_type_1381 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_52); reduce_scatter_tensor_52 = None + view_1235 = torch.ops.aten.view.default(view_1231, [16384, 1024]); view_1231 = None + permute_541 = torch.ops.aten.permute.default(view_1235, [1, 0]) + mm_307 = torch.ops.aten.mm.default(permute_541, view_887); permute_541 = None + permute_543 = torch.ops.aten.permute.default(permute_287, [1, 0]); permute_287 = None + mm_308 = torch.ops.aten.mm.default(view_1235, permute_543); view_1235 = permute_543 = None + view_1236 = torch.ops.aten.view.default(mm_308, [2, 8192, 4096]); mm_308 = None + add_168 = torch.ops.aten.add.Tensor(view_1234, view_1236); view_1234 = view_1236 = None + convert_element_type_1386 = torch.ops.prims.convert_element_type.default(mm_307, torch.float32); mm_307 = None + reduce_scatter_tensor_53 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1386, 'avg', 64, '0'); convert_element_type_1386 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53); reduce_scatter_tensor_53 = None + view_1237 = torch.ops.aten.view.default(view_1232, [16384, 4096]); view_1232 = None + permute_545 = torch.ops.aten.permute.default(view_1237, [1, 0]) + mm_309 = torch.ops.aten.mm.default(permute_545, view_887); permute_545 = view_887 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16); primals_239 = None + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 64, '0'); convert_element_type_862 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_236, [1, 0]); wait_tensor_236 = None + permute_547 = torch.ops.aten.permute.default(permute_286, [1, 0]); permute_286 = None + mm_310 = torch.ops.aten.mm.default(view_1237, permute_547); view_1237 = permute_547 = None + view_1238 = torch.ops.aten.view.default(mm_310, [2, 8192, 4096]); mm_310 = None + add_169 = torch.ops.aten.add.Tensor(add_168, view_1238); add_168 = view_1238 = None + convert_element_type_1391 = torch.ops.prims.convert_element_type.default(mm_309, torch.float32); mm_309 = None + reduce_scatter_tensor_54 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1391, 'avg', 64, '0'); convert_element_type_1391 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_54); reduce_scatter_tensor_54 = None + convert_element_type_1392 = torch.ops.prims.convert_element_type.default(add_169, torch.float32); add_169 = None + convert_element_type_1394 = torch.ops.prims.convert_element_type.default(wait_tensor_235, torch.float32); wait_tensor_235 = None + mul_378 = torch.ops.aten.mul.Tensor(convert_element_type_1392, convert_element_type_1394); convert_element_type_1394 = None + mul_380 = torch.ops.aten.mul.Tensor(mul_208, mul_378) + sum_37 = torch.ops.aten.sum.dim_IntList(mul_380, [2], True); mul_380 = None + div_12 = torch.ops.aten.div.Tensor(mul_208, 4096) + mul_381 = torch.ops.aten.mul.Tensor(div_12, sum_37); div_12 = sum_37 = None + sub_18 = torch.ops.aten.sub.Tensor(mul_378, mul_381); mul_378 = mul_381 = None + mul_382 = torch.ops.aten.mul.Tensor(sub_18, rsqrt_52); sub_18 = rsqrt_52 = None + mul_383 = torch.ops.aten.mul.Tensor(convert_element_type_1392, mul_208); convert_element_type_1392 = mul_208 = None + sum_38 = torch.ops.aten.sum.dim_IntList(mul_383, [0, 1]); mul_383 = None + convert_element_type_1395 = torch.ops.prims.convert_element_type.default(mul_382, torch.bfloat16); mul_382 = None + add_170 = torch.ops.aten.add.Tensor(add_167, convert_element_type_1395); add_167 = convert_element_type_1395 = None + convert_element_type_default_53 = torch.ops.prims.convert_element_type.default(sum_38, torch.float32); sum_38 = None + reduce_scatter_tensor_55 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_53, 'avg', 64, '0'); convert_element_type_default_53 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55); reduce_scatter_tensor_55 = None + view_1239 = torch.ops.aten.view.default(add_170, [16384, 4096]) + permute_549 = torch.ops.aten.permute.default(view_1239, [1, 0]) + permute_281 = torch.ops.aten.permute.default(getitem_225, [0, 2, 1, 3]) + view_871 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16); primals_233 = None + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 64, '0'); convert_element_type_842 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_230, [1, 0]); wait_tensor_230 = None + view_873 = torch.ops.aten.view.default(view_871, [16384, 4096]); view_871 = None + mm_178 = torch.ops.aten.mm.default(view_873, permute_282) + view_874 = torch.ops.aten.view.default(mm_178, [2, 8192, 4096]); mm_178 = None + add_101 = torch.ops.aten.add.Tensor(add_99, view_874); view_874 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16); primals_234 = None + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 64, '0'); convert_element_type_845 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32); add_101 = None + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_231) + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + view_877 = torch.ops.aten.view.default(convert_element_type_847, [16384, 4096]); convert_element_type_847 = None + view_878 = torch.ops.aten.view.default(mm_179, [2, 8192, 14336]); mm_179 = None + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_878, torch.float32); view_878 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16); primals_236 = None + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 64, '0'); convert_element_type_853 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_180 = torch.ops.aten.mm.default(view_877, permute_284) + view_881 = torch.ops.aten.view.default(mm_180, [2, 8192, 14336]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_881) + view_883 = torch.ops.aten.view.default(mul_207, [16384, 14336]); mul_207 = None + mm_311 = torch.ops.aten.mm.default(permute_549, view_883); permute_549 = view_883 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16); primals_237 = None + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 64, '0'); convert_element_type_856 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + permute_551 = torch.ops.aten.permute.default(permute_285, [1, 0]); permute_285 = None + mm_312 = torch.ops.aten.mm.default(view_1239, permute_551); view_1239 = permute_551 = None + view_1240 = torch.ops.aten.view.default(mm_312, [2, 8192, 14336]); mm_312 = None + convert_element_type_1402 = torch.ops.prims.convert_element_type.default(mm_311, torch.float32); mm_311 = None + reduce_scatter_tensor_56 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1402, 'avg', 64, '0'); convert_element_type_1402 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_56); reduce_scatter_tensor_56 = None + mul_384 = torch.ops.aten.mul.Tensor(view_1240, convert_element_type_852); convert_element_type_852 = None + mul_385 = torch.ops.aten.mul.Tensor(view_1240, view_881); view_1240 = view_881 = None + view_1241 = torch.ops.aten.view.default(mul_384, [16384, 14336]); mul_384 = None + permute_553 = torch.ops.aten.permute.default(view_1241, [1, 0]) + mm_313 = torch.ops.aten.mm.default(permute_553, view_877); permute_553 = None + permute_555 = torch.ops.aten.permute.default(permute_284, [1, 0]); permute_284 = None + mm_314 = torch.ops.aten.mm.default(view_1241, permute_555); view_1241 = permute_555 = None + view_1242 = torch.ops.aten.view.default(mm_314, [2, 8192, 4096]); mm_314 = None + convert_element_type_1407 = torch.ops.prims.convert_element_type.default(mm_313, torch.float32); mm_313 = None + reduce_scatter_tensor_57 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1407, 'avg', 64, '0'); convert_element_type_1407 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57); reduce_scatter_tensor_57 = None + convert_element_type_1408 = torch.ops.prims.convert_element_type.default(mul_385, torch.float32); mul_385 = None + neg_6 = torch.ops.aten.neg.default(convert_element_type_851) + exp_6 = torch.ops.aten.exp.default(neg_6); neg_6 = None + add_171 = torch.ops.aten.add.Tensor(exp_6, 1); exp_6 = None + reciprocal_6 = torch.ops.aten.reciprocal.default(add_171); add_171 = None + mul_386 = torch.ops.aten.mul.Tensor(reciprocal_6, 1); reciprocal_6 = None + mul_387 = torch.ops.aten.mul.Tensor(convert_element_type_1408, mul_386); convert_element_type_1408 = None + sub_19 = torch.ops.aten.sub.Tensor(1, mul_386); mul_386 = None + mul_388 = torch.ops.aten.mul.Tensor(convert_element_type_851, sub_19); convert_element_type_851 = sub_19 = None + add_172 = torch.ops.aten.add.Tensor(mul_388, 1); mul_388 = None + mul_389 = torch.ops.aten.mul.Tensor(mul_387, add_172); mul_387 = add_172 = None + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(mul_389, torch.bfloat16); mul_389 = None + view_1243 = torch.ops.aten.view.default(convert_element_type_1410, [16384, 14336]); convert_element_type_1410 = None + permute_557 = torch.ops.aten.permute.default(view_1243, [1, 0]) + mm_315 = torch.ops.aten.mm.default(permute_557, view_877); permute_557 = view_877 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16); primals_235 = None + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 64, '0'); convert_element_type_848 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + permute_559 = torch.ops.aten.permute.default(permute_283, [1, 0]); permute_283 = None + mm_316 = torch.ops.aten.mm.default(view_1243, permute_559); view_1243 = permute_559 = None + view_1244 = torch.ops.aten.view.default(mm_316, [2, 8192, 4096]); mm_316 = None + add_173 = torch.ops.aten.add.Tensor(view_1242, view_1244); view_1242 = view_1244 = None + convert_element_type_1415 = torch.ops.prims.convert_element_type.default(mm_315, torch.float32); mm_315 = None + reduce_scatter_tensor_58 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1415, 'avg', 64, '0'); convert_element_type_1415 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_58); reduce_scatter_tensor_58 = None + convert_element_type_1416 = torch.ops.prims.convert_element_type.default(add_173, torch.float32); add_173 = None + convert_element_type_1418 = torch.ops.prims.convert_element_type.default(wait_tensor_231, torch.float32); wait_tensor_231 = None + mul_390 = torch.ops.aten.mul.Tensor(convert_element_type_1416, convert_element_type_1418); convert_element_type_1418 = None + mul_392 = torch.ops.aten.mul.Tensor(mul_204, mul_390) + sum_39 = torch.ops.aten.sum.dim_IntList(mul_392, [2], True); mul_392 = None + div_13 = torch.ops.aten.div.Tensor(mul_204, 4096) + mul_393 = torch.ops.aten.mul.Tensor(div_13, sum_39); div_13 = sum_39 = None + sub_20 = torch.ops.aten.sub.Tensor(mul_390, mul_393); mul_390 = mul_393 = None + mul_394 = torch.ops.aten.mul.Tensor(sub_20, rsqrt_51); sub_20 = rsqrt_51 = None + mul_395 = torch.ops.aten.mul.Tensor(convert_element_type_1416, mul_204); convert_element_type_1416 = mul_204 = None + sum_40 = torch.ops.aten.sum.dim_IntList(mul_395, [0, 1]); mul_395 = None + convert_element_type_1419 = torch.ops.prims.convert_element_type.default(mul_394, torch.bfloat16); mul_394 = None + add_174 = torch.ops.aten.add.Tensor(add_170, convert_element_type_1419); add_170 = convert_element_type_1419 = None + convert_element_type_default_52 = torch.ops.prims.convert_element_type.default(sum_40, torch.float32); sum_40 = None + reduce_scatter_tensor_59 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_52, 'avg', 64, '0'); convert_element_type_default_52 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59); reduce_scatter_tensor_59 = None + view_1245 = torch.ops.aten.view.default(add_174, [16384, 4096]) + permute_561 = torch.ops.aten.permute.default(view_1245, [1, 0]) + mm_317 = torch.ops.aten.mm.default(permute_561, view_873); permute_561 = view_873 = None + permute_563 = torch.ops.aten.permute.default(permute_282, [1, 0]); permute_282 = None + mm_318 = torch.ops.aten.mm.default(view_1245, permute_563); view_1245 = permute_563 = None + view_1246 = torch.ops.aten.view.default(mm_318, [2, 8192, 4096]); mm_318 = None + convert_element_type_1426 = torch.ops.prims.convert_element_type.default(mm_317, torch.float32); mm_317 = None + reduce_scatter_tensor_60 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1426, 'avg', 64, '0'); convert_element_type_1426 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_60); reduce_scatter_tensor_60 = None + view_1247 = torch.ops.aten.view.default(view_1246, [2, 8192, 32, 128]); view_1246 = None + permute_565 = torch.ops.aten.permute.default(view_1247, [0, 2, 1, 3]); view_1247 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16); primals_229 = None + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 64, '0'); convert_element_type_826 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32); add_99 = None + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_226) + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + view_853 = torch.ops.aten.view.default(convert_element_type_828, [16384, 4096]); convert_element_type_828 = None + view_854 = torch.ops.aten.view.default(mm_175, [2, 8192, 4096]); mm_175 = None + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16); primals_231 = None + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 64, '0'); convert_element_type_832 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + mm_176 = torch.ops.aten.mm.default(view_853, permute_276) + view_857 = torch.ops.aten.view.default(mm_176, [2, 8192, 1024]); mm_176 = None + view_860 = torch.ops.aten.view.default(mm_177, [2, 8192, 1024]); mm_177 = None + view_861 = torch.ops.aten.view.default(view_854, [2, 8192, -1, 128]); view_854 = None + view_862 = torch.ops.aten.view.default(view_857, [2, 8192, -1, 128]); view_857 = None + view_863 = torch.ops.aten.view.default(view_860, [2, 8192, -1, 128]); view_860 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_861, torch.float32); view_861 = None + view_864 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 32, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_864); view_864 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_862, torch.float32); view_862 = None + view_865 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 8, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_865); view_865 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_16); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_867 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 32, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_16); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_868 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 8, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_867, torch.bfloat16); view_867 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_868, torch.bfloat16); view_868 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 8, 4, 128]); unsqueeze_50 = None + clone_50 = torch.ops.aten.clone.default(expand_50, memory_format = torch.contiguous_format); expand_50 = None + view_869 = torch.ops.aten.view.default(clone_50, [2, 8192, 32, 128]); clone_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_863, 3); view_863 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 8, 4, 128]); unsqueeze_51 = None + clone_51 = torch.ops.aten.clone.default(expand_51, memory_format = torch.contiguous_format); expand_51 = None + view_870 = torch.ops.aten.view.default(clone_51, [2, 8192, 32, 128]); clone_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_869, [0, 2, 1, 3]); view_869 = None + permute_280 = torch.ops.aten.permute.default(view_870, [0, 2, 1, 3]); view_870 = None + _scaled_dot_product_cudnn_attention_backward_6 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_565, permute_278, permute_279, permute_280, getitem_225, getitem_226, getitem_231, getitem_232, None, None, None, 8192, 8192, 0.0, True); permute_565 = permute_278 = permute_279 = permute_280 = getitem_225 = getitem_226 = getitem_231 = getitem_232 = None + getitem_306 = _scaled_dot_product_cudnn_attention_backward_6[0] + getitem_307 = _scaled_dot_product_cudnn_attention_backward_6[1] + getitem_308 = _scaled_dot_product_cudnn_attention_backward_6[2]; _scaled_dot_product_cudnn_attention_backward_6 = None + permute_566 = torch.ops.aten.permute.default(getitem_308, [0, 2, 1, 3]); getitem_308 = None + permute_567 = torch.ops.aten.permute.default(getitem_307, [0, 2, 1, 3]); getitem_307 = None + permute_568 = torch.ops.aten.permute.default(getitem_306, [0, 2, 1, 3]); getitem_306 = None + view_1248 = torch.ops.aten.view.default(permute_566, [2, 8192, 8, 4, 128]); permute_566 = None + sum_41 = torch.ops.aten.sum.dim_IntList(view_1248, [3], True); view_1248 = None + squeeze_12 = torch.ops.aten.squeeze.dim(sum_41, 3); sum_41 = None + view_1249 = torch.ops.aten.view.default(permute_567, [2, 8192, 8, 4, 128]); permute_567 = None + sum_42 = torch.ops.aten.sum.dim_IntList(view_1249, [3], True); view_1249 = None + squeeze_13 = torch.ops.aten.squeeze.dim(sum_42, 3); sum_42 = None + convert_element_type_1427 = torch.ops.prims.convert_element_type.default(squeeze_13, torch.float32); squeeze_13 = None + convert_element_type_1428 = torch.ops.prims.convert_element_type.default(permute_568, torch.float32); permute_568 = None + view_1250 = torch.ops.aten.view.default(convert_element_type_1427, [2, 8192, 8, 64, 2]); convert_element_type_1427 = None + view_as_complex_76 = torch.ops.aten.view_as_complex.default(view_1250); view_1250 = None + mul_396 = torch.ops.aten.mul.Tensor(view_as_complex_76, _conj); view_as_complex_76 = None + view_1251 = torch.ops.aten.view.default(convert_element_type_1428, [2, 8192, 32, 64, 2]); convert_element_type_1428 = None + view_as_complex_77 = torch.ops.aten.view_as_complex.default(view_1251); view_1251 = None + mul_397 = torch.ops.aten.mul.Tensor(view_as_complex_77, _conj); view_as_complex_77 = None + view_as_real_76 = torch.ops.aten.view_as_real.default(mul_396); mul_396 = None + view_1252 = torch.ops.aten.view.default(view_as_real_76, [2, 8192, 8, 128]); view_as_real_76 = None + convert_element_type_1429 = torch.ops.prims.convert_element_type.default(view_1252, torch.bfloat16); view_1252 = None + view_as_real_77 = torch.ops.aten.view_as_real.default(mul_397); mul_397 = None + view_1253 = torch.ops.aten.view.default(view_as_real_77, [2, 8192, 32, 128]); view_as_real_77 = None + convert_element_type_1430 = torch.ops.prims.convert_element_type.default(view_1253, torch.bfloat16); view_1253 = None + view_1254 = torch.ops.aten.view.default(squeeze_12, [2, 8192, 1024]); squeeze_12 = None + view_1255 = torch.ops.aten.view.default(convert_element_type_1429, [2, 8192, 1024]); convert_element_type_1429 = None + view_1256 = torch.ops.aten.view.default(convert_element_type_1430, [2, 8192, 4096]); convert_element_type_1430 = None + view_1257 = torch.ops.aten.view.default(view_1254, [16384, 1024]); view_1254 = None + permute_569 = torch.ops.aten.permute.default(view_1257, [1, 0]) + mm_319 = torch.ops.aten.mm.default(permute_569, view_853); permute_569 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16); primals_232 = None + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 64, '0'); convert_element_type_835 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_229, [1, 0]); wait_tensor_229 = None + permute_571 = torch.ops.aten.permute.default(permute_277, [1, 0]); permute_277 = None + mm_320 = torch.ops.aten.mm.default(view_1257, permute_571); view_1257 = permute_571 = None + view_1258 = torch.ops.aten.view.default(mm_320, [2, 8192, 4096]); mm_320 = None + convert_element_type_1435 = torch.ops.prims.convert_element_type.default(mm_319, torch.float32); mm_319 = None + reduce_scatter_tensor_61 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1435, 'avg', 64, '0'); convert_element_type_1435 = None + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61); reduce_scatter_tensor_61 = None + view_1259 = torch.ops.aten.view.default(view_1255, [16384, 1024]); view_1255 = None + permute_573 = torch.ops.aten.permute.default(view_1259, [1, 0]) + mm_321 = torch.ops.aten.mm.default(permute_573, view_853); permute_573 = None + permute_575 = torch.ops.aten.permute.default(permute_276, [1, 0]); permute_276 = None + mm_322 = torch.ops.aten.mm.default(view_1259, permute_575); view_1259 = permute_575 = None + view_1260 = torch.ops.aten.view.default(mm_322, [2, 8192, 4096]); mm_322 = None + add_175 = torch.ops.aten.add.Tensor(view_1258, view_1260); view_1258 = view_1260 = None + convert_element_type_1440 = torch.ops.prims.convert_element_type.default(mm_321, torch.float32); mm_321 = None + reduce_scatter_tensor_62 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1440, 'avg', 64, '0'); convert_element_type_1440 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_62); reduce_scatter_tensor_62 = None + view_1261 = torch.ops.aten.view.default(view_1256, [16384, 4096]); view_1256 = None + permute_577 = torch.ops.aten.permute.default(view_1261, [1, 0]) + mm_323 = torch.ops.aten.mm.default(permute_577, view_853); permute_577 = view_853 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16); primals_230 = None + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 64, '0'); convert_element_type_829 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + permute_579 = torch.ops.aten.permute.default(permute_275, [1, 0]); permute_275 = None + mm_324 = torch.ops.aten.mm.default(view_1261, permute_579); view_1261 = permute_579 = None + view_1262 = torch.ops.aten.view.default(mm_324, [2, 8192, 4096]); mm_324 = None + add_176 = torch.ops.aten.add.Tensor(add_175, view_1262); add_175 = view_1262 = None + convert_element_type_1445 = torch.ops.prims.convert_element_type.default(mm_323, torch.float32); mm_323 = None + reduce_scatter_tensor_63 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1445, 'avg', 64, '0'); convert_element_type_1445 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63); reduce_scatter_tensor_63 = None + convert_element_type_1446 = torch.ops.prims.convert_element_type.default(add_176, torch.float32); add_176 = None + convert_element_type_1448 = torch.ops.prims.convert_element_type.default(wait_tensor_226, torch.float32); wait_tensor_226 = None + mul_398 = torch.ops.aten.mul.Tensor(convert_element_type_1446, convert_element_type_1448); convert_element_type_1448 = None + mul_400 = torch.ops.aten.mul.Tensor(mul_200, mul_398) + sum_43 = torch.ops.aten.sum.dim_IntList(mul_400, [2], True); mul_400 = None + div_14 = torch.ops.aten.div.Tensor(mul_200, 4096) + mul_401 = torch.ops.aten.mul.Tensor(div_14, sum_43); div_14 = sum_43 = None + sub_21 = torch.ops.aten.sub.Tensor(mul_398, mul_401); mul_398 = mul_401 = None + mul_402 = torch.ops.aten.mul.Tensor(sub_21, rsqrt_50); sub_21 = rsqrt_50 = None + mul_403 = torch.ops.aten.mul.Tensor(convert_element_type_1446, mul_200); convert_element_type_1446 = mul_200 = None + sum_44 = torch.ops.aten.sum.dim_IntList(mul_403, [0, 1]); mul_403 = None + convert_element_type_1449 = torch.ops.prims.convert_element_type.default(mul_402, torch.bfloat16); mul_402 = None + add_177 = torch.ops.aten.add.Tensor(add_174, convert_element_type_1449); add_174 = convert_element_type_1449 = None + convert_element_type_default_51 = torch.ops.prims.convert_element_type.default(sum_44, torch.float32); sum_44 = None + reduce_scatter_tensor_64 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_51, 'avg', 64, '0'); convert_element_type_default_51 = None + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64); reduce_scatter_tensor_64 = None + view_1263 = torch.ops.aten.view.default(add_177, [16384, 4096]) + permute_581 = torch.ops.aten.permute.default(view_1263, [1, 0]) + permute_270 = torch.ops.aten.permute.default(getitem_216, [0, 2, 1, 3]) + view_837 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16); primals_224 = None + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 64, '0'); convert_element_type_809 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + view_839 = torch.ops.aten.view.default(view_837, [16384, 4096]); view_837 = None + mm_171 = torch.ops.aten.mm.default(view_839, permute_271) + view_840 = torch.ops.aten.view.default(mm_171, [2, 8192, 4096]); mm_171 = None + add_97 = torch.ops.aten.add.Tensor(add_95, view_840); view_840 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16); primals_225 = None + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 64, '0'); convert_element_type_812 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32); add_97 = None + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_222) + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + view_843 = torch.ops.aten.view.default(convert_element_type_814, [16384, 4096]); convert_element_type_814 = None + view_844 = torch.ops.aten.view.default(mm_172, [2, 8192, 14336]); mm_172 = None + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_844, torch.float32); view_844 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16); primals_227 = None + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 64, '0'); convert_element_type_820 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_224, [1, 0]); wait_tensor_224 = None + mm_173 = torch.ops.aten.mm.default(view_843, permute_273) + view_847 = torch.ops.aten.view.default(mm_173, [2, 8192, 14336]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_847) + view_849 = torch.ops.aten.view.default(mul_199, [16384, 14336]); mul_199 = None + mm_325 = torch.ops.aten.mm.default(permute_581, view_849); permute_581 = view_849 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16); primals_228 = None + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 64, '0'); convert_element_type_823 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + permute_583 = torch.ops.aten.permute.default(permute_274, [1, 0]); permute_274 = None + mm_326 = torch.ops.aten.mm.default(view_1263, permute_583); view_1263 = permute_583 = None + view_1264 = torch.ops.aten.view.default(mm_326, [2, 8192, 14336]); mm_326 = None + convert_element_type_1456 = torch.ops.prims.convert_element_type.default(mm_325, torch.float32); mm_325 = None + reduce_scatter_tensor_65 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1456, 'avg', 64, '0'); convert_element_type_1456 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_65); reduce_scatter_tensor_65 = None + mul_404 = torch.ops.aten.mul.Tensor(view_1264, convert_element_type_819); convert_element_type_819 = None + mul_405 = torch.ops.aten.mul.Tensor(view_1264, view_847); view_1264 = view_847 = None + view_1265 = torch.ops.aten.view.default(mul_404, [16384, 14336]); mul_404 = None + permute_585 = torch.ops.aten.permute.default(view_1265, [1, 0]) + mm_327 = torch.ops.aten.mm.default(permute_585, view_843); permute_585 = None + permute_587 = torch.ops.aten.permute.default(permute_273, [1, 0]); permute_273 = None + mm_328 = torch.ops.aten.mm.default(view_1265, permute_587); view_1265 = permute_587 = None + view_1266 = torch.ops.aten.view.default(mm_328, [2, 8192, 4096]); mm_328 = None + convert_element_type_1461 = torch.ops.prims.convert_element_type.default(mm_327, torch.float32); mm_327 = None + reduce_scatter_tensor_66 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1461, 'avg', 64, '0'); convert_element_type_1461 = None + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_66); reduce_scatter_tensor_66 = None + convert_element_type_1462 = torch.ops.prims.convert_element_type.default(mul_405, torch.float32); mul_405 = None + neg_7 = torch.ops.aten.neg.default(convert_element_type_818) + exp_7 = torch.ops.aten.exp.default(neg_7); neg_7 = None + add_178 = torch.ops.aten.add.Tensor(exp_7, 1); exp_7 = None + reciprocal_7 = torch.ops.aten.reciprocal.default(add_178); add_178 = None + mul_406 = torch.ops.aten.mul.Tensor(reciprocal_7, 1); reciprocal_7 = None + mul_407 = torch.ops.aten.mul.Tensor(convert_element_type_1462, mul_406); convert_element_type_1462 = None + sub_22 = torch.ops.aten.sub.Tensor(1, mul_406); mul_406 = None + mul_408 = torch.ops.aten.mul.Tensor(convert_element_type_818, sub_22); convert_element_type_818 = sub_22 = None + add_179 = torch.ops.aten.add.Tensor(mul_408, 1); mul_408 = None + mul_409 = torch.ops.aten.mul.Tensor(mul_407, add_179); mul_407 = add_179 = None + convert_element_type_1464 = torch.ops.prims.convert_element_type.default(mul_409, torch.bfloat16); mul_409 = None + view_1267 = torch.ops.aten.view.default(convert_element_type_1464, [16384, 14336]); convert_element_type_1464 = None + permute_589 = torch.ops.aten.permute.default(view_1267, [1, 0]) + mm_329 = torch.ops.aten.mm.default(permute_589, view_843); permute_589 = view_843 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16); primals_226 = None + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 64, '0'); convert_element_type_815 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_223, [1, 0]); wait_tensor_223 = None + permute_591 = torch.ops.aten.permute.default(permute_272, [1, 0]); permute_272 = None + mm_330 = torch.ops.aten.mm.default(view_1267, permute_591); view_1267 = permute_591 = None + view_1268 = torch.ops.aten.view.default(mm_330, [2, 8192, 4096]); mm_330 = None + add_180 = torch.ops.aten.add.Tensor(view_1266, view_1268); view_1266 = view_1268 = None + convert_element_type_1469 = torch.ops.prims.convert_element_type.default(mm_329, torch.float32); mm_329 = None + reduce_scatter_tensor_67 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1469, 'avg', 64, '0'); convert_element_type_1469 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_67); reduce_scatter_tensor_67 = None + convert_element_type_1470 = torch.ops.prims.convert_element_type.default(add_180, torch.float32); add_180 = None + convert_element_type_1472 = torch.ops.prims.convert_element_type.default(wait_tensor_222, torch.float32); wait_tensor_222 = None + mul_410 = torch.ops.aten.mul.Tensor(convert_element_type_1470, convert_element_type_1472); convert_element_type_1472 = None + mul_412 = torch.ops.aten.mul.Tensor(mul_196, mul_410) + sum_45 = torch.ops.aten.sum.dim_IntList(mul_412, [2], True); mul_412 = None + div_15 = torch.ops.aten.div.Tensor(mul_196, 4096) + mul_413 = torch.ops.aten.mul.Tensor(div_15, sum_45); div_15 = sum_45 = None + sub_23 = torch.ops.aten.sub.Tensor(mul_410, mul_413); mul_410 = mul_413 = None + mul_414 = torch.ops.aten.mul.Tensor(sub_23, rsqrt_49); sub_23 = rsqrt_49 = None + mul_415 = torch.ops.aten.mul.Tensor(convert_element_type_1470, mul_196); convert_element_type_1470 = mul_196 = None + sum_46 = torch.ops.aten.sum.dim_IntList(mul_415, [0, 1]); mul_415 = None + convert_element_type_1473 = torch.ops.prims.convert_element_type.default(mul_414, torch.bfloat16); mul_414 = None + add_181 = torch.ops.aten.add.Tensor(add_177, convert_element_type_1473); add_177 = convert_element_type_1473 = None + convert_element_type_default_50 = torch.ops.prims.convert_element_type.default(sum_46, torch.float32); sum_46 = None + reduce_scatter_tensor_68 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_50, 'avg', 64, '0'); convert_element_type_default_50 = None + wait_tensor_359 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_68); reduce_scatter_tensor_68 = None + view_1269 = torch.ops.aten.view.default(add_181, [16384, 4096]) + permute_593 = torch.ops.aten.permute.default(view_1269, [1, 0]) + mm_331 = torch.ops.aten.mm.default(permute_593, view_839); permute_593 = view_839 = None + permute_595 = torch.ops.aten.permute.default(permute_271, [1, 0]); permute_271 = None + mm_332 = torch.ops.aten.mm.default(view_1269, permute_595); view_1269 = permute_595 = None + view_1270 = torch.ops.aten.view.default(mm_332, [2, 8192, 4096]); mm_332 = None + convert_element_type_1480 = torch.ops.prims.convert_element_type.default(mm_331, torch.float32); mm_331 = None + reduce_scatter_tensor_69 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1480, 'avg', 64, '0'); convert_element_type_1480 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_69); reduce_scatter_tensor_69 = None + view_1271 = torch.ops.aten.view.default(view_1270, [2, 8192, 32, 128]); view_1270 = None + permute_597 = torch.ops.aten.permute.default(view_1271, [0, 2, 1, 3]); view_1271 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16); primals_220 = None + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 64, '0'); convert_element_type_793 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32); add_95 = None + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_217) + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + view_819 = torch.ops.aten.view.default(convert_element_type_795, [16384, 4096]); convert_element_type_795 = None + view_820 = torch.ops.aten.view.default(mm_168, [2, 8192, 4096]); mm_168 = None + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16); primals_222 = None + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 64, '0'); convert_element_type_799 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + mm_169 = torch.ops.aten.mm.default(view_819, permute_265) + view_823 = torch.ops.aten.view.default(mm_169, [2, 8192, 1024]); mm_169 = None + view_826 = torch.ops.aten.view.default(mm_170, [2, 8192, 1024]); mm_170 = None + view_827 = torch.ops.aten.view.default(view_820, [2, 8192, -1, 128]); view_820 = None + view_828 = torch.ops.aten.view.default(view_823, [2, 8192, -1, 128]); view_823 = None + view_829 = torch.ops.aten.view.default(view_826, [2, 8192, -1, 128]); view_826 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_827, torch.float32); view_827 = None + view_830 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 32, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_830); view_830 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_828, torch.float32); view_828 = None + view_831 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 8, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_831); view_831 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_16); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_833 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 32, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_16); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_834 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 8, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_833, torch.bfloat16); view_833 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_834, torch.bfloat16); view_834 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 8, 4, 128]); unsqueeze_48 = None + clone_48 = torch.ops.aten.clone.default(expand_48, memory_format = torch.contiguous_format); expand_48 = None + view_835 = torch.ops.aten.view.default(clone_48, [2, 8192, 32, 128]); clone_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_829, 3); view_829 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 8, 4, 128]); unsqueeze_49 = None + clone_49 = torch.ops.aten.clone.default(expand_49, memory_format = torch.contiguous_format); expand_49 = None + view_836 = torch.ops.aten.view.default(clone_49, [2, 8192, 32, 128]); clone_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_835, [0, 2, 1, 3]); view_835 = None + permute_269 = torch.ops.aten.permute.default(view_836, [0, 2, 1, 3]); view_836 = None + _scaled_dot_product_cudnn_attention_backward_7 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_597, permute_267, permute_268, permute_269, getitem_216, getitem_217, getitem_222, getitem_223, None, None, None, 8192, 8192, 0.0, True); permute_597 = permute_267 = permute_268 = permute_269 = getitem_216 = getitem_217 = getitem_222 = getitem_223 = None + getitem_309 = _scaled_dot_product_cudnn_attention_backward_7[0] + getitem_310 = _scaled_dot_product_cudnn_attention_backward_7[1] + getitem_311 = _scaled_dot_product_cudnn_attention_backward_7[2]; _scaled_dot_product_cudnn_attention_backward_7 = None + permute_598 = torch.ops.aten.permute.default(getitem_311, [0, 2, 1, 3]); getitem_311 = None + permute_599 = torch.ops.aten.permute.default(getitem_310, [0, 2, 1, 3]); getitem_310 = None + permute_600 = torch.ops.aten.permute.default(getitem_309, [0, 2, 1, 3]); getitem_309 = None + view_1272 = torch.ops.aten.view.default(permute_598, [2, 8192, 8, 4, 128]); permute_598 = None + sum_47 = torch.ops.aten.sum.dim_IntList(view_1272, [3], True); view_1272 = None + squeeze_14 = torch.ops.aten.squeeze.dim(sum_47, 3); sum_47 = None + view_1273 = torch.ops.aten.view.default(permute_599, [2, 8192, 8, 4, 128]); permute_599 = None + sum_48 = torch.ops.aten.sum.dim_IntList(view_1273, [3], True); view_1273 = None + squeeze_15 = torch.ops.aten.squeeze.dim(sum_48, 3); sum_48 = None + convert_element_type_1481 = torch.ops.prims.convert_element_type.default(squeeze_15, torch.float32); squeeze_15 = None + convert_element_type_1482 = torch.ops.prims.convert_element_type.default(permute_600, torch.float32); permute_600 = None + view_1274 = torch.ops.aten.view.default(convert_element_type_1481, [2, 8192, 8, 64, 2]); convert_element_type_1481 = None + view_as_complex_78 = torch.ops.aten.view_as_complex.default(view_1274); view_1274 = None + mul_416 = torch.ops.aten.mul.Tensor(view_as_complex_78, _conj); view_as_complex_78 = None + view_1275 = torch.ops.aten.view.default(convert_element_type_1482, [2, 8192, 32, 64, 2]); convert_element_type_1482 = None + view_as_complex_79 = torch.ops.aten.view_as_complex.default(view_1275); view_1275 = None + mul_417 = torch.ops.aten.mul.Tensor(view_as_complex_79, _conj); view_as_complex_79 = None + view_as_real_78 = torch.ops.aten.view_as_real.default(mul_416); mul_416 = None + view_1276 = torch.ops.aten.view.default(view_as_real_78, [2, 8192, 8, 128]); view_as_real_78 = None + convert_element_type_1483 = torch.ops.prims.convert_element_type.default(view_1276, torch.bfloat16); view_1276 = None + view_as_real_79 = torch.ops.aten.view_as_real.default(mul_417); mul_417 = None + view_1277 = torch.ops.aten.view.default(view_as_real_79, [2, 8192, 32, 128]); view_as_real_79 = None + convert_element_type_1484 = torch.ops.prims.convert_element_type.default(view_1277, torch.bfloat16); view_1277 = None + view_1278 = torch.ops.aten.view.default(squeeze_14, [2, 8192, 1024]); squeeze_14 = None + view_1279 = torch.ops.aten.view.default(convert_element_type_1483, [2, 8192, 1024]); convert_element_type_1483 = None + view_1280 = torch.ops.aten.view.default(convert_element_type_1484, [2, 8192, 4096]); convert_element_type_1484 = None + view_1281 = torch.ops.aten.view.default(view_1278, [16384, 1024]); view_1278 = None + permute_601 = torch.ops.aten.permute.default(view_1281, [1, 0]) + mm_333 = torch.ops.aten.mm.default(permute_601, view_819); permute_601 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16); primals_223 = None + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 64, '0'); convert_element_type_802 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + permute_603 = torch.ops.aten.permute.default(permute_266, [1, 0]); permute_266 = None + mm_334 = torch.ops.aten.mm.default(view_1281, permute_603); view_1281 = permute_603 = None + view_1282 = torch.ops.aten.view.default(mm_334, [2, 8192, 4096]); mm_334 = None + convert_element_type_1489 = torch.ops.prims.convert_element_type.default(mm_333, torch.float32); mm_333 = None + reduce_scatter_tensor_70 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1489, 'avg', 64, '0'); convert_element_type_1489 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_70); reduce_scatter_tensor_70 = None + view_1283 = torch.ops.aten.view.default(view_1279, [16384, 1024]); view_1279 = None + permute_605 = torch.ops.aten.permute.default(view_1283, [1, 0]) + mm_335 = torch.ops.aten.mm.default(permute_605, view_819); permute_605 = None + permute_607 = torch.ops.aten.permute.default(permute_265, [1, 0]); permute_265 = None + mm_336 = torch.ops.aten.mm.default(view_1283, permute_607); view_1283 = permute_607 = None + view_1284 = torch.ops.aten.view.default(mm_336, [2, 8192, 4096]); mm_336 = None + add_182 = torch.ops.aten.add.Tensor(view_1282, view_1284); view_1282 = view_1284 = None + convert_element_type_1494 = torch.ops.prims.convert_element_type.default(mm_335, torch.float32); mm_335 = None + reduce_scatter_tensor_71 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1494, 'avg', 64, '0'); convert_element_type_1494 = None + wait_tensor_362 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_71); reduce_scatter_tensor_71 = None + view_1285 = torch.ops.aten.view.default(view_1280, [16384, 4096]); view_1280 = None + permute_609 = torch.ops.aten.permute.default(view_1285, [1, 0]) + mm_337 = torch.ops.aten.mm.default(permute_609, view_819); permute_609 = view_819 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16); primals_221 = None + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 64, '0'); convert_element_type_796 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_218, [1, 0]); wait_tensor_218 = None + permute_611 = torch.ops.aten.permute.default(permute_264, [1, 0]); permute_264 = None + mm_338 = torch.ops.aten.mm.default(view_1285, permute_611); view_1285 = permute_611 = None + view_1286 = torch.ops.aten.view.default(mm_338, [2, 8192, 4096]); mm_338 = None + add_183 = torch.ops.aten.add.Tensor(add_182, view_1286); add_182 = view_1286 = None + convert_element_type_1499 = torch.ops.prims.convert_element_type.default(mm_337, torch.float32); mm_337 = None + reduce_scatter_tensor_72 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1499, 'avg', 64, '0'); convert_element_type_1499 = None + wait_tensor_363 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_72); reduce_scatter_tensor_72 = None + convert_element_type_1500 = torch.ops.prims.convert_element_type.default(add_183, torch.float32); add_183 = None + convert_element_type_1502 = torch.ops.prims.convert_element_type.default(wait_tensor_217, torch.float32); wait_tensor_217 = None + mul_418 = torch.ops.aten.mul.Tensor(convert_element_type_1500, convert_element_type_1502); convert_element_type_1502 = None + mul_420 = torch.ops.aten.mul.Tensor(mul_192, mul_418) + sum_49 = torch.ops.aten.sum.dim_IntList(mul_420, [2], True); mul_420 = None + div_16 = torch.ops.aten.div.Tensor(mul_192, 4096) + mul_421 = torch.ops.aten.mul.Tensor(div_16, sum_49); div_16 = sum_49 = None + sub_24 = torch.ops.aten.sub.Tensor(mul_418, mul_421); mul_418 = mul_421 = None + mul_422 = torch.ops.aten.mul.Tensor(sub_24, rsqrt_48); sub_24 = rsqrt_48 = None + mul_423 = torch.ops.aten.mul.Tensor(convert_element_type_1500, mul_192); convert_element_type_1500 = mul_192 = None + sum_50 = torch.ops.aten.sum.dim_IntList(mul_423, [0, 1]); mul_423 = None + convert_element_type_1503 = torch.ops.prims.convert_element_type.default(mul_422, torch.bfloat16); mul_422 = None + add_184 = torch.ops.aten.add.Tensor(add_181, convert_element_type_1503); add_181 = convert_element_type_1503 = None + convert_element_type_default_49 = torch.ops.prims.convert_element_type.default(sum_50, torch.float32); sum_50 = None + reduce_scatter_tensor_73 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_49, 'avg', 64, '0'); convert_element_type_default_49 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_73); reduce_scatter_tensor_73 = None + view_1287 = torch.ops.aten.view.default(add_184, [16384, 4096]) + permute_613 = torch.ops.aten.permute.default(view_1287, [1, 0]) + permute_259 = torch.ops.aten.permute.default(getitem_207, [0, 2, 1, 3]) + view_803 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16); primals_215 = None + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 64, '0'); convert_element_type_776 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + view_805 = torch.ops.aten.view.default(view_803, [16384, 4096]); view_803 = None + mm_164 = torch.ops.aten.mm.default(view_805, permute_260) + view_806 = torch.ops.aten.view.default(mm_164, [2, 8192, 4096]); mm_164 = None + add_93 = torch.ops.aten.add.Tensor(add_91, view_806); view_806 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16); primals_216 = None + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 64, '0'); convert_element_type_779 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32); add_93 = None + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_213) + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + view_809 = torch.ops.aten.view.default(convert_element_type_781, [16384, 4096]); convert_element_type_781 = None + view_810 = torch.ops.aten.view.default(mm_165, [2, 8192, 14336]); mm_165 = None + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_810, torch.float32); view_810 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16); primals_218 = None + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 64, '0'); convert_element_type_787 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + mm_166 = torch.ops.aten.mm.default(view_809, permute_262) + view_813 = torch.ops.aten.view.default(mm_166, [2, 8192, 14336]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_813) + view_815 = torch.ops.aten.view.default(mul_191, [16384, 14336]); mul_191 = None + mm_339 = torch.ops.aten.mm.default(permute_613, view_815); permute_613 = view_815 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16); primals_219 = None + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 64, '0'); convert_element_type_790 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_216, [1, 0]); wait_tensor_216 = None + permute_615 = torch.ops.aten.permute.default(permute_263, [1, 0]); permute_263 = None + mm_340 = torch.ops.aten.mm.default(view_1287, permute_615); view_1287 = permute_615 = None + view_1288 = torch.ops.aten.view.default(mm_340, [2, 8192, 14336]); mm_340 = None + convert_element_type_1510 = torch.ops.prims.convert_element_type.default(mm_339, torch.float32); mm_339 = None + reduce_scatter_tensor_74 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1510, 'avg', 64, '0'); convert_element_type_1510 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_74); reduce_scatter_tensor_74 = None + mul_424 = torch.ops.aten.mul.Tensor(view_1288, convert_element_type_786); convert_element_type_786 = None + mul_425 = torch.ops.aten.mul.Tensor(view_1288, view_813); view_1288 = view_813 = None + view_1289 = torch.ops.aten.view.default(mul_424, [16384, 14336]); mul_424 = None + permute_617 = torch.ops.aten.permute.default(view_1289, [1, 0]) + mm_341 = torch.ops.aten.mm.default(permute_617, view_809); permute_617 = None + permute_619 = torch.ops.aten.permute.default(permute_262, [1, 0]); permute_262 = None + mm_342 = torch.ops.aten.mm.default(view_1289, permute_619); view_1289 = permute_619 = None + view_1290 = torch.ops.aten.view.default(mm_342, [2, 8192, 4096]); mm_342 = None + convert_element_type_1515 = torch.ops.prims.convert_element_type.default(mm_341, torch.float32); mm_341 = None + reduce_scatter_tensor_75 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1515, 'avg', 64, '0'); convert_element_type_1515 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_75); reduce_scatter_tensor_75 = None + convert_element_type_1516 = torch.ops.prims.convert_element_type.default(mul_425, torch.float32); mul_425 = None + neg_8 = torch.ops.aten.neg.default(convert_element_type_785) + exp_8 = torch.ops.aten.exp.default(neg_8); neg_8 = None + add_185 = torch.ops.aten.add.Tensor(exp_8, 1); exp_8 = None + reciprocal_8 = torch.ops.aten.reciprocal.default(add_185); add_185 = None + mul_426 = torch.ops.aten.mul.Tensor(reciprocal_8, 1); reciprocal_8 = None + mul_427 = torch.ops.aten.mul.Tensor(convert_element_type_1516, mul_426); convert_element_type_1516 = None + sub_25 = torch.ops.aten.sub.Tensor(1, mul_426); mul_426 = None + mul_428 = torch.ops.aten.mul.Tensor(convert_element_type_785, sub_25); convert_element_type_785 = sub_25 = None + add_186 = torch.ops.aten.add.Tensor(mul_428, 1); mul_428 = None + mul_429 = torch.ops.aten.mul.Tensor(mul_427, add_186); mul_427 = add_186 = None + convert_element_type_1518 = torch.ops.prims.convert_element_type.default(mul_429, torch.bfloat16); mul_429 = None + view_1291 = torch.ops.aten.view.default(convert_element_type_1518, [16384, 14336]); convert_element_type_1518 = None + permute_621 = torch.ops.aten.permute.default(view_1291, [1, 0]) + mm_343 = torch.ops.aten.mm.default(permute_621, view_809); permute_621 = view_809 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16); primals_217 = None + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 64, '0'); convert_element_type_782 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + permute_623 = torch.ops.aten.permute.default(permute_261, [1, 0]); permute_261 = None + mm_344 = torch.ops.aten.mm.default(view_1291, permute_623); view_1291 = permute_623 = None + view_1292 = torch.ops.aten.view.default(mm_344, [2, 8192, 4096]); mm_344 = None + add_187 = torch.ops.aten.add.Tensor(view_1290, view_1292); view_1290 = view_1292 = None + convert_element_type_1523 = torch.ops.prims.convert_element_type.default(mm_343, torch.float32); mm_343 = None + reduce_scatter_tensor_76 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1523, 'avg', 64, '0'); convert_element_type_1523 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_76); reduce_scatter_tensor_76 = None + convert_element_type_1524 = torch.ops.prims.convert_element_type.default(add_187, torch.float32); add_187 = None + convert_element_type_1526 = torch.ops.prims.convert_element_type.default(wait_tensor_213, torch.float32); wait_tensor_213 = None + mul_430 = torch.ops.aten.mul.Tensor(convert_element_type_1524, convert_element_type_1526); convert_element_type_1526 = None + mul_432 = torch.ops.aten.mul.Tensor(mul_188, mul_430) + sum_51 = torch.ops.aten.sum.dim_IntList(mul_432, [2], True); mul_432 = None + div_17 = torch.ops.aten.div.Tensor(mul_188, 4096) + mul_433 = torch.ops.aten.mul.Tensor(div_17, sum_51); div_17 = sum_51 = None + sub_26 = torch.ops.aten.sub.Tensor(mul_430, mul_433); mul_430 = mul_433 = None + mul_434 = torch.ops.aten.mul.Tensor(sub_26, rsqrt_47); sub_26 = rsqrt_47 = None + mul_435 = torch.ops.aten.mul.Tensor(convert_element_type_1524, mul_188); convert_element_type_1524 = mul_188 = None + sum_52 = torch.ops.aten.sum.dim_IntList(mul_435, [0, 1]); mul_435 = None + convert_element_type_1527 = torch.ops.prims.convert_element_type.default(mul_434, torch.bfloat16); mul_434 = None + add_188 = torch.ops.aten.add.Tensor(add_184, convert_element_type_1527); add_184 = convert_element_type_1527 = None + convert_element_type_default_48 = torch.ops.prims.convert_element_type.default(sum_52, torch.float32); sum_52 = None + reduce_scatter_tensor_77 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_48, 'avg', 64, '0'); convert_element_type_default_48 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_77); reduce_scatter_tensor_77 = None + view_1293 = torch.ops.aten.view.default(add_188, [16384, 4096]) + permute_625 = torch.ops.aten.permute.default(view_1293, [1, 0]) + mm_345 = torch.ops.aten.mm.default(permute_625, view_805); permute_625 = view_805 = None + permute_627 = torch.ops.aten.permute.default(permute_260, [1, 0]); permute_260 = None + mm_346 = torch.ops.aten.mm.default(view_1293, permute_627); view_1293 = permute_627 = None + view_1294 = torch.ops.aten.view.default(mm_346, [2, 8192, 4096]); mm_346 = None + convert_element_type_1534 = torch.ops.prims.convert_element_type.default(mm_345, torch.float32); mm_345 = None + reduce_scatter_tensor_78 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1534, 'avg', 64, '0'); convert_element_type_1534 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_78); reduce_scatter_tensor_78 = None + view_1295 = torch.ops.aten.view.default(view_1294, [2, 8192, 32, 128]); view_1294 = None + permute_629 = torch.ops.aten.permute.default(view_1295, [0, 2, 1, 3]); view_1295 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16); primals_211 = None + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 64, '0'); convert_element_type_760 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32); add_91 = None + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_208) + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + view_785 = torch.ops.aten.view.default(convert_element_type_762, [16384, 4096]); convert_element_type_762 = None + view_786 = torch.ops.aten.view.default(mm_161, [2, 8192, 4096]); mm_161 = None + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16); primals_213 = None + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 64, '0'); convert_element_type_766 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_210, [1, 0]); wait_tensor_210 = None + mm_162 = torch.ops.aten.mm.default(view_785, permute_254) + view_789 = torch.ops.aten.view.default(mm_162, [2, 8192, 1024]); mm_162 = None + view_792 = torch.ops.aten.view.default(mm_163, [2, 8192, 1024]); mm_163 = None + view_793 = torch.ops.aten.view.default(view_786, [2, 8192, -1, 128]); view_786 = None + view_794 = torch.ops.aten.view.default(view_789, [2, 8192, -1, 128]); view_789 = None + view_795 = torch.ops.aten.view.default(view_792, [2, 8192, -1, 128]); view_792 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_793, torch.float32); view_793 = None + view_796 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 32, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_796); view_796 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_794, torch.float32); view_794 = None + view_797 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 8, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_797); view_797 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_16); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_799 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 32, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_16); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_800 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 8, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_799, torch.bfloat16); view_799 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_800, torch.bfloat16); view_800 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 8, 4, 128]); unsqueeze_46 = None + clone_46 = torch.ops.aten.clone.default(expand_46, memory_format = torch.contiguous_format); expand_46 = None + view_801 = torch.ops.aten.view.default(clone_46, [2, 8192, 32, 128]); clone_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_795, 3); view_795 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 8, 4, 128]); unsqueeze_47 = None + clone_47 = torch.ops.aten.clone.default(expand_47, memory_format = torch.contiguous_format); expand_47 = None + view_802 = torch.ops.aten.view.default(clone_47, [2, 8192, 32, 128]); clone_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_801, [0, 2, 1, 3]); view_801 = None + permute_258 = torch.ops.aten.permute.default(view_802, [0, 2, 1, 3]); view_802 = None + _scaled_dot_product_cudnn_attention_backward_8 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_629, permute_256, permute_257, permute_258, getitem_207, getitem_208, getitem_213, getitem_214, None, None, None, 8192, 8192, 0.0, True); permute_629 = permute_256 = permute_257 = permute_258 = getitem_207 = getitem_208 = getitem_213 = getitem_214 = None + getitem_312 = _scaled_dot_product_cudnn_attention_backward_8[0] + getitem_313 = _scaled_dot_product_cudnn_attention_backward_8[1] + getitem_314 = _scaled_dot_product_cudnn_attention_backward_8[2]; _scaled_dot_product_cudnn_attention_backward_8 = None + permute_630 = torch.ops.aten.permute.default(getitem_314, [0, 2, 1, 3]); getitem_314 = None + permute_631 = torch.ops.aten.permute.default(getitem_313, [0, 2, 1, 3]); getitem_313 = None + permute_632 = torch.ops.aten.permute.default(getitem_312, [0, 2, 1, 3]); getitem_312 = None + view_1296 = torch.ops.aten.view.default(permute_630, [2, 8192, 8, 4, 128]); permute_630 = None + sum_53 = torch.ops.aten.sum.dim_IntList(view_1296, [3], True); view_1296 = None + squeeze_16 = torch.ops.aten.squeeze.dim(sum_53, 3); sum_53 = None + view_1297 = torch.ops.aten.view.default(permute_631, [2, 8192, 8, 4, 128]); permute_631 = None + sum_54 = torch.ops.aten.sum.dim_IntList(view_1297, [3], True); view_1297 = None + squeeze_17 = torch.ops.aten.squeeze.dim(sum_54, 3); sum_54 = None + convert_element_type_1535 = torch.ops.prims.convert_element_type.default(squeeze_17, torch.float32); squeeze_17 = None + convert_element_type_1536 = torch.ops.prims.convert_element_type.default(permute_632, torch.float32); permute_632 = None + view_1298 = torch.ops.aten.view.default(convert_element_type_1535, [2, 8192, 8, 64, 2]); convert_element_type_1535 = None + view_as_complex_80 = torch.ops.aten.view_as_complex.default(view_1298); view_1298 = None + mul_436 = torch.ops.aten.mul.Tensor(view_as_complex_80, _conj); view_as_complex_80 = None + view_1299 = torch.ops.aten.view.default(convert_element_type_1536, [2, 8192, 32, 64, 2]); convert_element_type_1536 = None + view_as_complex_81 = torch.ops.aten.view_as_complex.default(view_1299); view_1299 = None + mul_437 = torch.ops.aten.mul.Tensor(view_as_complex_81, _conj); view_as_complex_81 = None + view_as_real_80 = torch.ops.aten.view_as_real.default(mul_436); mul_436 = None + view_1300 = torch.ops.aten.view.default(view_as_real_80, [2, 8192, 8, 128]); view_as_real_80 = None + convert_element_type_1537 = torch.ops.prims.convert_element_type.default(view_1300, torch.bfloat16); view_1300 = None + view_as_real_81 = torch.ops.aten.view_as_real.default(mul_437); mul_437 = None + view_1301 = torch.ops.aten.view.default(view_as_real_81, [2, 8192, 32, 128]); view_as_real_81 = None + convert_element_type_1538 = torch.ops.prims.convert_element_type.default(view_1301, torch.bfloat16); view_1301 = None + view_1302 = torch.ops.aten.view.default(squeeze_16, [2, 8192, 1024]); squeeze_16 = None + view_1303 = torch.ops.aten.view.default(convert_element_type_1537, [2, 8192, 1024]); convert_element_type_1537 = None + view_1304 = torch.ops.aten.view.default(convert_element_type_1538, [2, 8192, 4096]); convert_element_type_1538 = None + view_1305 = torch.ops.aten.view.default(view_1302, [16384, 1024]); view_1302 = None + permute_633 = torch.ops.aten.permute.default(view_1305, [1, 0]) + mm_347 = torch.ops.aten.mm.default(permute_633, view_785); permute_633 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16); primals_214 = None + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 64, '0'); convert_element_type_769 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_211, [1, 0]); wait_tensor_211 = None + permute_635 = torch.ops.aten.permute.default(permute_255, [1, 0]); permute_255 = None + mm_348 = torch.ops.aten.mm.default(view_1305, permute_635); view_1305 = permute_635 = None + view_1306 = torch.ops.aten.view.default(mm_348, [2, 8192, 4096]); mm_348 = None + convert_element_type_1543 = torch.ops.prims.convert_element_type.default(mm_347, torch.float32); mm_347 = None + reduce_scatter_tensor_79 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1543, 'avg', 64, '0'); convert_element_type_1543 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_79); reduce_scatter_tensor_79 = None + view_1307 = torch.ops.aten.view.default(view_1303, [16384, 1024]); view_1303 = None + permute_637 = torch.ops.aten.permute.default(view_1307, [1, 0]) + mm_349 = torch.ops.aten.mm.default(permute_637, view_785); permute_637 = None + permute_639 = torch.ops.aten.permute.default(permute_254, [1, 0]); permute_254 = None + mm_350 = torch.ops.aten.mm.default(view_1307, permute_639); view_1307 = permute_639 = None + view_1308 = torch.ops.aten.view.default(mm_350, [2, 8192, 4096]); mm_350 = None + add_189 = torch.ops.aten.add.Tensor(view_1306, view_1308); view_1306 = view_1308 = None + convert_element_type_1548 = torch.ops.prims.convert_element_type.default(mm_349, torch.float32); mm_349 = None + reduce_scatter_tensor_80 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1548, 'avg', 64, '0'); convert_element_type_1548 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_80); reduce_scatter_tensor_80 = None + view_1309 = torch.ops.aten.view.default(view_1304, [16384, 4096]); view_1304 = None + permute_641 = torch.ops.aten.permute.default(view_1309, [1, 0]) + mm_351 = torch.ops.aten.mm.default(permute_641, view_785); permute_641 = view_785 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16); primals_212 = None + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 64, '0'); convert_element_type_763 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_209, [1, 0]); wait_tensor_209 = None + permute_643 = torch.ops.aten.permute.default(permute_253, [1, 0]); permute_253 = None + mm_352 = torch.ops.aten.mm.default(view_1309, permute_643); view_1309 = permute_643 = None + view_1310 = torch.ops.aten.view.default(mm_352, [2, 8192, 4096]); mm_352 = None + add_190 = torch.ops.aten.add.Tensor(add_189, view_1310); add_189 = view_1310 = None + convert_element_type_1553 = torch.ops.prims.convert_element_type.default(mm_351, torch.float32); mm_351 = None + reduce_scatter_tensor_81 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1553, 'avg', 64, '0'); convert_element_type_1553 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_81); reduce_scatter_tensor_81 = None + convert_element_type_1554 = torch.ops.prims.convert_element_type.default(add_190, torch.float32); add_190 = None + convert_element_type_1556 = torch.ops.prims.convert_element_type.default(wait_tensor_208, torch.float32); wait_tensor_208 = None + mul_438 = torch.ops.aten.mul.Tensor(convert_element_type_1554, convert_element_type_1556); convert_element_type_1556 = None + mul_440 = torch.ops.aten.mul.Tensor(mul_184, mul_438) + sum_55 = torch.ops.aten.sum.dim_IntList(mul_440, [2], True); mul_440 = None + div_18 = torch.ops.aten.div.Tensor(mul_184, 4096) + mul_441 = torch.ops.aten.mul.Tensor(div_18, sum_55); div_18 = sum_55 = None + sub_27 = torch.ops.aten.sub.Tensor(mul_438, mul_441); mul_438 = mul_441 = None + mul_442 = torch.ops.aten.mul.Tensor(sub_27, rsqrt_46); sub_27 = rsqrt_46 = None + mul_443 = torch.ops.aten.mul.Tensor(convert_element_type_1554, mul_184); convert_element_type_1554 = mul_184 = None + sum_56 = torch.ops.aten.sum.dim_IntList(mul_443, [0, 1]); mul_443 = None + convert_element_type_1557 = torch.ops.prims.convert_element_type.default(mul_442, torch.bfloat16); mul_442 = None + add_191 = torch.ops.aten.add.Tensor(add_188, convert_element_type_1557); add_188 = convert_element_type_1557 = None + convert_element_type_default_47 = torch.ops.prims.convert_element_type.default(sum_56, torch.float32); sum_56 = None + reduce_scatter_tensor_82 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_47, 'avg', 64, '0'); convert_element_type_default_47 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_82); reduce_scatter_tensor_82 = None + view_1311 = torch.ops.aten.view.default(add_191, [16384, 4096]) + permute_645 = torch.ops.aten.permute.default(view_1311, [1, 0]) + permute_248 = torch.ops.aten.permute.default(getitem_198, [0, 2, 1, 3]) + view_769 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16); primals_206 = None + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 64, '0'); convert_element_type_743 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_203, [1, 0]); wait_tensor_203 = None + view_771 = torch.ops.aten.view.default(view_769, [16384, 4096]); view_769 = None + mm_157 = torch.ops.aten.mm.default(view_771, permute_249) + view_772 = torch.ops.aten.view.default(mm_157, [2, 8192, 4096]); mm_157 = None + add_89 = torch.ops.aten.add.Tensor(add_87, view_772); view_772 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16); primals_207 = None + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 64, '0'); convert_element_type_746 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32); add_89 = None + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_204) + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + view_775 = torch.ops.aten.view.default(convert_element_type_748, [16384, 4096]); convert_element_type_748 = None + view_776 = torch.ops.aten.view.default(mm_158, [2, 8192, 14336]); mm_158 = None + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_776, torch.float32); view_776 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16); primals_209 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 64, '0'); convert_element_type_754 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + mm_159 = torch.ops.aten.mm.default(view_775, permute_251) + view_779 = torch.ops.aten.view.default(mm_159, [2, 8192, 14336]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_779) + view_781 = torch.ops.aten.view.default(mul_183, [16384, 14336]); mul_183 = None + mm_353 = torch.ops.aten.mm.default(permute_645, view_781); permute_645 = view_781 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16); primals_210 = None + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 64, '0'); convert_element_type_757 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + permute_647 = torch.ops.aten.permute.default(permute_252, [1, 0]); permute_252 = None + mm_354 = torch.ops.aten.mm.default(view_1311, permute_647); view_1311 = permute_647 = None + view_1312 = torch.ops.aten.view.default(mm_354, [2, 8192, 14336]); mm_354 = None + convert_element_type_1564 = torch.ops.prims.convert_element_type.default(mm_353, torch.float32); mm_353 = None + reduce_scatter_tensor_83 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1564, 'avg', 64, '0'); convert_element_type_1564 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_83); reduce_scatter_tensor_83 = None + mul_444 = torch.ops.aten.mul.Tensor(view_1312, convert_element_type_753); convert_element_type_753 = None + mul_445 = torch.ops.aten.mul.Tensor(view_1312, view_779); view_1312 = view_779 = None + view_1313 = torch.ops.aten.view.default(mul_444, [16384, 14336]); mul_444 = None + permute_649 = torch.ops.aten.permute.default(view_1313, [1, 0]) + mm_355 = torch.ops.aten.mm.default(permute_649, view_775); permute_649 = None + permute_651 = torch.ops.aten.permute.default(permute_251, [1, 0]); permute_251 = None + mm_356 = torch.ops.aten.mm.default(view_1313, permute_651); view_1313 = permute_651 = None + view_1314 = torch.ops.aten.view.default(mm_356, [2, 8192, 4096]); mm_356 = None + convert_element_type_1569 = torch.ops.prims.convert_element_type.default(mm_355, torch.float32); mm_355 = None + reduce_scatter_tensor_84 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1569, 'avg', 64, '0'); convert_element_type_1569 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_84); reduce_scatter_tensor_84 = None + convert_element_type_1570 = torch.ops.prims.convert_element_type.default(mul_445, torch.float32); mul_445 = None + neg_9 = torch.ops.aten.neg.default(convert_element_type_752) + exp_9 = torch.ops.aten.exp.default(neg_9); neg_9 = None + add_192 = torch.ops.aten.add.Tensor(exp_9, 1); exp_9 = None + reciprocal_9 = torch.ops.aten.reciprocal.default(add_192); add_192 = None + mul_446 = torch.ops.aten.mul.Tensor(reciprocal_9, 1); reciprocal_9 = None + mul_447 = torch.ops.aten.mul.Tensor(convert_element_type_1570, mul_446); convert_element_type_1570 = None + sub_28 = torch.ops.aten.sub.Tensor(1, mul_446); mul_446 = None + mul_448 = torch.ops.aten.mul.Tensor(convert_element_type_752, sub_28); convert_element_type_752 = sub_28 = None + add_193 = torch.ops.aten.add.Tensor(mul_448, 1); mul_448 = None + mul_449 = torch.ops.aten.mul.Tensor(mul_447, add_193); mul_447 = add_193 = None + convert_element_type_1572 = torch.ops.prims.convert_element_type.default(mul_449, torch.bfloat16); mul_449 = None + view_1315 = torch.ops.aten.view.default(convert_element_type_1572, [16384, 14336]); convert_element_type_1572 = None + permute_653 = torch.ops.aten.permute.default(view_1315, [1, 0]) + mm_357 = torch.ops.aten.mm.default(permute_653, view_775); permute_653 = view_775 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16); primals_208 = None + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 64, '0'); convert_element_type_749 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_205, [1, 0]); wait_tensor_205 = None + permute_655 = torch.ops.aten.permute.default(permute_250, [1, 0]); permute_250 = None + mm_358 = torch.ops.aten.mm.default(view_1315, permute_655); view_1315 = permute_655 = None + view_1316 = torch.ops.aten.view.default(mm_358, [2, 8192, 4096]); mm_358 = None + add_194 = torch.ops.aten.add.Tensor(view_1314, view_1316); view_1314 = view_1316 = None + convert_element_type_1577 = torch.ops.prims.convert_element_type.default(mm_357, torch.float32); mm_357 = None + reduce_scatter_tensor_85 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1577, 'avg', 64, '0'); convert_element_type_1577 = None + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_85); reduce_scatter_tensor_85 = None + convert_element_type_1578 = torch.ops.prims.convert_element_type.default(add_194, torch.float32); add_194 = None + convert_element_type_1580 = torch.ops.prims.convert_element_type.default(wait_tensor_204, torch.float32); wait_tensor_204 = None + mul_450 = torch.ops.aten.mul.Tensor(convert_element_type_1578, convert_element_type_1580); convert_element_type_1580 = None + mul_452 = torch.ops.aten.mul.Tensor(mul_180, mul_450) + sum_57 = torch.ops.aten.sum.dim_IntList(mul_452, [2], True); mul_452 = None + div_19 = torch.ops.aten.div.Tensor(mul_180, 4096) + mul_453 = torch.ops.aten.mul.Tensor(div_19, sum_57); div_19 = sum_57 = None + sub_29 = torch.ops.aten.sub.Tensor(mul_450, mul_453); mul_450 = mul_453 = None + mul_454 = torch.ops.aten.mul.Tensor(sub_29, rsqrt_45); sub_29 = rsqrt_45 = None + mul_455 = torch.ops.aten.mul.Tensor(convert_element_type_1578, mul_180); convert_element_type_1578 = mul_180 = None + sum_58 = torch.ops.aten.sum.dim_IntList(mul_455, [0, 1]); mul_455 = None + convert_element_type_1581 = torch.ops.prims.convert_element_type.default(mul_454, torch.bfloat16); mul_454 = None + add_195 = torch.ops.aten.add.Tensor(add_191, convert_element_type_1581); add_191 = convert_element_type_1581 = None + convert_element_type_default_46 = torch.ops.prims.convert_element_type.default(sum_58, torch.float32); sum_58 = None + reduce_scatter_tensor_86 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_46, 'avg', 64, '0'); convert_element_type_default_46 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_86); reduce_scatter_tensor_86 = None + view_1317 = torch.ops.aten.view.default(add_195, [16384, 4096]) + permute_657 = torch.ops.aten.permute.default(view_1317, [1, 0]) + mm_359 = torch.ops.aten.mm.default(permute_657, view_771); permute_657 = view_771 = None + permute_659 = torch.ops.aten.permute.default(permute_249, [1, 0]); permute_249 = None + mm_360 = torch.ops.aten.mm.default(view_1317, permute_659); view_1317 = permute_659 = None + view_1318 = torch.ops.aten.view.default(mm_360, [2, 8192, 4096]); mm_360 = None + convert_element_type_1588 = torch.ops.prims.convert_element_type.default(mm_359, torch.float32); mm_359 = None + reduce_scatter_tensor_87 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1588, 'avg', 64, '0'); convert_element_type_1588 = None + wait_tensor_378 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_87); reduce_scatter_tensor_87 = None + view_1319 = torch.ops.aten.view.default(view_1318, [2, 8192, 32, 128]); view_1318 = None + permute_661 = torch.ops.aten.permute.default(view_1319, [0, 2, 1, 3]); view_1319 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16); primals_202 = None + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 64, '0'); convert_element_type_727 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32); add_87 = None + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_199) + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + view_751 = torch.ops.aten.view.default(convert_element_type_729, [16384, 4096]); convert_element_type_729 = None + view_752 = torch.ops.aten.view.default(mm_154, [2, 8192, 4096]); mm_154 = None + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16); primals_204 = None + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 64, '0'); convert_element_type_733 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + mm_155 = torch.ops.aten.mm.default(view_751, permute_243) + view_755 = torch.ops.aten.view.default(mm_155, [2, 8192, 1024]); mm_155 = None + view_758 = torch.ops.aten.view.default(mm_156, [2, 8192, 1024]); mm_156 = None + view_759 = torch.ops.aten.view.default(view_752, [2, 8192, -1, 128]); view_752 = None + view_760 = torch.ops.aten.view.default(view_755, [2, 8192, -1, 128]); view_755 = None + view_761 = torch.ops.aten.view.default(view_758, [2, 8192, -1, 128]); view_758 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_759, torch.float32); view_759 = None + view_762 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 32, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_762); view_762 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_760, torch.float32); view_760 = None + view_763 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 8, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_763); view_763 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_16); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_765 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 32, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_16); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_766 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 8, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_765, torch.bfloat16); view_765 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_766, torch.bfloat16); view_766 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 8, 4, 128]); unsqueeze_44 = None + clone_44 = torch.ops.aten.clone.default(expand_44, memory_format = torch.contiguous_format); expand_44 = None + view_767 = torch.ops.aten.view.default(clone_44, [2, 8192, 32, 128]); clone_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_761, 3); view_761 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 8, 4, 128]); unsqueeze_45 = None + clone_45 = torch.ops.aten.clone.default(expand_45, memory_format = torch.contiguous_format); expand_45 = None + view_768 = torch.ops.aten.view.default(clone_45, [2, 8192, 32, 128]); clone_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_767, [0, 2, 1, 3]); view_767 = None + permute_247 = torch.ops.aten.permute.default(view_768, [0, 2, 1, 3]); view_768 = None + _scaled_dot_product_cudnn_attention_backward_9 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_661, permute_245, permute_246, permute_247, getitem_198, getitem_199, getitem_204, getitem_205, None, None, None, 8192, 8192, 0.0, True); permute_661 = permute_245 = permute_246 = permute_247 = getitem_198 = getitem_199 = getitem_204 = getitem_205 = None + getitem_315 = _scaled_dot_product_cudnn_attention_backward_9[0] + getitem_316 = _scaled_dot_product_cudnn_attention_backward_9[1] + getitem_317 = _scaled_dot_product_cudnn_attention_backward_9[2]; _scaled_dot_product_cudnn_attention_backward_9 = None + permute_662 = torch.ops.aten.permute.default(getitem_317, [0, 2, 1, 3]); getitem_317 = None + permute_663 = torch.ops.aten.permute.default(getitem_316, [0, 2, 1, 3]); getitem_316 = None + permute_664 = torch.ops.aten.permute.default(getitem_315, [0, 2, 1, 3]); getitem_315 = None + view_1320 = torch.ops.aten.view.default(permute_662, [2, 8192, 8, 4, 128]); permute_662 = None + sum_59 = torch.ops.aten.sum.dim_IntList(view_1320, [3], True); view_1320 = None + squeeze_18 = torch.ops.aten.squeeze.dim(sum_59, 3); sum_59 = None + view_1321 = torch.ops.aten.view.default(permute_663, [2, 8192, 8, 4, 128]); permute_663 = None + sum_60 = torch.ops.aten.sum.dim_IntList(view_1321, [3], True); view_1321 = None + squeeze_19 = torch.ops.aten.squeeze.dim(sum_60, 3); sum_60 = None + convert_element_type_1589 = torch.ops.prims.convert_element_type.default(squeeze_19, torch.float32); squeeze_19 = None + convert_element_type_1590 = torch.ops.prims.convert_element_type.default(permute_664, torch.float32); permute_664 = None + view_1322 = torch.ops.aten.view.default(convert_element_type_1589, [2, 8192, 8, 64, 2]); convert_element_type_1589 = None + view_as_complex_82 = torch.ops.aten.view_as_complex.default(view_1322); view_1322 = None + mul_456 = torch.ops.aten.mul.Tensor(view_as_complex_82, _conj); view_as_complex_82 = None + view_1323 = torch.ops.aten.view.default(convert_element_type_1590, [2, 8192, 32, 64, 2]); convert_element_type_1590 = None + view_as_complex_83 = torch.ops.aten.view_as_complex.default(view_1323); view_1323 = None + mul_457 = torch.ops.aten.mul.Tensor(view_as_complex_83, _conj); view_as_complex_83 = None + view_as_real_82 = torch.ops.aten.view_as_real.default(mul_456); mul_456 = None + view_1324 = torch.ops.aten.view.default(view_as_real_82, [2, 8192, 8, 128]); view_as_real_82 = None + convert_element_type_1591 = torch.ops.prims.convert_element_type.default(view_1324, torch.bfloat16); view_1324 = None + view_as_real_83 = torch.ops.aten.view_as_real.default(mul_457); mul_457 = None + view_1325 = torch.ops.aten.view.default(view_as_real_83, [2, 8192, 32, 128]); view_as_real_83 = None + convert_element_type_1592 = torch.ops.prims.convert_element_type.default(view_1325, torch.bfloat16); view_1325 = None + view_1326 = torch.ops.aten.view.default(squeeze_18, [2, 8192, 1024]); squeeze_18 = None + view_1327 = torch.ops.aten.view.default(convert_element_type_1591, [2, 8192, 1024]); convert_element_type_1591 = None + view_1328 = torch.ops.aten.view.default(convert_element_type_1592, [2, 8192, 4096]); convert_element_type_1592 = None + view_1329 = torch.ops.aten.view.default(view_1326, [16384, 1024]); view_1326 = None + permute_665 = torch.ops.aten.permute.default(view_1329, [1, 0]) + mm_361 = torch.ops.aten.mm.default(permute_665, view_751); permute_665 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16); primals_205 = None + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 64, '0'); convert_element_type_736 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + permute_667 = torch.ops.aten.permute.default(permute_244, [1, 0]); permute_244 = None + mm_362 = torch.ops.aten.mm.default(view_1329, permute_667); view_1329 = permute_667 = None + view_1330 = torch.ops.aten.view.default(mm_362, [2, 8192, 4096]); mm_362 = None + convert_element_type_1597 = torch.ops.prims.convert_element_type.default(mm_361, torch.float32); mm_361 = None + reduce_scatter_tensor_88 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1597, 'avg', 64, '0'); convert_element_type_1597 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_88); reduce_scatter_tensor_88 = None + view_1331 = torch.ops.aten.view.default(view_1327, [16384, 1024]); view_1327 = None + permute_669 = torch.ops.aten.permute.default(view_1331, [1, 0]) + mm_363 = torch.ops.aten.mm.default(permute_669, view_751); permute_669 = None + permute_671 = torch.ops.aten.permute.default(permute_243, [1, 0]); permute_243 = None + mm_364 = torch.ops.aten.mm.default(view_1331, permute_671); view_1331 = permute_671 = None + view_1332 = torch.ops.aten.view.default(mm_364, [2, 8192, 4096]); mm_364 = None + add_196 = torch.ops.aten.add.Tensor(view_1330, view_1332); view_1330 = view_1332 = None + convert_element_type_1602 = torch.ops.prims.convert_element_type.default(mm_363, torch.float32); mm_363 = None + reduce_scatter_tensor_89 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1602, 'avg', 64, '0'); convert_element_type_1602 = None + wait_tensor_380 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_89); reduce_scatter_tensor_89 = None + view_1333 = torch.ops.aten.view.default(view_1328, [16384, 4096]); view_1328 = None + permute_673 = torch.ops.aten.permute.default(view_1333, [1, 0]) + mm_365 = torch.ops.aten.mm.default(permute_673, view_751); permute_673 = view_751 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16); primals_203 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 64, '0'); convert_element_type_730 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + permute_675 = torch.ops.aten.permute.default(permute_242, [1, 0]); permute_242 = None + mm_366 = torch.ops.aten.mm.default(view_1333, permute_675); view_1333 = permute_675 = None + view_1334 = torch.ops.aten.view.default(mm_366, [2, 8192, 4096]); mm_366 = None + add_197 = torch.ops.aten.add.Tensor(add_196, view_1334); add_196 = view_1334 = None + convert_element_type_1607 = torch.ops.prims.convert_element_type.default(mm_365, torch.float32); mm_365 = None + reduce_scatter_tensor_90 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1607, 'avg', 64, '0'); convert_element_type_1607 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_90); reduce_scatter_tensor_90 = None + convert_element_type_1608 = torch.ops.prims.convert_element_type.default(add_197, torch.float32); add_197 = None + convert_element_type_1610 = torch.ops.prims.convert_element_type.default(wait_tensor_199, torch.float32); wait_tensor_199 = None + mul_458 = torch.ops.aten.mul.Tensor(convert_element_type_1608, convert_element_type_1610); convert_element_type_1610 = None + mul_460 = torch.ops.aten.mul.Tensor(mul_176, mul_458) + sum_61 = torch.ops.aten.sum.dim_IntList(mul_460, [2], True); mul_460 = None + div_20 = torch.ops.aten.div.Tensor(mul_176, 4096) + mul_461 = torch.ops.aten.mul.Tensor(div_20, sum_61); div_20 = sum_61 = None + sub_30 = torch.ops.aten.sub.Tensor(mul_458, mul_461); mul_458 = mul_461 = None + mul_462 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_44); sub_30 = rsqrt_44 = None + mul_463 = torch.ops.aten.mul.Tensor(convert_element_type_1608, mul_176); convert_element_type_1608 = mul_176 = None + sum_62 = torch.ops.aten.sum.dim_IntList(mul_463, [0, 1]); mul_463 = None + convert_element_type_1611 = torch.ops.prims.convert_element_type.default(mul_462, torch.bfloat16); mul_462 = None + add_198 = torch.ops.aten.add.Tensor(add_195, convert_element_type_1611); add_195 = convert_element_type_1611 = None + convert_element_type_default_45 = torch.ops.prims.convert_element_type.default(sum_62, torch.float32); sum_62 = None + reduce_scatter_tensor_91 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_45, 'avg', 64, '0'); convert_element_type_default_45 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_91); reduce_scatter_tensor_91 = None + view_1335 = torch.ops.aten.view.default(add_198, [16384, 4096]) + permute_677 = torch.ops.aten.permute.default(view_1335, [1, 0]) + permute_237 = torch.ops.aten.permute.default(getitem_189, [0, 2, 1, 3]) + view_735 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16); primals_197 = None + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 64, '0'); convert_element_type_710 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + view_737 = torch.ops.aten.view.default(view_735, [16384, 4096]); view_735 = None + mm_150 = torch.ops.aten.mm.default(view_737, permute_238) + view_738 = torch.ops.aten.view.default(mm_150, [2, 8192, 4096]); mm_150 = None + add_85 = torch.ops.aten.add.Tensor(add_83, view_738); view_738 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16); primals_198 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 64, '0'); convert_element_type_713 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32); add_85 = None + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_195) + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + view_741 = torch.ops.aten.view.default(convert_element_type_715, [16384, 4096]); convert_element_type_715 = None + view_742 = torch.ops.aten.view.default(mm_151, [2, 8192, 14336]); mm_151 = None + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_742, torch.float32); view_742 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16); primals_200 = None + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 64, '0'); convert_element_type_721 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_197, [1, 0]); wait_tensor_197 = None + mm_152 = torch.ops.aten.mm.default(view_741, permute_240) + view_745 = torch.ops.aten.view.default(mm_152, [2, 8192, 14336]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_745) + view_747 = torch.ops.aten.view.default(mul_175, [16384, 14336]); mul_175 = None + mm_367 = torch.ops.aten.mm.default(permute_677, view_747); permute_677 = view_747 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16); primals_201 = None + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 64, '0'); convert_element_type_724 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_198, [1, 0]); wait_tensor_198 = None + permute_679 = torch.ops.aten.permute.default(permute_241, [1, 0]); permute_241 = None + mm_368 = torch.ops.aten.mm.default(view_1335, permute_679); view_1335 = permute_679 = None + view_1336 = torch.ops.aten.view.default(mm_368, [2, 8192, 14336]); mm_368 = None + convert_element_type_1618 = torch.ops.prims.convert_element_type.default(mm_367, torch.float32); mm_367 = None + reduce_scatter_tensor_92 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1618, 'avg', 64, '0'); convert_element_type_1618 = None + wait_tensor_383 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_92); reduce_scatter_tensor_92 = None + mul_464 = torch.ops.aten.mul.Tensor(view_1336, convert_element_type_720); convert_element_type_720 = None + mul_465 = torch.ops.aten.mul.Tensor(view_1336, view_745); view_1336 = view_745 = None + view_1337 = torch.ops.aten.view.default(mul_464, [16384, 14336]); mul_464 = None + permute_681 = torch.ops.aten.permute.default(view_1337, [1, 0]) + mm_369 = torch.ops.aten.mm.default(permute_681, view_741); permute_681 = None + permute_683 = torch.ops.aten.permute.default(permute_240, [1, 0]); permute_240 = None + mm_370 = torch.ops.aten.mm.default(view_1337, permute_683); view_1337 = permute_683 = None + view_1338 = torch.ops.aten.view.default(mm_370, [2, 8192, 4096]); mm_370 = None + convert_element_type_1623 = torch.ops.prims.convert_element_type.default(mm_369, torch.float32); mm_369 = None + reduce_scatter_tensor_93 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1623, 'avg', 64, '0'); convert_element_type_1623 = None + wait_tensor_384 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_93); reduce_scatter_tensor_93 = None + convert_element_type_1624 = torch.ops.prims.convert_element_type.default(mul_465, torch.float32); mul_465 = None + neg_10 = torch.ops.aten.neg.default(convert_element_type_719) + exp_10 = torch.ops.aten.exp.default(neg_10); neg_10 = None + add_199 = torch.ops.aten.add.Tensor(exp_10, 1); exp_10 = None + reciprocal_10 = torch.ops.aten.reciprocal.default(add_199); add_199 = None + mul_466 = torch.ops.aten.mul.Tensor(reciprocal_10, 1); reciprocal_10 = None + mul_467 = torch.ops.aten.mul.Tensor(convert_element_type_1624, mul_466); convert_element_type_1624 = None + sub_31 = torch.ops.aten.sub.Tensor(1, mul_466); mul_466 = None + mul_468 = torch.ops.aten.mul.Tensor(convert_element_type_719, sub_31); convert_element_type_719 = sub_31 = None + add_200 = torch.ops.aten.add.Tensor(mul_468, 1); mul_468 = None + mul_469 = torch.ops.aten.mul.Tensor(mul_467, add_200); mul_467 = add_200 = None + convert_element_type_1626 = torch.ops.prims.convert_element_type.default(mul_469, torch.bfloat16); mul_469 = None + view_1339 = torch.ops.aten.view.default(convert_element_type_1626, [16384, 14336]); convert_element_type_1626 = None + permute_685 = torch.ops.aten.permute.default(view_1339, [1, 0]) + mm_371 = torch.ops.aten.mm.default(permute_685, view_741); permute_685 = view_741 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16); primals_199 = None + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 64, '0'); convert_element_type_716 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_196, [1, 0]); wait_tensor_196 = None + permute_687 = torch.ops.aten.permute.default(permute_239, [1, 0]); permute_239 = None + mm_372 = torch.ops.aten.mm.default(view_1339, permute_687); view_1339 = permute_687 = None + view_1340 = torch.ops.aten.view.default(mm_372, [2, 8192, 4096]); mm_372 = None + add_201 = torch.ops.aten.add.Tensor(view_1338, view_1340); view_1338 = view_1340 = None + convert_element_type_1631 = torch.ops.prims.convert_element_type.default(mm_371, torch.float32); mm_371 = None + reduce_scatter_tensor_94 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1631, 'avg', 64, '0'); convert_element_type_1631 = None + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_94); reduce_scatter_tensor_94 = None + convert_element_type_1632 = torch.ops.prims.convert_element_type.default(add_201, torch.float32); add_201 = None + convert_element_type_1634 = torch.ops.prims.convert_element_type.default(wait_tensor_195, torch.float32); wait_tensor_195 = None + mul_470 = torch.ops.aten.mul.Tensor(convert_element_type_1632, convert_element_type_1634); convert_element_type_1634 = None + mul_472 = torch.ops.aten.mul.Tensor(mul_172, mul_470) + sum_63 = torch.ops.aten.sum.dim_IntList(mul_472, [2], True); mul_472 = None + div_21 = torch.ops.aten.div.Tensor(mul_172, 4096) + mul_473 = torch.ops.aten.mul.Tensor(div_21, sum_63); div_21 = sum_63 = None + sub_32 = torch.ops.aten.sub.Tensor(mul_470, mul_473); mul_470 = mul_473 = None + mul_474 = torch.ops.aten.mul.Tensor(sub_32, rsqrt_43); sub_32 = rsqrt_43 = None + mul_475 = torch.ops.aten.mul.Tensor(convert_element_type_1632, mul_172); convert_element_type_1632 = mul_172 = None + sum_64 = torch.ops.aten.sum.dim_IntList(mul_475, [0, 1]); mul_475 = None + convert_element_type_1635 = torch.ops.prims.convert_element_type.default(mul_474, torch.bfloat16); mul_474 = None + add_202 = torch.ops.aten.add.Tensor(add_198, convert_element_type_1635); add_198 = convert_element_type_1635 = None + convert_element_type_default_44 = torch.ops.prims.convert_element_type.default(sum_64, torch.float32); sum_64 = None + reduce_scatter_tensor_95 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_44, 'avg', 64, '0'); convert_element_type_default_44 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_95); reduce_scatter_tensor_95 = None + view_1341 = torch.ops.aten.view.default(add_202, [16384, 4096]) + permute_689 = torch.ops.aten.permute.default(view_1341, [1, 0]) + mm_373 = torch.ops.aten.mm.default(permute_689, view_737); permute_689 = view_737 = None + permute_691 = torch.ops.aten.permute.default(permute_238, [1, 0]); permute_238 = None + mm_374 = torch.ops.aten.mm.default(view_1341, permute_691); view_1341 = permute_691 = None + view_1342 = torch.ops.aten.view.default(mm_374, [2, 8192, 4096]); mm_374 = None + convert_element_type_1642 = torch.ops.prims.convert_element_type.default(mm_373, torch.float32); mm_373 = None + reduce_scatter_tensor_96 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1642, 'avg', 64, '0'); convert_element_type_1642 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_96); reduce_scatter_tensor_96 = None + view_1343 = torch.ops.aten.view.default(view_1342, [2, 8192, 32, 128]); view_1342 = None + permute_693 = torch.ops.aten.permute.default(view_1343, [0, 2, 1, 3]); view_1343 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16); primals_193 = None + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 64, '0'); convert_element_type_694 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32); add_83 = None + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_190) + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + view_717 = torch.ops.aten.view.default(convert_element_type_696, [16384, 4096]); convert_element_type_696 = None + view_718 = torch.ops.aten.view.default(mm_147, [2, 8192, 4096]); mm_147 = None + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16); primals_195 = None + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 64, '0'); convert_element_type_700 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_192, [1, 0]); wait_tensor_192 = None + mm_148 = torch.ops.aten.mm.default(view_717, permute_232) + view_721 = torch.ops.aten.view.default(mm_148, [2, 8192, 1024]); mm_148 = None + view_724 = torch.ops.aten.view.default(mm_149, [2, 8192, 1024]); mm_149 = None + view_725 = torch.ops.aten.view.default(view_718, [2, 8192, -1, 128]); view_718 = None + view_726 = torch.ops.aten.view.default(view_721, [2, 8192, -1, 128]); view_721 = None + view_727 = torch.ops.aten.view.default(view_724, [2, 8192, -1, 128]); view_724 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_725, torch.float32); view_725 = None + view_728 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 32, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_728); view_728 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_726, torch.float32); view_726 = None + view_729 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 8, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_729); view_729 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_16); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_731 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 32, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_16); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_732 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 8, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_731, torch.bfloat16); view_731 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_732, torch.bfloat16); view_732 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 8, 4, 128]); unsqueeze_42 = None + clone_42 = torch.ops.aten.clone.default(expand_42, memory_format = torch.contiguous_format); expand_42 = None + view_733 = torch.ops.aten.view.default(clone_42, [2, 8192, 32, 128]); clone_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_727, 3); view_727 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 8, 4, 128]); unsqueeze_43 = None + clone_43 = torch.ops.aten.clone.default(expand_43, memory_format = torch.contiguous_format); expand_43 = None + view_734 = torch.ops.aten.view.default(clone_43, [2, 8192, 32, 128]); clone_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_733, [0, 2, 1, 3]); view_733 = None + permute_236 = torch.ops.aten.permute.default(view_734, [0, 2, 1, 3]); view_734 = None + _scaled_dot_product_cudnn_attention_backward_10 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_693, permute_234, permute_235, permute_236, getitem_189, getitem_190, getitem_195, getitem_196, None, None, None, 8192, 8192, 0.0, True); permute_693 = permute_234 = permute_235 = permute_236 = getitem_189 = getitem_190 = getitem_195 = getitem_196 = None + getitem_318 = _scaled_dot_product_cudnn_attention_backward_10[0] + getitem_319 = _scaled_dot_product_cudnn_attention_backward_10[1] + getitem_320 = _scaled_dot_product_cudnn_attention_backward_10[2]; _scaled_dot_product_cudnn_attention_backward_10 = None + permute_694 = torch.ops.aten.permute.default(getitem_320, [0, 2, 1, 3]); getitem_320 = None + permute_695 = torch.ops.aten.permute.default(getitem_319, [0, 2, 1, 3]); getitem_319 = None + permute_696 = torch.ops.aten.permute.default(getitem_318, [0, 2, 1, 3]); getitem_318 = None + view_1344 = torch.ops.aten.view.default(permute_694, [2, 8192, 8, 4, 128]); permute_694 = None + sum_65 = torch.ops.aten.sum.dim_IntList(view_1344, [3], True); view_1344 = None + squeeze_20 = torch.ops.aten.squeeze.dim(sum_65, 3); sum_65 = None + view_1345 = torch.ops.aten.view.default(permute_695, [2, 8192, 8, 4, 128]); permute_695 = None + sum_66 = torch.ops.aten.sum.dim_IntList(view_1345, [3], True); view_1345 = None + squeeze_21 = torch.ops.aten.squeeze.dim(sum_66, 3); sum_66 = None + convert_element_type_1643 = torch.ops.prims.convert_element_type.default(squeeze_21, torch.float32); squeeze_21 = None + convert_element_type_1644 = torch.ops.prims.convert_element_type.default(permute_696, torch.float32); permute_696 = None + view_1346 = torch.ops.aten.view.default(convert_element_type_1643, [2, 8192, 8, 64, 2]); convert_element_type_1643 = None + view_as_complex_84 = torch.ops.aten.view_as_complex.default(view_1346); view_1346 = None + mul_476 = torch.ops.aten.mul.Tensor(view_as_complex_84, _conj); view_as_complex_84 = None + view_1347 = torch.ops.aten.view.default(convert_element_type_1644, [2, 8192, 32, 64, 2]); convert_element_type_1644 = None + view_as_complex_85 = torch.ops.aten.view_as_complex.default(view_1347); view_1347 = None + mul_477 = torch.ops.aten.mul.Tensor(view_as_complex_85, _conj); view_as_complex_85 = None + view_as_real_84 = torch.ops.aten.view_as_real.default(mul_476); mul_476 = None + view_1348 = torch.ops.aten.view.default(view_as_real_84, [2, 8192, 8, 128]); view_as_real_84 = None + convert_element_type_1645 = torch.ops.prims.convert_element_type.default(view_1348, torch.bfloat16); view_1348 = None + view_as_real_85 = torch.ops.aten.view_as_real.default(mul_477); mul_477 = None + view_1349 = torch.ops.aten.view.default(view_as_real_85, [2, 8192, 32, 128]); view_as_real_85 = None + convert_element_type_1646 = torch.ops.prims.convert_element_type.default(view_1349, torch.bfloat16); view_1349 = None + view_1350 = torch.ops.aten.view.default(squeeze_20, [2, 8192, 1024]); squeeze_20 = None + view_1351 = torch.ops.aten.view.default(convert_element_type_1645, [2, 8192, 1024]); convert_element_type_1645 = None + view_1352 = torch.ops.aten.view.default(convert_element_type_1646, [2, 8192, 4096]); convert_element_type_1646 = None + view_1353 = torch.ops.aten.view.default(view_1350, [16384, 1024]); view_1350 = None + permute_697 = torch.ops.aten.permute.default(view_1353, [1, 0]) + mm_375 = torch.ops.aten.mm.default(permute_697, view_717); permute_697 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16); primals_196 = None + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 64, '0'); convert_element_type_703 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + permute_699 = torch.ops.aten.permute.default(permute_233, [1, 0]); permute_233 = None + mm_376 = torch.ops.aten.mm.default(view_1353, permute_699); view_1353 = permute_699 = None + view_1354 = torch.ops.aten.view.default(mm_376, [2, 8192, 4096]); mm_376 = None + convert_element_type_1651 = torch.ops.prims.convert_element_type.default(mm_375, torch.float32); mm_375 = None + reduce_scatter_tensor_97 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1651, 'avg', 64, '0'); convert_element_type_1651 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_97); reduce_scatter_tensor_97 = None + view_1355 = torch.ops.aten.view.default(view_1351, [16384, 1024]); view_1351 = None + permute_701 = torch.ops.aten.permute.default(view_1355, [1, 0]) + mm_377 = torch.ops.aten.mm.default(permute_701, view_717); permute_701 = None + permute_703 = torch.ops.aten.permute.default(permute_232, [1, 0]); permute_232 = None + mm_378 = torch.ops.aten.mm.default(view_1355, permute_703); view_1355 = permute_703 = None + view_1356 = torch.ops.aten.view.default(mm_378, [2, 8192, 4096]); mm_378 = None + add_203 = torch.ops.aten.add.Tensor(view_1354, view_1356); view_1354 = view_1356 = None + convert_element_type_1656 = torch.ops.prims.convert_element_type.default(mm_377, torch.float32); mm_377 = None + reduce_scatter_tensor_98 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1656, 'avg', 64, '0'); convert_element_type_1656 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_98); reduce_scatter_tensor_98 = None + view_1357 = torch.ops.aten.view.default(view_1352, [16384, 4096]); view_1352 = None + permute_705 = torch.ops.aten.permute.default(view_1357, [1, 0]) + mm_379 = torch.ops.aten.mm.default(permute_705, view_717); permute_705 = view_717 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16); primals_194 = None + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 64, '0'); convert_element_type_697 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_191, [1, 0]); wait_tensor_191 = None + permute_707 = torch.ops.aten.permute.default(permute_231, [1, 0]); permute_231 = None + mm_380 = torch.ops.aten.mm.default(view_1357, permute_707); view_1357 = permute_707 = None + view_1358 = torch.ops.aten.view.default(mm_380, [2, 8192, 4096]); mm_380 = None + add_204 = torch.ops.aten.add.Tensor(add_203, view_1358); add_203 = view_1358 = None + convert_element_type_1661 = torch.ops.prims.convert_element_type.default(mm_379, torch.float32); mm_379 = None + reduce_scatter_tensor_99 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1661, 'avg', 64, '0'); convert_element_type_1661 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_99); reduce_scatter_tensor_99 = None + convert_element_type_1662 = torch.ops.prims.convert_element_type.default(add_204, torch.float32); add_204 = None + convert_element_type_1664 = torch.ops.prims.convert_element_type.default(wait_tensor_190, torch.float32); wait_tensor_190 = None + mul_478 = torch.ops.aten.mul.Tensor(convert_element_type_1662, convert_element_type_1664); convert_element_type_1664 = None + mul_480 = torch.ops.aten.mul.Tensor(mul_168, mul_478) + sum_67 = torch.ops.aten.sum.dim_IntList(mul_480, [2], True); mul_480 = None + div_22 = torch.ops.aten.div.Tensor(mul_168, 4096) + mul_481 = torch.ops.aten.mul.Tensor(div_22, sum_67); div_22 = sum_67 = None + sub_33 = torch.ops.aten.sub.Tensor(mul_478, mul_481); mul_478 = mul_481 = None + mul_482 = torch.ops.aten.mul.Tensor(sub_33, rsqrt_42); sub_33 = rsqrt_42 = None + mul_483 = torch.ops.aten.mul.Tensor(convert_element_type_1662, mul_168); convert_element_type_1662 = mul_168 = None + sum_68 = torch.ops.aten.sum.dim_IntList(mul_483, [0, 1]); mul_483 = None + convert_element_type_1665 = torch.ops.prims.convert_element_type.default(mul_482, torch.bfloat16); mul_482 = None + add_205 = torch.ops.aten.add.Tensor(add_202, convert_element_type_1665); add_202 = convert_element_type_1665 = None + convert_element_type_default_43 = torch.ops.prims.convert_element_type.default(sum_68, torch.float32); sum_68 = None + reduce_scatter_tensor_100 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_43, 'avg', 64, '0'); convert_element_type_default_43 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_100); reduce_scatter_tensor_100 = None + view_1359 = torch.ops.aten.view.default(add_205, [16384, 4096]) + permute_709 = torch.ops.aten.permute.default(view_1359, [1, 0]) + permute_226 = torch.ops.aten.permute.default(getitem_180, [0, 2, 1, 3]) + view_701 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16); primals_188 = None + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 64, '0'); convert_element_type_677 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_185, [1, 0]); wait_tensor_185 = None + view_703 = torch.ops.aten.view.default(view_701, [16384, 4096]); view_701 = None + mm_143 = torch.ops.aten.mm.default(view_703, permute_227) + view_704 = torch.ops.aten.view.default(mm_143, [2, 8192, 4096]); mm_143 = None + add_81 = torch.ops.aten.add.Tensor(add_79, view_704); view_704 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16); primals_189 = None + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 64, '0'); convert_element_type_680 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32); add_81 = None + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_186) + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + view_707 = torch.ops.aten.view.default(convert_element_type_682, [16384, 4096]); convert_element_type_682 = None + view_708 = torch.ops.aten.view.default(mm_144, [2, 8192, 14336]); mm_144 = None + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_708, torch.float32); view_708 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16); primals_191 = None + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 64, '0'); convert_element_type_688 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + mm_145 = torch.ops.aten.mm.default(view_707, permute_229) + view_711 = torch.ops.aten.view.default(mm_145, [2, 8192, 14336]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_711) + view_713 = torch.ops.aten.view.default(mul_167, [16384, 14336]); mul_167 = None + mm_381 = torch.ops.aten.mm.default(permute_709, view_713); permute_709 = view_713 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16); primals_192 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 64, '0'); convert_element_type_691 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + permute_711 = torch.ops.aten.permute.default(permute_230, [1, 0]); permute_230 = None + mm_382 = torch.ops.aten.mm.default(view_1359, permute_711); view_1359 = permute_711 = None + view_1360 = torch.ops.aten.view.default(mm_382, [2, 8192, 14336]); mm_382 = None + convert_element_type_1672 = torch.ops.prims.convert_element_type.default(mm_381, torch.float32); mm_381 = None + reduce_scatter_tensor_101 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1672, 'avg', 64, '0'); convert_element_type_1672 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_101); reduce_scatter_tensor_101 = None + mul_484 = torch.ops.aten.mul.Tensor(view_1360, convert_element_type_687); convert_element_type_687 = None + mul_485 = torch.ops.aten.mul.Tensor(view_1360, view_711); view_1360 = view_711 = None + view_1361 = torch.ops.aten.view.default(mul_484, [16384, 14336]); mul_484 = None + permute_713 = torch.ops.aten.permute.default(view_1361, [1, 0]) + mm_383 = torch.ops.aten.mm.default(permute_713, view_707); permute_713 = None + permute_715 = torch.ops.aten.permute.default(permute_229, [1, 0]); permute_229 = None + mm_384 = torch.ops.aten.mm.default(view_1361, permute_715); view_1361 = permute_715 = None + view_1362 = torch.ops.aten.view.default(mm_384, [2, 8192, 4096]); mm_384 = None + convert_element_type_1677 = torch.ops.prims.convert_element_type.default(mm_383, torch.float32); mm_383 = None + reduce_scatter_tensor_102 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1677, 'avg', 64, '0'); convert_element_type_1677 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_102); reduce_scatter_tensor_102 = None + convert_element_type_1678 = torch.ops.prims.convert_element_type.default(mul_485, torch.float32); mul_485 = None + neg_11 = torch.ops.aten.neg.default(convert_element_type_686) + exp_11 = torch.ops.aten.exp.default(neg_11); neg_11 = None + add_206 = torch.ops.aten.add.Tensor(exp_11, 1); exp_11 = None + reciprocal_11 = torch.ops.aten.reciprocal.default(add_206); add_206 = None + mul_486 = torch.ops.aten.mul.Tensor(reciprocal_11, 1); reciprocal_11 = None + mul_487 = torch.ops.aten.mul.Tensor(convert_element_type_1678, mul_486); convert_element_type_1678 = None + sub_34 = torch.ops.aten.sub.Tensor(1, mul_486); mul_486 = None + mul_488 = torch.ops.aten.mul.Tensor(convert_element_type_686, sub_34); convert_element_type_686 = sub_34 = None + add_207 = torch.ops.aten.add.Tensor(mul_488, 1); mul_488 = None + mul_489 = torch.ops.aten.mul.Tensor(mul_487, add_207); mul_487 = add_207 = None + convert_element_type_1680 = torch.ops.prims.convert_element_type.default(mul_489, torch.bfloat16); mul_489 = None + view_1363 = torch.ops.aten.view.default(convert_element_type_1680, [16384, 14336]); convert_element_type_1680 = None + permute_717 = torch.ops.aten.permute.default(view_1363, [1, 0]) + mm_385 = torch.ops.aten.mm.default(permute_717, view_707); permute_717 = view_707 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16); primals_190 = None + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 64, '0'); convert_element_type_683 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + permute_719 = torch.ops.aten.permute.default(permute_228, [1, 0]); permute_228 = None + mm_386 = torch.ops.aten.mm.default(view_1363, permute_719); view_1363 = permute_719 = None + view_1364 = torch.ops.aten.view.default(mm_386, [2, 8192, 4096]); mm_386 = None + add_208 = torch.ops.aten.add.Tensor(view_1362, view_1364); view_1362 = view_1364 = None + convert_element_type_1685 = torch.ops.prims.convert_element_type.default(mm_385, torch.float32); mm_385 = None + reduce_scatter_tensor_103 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1685, 'avg', 64, '0'); convert_element_type_1685 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_103); reduce_scatter_tensor_103 = None + convert_element_type_1686 = torch.ops.prims.convert_element_type.default(add_208, torch.float32); add_208 = None + convert_element_type_1688 = torch.ops.prims.convert_element_type.default(wait_tensor_186, torch.float32); wait_tensor_186 = None + mul_490 = torch.ops.aten.mul.Tensor(convert_element_type_1686, convert_element_type_1688); convert_element_type_1688 = None + mul_492 = torch.ops.aten.mul.Tensor(mul_164, mul_490) + sum_69 = torch.ops.aten.sum.dim_IntList(mul_492, [2], True); mul_492 = None + div_23 = torch.ops.aten.div.Tensor(mul_164, 4096) + mul_493 = torch.ops.aten.mul.Tensor(div_23, sum_69); div_23 = sum_69 = None + sub_35 = torch.ops.aten.sub.Tensor(mul_490, mul_493); mul_490 = mul_493 = None + mul_494 = torch.ops.aten.mul.Tensor(sub_35, rsqrt_41); sub_35 = rsqrt_41 = None + mul_495 = torch.ops.aten.mul.Tensor(convert_element_type_1686, mul_164); convert_element_type_1686 = mul_164 = None + sum_70 = torch.ops.aten.sum.dim_IntList(mul_495, [0, 1]); mul_495 = None + convert_element_type_1689 = torch.ops.prims.convert_element_type.default(mul_494, torch.bfloat16); mul_494 = None + add_209 = torch.ops.aten.add.Tensor(add_205, convert_element_type_1689); add_205 = convert_element_type_1689 = None + convert_element_type_default_42 = torch.ops.prims.convert_element_type.default(sum_70, torch.float32); sum_70 = None + reduce_scatter_tensor_104 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_42, 'avg', 64, '0'); convert_element_type_default_42 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_104); reduce_scatter_tensor_104 = None + view_1365 = torch.ops.aten.view.default(add_209, [16384, 4096]) + permute_721 = torch.ops.aten.permute.default(view_1365, [1, 0]) + mm_387 = torch.ops.aten.mm.default(permute_721, view_703); permute_721 = view_703 = None + permute_723 = torch.ops.aten.permute.default(permute_227, [1, 0]); permute_227 = None + mm_388 = torch.ops.aten.mm.default(view_1365, permute_723); view_1365 = permute_723 = None + view_1366 = torch.ops.aten.view.default(mm_388, [2, 8192, 4096]); mm_388 = None + convert_element_type_1696 = torch.ops.prims.convert_element_type.default(mm_387, torch.float32); mm_387 = None + reduce_scatter_tensor_105 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1696, 'avg', 64, '0'); convert_element_type_1696 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_105); reduce_scatter_tensor_105 = None + view_1367 = torch.ops.aten.view.default(view_1366, [2, 8192, 32, 128]); view_1366 = None + permute_725 = torch.ops.aten.permute.default(view_1367, [0, 2, 1, 3]); view_1367 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16); primals_184 = None + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 64, '0'); convert_element_type_661 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32); add_79 = None + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_181) + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + view_683 = torch.ops.aten.view.default(convert_element_type_663, [16384, 4096]); convert_element_type_663 = None + view_684 = torch.ops.aten.view.default(mm_140, [2, 8192, 4096]); mm_140 = None + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16); primals_186 = None + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 64, '0'); convert_element_type_667 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_183, [1, 0]); wait_tensor_183 = None + mm_141 = torch.ops.aten.mm.default(view_683, permute_221) + view_687 = torch.ops.aten.view.default(mm_141, [2, 8192, 1024]); mm_141 = None + view_690 = torch.ops.aten.view.default(mm_142, [2, 8192, 1024]); mm_142 = None + view_691 = torch.ops.aten.view.default(view_684, [2, 8192, -1, 128]); view_684 = None + view_692 = torch.ops.aten.view.default(view_687, [2, 8192, -1, 128]); view_687 = None + view_693 = torch.ops.aten.view.default(view_690, [2, 8192, -1, 128]); view_690 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_691, torch.float32); view_691 = None + view_694 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 32, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_694); view_694 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_692, torch.float32); view_692 = None + view_695 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 8, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_695); view_695 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_16); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_697 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 32, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_16); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_698 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 8, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_697, torch.bfloat16); view_697 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_698, torch.bfloat16); view_698 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 8, 4, 128]); unsqueeze_40 = None + clone_40 = torch.ops.aten.clone.default(expand_40, memory_format = torch.contiguous_format); expand_40 = None + view_699 = torch.ops.aten.view.default(clone_40, [2, 8192, 32, 128]); clone_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_693, 3); view_693 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 8, 4, 128]); unsqueeze_41 = None + clone_41 = torch.ops.aten.clone.default(expand_41, memory_format = torch.contiguous_format); expand_41 = None + view_700 = torch.ops.aten.view.default(clone_41, [2, 8192, 32, 128]); clone_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_699, [0, 2, 1, 3]); view_699 = None + permute_225 = torch.ops.aten.permute.default(view_700, [0, 2, 1, 3]); view_700 = None + _scaled_dot_product_cudnn_attention_backward_11 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_725, permute_223, permute_224, permute_225, getitem_180, getitem_181, getitem_186, getitem_187, None, None, None, 8192, 8192, 0.0, True); permute_725 = permute_223 = permute_224 = permute_225 = getitem_180 = getitem_181 = getitem_186 = getitem_187 = None + getitem_321 = _scaled_dot_product_cudnn_attention_backward_11[0] + getitem_322 = _scaled_dot_product_cudnn_attention_backward_11[1] + getitem_323 = _scaled_dot_product_cudnn_attention_backward_11[2]; _scaled_dot_product_cudnn_attention_backward_11 = None + permute_726 = torch.ops.aten.permute.default(getitem_323, [0, 2, 1, 3]); getitem_323 = None + permute_727 = torch.ops.aten.permute.default(getitem_322, [0, 2, 1, 3]); getitem_322 = None + permute_728 = torch.ops.aten.permute.default(getitem_321, [0, 2, 1, 3]); getitem_321 = None + view_1368 = torch.ops.aten.view.default(permute_726, [2, 8192, 8, 4, 128]); permute_726 = None + sum_71 = torch.ops.aten.sum.dim_IntList(view_1368, [3], True); view_1368 = None + squeeze_22 = torch.ops.aten.squeeze.dim(sum_71, 3); sum_71 = None + view_1369 = torch.ops.aten.view.default(permute_727, [2, 8192, 8, 4, 128]); permute_727 = None + sum_72 = torch.ops.aten.sum.dim_IntList(view_1369, [3], True); view_1369 = None + squeeze_23 = torch.ops.aten.squeeze.dim(sum_72, 3); sum_72 = None + convert_element_type_1697 = torch.ops.prims.convert_element_type.default(squeeze_23, torch.float32); squeeze_23 = None + convert_element_type_1698 = torch.ops.prims.convert_element_type.default(permute_728, torch.float32); permute_728 = None + view_1370 = torch.ops.aten.view.default(convert_element_type_1697, [2, 8192, 8, 64, 2]); convert_element_type_1697 = None + view_as_complex_86 = torch.ops.aten.view_as_complex.default(view_1370); view_1370 = None + mul_496 = torch.ops.aten.mul.Tensor(view_as_complex_86, _conj); view_as_complex_86 = None + view_1371 = torch.ops.aten.view.default(convert_element_type_1698, [2, 8192, 32, 64, 2]); convert_element_type_1698 = None + view_as_complex_87 = torch.ops.aten.view_as_complex.default(view_1371); view_1371 = None + mul_497 = torch.ops.aten.mul.Tensor(view_as_complex_87, _conj); view_as_complex_87 = None + view_as_real_86 = torch.ops.aten.view_as_real.default(mul_496); mul_496 = None + view_1372 = torch.ops.aten.view.default(view_as_real_86, [2, 8192, 8, 128]); view_as_real_86 = None + convert_element_type_1699 = torch.ops.prims.convert_element_type.default(view_1372, torch.bfloat16); view_1372 = None + view_as_real_87 = torch.ops.aten.view_as_real.default(mul_497); mul_497 = None + view_1373 = torch.ops.aten.view.default(view_as_real_87, [2, 8192, 32, 128]); view_as_real_87 = None + convert_element_type_1700 = torch.ops.prims.convert_element_type.default(view_1373, torch.bfloat16); view_1373 = None + view_1374 = torch.ops.aten.view.default(squeeze_22, [2, 8192, 1024]); squeeze_22 = None + view_1375 = torch.ops.aten.view.default(convert_element_type_1699, [2, 8192, 1024]); convert_element_type_1699 = None + view_1376 = torch.ops.aten.view.default(convert_element_type_1700, [2, 8192, 4096]); convert_element_type_1700 = None + view_1377 = torch.ops.aten.view.default(view_1374, [16384, 1024]); view_1374 = None + permute_729 = torch.ops.aten.permute.default(view_1377, [1, 0]) + mm_389 = torch.ops.aten.mm.default(permute_729, view_683); permute_729 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16); primals_187 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 64, '0'); convert_element_type_670 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_184, [1, 0]); wait_tensor_184 = None + permute_731 = torch.ops.aten.permute.default(permute_222, [1, 0]); permute_222 = None + mm_390 = torch.ops.aten.mm.default(view_1377, permute_731); view_1377 = permute_731 = None + view_1378 = torch.ops.aten.view.default(mm_390, [2, 8192, 4096]); mm_390 = None + convert_element_type_1705 = torch.ops.prims.convert_element_type.default(mm_389, torch.float32); mm_389 = None + reduce_scatter_tensor_106 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1705, 'avg', 64, '0'); convert_element_type_1705 = None + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_106); reduce_scatter_tensor_106 = None + view_1379 = torch.ops.aten.view.default(view_1375, [16384, 1024]); view_1375 = None + permute_733 = torch.ops.aten.permute.default(view_1379, [1, 0]) + mm_391 = torch.ops.aten.mm.default(permute_733, view_683); permute_733 = None + permute_735 = torch.ops.aten.permute.default(permute_221, [1, 0]); permute_221 = None + mm_392 = torch.ops.aten.mm.default(view_1379, permute_735); view_1379 = permute_735 = None + view_1380 = torch.ops.aten.view.default(mm_392, [2, 8192, 4096]); mm_392 = None + add_210 = torch.ops.aten.add.Tensor(view_1378, view_1380); view_1378 = view_1380 = None + convert_element_type_1710 = torch.ops.prims.convert_element_type.default(mm_391, torch.float32); mm_391 = None + reduce_scatter_tensor_107 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1710, 'avg', 64, '0'); convert_element_type_1710 = None + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_107); reduce_scatter_tensor_107 = None + view_1381 = torch.ops.aten.view.default(view_1376, [16384, 4096]); view_1376 = None + permute_737 = torch.ops.aten.permute.default(view_1381, [1, 0]) + mm_393 = torch.ops.aten.mm.default(permute_737, view_683); permute_737 = view_683 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16); primals_185 = None + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 64, '0'); convert_element_type_664 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + permute_739 = torch.ops.aten.permute.default(permute_220, [1, 0]); permute_220 = None + mm_394 = torch.ops.aten.mm.default(view_1381, permute_739); view_1381 = permute_739 = None + view_1382 = torch.ops.aten.view.default(mm_394, [2, 8192, 4096]); mm_394 = None + add_211 = torch.ops.aten.add.Tensor(add_210, view_1382); add_210 = view_1382 = None + convert_element_type_1715 = torch.ops.prims.convert_element_type.default(mm_393, torch.float32); mm_393 = None + reduce_scatter_tensor_108 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1715, 'avg', 64, '0'); convert_element_type_1715 = None + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_108); reduce_scatter_tensor_108 = None + convert_element_type_1716 = torch.ops.prims.convert_element_type.default(add_211, torch.float32); add_211 = None + convert_element_type_1718 = torch.ops.prims.convert_element_type.default(wait_tensor_181, torch.float32); wait_tensor_181 = None + mul_498 = torch.ops.aten.mul.Tensor(convert_element_type_1716, convert_element_type_1718); convert_element_type_1718 = None + mul_500 = torch.ops.aten.mul.Tensor(mul_160, mul_498) + sum_73 = torch.ops.aten.sum.dim_IntList(mul_500, [2], True); mul_500 = None + div_24 = torch.ops.aten.div.Tensor(mul_160, 4096) + mul_501 = torch.ops.aten.mul.Tensor(div_24, sum_73); div_24 = sum_73 = None + sub_36 = torch.ops.aten.sub.Tensor(mul_498, mul_501); mul_498 = mul_501 = None + mul_502 = torch.ops.aten.mul.Tensor(sub_36, rsqrt_40); sub_36 = rsqrt_40 = None + mul_503 = torch.ops.aten.mul.Tensor(convert_element_type_1716, mul_160); convert_element_type_1716 = mul_160 = None + sum_74 = torch.ops.aten.sum.dim_IntList(mul_503, [0, 1]); mul_503 = None + convert_element_type_1719 = torch.ops.prims.convert_element_type.default(mul_502, torch.bfloat16); mul_502 = None + add_212 = torch.ops.aten.add.Tensor(add_209, convert_element_type_1719); add_209 = convert_element_type_1719 = None + convert_element_type_default_41 = torch.ops.prims.convert_element_type.default(sum_74, torch.float32); sum_74 = None + reduce_scatter_tensor_109 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_41, 'avg', 64, '0'); convert_element_type_default_41 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_109); reduce_scatter_tensor_109 = None + view_1383 = torch.ops.aten.view.default(add_212, [16384, 4096]) + permute_741 = torch.ops.aten.permute.default(view_1383, [1, 0]) + permute_215 = torch.ops.aten.permute.default(getitem_171, [0, 2, 1, 3]) + view_667 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16); primals_179 = None + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 64, '0'); convert_element_type_644 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + view_669 = torch.ops.aten.view.default(view_667, [16384, 4096]); view_667 = None + mm_136 = torch.ops.aten.mm.default(view_669, permute_216) + view_670 = torch.ops.aten.view.default(mm_136, [2, 8192, 4096]); mm_136 = None + add_77 = torch.ops.aten.add.Tensor(add_75, view_670); view_670 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16); primals_180 = None + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 64, '0'); convert_element_type_647 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32); add_77 = None + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_177) + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + view_673 = torch.ops.aten.view.default(convert_element_type_649, [16384, 4096]); convert_element_type_649 = None + view_674 = torch.ops.aten.view.default(mm_137, [2, 8192, 14336]); mm_137 = None + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_674, torch.float32); view_674 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16); primals_182 = None + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 64, '0'); convert_element_type_655 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_179, [1, 0]); wait_tensor_179 = None + mm_138 = torch.ops.aten.mm.default(view_673, permute_218) + view_677 = torch.ops.aten.view.default(mm_138, [2, 8192, 14336]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_677) + view_679 = torch.ops.aten.view.default(mul_159, [16384, 14336]); mul_159 = None + mm_395 = torch.ops.aten.mm.default(permute_741, view_679); permute_741 = view_679 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16); primals_183 = None + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 64, '0'); convert_element_type_658 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + permute_743 = torch.ops.aten.permute.default(permute_219, [1, 0]); permute_219 = None + mm_396 = torch.ops.aten.mm.default(view_1383, permute_743); view_1383 = permute_743 = None + view_1384 = torch.ops.aten.view.default(mm_396, [2, 8192, 14336]); mm_396 = None + convert_element_type_1726 = torch.ops.prims.convert_element_type.default(mm_395, torch.float32); mm_395 = None + reduce_scatter_tensor_110 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1726, 'avg', 64, '0'); convert_element_type_1726 = None + wait_tensor_401 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_110); reduce_scatter_tensor_110 = None + mul_504 = torch.ops.aten.mul.Tensor(view_1384, convert_element_type_654); convert_element_type_654 = None + mul_505 = torch.ops.aten.mul.Tensor(view_1384, view_677); view_1384 = view_677 = None + view_1385 = torch.ops.aten.view.default(mul_504, [16384, 14336]); mul_504 = None + permute_745 = torch.ops.aten.permute.default(view_1385, [1, 0]) + mm_397 = torch.ops.aten.mm.default(permute_745, view_673); permute_745 = None + permute_747 = torch.ops.aten.permute.default(permute_218, [1, 0]); permute_218 = None + mm_398 = torch.ops.aten.mm.default(view_1385, permute_747); view_1385 = permute_747 = None + view_1386 = torch.ops.aten.view.default(mm_398, [2, 8192, 4096]); mm_398 = None + convert_element_type_1731 = torch.ops.prims.convert_element_type.default(mm_397, torch.float32); mm_397 = None + reduce_scatter_tensor_111 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1731, 'avg', 64, '0'); convert_element_type_1731 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_111); reduce_scatter_tensor_111 = None + convert_element_type_1732 = torch.ops.prims.convert_element_type.default(mul_505, torch.float32); mul_505 = None + neg_12 = torch.ops.aten.neg.default(convert_element_type_653) + exp_12 = torch.ops.aten.exp.default(neg_12); neg_12 = None + add_213 = torch.ops.aten.add.Tensor(exp_12, 1); exp_12 = None + reciprocal_12 = torch.ops.aten.reciprocal.default(add_213); add_213 = None + mul_506 = torch.ops.aten.mul.Tensor(reciprocal_12, 1); reciprocal_12 = None + mul_507 = torch.ops.aten.mul.Tensor(convert_element_type_1732, mul_506); convert_element_type_1732 = None + sub_37 = torch.ops.aten.sub.Tensor(1, mul_506); mul_506 = None + mul_508 = torch.ops.aten.mul.Tensor(convert_element_type_653, sub_37); convert_element_type_653 = sub_37 = None + add_214 = torch.ops.aten.add.Tensor(mul_508, 1); mul_508 = None + mul_509 = torch.ops.aten.mul.Tensor(mul_507, add_214); mul_507 = add_214 = None + convert_element_type_1734 = torch.ops.prims.convert_element_type.default(mul_509, torch.bfloat16); mul_509 = None + view_1387 = torch.ops.aten.view.default(convert_element_type_1734, [16384, 14336]); convert_element_type_1734 = None + permute_749 = torch.ops.aten.permute.default(view_1387, [1, 0]) + mm_399 = torch.ops.aten.mm.default(permute_749, view_673); permute_749 = view_673 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16); primals_181 = None + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 64, '0'); convert_element_type_650 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_178, [1, 0]); wait_tensor_178 = None + permute_751 = torch.ops.aten.permute.default(permute_217, [1, 0]); permute_217 = None + mm_400 = torch.ops.aten.mm.default(view_1387, permute_751); view_1387 = permute_751 = None + view_1388 = torch.ops.aten.view.default(mm_400, [2, 8192, 4096]); mm_400 = None + add_215 = torch.ops.aten.add.Tensor(view_1386, view_1388); view_1386 = view_1388 = None + convert_element_type_1739 = torch.ops.prims.convert_element_type.default(mm_399, torch.float32); mm_399 = None + reduce_scatter_tensor_112 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1739, 'avg', 64, '0'); convert_element_type_1739 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_112); reduce_scatter_tensor_112 = None + convert_element_type_1740 = torch.ops.prims.convert_element_type.default(add_215, torch.float32); add_215 = None + convert_element_type_1742 = torch.ops.prims.convert_element_type.default(wait_tensor_177, torch.float32); wait_tensor_177 = None + mul_510 = torch.ops.aten.mul.Tensor(convert_element_type_1740, convert_element_type_1742); convert_element_type_1742 = None + mul_512 = torch.ops.aten.mul.Tensor(mul_156, mul_510) + sum_75 = torch.ops.aten.sum.dim_IntList(mul_512, [2], True); mul_512 = None + div_25 = torch.ops.aten.div.Tensor(mul_156, 4096) + mul_513 = torch.ops.aten.mul.Tensor(div_25, sum_75); div_25 = sum_75 = None + sub_38 = torch.ops.aten.sub.Tensor(mul_510, mul_513); mul_510 = mul_513 = None + mul_514 = torch.ops.aten.mul.Tensor(sub_38, rsqrt_39); sub_38 = rsqrt_39 = None + mul_515 = torch.ops.aten.mul.Tensor(convert_element_type_1740, mul_156); convert_element_type_1740 = mul_156 = None + sum_76 = torch.ops.aten.sum.dim_IntList(mul_515, [0, 1]); mul_515 = None + convert_element_type_1743 = torch.ops.prims.convert_element_type.default(mul_514, torch.bfloat16); mul_514 = None + add_216 = torch.ops.aten.add.Tensor(add_212, convert_element_type_1743); add_212 = convert_element_type_1743 = None + convert_element_type_default_40 = torch.ops.prims.convert_element_type.default(sum_76, torch.float32); sum_76 = None + reduce_scatter_tensor_113 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_40, 'avg', 64, '0'); convert_element_type_default_40 = None + wait_tensor_404 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_113); reduce_scatter_tensor_113 = None + view_1389 = torch.ops.aten.view.default(add_216, [16384, 4096]) + permute_753 = torch.ops.aten.permute.default(view_1389, [1, 0]) + mm_401 = torch.ops.aten.mm.default(permute_753, view_669); permute_753 = view_669 = None + permute_755 = torch.ops.aten.permute.default(permute_216, [1, 0]); permute_216 = None + mm_402 = torch.ops.aten.mm.default(view_1389, permute_755); view_1389 = permute_755 = None + view_1390 = torch.ops.aten.view.default(mm_402, [2, 8192, 4096]); mm_402 = None + convert_element_type_1750 = torch.ops.prims.convert_element_type.default(mm_401, torch.float32); mm_401 = None + reduce_scatter_tensor_114 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1750, 'avg', 64, '0'); convert_element_type_1750 = None + wait_tensor_405 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_114); reduce_scatter_tensor_114 = None + view_1391 = torch.ops.aten.view.default(view_1390, [2, 8192, 32, 128]); view_1390 = None + permute_757 = torch.ops.aten.permute.default(view_1391, [0, 2, 1, 3]); view_1391 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16); primals_175 = None + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 64, '0'); convert_element_type_628 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32); add_75 = None + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_172) + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + view_649 = torch.ops.aten.view.default(convert_element_type_630, [16384, 4096]); convert_element_type_630 = None + view_650 = torch.ops.aten.view.default(mm_133, [2, 8192, 4096]); mm_133 = None + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16); primals_177 = None + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 64, '0'); convert_element_type_634 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_134 = torch.ops.aten.mm.default(view_649, permute_210) + view_653 = torch.ops.aten.view.default(mm_134, [2, 8192, 1024]); mm_134 = None + view_656 = torch.ops.aten.view.default(mm_135, [2, 8192, 1024]); mm_135 = None + view_657 = torch.ops.aten.view.default(view_650, [2, 8192, -1, 128]); view_650 = None + view_658 = torch.ops.aten.view.default(view_653, [2, 8192, -1, 128]); view_653 = None + view_659 = torch.ops.aten.view.default(view_656, [2, 8192, -1, 128]); view_656 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_657, torch.float32); view_657 = None + view_660 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 32, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_660); view_660 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_658, torch.float32); view_658 = None + view_661 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 8, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_661); view_661 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_16); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_663 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 32, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_16); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_664 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 8, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_663, torch.bfloat16); view_663 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_664, torch.bfloat16); view_664 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 8, 4, 128]); unsqueeze_38 = None + clone_38 = torch.ops.aten.clone.default(expand_38, memory_format = torch.contiguous_format); expand_38 = None + view_665 = torch.ops.aten.view.default(clone_38, [2, 8192, 32, 128]); clone_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_659, 3); view_659 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 8, 4, 128]); unsqueeze_39 = None + clone_39 = torch.ops.aten.clone.default(expand_39, memory_format = torch.contiguous_format); expand_39 = None + view_666 = torch.ops.aten.view.default(clone_39, [2, 8192, 32, 128]); clone_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_665, [0, 2, 1, 3]); view_665 = None + permute_214 = torch.ops.aten.permute.default(view_666, [0, 2, 1, 3]); view_666 = None + _scaled_dot_product_cudnn_attention_backward_12 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_757, permute_212, permute_213, permute_214, getitem_171, getitem_172, getitem_177, getitem_178, None, None, None, 8192, 8192, 0.0, True); permute_757 = permute_212 = permute_213 = permute_214 = getitem_171 = getitem_172 = getitem_177 = getitem_178 = None + getitem_324 = _scaled_dot_product_cudnn_attention_backward_12[0] + getitem_325 = _scaled_dot_product_cudnn_attention_backward_12[1] + getitem_326 = _scaled_dot_product_cudnn_attention_backward_12[2]; _scaled_dot_product_cudnn_attention_backward_12 = None + permute_758 = torch.ops.aten.permute.default(getitem_326, [0, 2, 1, 3]); getitem_326 = None + permute_759 = torch.ops.aten.permute.default(getitem_325, [0, 2, 1, 3]); getitem_325 = None + permute_760 = torch.ops.aten.permute.default(getitem_324, [0, 2, 1, 3]); getitem_324 = None + view_1392 = torch.ops.aten.view.default(permute_758, [2, 8192, 8, 4, 128]); permute_758 = None + sum_77 = torch.ops.aten.sum.dim_IntList(view_1392, [3], True); view_1392 = None + squeeze_24 = torch.ops.aten.squeeze.dim(sum_77, 3); sum_77 = None + view_1393 = torch.ops.aten.view.default(permute_759, [2, 8192, 8, 4, 128]); permute_759 = None + sum_78 = torch.ops.aten.sum.dim_IntList(view_1393, [3], True); view_1393 = None + squeeze_25 = torch.ops.aten.squeeze.dim(sum_78, 3); sum_78 = None + convert_element_type_1751 = torch.ops.prims.convert_element_type.default(squeeze_25, torch.float32); squeeze_25 = None + convert_element_type_1752 = torch.ops.prims.convert_element_type.default(permute_760, torch.float32); permute_760 = None + view_1394 = torch.ops.aten.view.default(convert_element_type_1751, [2, 8192, 8, 64, 2]); convert_element_type_1751 = None + view_as_complex_88 = torch.ops.aten.view_as_complex.default(view_1394); view_1394 = None + mul_516 = torch.ops.aten.mul.Tensor(view_as_complex_88, _conj); view_as_complex_88 = None + view_1395 = torch.ops.aten.view.default(convert_element_type_1752, [2, 8192, 32, 64, 2]); convert_element_type_1752 = None + view_as_complex_89 = torch.ops.aten.view_as_complex.default(view_1395); view_1395 = None + mul_517 = torch.ops.aten.mul.Tensor(view_as_complex_89, _conj); view_as_complex_89 = None + view_as_real_88 = torch.ops.aten.view_as_real.default(mul_516); mul_516 = None + view_1396 = torch.ops.aten.view.default(view_as_real_88, [2, 8192, 8, 128]); view_as_real_88 = None + convert_element_type_1753 = torch.ops.prims.convert_element_type.default(view_1396, torch.bfloat16); view_1396 = None + view_as_real_89 = torch.ops.aten.view_as_real.default(mul_517); mul_517 = None + view_1397 = torch.ops.aten.view.default(view_as_real_89, [2, 8192, 32, 128]); view_as_real_89 = None + convert_element_type_1754 = torch.ops.prims.convert_element_type.default(view_1397, torch.bfloat16); view_1397 = None + view_1398 = torch.ops.aten.view.default(squeeze_24, [2, 8192, 1024]); squeeze_24 = None + view_1399 = torch.ops.aten.view.default(convert_element_type_1753, [2, 8192, 1024]); convert_element_type_1753 = None + view_1400 = torch.ops.aten.view.default(convert_element_type_1754, [2, 8192, 4096]); convert_element_type_1754 = None + view_1401 = torch.ops.aten.view.default(view_1398, [16384, 1024]); view_1398 = None + permute_761 = torch.ops.aten.permute.default(view_1401, [1, 0]) + mm_403 = torch.ops.aten.mm.default(permute_761, view_649); permute_761 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16); primals_178 = None + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 64, '0'); convert_element_type_637 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + permute_763 = torch.ops.aten.permute.default(permute_211, [1, 0]); permute_211 = None + mm_404 = torch.ops.aten.mm.default(view_1401, permute_763); view_1401 = permute_763 = None + view_1402 = torch.ops.aten.view.default(mm_404, [2, 8192, 4096]); mm_404 = None + convert_element_type_1759 = torch.ops.prims.convert_element_type.default(mm_403, torch.float32); mm_403 = None + reduce_scatter_tensor_115 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1759, 'avg', 64, '0'); convert_element_type_1759 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_115); reduce_scatter_tensor_115 = None + view_1403 = torch.ops.aten.view.default(view_1399, [16384, 1024]); view_1399 = None + permute_765 = torch.ops.aten.permute.default(view_1403, [1, 0]) + mm_405 = torch.ops.aten.mm.default(permute_765, view_649); permute_765 = None + permute_767 = torch.ops.aten.permute.default(permute_210, [1, 0]); permute_210 = None + mm_406 = torch.ops.aten.mm.default(view_1403, permute_767); view_1403 = permute_767 = None + view_1404 = torch.ops.aten.view.default(mm_406, [2, 8192, 4096]); mm_406 = None + add_217 = torch.ops.aten.add.Tensor(view_1402, view_1404); view_1402 = view_1404 = None + convert_element_type_1764 = torch.ops.prims.convert_element_type.default(mm_405, torch.float32); mm_405 = None + reduce_scatter_tensor_116 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1764, 'avg', 64, '0'); convert_element_type_1764 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_116); reduce_scatter_tensor_116 = None + view_1405 = torch.ops.aten.view.default(view_1400, [16384, 4096]); view_1400 = None + permute_769 = torch.ops.aten.permute.default(view_1405, [1, 0]) + mm_407 = torch.ops.aten.mm.default(permute_769, view_649); permute_769 = view_649 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16); primals_176 = None + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 64, '0'); convert_element_type_631 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + permute_771 = torch.ops.aten.permute.default(permute_209, [1, 0]); permute_209 = None + mm_408 = torch.ops.aten.mm.default(view_1405, permute_771); view_1405 = permute_771 = None + view_1406 = torch.ops.aten.view.default(mm_408, [2, 8192, 4096]); mm_408 = None + add_218 = torch.ops.aten.add.Tensor(add_217, view_1406); add_217 = view_1406 = None + convert_element_type_1769 = torch.ops.prims.convert_element_type.default(mm_407, torch.float32); mm_407 = None + reduce_scatter_tensor_117 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1769, 'avg', 64, '0'); convert_element_type_1769 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_117); reduce_scatter_tensor_117 = None + convert_element_type_1770 = torch.ops.prims.convert_element_type.default(add_218, torch.float32); add_218 = None + convert_element_type_1772 = torch.ops.prims.convert_element_type.default(wait_tensor_172, torch.float32); wait_tensor_172 = None + mul_518 = torch.ops.aten.mul.Tensor(convert_element_type_1770, convert_element_type_1772); convert_element_type_1772 = None + mul_520 = torch.ops.aten.mul.Tensor(mul_152, mul_518) + sum_79 = torch.ops.aten.sum.dim_IntList(mul_520, [2], True); mul_520 = None + div_26 = torch.ops.aten.div.Tensor(mul_152, 4096) + mul_521 = torch.ops.aten.mul.Tensor(div_26, sum_79); div_26 = sum_79 = None + sub_39 = torch.ops.aten.sub.Tensor(mul_518, mul_521); mul_518 = mul_521 = None + mul_522 = torch.ops.aten.mul.Tensor(sub_39, rsqrt_38); sub_39 = rsqrt_38 = None + mul_523 = torch.ops.aten.mul.Tensor(convert_element_type_1770, mul_152); convert_element_type_1770 = mul_152 = None + sum_80 = torch.ops.aten.sum.dim_IntList(mul_523, [0, 1]); mul_523 = None + convert_element_type_1773 = torch.ops.prims.convert_element_type.default(mul_522, torch.bfloat16); mul_522 = None + add_219 = torch.ops.aten.add.Tensor(add_216, convert_element_type_1773); add_216 = convert_element_type_1773 = None + convert_element_type_default_39 = torch.ops.prims.convert_element_type.default(sum_80, torch.float32); sum_80 = None + reduce_scatter_tensor_118 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_39, 'avg', 64, '0'); convert_element_type_default_39 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_118); reduce_scatter_tensor_118 = None + view_1407 = torch.ops.aten.view.default(add_219, [16384, 4096]) + permute_773 = torch.ops.aten.permute.default(view_1407, [1, 0]) + permute_204 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_633 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16); primals_170 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 64, '0'); convert_element_type_611 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + view_635 = torch.ops.aten.view.default(view_633, [16384, 4096]); view_633 = None + mm_129 = torch.ops.aten.mm.default(view_635, permute_205) + view_636 = torch.ops.aten.view.default(mm_129, [2, 8192, 4096]); mm_129 = None + add_73 = torch.ops.aten.add.Tensor(add_71, view_636); view_636 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16); primals_171 = None + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 64, '0'); convert_element_type_614 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32); add_73 = None + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_168) + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + view_639 = torch.ops.aten.view.default(convert_element_type_616, [16384, 4096]); convert_element_type_616 = None + view_640 = torch.ops.aten.view.default(mm_130, [2, 8192, 14336]); mm_130 = None + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_640, torch.float32); view_640 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16); primals_173 = None + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 64, '0'); convert_element_type_622 = None + wait_tensor_170 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_170, [1, 0]); wait_tensor_170 = None + mm_131 = torch.ops.aten.mm.default(view_639, permute_207) + view_643 = torch.ops.aten.view.default(mm_131, [2, 8192, 14336]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_643) + view_645 = torch.ops.aten.view.default(mul_151, [16384, 14336]); mul_151 = None + mm_409 = torch.ops.aten.mm.default(permute_773, view_645); permute_773 = view_645 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16); primals_174 = None + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 64, '0'); convert_element_type_625 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_171, [1, 0]); wait_tensor_171 = None + permute_775 = torch.ops.aten.permute.default(permute_208, [1, 0]); permute_208 = None + mm_410 = torch.ops.aten.mm.default(view_1407, permute_775); view_1407 = permute_775 = None + view_1408 = torch.ops.aten.view.default(mm_410, [2, 8192, 14336]); mm_410 = None + convert_element_type_1780 = torch.ops.prims.convert_element_type.default(mm_409, torch.float32); mm_409 = None + reduce_scatter_tensor_119 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1780, 'avg', 64, '0'); convert_element_type_1780 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_119); reduce_scatter_tensor_119 = None + mul_524 = torch.ops.aten.mul.Tensor(view_1408, convert_element_type_621); convert_element_type_621 = None + mul_525 = torch.ops.aten.mul.Tensor(view_1408, view_643); view_1408 = view_643 = None + view_1409 = torch.ops.aten.view.default(mul_524, [16384, 14336]); mul_524 = None + permute_777 = torch.ops.aten.permute.default(view_1409, [1, 0]) + mm_411 = torch.ops.aten.mm.default(permute_777, view_639); permute_777 = None + permute_779 = torch.ops.aten.permute.default(permute_207, [1, 0]); permute_207 = None + mm_412 = torch.ops.aten.mm.default(view_1409, permute_779); view_1409 = permute_779 = None + view_1410 = torch.ops.aten.view.default(mm_412, [2, 8192, 4096]); mm_412 = None + convert_element_type_1785 = torch.ops.prims.convert_element_type.default(mm_411, torch.float32); mm_411 = None + reduce_scatter_tensor_120 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1785, 'avg', 64, '0'); convert_element_type_1785 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_120); reduce_scatter_tensor_120 = None + convert_element_type_1786 = torch.ops.prims.convert_element_type.default(mul_525, torch.float32); mul_525 = None + neg_13 = torch.ops.aten.neg.default(convert_element_type_620) + exp_13 = torch.ops.aten.exp.default(neg_13); neg_13 = None + add_220 = torch.ops.aten.add.Tensor(exp_13, 1); exp_13 = None + reciprocal_13 = torch.ops.aten.reciprocal.default(add_220); add_220 = None + mul_526 = torch.ops.aten.mul.Tensor(reciprocal_13, 1); reciprocal_13 = None + mul_527 = torch.ops.aten.mul.Tensor(convert_element_type_1786, mul_526); convert_element_type_1786 = None + sub_40 = torch.ops.aten.sub.Tensor(1, mul_526); mul_526 = None + mul_528 = torch.ops.aten.mul.Tensor(convert_element_type_620, sub_40); convert_element_type_620 = sub_40 = None + add_221 = torch.ops.aten.add.Tensor(mul_528, 1); mul_528 = None + mul_529 = torch.ops.aten.mul.Tensor(mul_527, add_221); mul_527 = add_221 = None + convert_element_type_1788 = torch.ops.prims.convert_element_type.default(mul_529, torch.bfloat16); mul_529 = None + view_1411 = torch.ops.aten.view.default(convert_element_type_1788, [16384, 14336]); convert_element_type_1788 = None + permute_781 = torch.ops.aten.permute.default(view_1411, [1, 0]) + mm_413 = torch.ops.aten.mm.default(permute_781, view_639); permute_781 = view_639 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16); primals_172 = None + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 64, '0'); convert_element_type_617 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + permute_783 = torch.ops.aten.permute.default(permute_206, [1, 0]); permute_206 = None + mm_414 = torch.ops.aten.mm.default(view_1411, permute_783); view_1411 = permute_783 = None + view_1412 = torch.ops.aten.view.default(mm_414, [2, 8192, 4096]); mm_414 = None + add_222 = torch.ops.aten.add.Tensor(view_1410, view_1412); view_1410 = view_1412 = None + convert_element_type_1793 = torch.ops.prims.convert_element_type.default(mm_413, torch.float32); mm_413 = None + reduce_scatter_tensor_121 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1793, 'avg', 64, '0'); convert_element_type_1793 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_121); reduce_scatter_tensor_121 = None + convert_element_type_1794 = torch.ops.prims.convert_element_type.default(add_222, torch.float32); add_222 = None + convert_element_type_1796 = torch.ops.prims.convert_element_type.default(wait_tensor_168, torch.float32); wait_tensor_168 = None + mul_530 = torch.ops.aten.mul.Tensor(convert_element_type_1794, convert_element_type_1796); convert_element_type_1796 = None + mul_532 = torch.ops.aten.mul.Tensor(mul_148, mul_530) + sum_81 = torch.ops.aten.sum.dim_IntList(mul_532, [2], True); mul_532 = None + div_27 = torch.ops.aten.div.Tensor(mul_148, 4096) + mul_533 = torch.ops.aten.mul.Tensor(div_27, sum_81); div_27 = sum_81 = None + sub_41 = torch.ops.aten.sub.Tensor(mul_530, mul_533); mul_530 = mul_533 = None + mul_534 = torch.ops.aten.mul.Tensor(sub_41, rsqrt_37); sub_41 = rsqrt_37 = None + mul_535 = torch.ops.aten.mul.Tensor(convert_element_type_1794, mul_148); convert_element_type_1794 = mul_148 = None + sum_82 = torch.ops.aten.sum.dim_IntList(mul_535, [0, 1]); mul_535 = None + convert_element_type_1797 = torch.ops.prims.convert_element_type.default(mul_534, torch.bfloat16); mul_534 = None + add_223 = torch.ops.aten.add.Tensor(add_219, convert_element_type_1797); add_219 = convert_element_type_1797 = None + convert_element_type_default_38 = torch.ops.prims.convert_element_type.default(sum_82, torch.float32); sum_82 = None + reduce_scatter_tensor_122 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_38, 'avg', 64, '0'); convert_element_type_default_38 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_122); reduce_scatter_tensor_122 = None + view_1413 = torch.ops.aten.view.default(add_223, [16384, 4096]) + permute_785 = torch.ops.aten.permute.default(view_1413, [1, 0]) + mm_415 = torch.ops.aten.mm.default(permute_785, view_635); permute_785 = view_635 = None + permute_787 = torch.ops.aten.permute.default(permute_205, [1, 0]); permute_205 = None + mm_416 = torch.ops.aten.mm.default(view_1413, permute_787); view_1413 = permute_787 = None + view_1414 = torch.ops.aten.view.default(mm_416, [2, 8192, 4096]); mm_416 = None + convert_element_type_1804 = torch.ops.prims.convert_element_type.default(mm_415, torch.float32); mm_415 = None + reduce_scatter_tensor_123 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1804, 'avg', 64, '0'); convert_element_type_1804 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_123); reduce_scatter_tensor_123 = None + view_1415 = torch.ops.aten.view.default(view_1414, [2, 8192, 32, 128]); view_1414 = None + permute_789 = torch.ops.aten.permute.default(view_1415, [0, 2, 1, 3]); view_1415 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16); primals_166 = None + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 64, '0'); convert_element_type_595 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32); add_71 = None + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_163) + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + view_615 = torch.ops.aten.view.default(convert_element_type_597, [16384, 4096]); convert_element_type_597 = None + view_616 = torch.ops.aten.view.default(mm_126, [2, 8192, 4096]); mm_126 = None + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16); primals_168 = None + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 64, '0'); convert_element_type_601 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_165, [1, 0]); wait_tensor_165 = None + mm_127 = torch.ops.aten.mm.default(view_615, permute_199) + view_619 = torch.ops.aten.view.default(mm_127, [2, 8192, 1024]); mm_127 = None + view_622 = torch.ops.aten.view.default(mm_128, [2, 8192, 1024]); mm_128 = None + view_623 = torch.ops.aten.view.default(view_616, [2, 8192, -1, 128]); view_616 = None + view_624 = torch.ops.aten.view.default(view_619, [2, 8192, -1, 128]); view_619 = None + view_625 = torch.ops.aten.view.default(view_622, [2, 8192, -1, 128]); view_622 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_623, torch.float32); view_623 = None + view_626 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 32, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_626); view_626 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_624, torch.float32); view_624 = None + view_627 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 8, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_627); view_627 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_16); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_629 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 32, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_16); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_630 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 8, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_629, torch.bfloat16); view_629 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_630, torch.bfloat16); view_630 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 8, 4, 128]); unsqueeze_36 = None + clone_36 = torch.ops.aten.clone.default(expand_36, memory_format = torch.contiguous_format); expand_36 = None + view_631 = torch.ops.aten.view.default(clone_36, [2, 8192, 32, 128]); clone_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_625, 3); view_625 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 8, 4, 128]); unsqueeze_37 = None + clone_37 = torch.ops.aten.clone.default(expand_37, memory_format = torch.contiguous_format); expand_37 = None + view_632 = torch.ops.aten.view.default(clone_37, [2, 8192, 32, 128]); clone_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_631, [0, 2, 1, 3]); view_631 = None + permute_203 = torch.ops.aten.permute.default(view_632, [0, 2, 1, 3]); view_632 = None + _scaled_dot_product_cudnn_attention_backward_13 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_789, permute_201, permute_202, permute_203, getitem_162, getitem_163, getitem_168, getitem_169, None, None, None, 8192, 8192, 0.0, True); permute_789 = permute_201 = permute_202 = permute_203 = getitem_162 = getitem_163 = getitem_168 = getitem_169 = None + getitem_327 = _scaled_dot_product_cudnn_attention_backward_13[0] + getitem_328 = _scaled_dot_product_cudnn_attention_backward_13[1] + getitem_329 = _scaled_dot_product_cudnn_attention_backward_13[2]; _scaled_dot_product_cudnn_attention_backward_13 = None + permute_790 = torch.ops.aten.permute.default(getitem_329, [0, 2, 1, 3]); getitem_329 = None + permute_791 = torch.ops.aten.permute.default(getitem_328, [0, 2, 1, 3]); getitem_328 = None + permute_792 = torch.ops.aten.permute.default(getitem_327, [0, 2, 1, 3]); getitem_327 = None + view_1416 = torch.ops.aten.view.default(permute_790, [2, 8192, 8, 4, 128]); permute_790 = None + sum_83 = torch.ops.aten.sum.dim_IntList(view_1416, [3], True); view_1416 = None + squeeze_26 = torch.ops.aten.squeeze.dim(sum_83, 3); sum_83 = None + view_1417 = torch.ops.aten.view.default(permute_791, [2, 8192, 8, 4, 128]); permute_791 = None + sum_84 = torch.ops.aten.sum.dim_IntList(view_1417, [3], True); view_1417 = None + squeeze_27 = torch.ops.aten.squeeze.dim(sum_84, 3); sum_84 = None + convert_element_type_1805 = torch.ops.prims.convert_element_type.default(squeeze_27, torch.float32); squeeze_27 = None + convert_element_type_1806 = torch.ops.prims.convert_element_type.default(permute_792, torch.float32); permute_792 = None + view_1418 = torch.ops.aten.view.default(convert_element_type_1805, [2, 8192, 8, 64, 2]); convert_element_type_1805 = None + view_as_complex_90 = torch.ops.aten.view_as_complex.default(view_1418); view_1418 = None + mul_536 = torch.ops.aten.mul.Tensor(view_as_complex_90, _conj); view_as_complex_90 = None + view_1419 = torch.ops.aten.view.default(convert_element_type_1806, [2, 8192, 32, 64, 2]); convert_element_type_1806 = None + view_as_complex_91 = torch.ops.aten.view_as_complex.default(view_1419); view_1419 = None + mul_537 = torch.ops.aten.mul.Tensor(view_as_complex_91, _conj); view_as_complex_91 = None + view_as_real_90 = torch.ops.aten.view_as_real.default(mul_536); mul_536 = None + view_1420 = torch.ops.aten.view.default(view_as_real_90, [2, 8192, 8, 128]); view_as_real_90 = None + convert_element_type_1807 = torch.ops.prims.convert_element_type.default(view_1420, torch.bfloat16); view_1420 = None + view_as_real_91 = torch.ops.aten.view_as_real.default(mul_537); mul_537 = None + view_1421 = torch.ops.aten.view.default(view_as_real_91, [2, 8192, 32, 128]); view_as_real_91 = None + convert_element_type_1808 = torch.ops.prims.convert_element_type.default(view_1421, torch.bfloat16); view_1421 = None + view_1422 = torch.ops.aten.view.default(squeeze_26, [2, 8192, 1024]); squeeze_26 = None + view_1423 = torch.ops.aten.view.default(convert_element_type_1807, [2, 8192, 1024]); convert_element_type_1807 = None + view_1424 = torch.ops.aten.view.default(convert_element_type_1808, [2, 8192, 4096]); convert_element_type_1808 = None + view_1425 = torch.ops.aten.view.default(view_1422, [16384, 1024]); view_1422 = None + permute_793 = torch.ops.aten.permute.default(view_1425, [1, 0]) + mm_417 = torch.ops.aten.mm.default(permute_793, view_615); permute_793 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16); primals_169 = None + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 64, '0'); convert_element_type_604 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_166, [1, 0]); wait_tensor_166 = None + permute_795 = torch.ops.aten.permute.default(permute_200, [1, 0]); permute_200 = None + mm_418 = torch.ops.aten.mm.default(view_1425, permute_795); view_1425 = permute_795 = None + view_1426 = torch.ops.aten.view.default(mm_418, [2, 8192, 4096]); mm_418 = None + convert_element_type_1813 = torch.ops.prims.convert_element_type.default(mm_417, torch.float32); mm_417 = None + reduce_scatter_tensor_124 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1813, 'avg', 64, '0'); convert_element_type_1813 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_124); reduce_scatter_tensor_124 = None + view_1427 = torch.ops.aten.view.default(view_1423, [16384, 1024]); view_1423 = None + permute_797 = torch.ops.aten.permute.default(view_1427, [1, 0]) + mm_419 = torch.ops.aten.mm.default(permute_797, view_615); permute_797 = None + permute_799 = torch.ops.aten.permute.default(permute_199, [1, 0]); permute_199 = None + mm_420 = torch.ops.aten.mm.default(view_1427, permute_799); view_1427 = permute_799 = None + view_1428 = torch.ops.aten.view.default(mm_420, [2, 8192, 4096]); mm_420 = None + add_224 = torch.ops.aten.add.Tensor(view_1426, view_1428); view_1426 = view_1428 = None + convert_element_type_1818 = torch.ops.prims.convert_element_type.default(mm_419, torch.float32); mm_419 = None + reduce_scatter_tensor_125 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1818, 'avg', 64, '0'); convert_element_type_1818 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_125); reduce_scatter_tensor_125 = None + view_1429 = torch.ops.aten.view.default(view_1424, [16384, 4096]); view_1424 = None + permute_801 = torch.ops.aten.permute.default(view_1429, [1, 0]) + mm_421 = torch.ops.aten.mm.default(permute_801, view_615); permute_801 = view_615 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16); primals_167 = None + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 64, '0'); convert_element_type_598 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_164, [1, 0]); wait_tensor_164 = None + permute_803 = torch.ops.aten.permute.default(permute_198, [1, 0]); permute_198 = None + mm_422 = torch.ops.aten.mm.default(view_1429, permute_803); view_1429 = permute_803 = None + view_1430 = torch.ops.aten.view.default(mm_422, [2, 8192, 4096]); mm_422 = None + add_225 = torch.ops.aten.add.Tensor(add_224, view_1430); add_224 = view_1430 = None + convert_element_type_1823 = torch.ops.prims.convert_element_type.default(mm_421, torch.float32); mm_421 = None + reduce_scatter_tensor_126 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1823, 'avg', 64, '0'); convert_element_type_1823 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_126); reduce_scatter_tensor_126 = None + convert_element_type_1824 = torch.ops.prims.convert_element_type.default(add_225, torch.float32); add_225 = None + convert_element_type_1826 = torch.ops.prims.convert_element_type.default(wait_tensor_163, torch.float32); wait_tensor_163 = None + mul_538 = torch.ops.aten.mul.Tensor(convert_element_type_1824, convert_element_type_1826); convert_element_type_1826 = None + mul_540 = torch.ops.aten.mul.Tensor(mul_144, mul_538) + sum_85 = torch.ops.aten.sum.dim_IntList(mul_540, [2], True); mul_540 = None + div_28 = torch.ops.aten.div.Tensor(mul_144, 4096) + mul_541 = torch.ops.aten.mul.Tensor(div_28, sum_85); div_28 = sum_85 = None + sub_42 = torch.ops.aten.sub.Tensor(mul_538, mul_541); mul_538 = mul_541 = None + mul_542 = torch.ops.aten.mul.Tensor(sub_42, rsqrt_36); sub_42 = rsqrt_36 = None + mul_543 = torch.ops.aten.mul.Tensor(convert_element_type_1824, mul_144); convert_element_type_1824 = mul_144 = None + sum_86 = torch.ops.aten.sum.dim_IntList(mul_543, [0, 1]); mul_543 = None + convert_element_type_1827 = torch.ops.prims.convert_element_type.default(mul_542, torch.bfloat16); mul_542 = None + add_226 = torch.ops.aten.add.Tensor(add_223, convert_element_type_1827); add_223 = convert_element_type_1827 = None + convert_element_type_default_37 = torch.ops.prims.convert_element_type.default(sum_86, torch.float32); sum_86 = None + reduce_scatter_tensor_127 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_37, 'avg', 64, '0'); convert_element_type_default_37 = None + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_127); reduce_scatter_tensor_127 = None + view_1431 = torch.ops.aten.view.default(add_226, [16384, 4096]) + permute_805 = torch.ops.aten.permute.default(view_1431, [1, 0]) + permute_193 = torch.ops.aten.permute.default(getitem_153, [0, 2, 1, 3]) + view_599 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16); primals_161 = None + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 64, '0'); convert_element_type_578 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_158, [1, 0]); wait_tensor_158 = None + view_601 = torch.ops.aten.view.default(view_599, [16384, 4096]); view_599 = None + mm_122 = torch.ops.aten.mm.default(view_601, permute_194) + view_602 = torch.ops.aten.view.default(mm_122, [2, 8192, 4096]); mm_122 = None + add_69 = torch.ops.aten.add.Tensor(add_67, view_602); view_602 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16); primals_162 = None + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 64, '0'); convert_element_type_581 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32); add_69 = None + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_159) + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + view_605 = torch.ops.aten.view.default(convert_element_type_583, [16384, 4096]); convert_element_type_583 = None + view_606 = torch.ops.aten.view.default(mm_123, [2, 8192, 14336]); mm_123 = None + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_606, torch.float32); view_606 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16); primals_164 = None + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 64, '0'); convert_element_type_589 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_124 = torch.ops.aten.mm.default(view_605, permute_196) + view_609 = torch.ops.aten.view.default(mm_124, [2, 8192, 14336]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_609) + view_611 = torch.ops.aten.view.default(mul_143, [16384, 14336]); mul_143 = None + mm_423 = torch.ops.aten.mm.default(permute_805, view_611); permute_805 = view_611 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16); primals_165 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 64, '0'); convert_element_type_592 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + permute_807 = torch.ops.aten.permute.default(permute_197, [1, 0]); permute_197 = None + mm_424 = torch.ops.aten.mm.default(view_1431, permute_807); view_1431 = permute_807 = None + view_1432 = torch.ops.aten.view.default(mm_424, [2, 8192, 14336]); mm_424 = None + convert_element_type_1834 = torch.ops.prims.convert_element_type.default(mm_423, torch.float32); mm_423 = None + reduce_scatter_tensor_128 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1834, 'avg', 64, '0'); convert_element_type_1834 = None + wait_tensor_419 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_128); reduce_scatter_tensor_128 = None + mul_544 = torch.ops.aten.mul.Tensor(view_1432, convert_element_type_588); convert_element_type_588 = None + mul_545 = torch.ops.aten.mul.Tensor(view_1432, view_609); view_1432 = view_609 = None + view_1433 = torch.ops.aten.view.default(mul_544, [16384, 14336]); mul_544 = None + permute_809 = torch.ops.aten.permute.default(view_1433, [1, 0]) + mm_425 = torch.ops.aten.mm.default(permute_809, view_605); permute_809 = None + permute_811 = torch.ops.aten.permute.default(permute_196, [1, 0]); permute_196 = None + mm_426 = torch.ops.aten.mm.default(view_1433, permute_811); view_1433 = permute_811 = None + view_1434 = torch.ops.aten.view.default(mm_426, [2, 8192, 4096]); mm_426 = None + convert_element_type_1839 = torch.ops.prims.convert_element_type.default(mm_425, torch.float32); mm_425 = None + reduce_scatter_tensor_129 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1839, 'avg', 64, '0'); convert_element_type_1839 = None + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_129); reduce_scatter_tensor_129 = None + convert_element_type_1840 = torch.ops.prims.convert_element_type.default(mul_545, torch.float32); mul_545 = None + neg_14 = torch.ops.aten.neg.default(convert_element_type_587) + exp_14 = torch.ops.aten.exp.default(neg_14); neg_14 = None + add_227 = torch.ops.aten.add.Tensor(exp_14, 1); exp_14 = None + reciprocal_14 = torch.ops.aten.reciprocal.default(add_227); add_227 = None + mul_546 = torch.ops.aten.mul.Tensor(reciprocal_14, 1); reciprocal_14 = None + mul_547 = torch.ops.aten.mul.Tensor(convert_element_type_1840, mul_546); convert_element_type_1840 = None + sub_43 = torch.ops.aten.sub.Tensor(1, mul_546); mul_546 = None + mul_548 = torch.ops.aten.mul.Tensor(convert_element_type_587, sub_43); convert_element_type_587 = sub_43 = None + add_228 = torch.ops.aten.add.Tensor(mul_548, 1); mul_548 = None + mul_549 = torch.ops.aten.mul.Tensor(mul_547, add_228); mul_547 = add_228 = None + convert_element_type_1842 = torch.ops.prims.convert_element_type.default(mul_549, torch.bfloat16); mul_549 = None + view_1435 = torch.ops.aten.view.default(convert_element_type_1842, [16384, 14336]); convert_element_type_1842 = None + permute_813 = torch.ops.aten.permute.default(view_1435, [1, 0]) + mm_427 = torch.ops.aten.mm.default(permute_813, view_605); permute_813 = view_605 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16); primals_163 = None + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 64, '0'); convert_element_type_584 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + permute_815 = torch.ops.aten.permute.default(permute_195, [1, 0]); permute_195 = None + mm_428 = torch.ops.aten.mm.default(view_1435, permute_815); view_1435 = permute_815 = None + view_1436 = torch.ops.aten.view.default(mm_428, [2, 8192, 4096]); mm_428 = None + add_229 = torch.ops.aten.add.Tensor(view_1434, view_1436); view_1434 = view_1436 = None + convert_element_type_1847 = torch.ops.prims.convert_element_type.default(mm_427, torch.float32); mm_427 = None + reduce_scatter_tensor_130 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1847, 'avg', 64, '0'); convert_element_type_1847 = None + wait_tensor_421 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_130); reduce_scatter_tensor_130 = None + convert_element_type_1848 = torch.ops.prims.convert_element_type.default(add_229, torch.float32); add_229 = None + convert_element_type_1850 = torch.ops.prims.convert_element_type.default(wait_tensor_159, torch.float32); wait_tensor_159 = None + mul_550 = torch.ops.aten.mul.Tensor(convert_element_type_1848, convert_element_type_1850); convert_element_type_1850 = None + mul_552 = torch.ops.aten.mul.Tensor(mul_140, mul_550) + sum_87 = torch.ops.aten.sum.dim_IntList(mul_552, [2], True); mul_552 = None + div_29 = torch.ops.aten.div.Tensor(mul_140, 4096) + mul_553 = torch.ops.aten.mul.Tensor(div_29, sum_87); div_29 = sum_87 = None + sub_44 = torch.ops.aten.sub.Tensor(mul_550, mul_553); mul_550 = mul_553 = None + mul_554 = torch.ops.aten.mul.Tensor(sub_44, rsqrt_35); sub_44 = rsqrt_35 = None + mul_555 = torch.ops.aten.mul.Tensor(convert_element_type_1848, mul_140); convert_element_type_1848 = mul_140 = None + sum_88 = torch.ops.aten.sum.dim_IntList(mul_555, [0, 1]); mul_555 = None + convert_element_type_1851 = torch.ops.prims.convert_element_type.default(mul_554, torch.bfloat16); mul_554 = None + add_230 = torch.ops.aten.add.Tensor(add_226, convert_element_type_1851); add_226 = convert_element_type_1851 = None + convert_element_type_default_36 = torch.ops.prims.convert_element_type.default(sum_88, torch.float32); sum_88 = None + reduce_scatter_tensor_131 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_36, 'avg', 64, '0'); convert_element_type_default_36 = None + wait_tensor_422 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_131); reduce_scatter_tensor_131 = None + view_1437 = torch.ops.aten.view.default(add_230, [16384, 4096]) + permute_817 = torch.ops.aten.permute.default(view_1437, [1, 0]) + mm_429 = torch.ops.aten.mm.default(permute_817, view_601); permute_817 = view_601 = None + permute_819 = torch.ops.aten.permute.default(permute_194, [1, 0]); permute_194 = None + mm_430 = torch.ops.aten.mm.default(view_1437, permute_819); view_1437 = permute_819 = None + view_1438 = torch.ops.aten.view.default(mm_430, [2, 8192, 4096]); mm_430 = None + convert_element_type_1858 = torch.ops.prims.convert_element_type.default(mm_429, torch.float32); mm_429 = None + reduce_scatter_tensor_132 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1858, 'avg', 64, '0'); convert_element_type_1858 = None + wait_tensor_423 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_132); reduce_scatter_tensor_132 = None + view_1439 = torch.ops.aten.view.default(view_1438, [2, 8192, 32, 128]); view_1438 = None + permute_821 = torch.ops.aten.permute.default(view_1439, [0, 2, 1, 3]); view_1439 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16); primals_157 = None + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 64, '0'); convert_element_type_562 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32); add_67 = None + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_154) + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + view_581 = torch.ops.aten.view.default(convert_element_type_564, [16384, 4096]); convert_element_type_564 = None + view_582 = torch.ops.aten.view.default(mm_119, [2, 8192, 4096]); mm_119 = None + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16); primals_159 = None + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 64, '0'); convert_element_type_568 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + mm_120 = torch.ops.aten.mm.default(view_581, permute_188) + view_585 = torch.ops.aten.view.default(mm_120, [2, 8192, 1024]); mm_120 = None + view_588 = torch.ops.aten.view.default(mm_121, [2, 8192, 1024]); mm_121 = None + view_589 = torch.ops.aten.view.default(view_582, [2, 8192, -1, 128]); view_582 = None + view_590 = torch.ops.aten.view.default(view_585, [2, 8192, -1, 128]); view_585 = None + view_591 = torch.ops.aten.view.default(view_588, [2, 8192, -1, 128]); view_588 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_589, torch.float32); view_589 = None + view_592 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 32, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_592); view_592 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_590, torch.float32); view_590 = None + view_593 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 8, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_593); view_593 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_16); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_595 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 32, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_16); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_596 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 8, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_595, torch.bfloat16); view_595 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_596, torch.bfloat16); view_596 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 8, 4, 128]); unsqueeze_34 = None + clone_34 = torch.ops.aten.clone.default(expand_34, memory_format = torch.contiguous_format); expand_34 = None + view_597 = torch.ops.aten.view.default(clone_34, [2, 8192, 32, 128]); clone_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_591, 3); view_591 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 8, 4, 128]); unsqueeze_35 = None + clone_35 = torch.ops.aten.clone.default(expand_35, memory_format = torch.contiguous_format); expand_35 = None + view_598 = torch.ops.aten.view.default(clone_35, [2, 8192, 32, 128]); clone_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_597, [0, 2, 1, 3]); view_597 = None + permute_192 = torch.ops.aten.permute.default(view_598, [0, 2, 1, 3]); view_598 = None + _scaled_dot_product_cudnn_attention_backward_14 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_821, permute_190, permute_191, permute_192, getitem_153, getitem_154, getitem_159, getitem_160, None, None, None, 8192, 8192, 0.0, True); permute_821 = permute_190 = permute_191 = permute_192 = getitem_153 = getitem_154 = getitem_159 = getitem_160 = None + getitem_330 = _scaled_dot_product_cudnn_attention_backward_14[0] + getitem_331 = _scaled_dot_product_cudnn_attention_backward_14[1] + getitem_332 = _scaled_dot_product_cudnn_attention_backward_14[2]; _scaled_dot_product_cudnn_attention_backward_14 = None + permute_822 = torch.ops.aten.permute.default(getitem_332, [0, 2, 1, 3]); getitem_332 = None + permute_823 = torch.ops.aten.permute.default(getitem_331, [0, 2, 1, 3]); getitem_331 = None + permute_824 = torch.ops.aten.permute.default(getitem_330, [0, 2, 1, 3]); getitem_330 = None + view_1440 = torch.ops.aten.view.default(permute_822, [2, 8192, 8, 4, 128]); permute_822 = None + sum_89 = torch.ops.aten.sum.dim_IntList(view_1440, [3], True); view_1440 = None + squeeze_28 = torch.ops.aten.squeeze.dim(sum_89, 3); sum_89 = None + view_1441 = torch.ops.aten.view.default(permute_823, [2, 8192, 8, 4, 128]); permute_823 = None + sum_90 = torch.ops.aten.sum.dim_IntList(view_1441, [3], True); view_1441 = None + squeeze_29 = torch.ops.aten.squeeze.dim(sum_90, 3); sum_90 = None + convert_element_type_1859 = torch.ops.prims.convert_element_type.default(squeeze_29, torch.float32); squeeze_29 = None + convert_element_type_1860 = torch.ops.prims.convert_element_type.default(permute_824, torch.float32); permute_824 = None + view_1442 = torch.ops.aten.view.default(convert_element_type_1859, [2, 8192, 8, 64, 2]); convert_element_type_1859 = None + view_as_complex_92 = torch.ops.aten.view_as_complex.default(view_1442); view_1442 = None + mul_556 = torch.ops.aten.mul.Tensor(view_as_complex_92, _conj); view_as_complex_92 = None + view_1443 = torch.ops.aten.view.default(convert_element_type_1860, [2, 8192, 32, 64, 2]); convert_element_type_1860 = None + view_as_complex_93 = torch.ops.aten.view_as_complex.default(view_1443); view_1443 = None + mul_557 = torch.ops.aten.mul.Tensor(view_as_complex_93, _conj); view_as_complex_93 = None + view_as_real_92 = torch.ops.aten.view_as_real.default(mul_556); mul_556 = None + view_1444 = torch.ops.aten.view.default(view_as_real_92, [2, 8192, 8, 128]); view_as_real_92 = None + convert_element_type_1861 = torch.ops.prims.convert_element_type.default(view_1444, torch.bfloat16); view_1444 = None + view_as_real_93 = torch.ops.aten.view_as_real.default(mul_557); mul_557 = None + view_1445 = torch.ops.aten.view.default(view_as_real_93, [2, 8192, 32, 128]); view_as_real_93 = None + convert_element_type_1862 = torch.ops.prims.convert_element_type.default(view_1445, torch.bfloat16); view_1445 = None + view_1446 = torch.ops.aten.view.default(squeeze_28, [2, 8192, 1024]); squeeze_28 = None + view_1447 = torch.ops.aten.view.default(convert_element_type_1861, [2, 8192, 1024]); convert_element_type_1861 = None + view_1448 = torch.ops.aten.view.default(convert_element_type_1862, [2, 8192, 4096]); convert_element_type_1862 = None + view_1449 = torch.ops.aten.view.default(view_1446, [16384, 1024]); view_1446 = None + permute_825 = torch.ops.aten.permute.default(view_1449, [1, 0]) + mm_431 = torch.ops.aten.mm.default(permute_825, view_581); permute_825 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16); primals_160 = None + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 64, '0'); convert_element_type_571 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_157, [1, 0]); wait_tensor_157 = None + permute_827 = torch.ops.aten.permute.default(permute_189, [1, 0]); permute_189 = None + mm_432 = torch.ops.aten.mm.default(view_1449, permute_827); view_1449 = permute_827 = None + view_1450 = torch.ops.aten.view.default(mm_432, [2, 8192, 4096]); mm_432 = None + convert_element_type_1867 = torch.ops.prims.convert_element_type.default(mm_431, torch.float32); mm_431 = None + reduce_scatter_tensor_133 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1867, 'avg', 64, '0'); convert_element_type_1867 = None + wait_tensor_424 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_133); reduce_scatter_tensor_133 = None + view_1451 = torch.ops.aten.view.default(view_1447, [16384, 1024]); view_1447 = None + permute_829 = torch.ops.aten.permute.default(view_1451, [1, 0]) + mm_433 = torch.ops.aten.mm.default(permute_829, view_581); permute_829 = None + permute_831 = torch.ops.aten.permute.default(permute_188, [1, 0]); permute_188 = None + mm_434 = torch.ops.aten.mm.default(view_1451, permute_831); view_1451 = permute_831 = None + view_1452 = torch.ops.aten.view.default(mm_434, [2, 8192, 4096]); mm_434 = None + add_231 = torch.ops.aten.add.Tensor(view_1450, view_1452); view_1450 = view_1452 = None + convert_element_type_1872 = torch.ops.prims.convert_element_type.default(mm_433, torch.float32); mm_433 = None + reduce_scatter_tensor_134 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1872, 'avg', 64, '0'); convert_element_type_1872 = None + wait_tensor_425 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_134); reduce_scatter_tensor_134 = None + view_1453 = torch.ops.aten.view.default(view_1448, [16384, 4096]); view_1448 = None + permute_833 = torch.ops.aten.permute.default(view_1453, [1, 0]) + mm_435 = torch.ops.aten.mm.default(permute_833, view_581); permute_833 = view_581 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16); primals_158 = None + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 64, '0'); convert_element_type_565 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + permute_835 = torch.ops.aten.permute.default(permute_187, [1, 0]); permute_187 = None + mm_436 = torch.ops.aten.mm.default(view_1453, permute_835); view_1453 = permute_835 = None + view_1454 = torch.ops.aten.view.default(mm_436, [2, 8192, 4096]); mm_436 = None + add_232 = torch.ops.aten.add.Tensor(add_231, view_1454); add_231 = view_1454 = None + convert_element_type_1877 = torch.ops.prims.convert_element_type.default(mm_435, torch.float32); mm_435 = None + reduce_scatter_tensor_135 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1877, 'avg', 64, '0'); convert_element_type_1877 = None + wait_tensor_426 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_135); reduce_scatter_tensor_135 = None + convert_element_type_1878 = torch.ops.prims.convert_element_type.default(add_232, torch.float32); add_232 = None + convert_element_type_1880 = torch.ops.prims.convert_element_type.default(wait_tensor_154, torch.float32); wait_tensor_154 = None + mul_558 = torch.ops.aten.mul.Tensor(convert_element_type_1878, convert_element_type_1880); convert_element_type_1880 = None + mul_560 = torch.ops.aten.mul.Tensor(mul_136, mul_558) + sum_91 = torch.ops.aten.sum.dim_IntList(mul_560, [2], True); mul_560 = None + div_30 = torch.ops.aten.div.Tensor(mul_136, 4096) + mul_561 = torch.ops.aten.mul.Tensor(div_30, sum_91); div_30 = sum_91 = None + sub_45 = torch.ops.aten.sub.Tensor(mul_558, mul_561); mul_558 = mul_561 = None + mul_562 = torch.ops.aten.mul.Tensor(sub_45, rsqrt_34); sub_45 = rsqrt_34 = None + mul_563 = torch.ops.aten.mul.Tensor(convert_element_type_1878, mul_136); convert_element_type_1878 = mul_136 = None + sum_92 = torch.ops.aten.sum.dim_IntList(mul_563, [0, 1]); mul_563 = None + convert_element_type_1881 = torch.ops.prims.convert_element_type.default(mul_562, torch.bfloat16); mul_562 = None + add_233 = torch.ops.aten.add.Tensor(add_230, convert_element_type_1881); add_230 = convert_element_type_1881 = None + convert_element_type_default_35 = torch.ops.prims.convert_element_type.default(sum_92, torch.float32); sum_92 = None + reduce_scatter_tensor_136 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_35, 'avg', 64, '0'); convert_element_type_default_35 = None + wait_tensor_427 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_136); reduce_scatter_tensor_136 = None + view_1455 = torch.ops.aten.view.default(add_233, [16384, 4096]) + permute_837 = torch.ops.aten.permute.default(view_1455, [1, 0]) + permute_182 = torch.ops.aten.permute.default(getitem_144, [0, 2, 1, 3]) + view_565 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16); primals_152 = None + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 64, '0'); convert_element_type_545 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + view_567 = torch.ops.aten.view.default(view_565, [16384, 4096]); view_565 = None + mm_115 = torch.ops.aten.mm.default(view_567, permute_183) + view_568 = torch.ops.aten.view.default(mm_115, [2, 8192, 4096]); mm_115 = None + add_65 = torch.ops.aten.add.Tensor(add_63, view_568); view_568 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16); primals_153 = None + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 64, '0'); convert_element_type_548 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32); add_65 = None + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_150) + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + view_571 = torch.ops.aten.view.default(convert_element_type_550, [16384, 4096]); convert_element_type_550 = None + view_572 = torch.ops.aten.view.default(mm_116, [2, 8192, 14336]); mm_116 = None + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_572, torch.float32); view_572 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16); primals_155 = None + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 64, '0'); convert_element_type_556 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_152, [1, 0]); wait_tensor_152 = None + mm_117 = torch.ops.aten.mm.default(view_571, permute_185) + view_575 = torch.ops.aten.view.default(mm_117, [2, 8192, 14336]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_575) + view_577 = torch.ops.aten.view.default(mul_135, [16384, 14336]); mul_135 = None + mm_437 = torch.ops.aten.mm.default(permute_837, view_577); permute_837 = view_577 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16); primals_156 = None + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 64, '0'); convert_element_type_559 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_153, [1, 0]); wait_tensor_153 = None + permute_839 = torch.ops.aten.permute.default(permute_186, [1, 0]); permute_186 = None + mm_438 = torch.ops.aten.mm.default(view_1455, permute_839); view_1455 = permute_839 = None + view_1456 = torch.ops.aten.view.default(mm_438, [2, 8192, 14336]); mm_438 = None + convert_element_type_1888 = torch.ops.prims.convert_element_type.default(mm_437, torch.float32); mm_437 = None + reduce_scatter_tensor_137 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1888, 'avg', 64, '0'); convert_element_type_1888 = None + wait_tensor_428 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_137); reduce_scatter_tensor_137 = None + mul_564 = torch.ops.aten.mul.Tensor(view_1456, convert_element_type_555); convert_element_type_555 = None + mul_565 = torch.ops.aten.mul.Tensor(view_1456, view_575); view_1456 = view_575 = None + view_1457 = torch.ops.aten.view.default(mul_564, [16384, 14336]); mul_564 = None + permute_841 = torch.ops.aten.permute.default(view_1457, [1, 0]) + mm_439 = torch.ops.aten.mm.default(permute_841, view_571); permute_841 = None + permute_843 = torch.ops.aten.permute.default(permute_185, [1, 0]); permute_185 = None + mm_440 = torch.ops.aten.mm.default(view_1457, permute_843); view_1457 = permute_843 = None + view_1458 = torch.ops.aten.view.default(mm_440, [2, 8192, 4096]); mm_440 = None + convert_element_type_1893 = torch.ops.prims.convert_element_type.default(mm_439, torch.float32); mm_439 = None + reduce_scatter_tensor_138 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1893, 'avg', 64, '0'); convert_element_type_1893 = None + wait_tensor_429 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_138); reduce_scatter_tensor_138 = None + convert_element_type_1894 = torch.ops.prims.convert_element_type.default(mul_565, torch.float32); mul_565 = None + neg_15 = torch.ops.aten.neg.default(convert_element_type_554) + exp_15 = torch.ops.aten.exp.default(neg_15); neg_15 = None + add_234 = torch.ops.aten.add.Tensor(exp_15, 1); exp_15 = None + reciprocal_15 = torch.ops.aten.reciprocal.default(add_234); add_234 = None + mul_566 = torch.ops.aten.mul.Tensor(reciprocal_15, 1); reciprocal_15 = None + mul_567 = torch.ops.aten.mul.Tensor(convert_element_type_1894, mul_566); convert_element_type_1894 = None + sub_46 = torch.ops.aten.sub.Tensor(1, mul_566); mul_566 = None + mul_568 = torch.ops.aten.mul.Tensor(convert_element_type_554, sub_46); convert_element_type_554 = sub_46 = None + add_235 = torch.ops.aten.add.Tensor(mul_568, 1); mul_568 = None + mul_569 = torch.ops.aten.mul.Tensor(mul_567, add_235); mul_567 = add_235 = None + convert_element_type_1896 = torch.ops.prims.convert_element_type.default(mul_569, torch.bfloat16); mul_569 = None + view_1459 = torch.ops.aten.view.default(convert_element_type_1896, [16384, 14336]); convert_element_type_1896 = None + permute_845 = torch.ops.aten.permute.default(view_1459, [1, 0]) + mm_441 = torch.ops.aten.mm.default(permute_845, view_571); permute_845 = view_571 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16); primals_154 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 64, '0'); convert_element_type_551 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_151, [1, 0]); wait_tensor_151 = None + permute_847 = torch.ops.aten.permute.default(permute_184, [1, 0]); permute_184 = None + mm_442 = torch.ops.aten.mm.default(view_1459, permute_847); view_1459 = permute_847 = None + view_1460 = torch.ops.aten.view.default(mm_442, [2, 8192, 4096]); mm_442 = None + add_236 = torch.ops.aten.add.Tensor(view_1458, view_1460); view_1458 = view_1460 = None + convert_element_type_1901 = torch.ops.prims.convert_element_type.default(mm_441, torch.float32); mm_441 = None + reduce_scatter_tensor_139 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1901, 'avg', 64, '0'); convert_element_type_1901 = None + wait_tensor_430 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_139); reduce_scatter_tensor_139 = None + convert_element_type_1902 = torch.ops.prims.convert_element_type.default(add_236, torch.float32); add_236 = None + convert_element_type_1904 = torch.ops.prims.convert_element_type.default(wait_tensor_150, torch.float32); wait_tensor_150 = None + mul_570 = torch.ops.aten.mul.Tensor(convert_element_type_1902, convert_element_type_1904); convert_element_type_1904 = None + mul_572 = torch.ops.aten.mul.Tensor(mul_132, mul_570) + sum_93 = torch.ops.aten.sum.dim_IntList(mul_572, [2], True); mul_572 = None + div_31 = torch.ops.aten.div.Tensor(mul_132, 4096) + mul_573 = torch.ops.aten.mul.Tensor(div_31, sum_93); div_31 = sum_93 = None + sub_47 = torch.ops.aten.sub.Tensor(mul_570, mul_573); mul_570 = mul_573 = None + mul_574 = torch.ops.aten.mul.Tensor(sub_47, rsqrt_33); sub_47 = rsqrt_33 = None + mul_575 = torch.ops.aten.mul.Tensor(convert_element_type_1902, mul_132); convert_element_type_1902 = mul_132 = None + sum_94 = torch.ops.aten.sum.dim_IntList(mul_575, [0, 1]); mul_575 = None + convert_element_type_1905 = torch.ops.prims.convert_element_type.default(mul_574, torch.bfloat16); mul_574 = None + add_237 = torch.ops.aten.add.Tensor(add_233, convert_element_type_1905); add_233 = convert_element_type_1905 = None + convert_element_type_default_34 = torch.ops.prims.convert_element_type.default(sum_94, torch.float32); sum_94 = None + reduce_scatter_tensor_140 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_34, 'avg', 64, '0'); convert_element_type_default_34 = None + wait_tensor_431 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_140); reduce_scatter_tensor_140 = None + view_1461 = torch.ops.aten.view.default(add_237, [16384, 4096]) + permute_849 = torch.ops.aten.permute.default(view_1461, [1, 0]) + mm_443 = torch.ops.aten.mm.default(permute_849, view_567); permute_849 = view_567 = None + permute_851 = torch.ops.aten.permute.default(permute_183, [1, 0]); permute_183 = None + mm_444 = torch.ops.aten.mm.default(view_1461, permute_851); view_1461 = permute_851 = None + view_1462 = torch.ops.aten.view.default(mm_444, [2, 8192, 4096]); mm_444 = None + convert_element_type_1912 = torch.ops.prims.convert_element_type.default(mm_443, torch.float32); mm_443 = None + reduce_scatter_tensor_141 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1912, 'avg', 64, '0'); convert_element_type_1912 = None + wait_tensor_432 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_141); reduce_scatter_tensor_141 = None + view_1463 = torch.ops.aten.view.default(view_1462, [2, 8192, 32, 128]); view_1462 = None + permute_853 = torch.ops.aten.permute.default(view_1463, [0, 2, 1, 3]); view_1463 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16); primals_148 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 64, '0'); convert_element_type_529 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32); add_63 = None + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_145) + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + view_547 = torch.ops.aten.view.default(convert_element_type_531, [16384, 4096]); convert_element_type_531 = None + view_548 = torch.ops.aten.view.default(mm_112, [2, 8192, 4096]); mm_112 = None + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16); primals_150 = None + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 64, '0'); convert_element_type_535 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + mm_113 = torch.ops.aten.mm.default(view_547, permute_177) + view_551 = torch.ops.aten.view.default(mm_113, [2, 8192, 1024]); mm_113 = None + view_554 = torch.ops.aten.view.default(mm_114, [2, 8192, 1024]); mm_114 = None + view_555 = torch.ops.aten.view.default(view_548, [2, 8192, -1, 128]); view_548 = None + view_556 = torch.ops.aten.view.default(view_551, [2, 8192, -1, 128]); view_551 = None + view_557 = torch.ops.aten.view.default(view_554, [2, 8192, -1, 128]); view_554 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_555, torch.float32); view_555 = None + view_558 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 32, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_558); view_558 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_556, torch.float32); view_556 = None + view_559 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 8, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_559); view_559 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_16); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_561 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 32, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_16); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_562 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 8, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_561, torch.bfloat16); view_561 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_562, torch.bfloat16); view_562 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 8, 4, 128]); unsqueeze_32 = None + clone_32 = torch.ops.aten.clone.default(expand_32, memory_format = torch.contiguous_format); expand_32 = None + view_563 = torch.ops.aten.view.default(clone_32, [2, 8192, 32, 128]); clone_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_557, 3); view_557 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 8, 4, 128]); unsqueeze_33 = None + clone_33 = torch.ops.aten.clone.default(expand_33, memory_format = torch.contiguous_format); expand_33 = None + view_564 = torch.ops.aten.view.default(clone_33, [2, 8192, 32, 128]); clone_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_563, [0, 2, 1, 3]); view_563 = None + permute_181 = torch.ops.aten.permute.default(view_564, [0, 2, 1, 3]); view_564 = None + _scaled_dot_product_cudnn_attention_backward_15 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_853, permute_179, permute_180, permute_181, getitem_144, getitem_145, getitem_150, getitem_151, None, None, None, 8192, 8192, 0.0, True); permute_853 = permute_179 = permute_180 = permute_181 = getitem_144 = getitem_145 = getitem_150 = getitem_151 = None + getitem_333 = _scaled_dot_product_cudnn_attention_backward_15[0] + getitem_334 = _scaled_dot_product_cudnn_attention_backward_15[1] + getitem_335 = _scaled_dot_product_cudnn_attention_backward_15[2]; _scaled_dot_product_cudnn_attention_backward_15 = None + permute_854 = torch.ops.aten.permute.default(getitem_335, [0, 2, 1, 3]); getitem_335 = None + permute_855 = torch.ops.aten.permute.default(getitem_334, [0, 2, 1, 3]); getitem_334 = None + permute_856 = torch.ops.aten.permute.default(getitem_333, [0, 2, 1, 3]); getitem_333 = None + view_1464 = torch.ops.aten.view.default(permute_854, [2, 8192, 8, 4, 128]); permute_854 = None + sum_95 = torch.ops.aten.sum.dim_IntList(view_1464, [3], True); view_1464 = None + squeeze_30 = torch.ops.aten.squeeze.dim(sum_95, 3); sum_95 = None + view_1465 = torch.ops.aten.view.default(permute_855, [2, 8192, 8, 4, 128]); permute_855 = None + sum_96 = torch.ops.aten.sum.dim_IntList(view_1465, [3], True); view_1465 = None + squeeze_31 = torch.ops.aten.squeeze.dim(sum_96, 3); sum_96 = None + convert_element_type_1913 = torch.ops.prims.convert_element_type.default(squeeze_31, torch.float32); squeeze_31 = None + convert_element_type_1914 = torch.ops.prims.convert_element_type.default(permute_856, torch.float32); permute_856 = None + view_1466 = torch.ops.aten.view.default(convert_element_type_1913, [2, 8192, 8, 64, 2]); convert_element_type_1913 = None + view_as_complex_94 = torch.ops.aten.view_as_complex.default(view_1466); view_1466 = None + mul_576 = torch.ops.aten.mul.Tensor(view_as_complex_94, _conj); view_as_complex_94 = None + view_1467 = torch.ops.aten.view.default(convert_element_type_1914, [2, 8192, 32, 64, 2]); convert_element_type_1914 = None + view_as_complex_95 = torch.ops.aten.view_as_complex.default(view_1467); view_1467 = None + mul_577 = torch.ops.aten.mul.Tensor(view_as_complex_95, _conj); view_as_complex_95 = None + view_as_real_94 = torch.ops.aten.view_as_real.default(mul_576); mul_576 = None + view_1468 = torch.ops.aten.view.default(view_as_real_94, [2, 8192, 8, 128]); view_as_real_94 = None + convert_element_type_1915 = torch.ops.prims.convert_element_type.default(view_1468, torch.bfloat16); view_1468 = None + view_as_real_95 = torch.ops.aten.view_as_real.default(mul_577); mul_577 = None + view_1469 = torch.ops.aten.view.default(view_as_real_95, [2, 8192, 32, 128]); view_as_real_95 = None + convert_element_type_1916 = torch.ops.prims.convert_element_type.default(view_1469, torch.bfloat16); view_1469 = None + view_1470 = torch.ops.aten.view.default(squeeze_30, [2, 8192, 1024]); squeeze_30 = None + view_1471 = torch.ops.aten.view.default(convert_element_type_1915, [2, 8192, 1024]); convert_element_type_1915 = None + view_1472 = torch.ops.aten.view.default(convert_element_type_1916, [2, 8192, 4096]); convert_element_type_1916 = None + view_1473 = torch.ops.aten.view.default(view_1470, [16384, 1024]); view_1470 = None + permute_857 = torch.ops.aten.permute.default(view_1473, [1, 0]) + mm_445 = torch.ops.aten.mm.default(permute_857, view_547); permute_857 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16); primals_151 = None + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 64, '0'); convert_element_type_538 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + permute_859 = torch.ops.aten.permute.default(permute_178, [1, 0]); permute_178 = None + mm_446 = torch.ops.aten.mm.default(view_1473, permute_859); view_1473 = permute_859 = None + view_1474 = torch.ops.aten.view.default(mm_446, [2, 8192, 4096]); mm_446 = None + convert_element_type_1921 = torch.ops.prims.convert_element_type.default(mm_445, torch.float32); mm_445 = None + reduce_scatter_tensor_142 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1921, 'avg', 64, '0'); convert_element_type_1921 = None + wait_tensor_433 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_142); reduce_scatter_tensor_142 = None + view_1475 = torch.ops.aten.view.default(view_1471, [16384, 1024]); view_1471 = None + permute_861 = torch.ops.aten.permute.default(view_1475, [1, 0]) + mm_447 = torch.ops.aten.mm.default(permute_861, view_547); permute_861 = None + permute_863 = torch.ops.aten.permute.default(permute_177, [1, 0]); permute_177 = None + mm_448 = torch.ops.aten.mm.default(view_1475, permute_863); view_1475 = permute_863 = None + view_1476 = torch.ops.aten.view.default(mm_448, [2, 8192, 4096]); mm_448 = None + add_238 = torch.ops.aten.add.Tensor(view_1474, view_1476); view_1474 = view_1476 = None + convert_element_type_1926 = torch.ops.prims.convert_element_type.default(mm_447, torch.float32); mm_447 = None + reduce_scatter_tensor_143 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1926, 'avg', 64, '0'); convert_element_type_1926 = None + wait_tensor_434 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_143); reduce_scatter_tensor_143 = None + view_1477 = torch.ops.aten.view.default(view_1472, [16384, 4096]); view_1472 = None + permute_865 = torch.ops.aten.permute.default(view_1477, [1, 0]) + mm_449 = torch.ops.aten.mm.default(permute_865, view_547); permute_865 = view_547 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16); primals_149 = None + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 64, '0'); convert_element_type_532 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_146, [1, 0]); wait_tensor_146 = None + permute_867 = torch.ops.aten.permute.default(permute_176, [1, 0]); permute_176 = None + mm_450 = torch.ops.aten.mm.default(view_1477, permute_867); view_1477 = permute_867 = None + view_1478 = torch.ops.aten.view.default(mm_450, [2, 8192, 4096]); mm_450 = None + add_239 = torch.ops.aten.add.Tensor(add_238, view_1478); add_238 = view_1478 = None + convert_element_type_1931 = torch.ops.prims.convert_element_type.default(mm_449, torch.float32); mm_449 = None + reduce_scatter_tensor_144 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1931, 'avg', 64, '0'); convert_element_type_1931 = None + wait_tensor_435 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_144); reduce_scatter_tensor_144 = None + convert_element_type_1932 = torch.ops.prims.convert_element_type.default(add_239, torch.float32); add_239 = None + convert_element_type_1934 = torch.ops.prims.convert_element_type.default(wait_tensor_145, torch.float32); wait_tensor_145 = None + mul_578 = torch.ops.aten.mul.Tensor(convert_element_type_1932, convert_element_type_1934); convert_element_type_1934 = None + mul_580 = torch.ops.aten.mul.Tensor(mul_128, mul_578) + sum_97 = torch.ops.aten.sum.dim_IntList(mul_580, [2], True); mul_580 = None + div_32 = torch.ops.aten.div.Tensor(mul_128, 4096) + mul_581 = torch.ops.aten.mul.Tensor(div_32, sum_97); div_32 = sum_97 = None + sub_48 = torch.ops.aten.sub.Tensor(mul_578, mul_581); mul_578 = mul_581 = None + mul_582 = torch.ops.aten.mul.Tensor(sub_48, rsqrt_32); sub_48 = rsqrt_32 = None + mul_583 = torch.ops.aten.mul.Tensor(convert_element_type_1932, mul_128); convert_element_type_1932 = mul_128 = None + sum_98 = torch.ops.aten.sum.dim_IntList(mul_583, [0, 1]); mul_583 = None + convert_element_type_1935 = torch.ops.prims.convert_element_type.default(mul_582, torch.bfloat16); mul_582 = None + add_240 = torch.ops.aten.add.Tensor(add_237, convert_element_type_1935); add_237 = convert_element_type_1935 = None + convert_element_type_default_33 = torch.ops.prims.convert_element_type.default(sum_98, torch.float32); sum_98 = None + reduce_scatter_tensor_145 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_33, 'avg', 64, '0'); convert_element_type_default_33 = None + wait_tensor_436 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_145); reduce_scatter_tensor_145 = None + view_1479 = torch.ops.aten.view.default(add_240, [16384, 4096]) + permute_869 = torch.ops.aten.permute.default(view_1479, [1, 0]) + permute_171 = torch.ops.aten.permute.default(getitem_135, [0, 2, 1, 3]) + view_531 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16); primals_143 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 64, '0'); convert_element_type_512 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_140, [1, 0]); wait_tensor_140 = None + view_533 = torch.ops.aten.view.default(view_531, [16384, 4096]); view_531 = None + mm_108 = torch.ops.aten.mm.default(view_533, permute_172) + view_534 = torch.ops.aten.view.default(mm_108, [2, 8192, 4096]); mm_108 = None + add_61 = torch.ops.aten.add.Tensor(add_59, view_534); view_534 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16); primals_144 = None + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 64, '0'); convert_element_type_515 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32); add_61 = None + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_141) + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + view_537 = torch.ops.aten.view.default(convert_element_type_517, [16384, 4096]); convert_element_type_517 = None + view_538 = torch.ops.aten.view.default(mm_109, [2, 8192, 14336]); mm_109 = None + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_538, torch.float32); view_538 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16); primals_146 = None + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 64, '0'); convert_element_type_523 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + mm_110 = torch.ops.aten.mm.default(view_537, permute_174) + view_541 = torch.ops.aten.view.default(mm_110, [2, 8192, 14336]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_541) + view_543 = torch.ops.aten.view.default(mul_127, [16384, 14336]); mul_127 = None + mm_451 = torch.ops.aten.mm.default(permute_869, view_543); permute_869 = view_543 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16); primals_147 = None + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 64, '0'); convert_element_type_526 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_144, [1, 0]); wait_tensor_144 = None + permute_871 = torch.ops.aten.permute.default(permute_175, [1, 0]); permute_175 = None + mm_452 = torch.ops.aten.mm.default(view_1479, permute_871); view_1479 = permute_871 = None + view_1480 = torch.ops.aten.view.default(mm_452, [2, 8192, 14336]); mm_452 = None + convert_element_type_1942 = torch.ops.prims.convert_element_type.default(mm_451, torch.float32); mm_451 = None + reduce_scatter_tensor_146 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1942, 'avg', 64, '0'); convert_element_type_1942 = None + wait_tensor_437 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_146); reduce_scatter_tensor_146 = None + mul_584 = torch.ops.aten.mul.Tensor(view_1480, convert_element_type_522); convert_element_type_522 = None + mul_585 = torch.ops.aten.mul.Tensor(view_1480, view_541); view_1480 = view_541 = None + view_1481 = torch.ops.aten.view.default(mul_584, [16384, 14336]); mul_584 = None + permute_873 = torch.ops.aten.permute.default(view_1481, [1, 0]) + mm_453 = torch.ops.aten.mm.default(permute_873, view_537); permute_873 = None + permute_875 = torch.ops.aten.permute.default(permute_174, [1, 0]); permute_174 = None + mm_454 = torch.ops.aten.mm.default(view_1481, permute_875); view_1481 = permute_875 = None + view_1482 = torch.ops.aten.view.default(mm_454, [2, 8192, 4096]); mm_454 = None + convert_element_type_1947 = torch.ops.prims.convert_element_type.default(mm_453, torch.float32); mm_453 = None + reduce_scatter_tensor_147 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1947, 'avg', 64, '0'); convert_element_type_1947 = None + wait_tensor_438 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_147); reduce_scatter_tensor_147 = None + convert_element_type_1948 = torch.ops.prims.convert_element_type.default(mul_585, torch.float32); mul_585 = None + neg_16 = torch.ops.aten.neg.default(convert_element_type_521) + exp_16 = torch.ops.aten.exp.default(neg_16); neg_16 = None + add_241 = torch.ops.aten.add.Tensor(exp_16, 1); exp_16 = None + reciprocal_16 = torch.ops.aten.reciprocal.default(add_241); add_241 = None + mul_586 = torch.ops.aten.mul.Tensor(reciprocal_16, 1); reciprocal_16 = None + mul_587 = torch.ops.aten.mul.Tensor(convert_element_type_1948, mul_586); convert_element_type_1948 = None + sub_49 = torch.ops.aten.sub.Tensor(1, mul_586); mul_586 = None + mul_588 = torch.ops.aten.mul.Tensor(convert_element_type_521, sub_49); convert_element_type_521 = sub_49 = None + add_242 = torch.ops.aten.add.Tensor(mul_588, 1); mul_588 = None + mul_589 = torch.ops.aten.mul.Tensor(mul_587, add_242); mul_587 = add_242 = None + convert_element_type_1950 = torch.ops.prims.convert_element_type.default(mul_589, torch.bfloat16); mul_589 = None + view_1483 = torch.ops.aten.view.default(convert_element_type_1950, [16384, 14336]); convert_element_type_1950 = None + permute_877 = torch.ops.aten.permute.default(view_1483, [1, 0]) + mm_455 = torch.ops.aten.mm.default(permute_877, view_537); permute_877 = view_537 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16); primals_145 = None + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 64, '0'); convert_element_type_518 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + permute_879 = torch.ops.aten.permute.default(permute_173, [1, 0]); permute_173 = None + mm_456 = torch.ops.aten.mm.default(view_1483, permute_879); view_1483 = permute_879 = None + view_1484 = torch.ops.aten.view.default(mm_456, [2, 8192, 4096]); mm_456 = None + add_243 = torch.ops.aten.add.Tensor(view_1482, view_1484); view_1482 = view_1484 = None + convert_element_type_1955 = torch.ops.prims.convert_element_type.default(mm_455, torch.float32); mm_455 = None + reduce_scatter_tensor_148 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1955, 'avg', 64, '0'); convert_element_type_1955 = None + wait_tensor_439 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_148); reduce_scatter_tensor_148 = None + convert_element_type_1956 = torch.ops.prims.convert_element_type.default(add_243, torch.float32); add_243 = None + convert_element_type_1958 = torch.ops.prims.convert_element_type.default(wait_tensor_141, torch.float32); wait_tensor_141 = None + mul_590 = torch.ops.aten.mul.Tensor(convert_element_type_1956, convert_element_type_1958); convert_element_type_1958 = None + mul_592 = torch.ops.aten.mul.Tensor(mul_124, mul_590) + sum_99 = torch.ops.aten.sum.dim_IntList(mul_592, [2], True); mul_592 = None + div_33 = torch.ops.aten.div.Tensor(mul_124, 4096) + mul_593 = torch.ops.aten.mul.Tensor(div_33, sum_99); div_33 = sum_99 = None + sub_50 = torch.ops.aten.sub.Tensor(mul_590, mul_593); mul_590 = mul_593 = None + mul_594 = torch.ops.aten.mul.Tensor(sub_50, rsqrt_31); sub_50 = rsqrt_31 = None + mul_595 = torch.ops.aten.mul.Tensor(convert_element_type_1956, mul_124); convert_element_type_1956 = mul_124 = None + sum_100 = torch.ops.aten.sum.dim_IntList(mul_595, [0, 1]); mul_595 = None + convert_element_type_1959 = torch.ops.prims.convert_element_type.default(mul_594, torch.bfloat16); mul_594 = None + add_244 = torch.ops.aten.add.Tensor(add_240, convert_element_type_1959); add_240 = convert_element_type_1959 = None + convert_element_type_default_32 = torch.ops.prims.convert_element_type.default(sum_100, torch.float32); sum_100 = None + reduce_scatter_tensor_149 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_32, 'avg', 64, '0'); convert_element_type_default_32 = None + wait_tensor_440 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_149); reduce_scatter_tensor_149 = None + view_1485 = torch.ops.aten.view.default(add_244, [16384, 4096]) + permute_881 = torch.ops.aten.permute.default(view_1485, [1, 0]) + mm_457 = torch.ops.aten.mm.default(permute_881, view_533); permute_881 = view_533 = None + permute_883 = torch.ops.aten.permute.default(permute_172, [1, 0]); permute_172 = None + mm_458 = torch.ops.aten.mm.default(view_1485, permute_883); view_1485 = permute_883 = None + view_1486 = torch.ops.aten.view.default(mm_458, [2, 8192, 4096]); mm_458 = None + convert_element_type_1966 = torch.ops.prims.convert_element_type.default(mm_457, torch.float32); mm_457 = None + reduce_scatter_tensor_150 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1966, 'avg', 64, '0'); convert_element_type_1966 = None + wait_tensor_441 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_150); reduce_scatter_tensor_150 = None + view_1487 = torch.ops.aten.view.default(view_1486, [2, 8192, 32, 128]); view_1486 = None + permute_885 = torch.ops.aten.permute.default(view_1487, [0, 2, 1, 3]); view_1487 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16); primals_139 = None + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 64, '0'); convert_element_type_496 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32); add_59 = None + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_136) + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + view_513 = torch.ops.aten.view.default(convert_element_type_498, [16384, 4096]); convert_element_type_498 = None + view_514 = torch.ops.aten.view.default(mm_105, [2, 8192, 4096]); mm_105 = None + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16); primals_141 = None + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 64, '0'); convert_element_type_502 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_138, [1, 0]); wait_tensor_138 = None + mm_106 = torch.ops.aten.mm.default(view_513, permute_166) + view_517 = torch.ops.aten.view.default(mm_106, [2, 8192, 1024]); mm_106 = None + view_520 = torch.ops.aten.view.default(mm_107, [2, 8192, 1024]); mm_107 = None + view_521 = torch.ops.aten.view.default(view_514, [2, 8192, -1, 128]); view_514 = None + view_522 = torch.ops.aten.view.default(view_517, [2, 8192, -1, 128]); view_517 = None + view_523 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_521, torch.float32); view_521 = None + view_524 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 32, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_524); view_524 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_522, torch.float32); view_522 = None + view_525 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 8, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_525); view_525 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_16); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_527 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 32, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_16); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_528 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 8, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_527, torch.bfloat16); view_527 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_528, torch.bfloat16); view_528 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 8, 4, 128]); unsqueeze_30 = None + clone_30 = torch.ops.aten.clone.default(expand_30, memory_format = torch.contiguous_format); expand_30 = None + view_529 = torch.ops.aten.view.default(clone_30, [2, 8192, 32, 128]); clone_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_523, 3); view_523 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 8, 4, 128]); unsqueeze_31 = None + clone_31 = torch.ops.aten.clone.default(expand_31, memory_format = torch.contiguous_format); expand_31 = None + view_530 = torch.ops.aten.view.default(clone_31, [2, 8192, 32, 128]); clone_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_529, [0, 2, 1, 3]); view_529 = None + permute_170 = torch.ops.aten.permute.default(view_530, [0, 2, 1, 3]); view_530 = None + _scaled_dot_product_cudnn_attention_backward_16 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_885, permute_168, permute_169, permute_170, getitem_135, getitem_136, getitem_141, getitem_142, None, None, None, 8192, 8192, 0.0, True); permute_885 = permute_168 = permute_169 = permute_170 = getitem_135 = getitem_136 = getitem_141 = getitem_142 = None + getitem_336 = _scaled_dot_product_cudnn_attention_backward_16[0] + getitem_337 = _scaled_dot_product_cudnn_attention_backward_16[1] + getitem_338 = _scaled_dot_product_cudnn_attention_backward_16[2]; _scaled_dot_product_cudnn_attention_backward_16 = None + permute_886 = torch.ops.aten.permute.default(getitem_338, [0, 2, 1, 3]); getitem_338 = None + permute_887 = torch.ops.aten.permute.default(getitem_337, [0, 2, 1, 3]); getitem_337 = None + permute_888 = torch.ops.aten.permute.default(getitem_336, [0, 2, 1, 3]); getitem_336 = None + view_1488 = torch.ops.aten.view.default(permute_886, [2, 8192, 8, 4, 128]); permute_886 = None + sum_101 = torch.ops.aten.sum.dim_IntList(view_1488, [3], True); view_1488 = None + squeeze_32 = torch.ops.aten.squeeze.dim(sum_101, 3); sum_101 = None + view_1489 = torch.ops.aten.view.default(permute_887, [2, 8192, 8, 4, 128]); permute_887 = None + sum_102 = torch.ops.aten.sum.dim_IntList(view_1489, [3], True); view_1489 = None + squeeze_33 = torch.ops.aten.squeeze.dim(sum_102, 3); sum_102 = None + convert_element_type_1967 = torch.ops.prims.convert_element_type.default(squeeze_33, torch.float32); squeeze_33 = None + convert_element_type_1968 = torch.ops.prims.convert_element_type.default(permute_888, torch.float32); permute_888 = None + view_1490 = torch.ops.aten.view.default(convert_element_type_1967, [2, 8192, 8, 64, 2]); convert_element_type_1967 = None + view_as_complex_96 = torch.ops.aten.view_as_complex.default(view_1490); view_1490 = None + mul_596 = torch.ops.aten.mul.Tensor(view_as_complex_96, _conj); view_as_complex_96 = None + view_1491 = torch.ops.aten.view.default(convert_element_type_1968, [2, 8192, 32, 64, 2]); convert_element_type_1968 = None + view_as_complex_97 = torch.ops.aten.view_as_complex.default(view_1491); view_1491 = None + mul_597 = torch.ops.aten.mul.Tensor(view_as_complex_97, _conj); view_as_complex_97 = None + view_as_real_96 = torch.ops.aten.view_as_real.default(mul_596); mul_596 = None + view_1492 = torch.ops.aten.view.default(view_as_real_96, [2, 8192, 8, 128]); view_as_real_96 = None + convert_element_type_1969 = torch.ops.prims.convert_element_type.default(view_1492, torch.bfloat16); view_1492 = None + view_as_real_97 = torch.ops.aten.view_as_real.default(mul_597); mul_597 = None + view_1493 = torch.ops.aten.view.default(view_as_real_97, [2, 8192, 32, 128]); view_as_real_97 = None + convert_element_type_1970 = torch.ops.prims.convert_element_type.default(view_1493, torch.bfloat16); view_1493 = None + view_1494 = torch.ops.aten.view.default(squeeze_32, [2, 8192, 1024]); squeeze_32 = None + view_1495 = torch.ops.aten.view.default(convert_element_type_1969, [2, 8192, 1024]); convert_element_type_1969 = None + view_1496 = torch.ops.aten.view.default(convert_element_type_1970, [2, 8192, 4096]); convert_element_type_1970 = None + view_1497 = torch.ops.aten.view.default(view_1494, [16384, 1024]); view_1494 = None + permute_889 = torch.ops.aten.permute.default(view_1497, [1, 0]) + mm_459 = torch.ops.aten.mm.default(permute_889, view_513); permute_889 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16); primals_142 = None + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 64, '0'); convert_element_type_505 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_139, [1, 0]); wait_tensor_139 = None + permute_891 = torch.ops.aten.permute.default(permute_167, [1, 0]); permute_167 = None + mm_460 = torch.ops.aten.mm.default(view_1497, permute_891); view_1497 = permute_891 = None + view_1498 = torch.ops.aten.view.default(mm_460, [2, 8192, 4096]); mm_460 = None + convert_element_type_1975 = torch.ops.prims.convert_element_type.default(mm_459, torch.float32); mm_459 = None + reduce_scatter_tensor_151 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1975, 'avg', 64, '0'); convert_element_type_1975 = None + wait_tensor_442 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_151); reduce_scatter_tensor_151 = None + view_1499 = torch.ops.aten.view.default(view_1495, [16384, 1024]); view_1495 = None + permute_893 = torch.ops.aten.permute.default(view_1499, [1, 0]) + mm_461 = torch.ops.aten.mm.default(permute_893, view_513); permute_893 = None + permute_895 = torch.ops.aten.permute.default(permute_166, [1, 0]); permute_166 = None + mm_462 = torch.ops.aten.mm.default(view_1499, permute_895); view_1499 = permute_895 = None + view_1500 = torch.ops.aten.view.default(mm_462, [2, 8192, 4096]); mm_462 = None + add_245 = torch.ops.aten.add.Tensor(view_1498, view_1500); view_1498 = view_1500 = None + convert_element_type_1980 = torch.ops.prims.convert_element_type.default(mm_461, torch.float32); mm_461 = None + reduce_scatter_tensor_152 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1980, 'avg', 64, '0'); convert_element_type_1980 = None + wait_tensor_443 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_152); reduce_scatter_tensor_152 = None + view_1501 = torch.ops.aten.view.default(view_1496, [16384, 4096]); view_1496 = None + permute_897 = torch.ops.aten.permute.default(view_1501, [1, 0]) + mm_463 = torch.ops.aten.mm.default(permute_897, view_513); permute_897 = view_513 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16); primals_140 = None + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 64, '0'); convert_element_type_499 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + permute_899 = torch.ops.aten.permute.default(permute_165, [1, 0]); permute_165 = None + mm_464 = torch.ops.aten.mm.default(view_1501, permute_899); view_1501 = permute_899 = None + view_1502 = torch.ops.aten.view.default(mm_464, [2, 8192, 4096]); mm_464 = None + add_246 = torch.ops.aten.add.Tensor(add_245, view_1502); add_245 = view_1502 = None + convert_element_type_1985 = torch.ops.prims.convert_element_type.default(mm_463, torch.float32); mm_463 = None + reduce_scatter_tensor_153 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1985, 'avg', 64, '0'); convert_element_type_1985 = None + wait_tensor_444 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_153); reduce_scatter_tensor_153 = None + convert_element_type_1986 = torch.ops.prims.convert_element_type.default(add_246, torch.float32); add_246 = None + convert_element_type_1988 = torch.ops.prims.convert_element_type.default(wait_tensor_136, torch.float32); wait_tensor_136 = None + mul_598 = torch.ops.aten.mul.Tensor(convert_element_type_1986, convert_element_type_1988); convert_element_type_1988 = None + mul_600 = torch.ops.aten.mul.Tensor(mul_120, mul_598) + sum_103 = torch.ops.aten.sum.dim_IntList(mul_600, [2], True); mul_600 = None + div_34 = torch.ops.aten.div.Tensor(mul_120, 4096) + mul_601 = torch.ops.aten.mul.Tensor(div_34, sum_103); div_34 = sum_103 = None + sub_51 = torch.ops.aten.sub.Tensor(mul_598, mul_601); mul_598 = mul_601 = None + mul_602 = torch.ops.aten.mul.Tensor(sub_51, rsqrt_30); sub_51 = rsqrt_30 = None + mul_603 = torch.ops.aten.mul.Tensor(convert_element_type_1986, mul_120); convert_element_type_1986 = mul_120 = None + sum_104 = torch.ops.aten.sum.dim_IntList(mul_603, [0, 1]); mul_603 = None + convert_element_type_1989 = torch.ops.prims.convert_element_type.default(mul_602, torch.bfloat16); mul_602 = None + add_247 = torch.ops.aten.add.Tensor(add_244, convert_element_type_1989); add_244 = convert_element_type_1989 = None + convert_element_type_default_31 = torch.ops.prims.convert_element_type.default(sum_104, torch.float32); sum_104 = None + reduce_scatter_tensor_154 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_31, 'avg', 64, '0'); convert_element_type_default_31 = None + wait_tensor_445 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_154); reduce_scatter_tensor_154 = None + view_1503 = torch.ops.aten.view.default(add_247, [16384, 4096]) + permute_901 = torch.ops.aten.permute.default(view_1503, [1, 0]) + permute_160 = torch.ops.aten.permute.default(getitem_126, [0, 2, 1, 3]) + view_497 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16); primals_134 = None + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 64, '0'); convert_element_type_479 = None + wait_tensor_131 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_131, [1, 0]); wait_tensor_131 = None + view_499 = torch.ops.aten.view.default(view_497, [16384, 4096]); view_497 = None + mm_101 = torch.ops.aten.mm.default(view_499, permute_161) + view_500 = torch.ops.aten.view.default(mm_101, [2, 8192, 4096]); mm_101 = None + add_57 = torch.ops.aten.add.Tensor(add_55, view_500); view_500 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16); primals_135 = None + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 64, '0'); convert_element_type_482 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32); add_57 = None + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_132) + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + view_503 = torch.ops.aten.view.default(convert_element_type_484, [16384, 4096]); convert_element_type_484 = None + view_504 = torch.ops.aten.view.default(mm_102, [2, 8192, 14336]); mm_102 = None + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_504, torch.float32); view_504 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16); primals_137 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 64, '0'); convert_element_type_490 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + mm_103 = torch.ops.aten.mm.default(view_503, permute_163) + view_507 = torch.ops.aten.view.default(mm_103, [2, 8192, 14336]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_507) + view_509 = torch.ops.aten.view.default(mul_119, [16384, 14336]); mul_119 = None + mm_465 = torch.ops.aten.mm.default(permute_901, view_509); permute_901 = view_509 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16); primals_138 = None + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 64, '0'); convert_element_type_493 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + permute_903 = torch.ops.aten.permute.default(permute_164, [1, 0]); permute_164 = None + mm_466 = torch.ops.aten.mm.default(view_1503, permute_903); view_1503 = permute_903 = None + view_1504 = torch.ops.aten.view.default(mm_466, [2, 8192, 14336]); mm_466 = None + convert_element_type_1996 = torch.ops.prims.convert_element_type.default(mm_465, torch.float32); mm_465 = None + reduce_scatter_tensor_155 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1996, 'avg', 64, '0'); convert_element_type_1996 = None + wait_tensor_446 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_155); reduce_scatter_tensor_155 = None + mul_604 = torch.ops.aten.mul.Tensor(view_1504, convert_element_type_489); convert_element_type_489 = None + mul_605 = torch.ops.aten.mul.Tensor(view_1504, view_507); view_1504 = view_507 = None + view_1505 = torch.ops.aten.view.default(mul_604, [16384, 14336]); mul_604 = None + permute_905 = torch.ops.aten.permute.default(view_1505, [1, 0]) + mm_467 = torch.ops.aten.mm.default(permute_905, view_503); permute_905 = None + permute_907 = torch.ops.aten.permute.default(permute_163, [1, 0]); permute_163 = None + mm_468 = torch.ops.aten.mm.default(view_1505, permute_907); view_1505 = permute_907 = None + view_1506 = torch.ops.aten.view.default(mm_468, [2, 8192, 4096]); mm_468 = None + convert_element_type_2001 = torch.ops.prims.convert_element_type.default(mm_467, torch.float32); mm_467 = None + reduce_scatter_tensor_156 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2001, 'avg', 64, '0'); convert_element_type_2001 = None + wait_tensor_447 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_156); reduce_scatter_tensor_156 = None + convert_element_type_2002 = torch.ops.prims.convert_element_type.default(mul_605, torch.float32); mul_605 = None + neg_17 = torch.ops.aten.neg.default(convert_element_type_488) + exp_17 = torch.ops.aten.exp.default(neg_17); neg_17 = None + add_248 = torch.ops.aten.add.Tensor(exp_17, 1); exp_17 = None + reciprocal_17 = torch.ops.aten.reciprocal.default(add_248); add_248 = None + mul_606 = torch.ops.aten.mul.Tensor(reciprocal_17, 1); reciprocal_17 = None + mul_607 = torch.ops.aten.mul.Tensor(convert_element_type_2002, mul_606); convert_element_type_2002 = None + sub_52 = torch.ops.aten.sub.Tensor(1, mul_606); mul_606 = None + mul_608 = torch.ops.aten.mul.Tensor(convert_element_type_488, sub_52); convert_element_type_488 = sub_52 = None + add_249 = torch.ops.aten.add.Tensor(mul_608, 1); mul_608 = None + mul_609 = torch.ops.aten.mul.Tensor(mul_607, add_249); mul_607 = add_249 = None + convert_element_type_2004 = torch.ops.prims.convert_element_type.default(mul_609, torch.bfloat16); mul_609 = None + view_1507 = torch.ops.aten.view.default(convert_element_type_2004, [16384, 14336]); convert_element_type_2004 = None + permute_909 = torch.ops.aten.permute.default(view_1507, [1, 0]) + mm_469 = torch.ops.aten.mm.default(permute_909, view_503); permute_909 = view_503 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16); primals_136 = None + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 64, '0'); convert_element_type_485 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_133, [1, 0]); wait_tensor_133 = None + permute_911 = torch.ops.aten.permute.default(permute_162, [1, 0]); permute_162 = None + mm_470 = torch.ops.aten.mm.default(view_1507, permute_911); view_1507 = permute_911 = None + view_1508 = torch.ops.aten.view.default(mm_470, [2, 8192, 4096]); mm_470 = None + add_250 = torch.ops.aten.add.Tensor(view_1506, view_1508); view_1506 = view_1508 = None + convert_element_type_2009 = torch.ops.prims.convert_element_type.default(mm_469, torch.float32); mm_469 = None + reduce_scatter_tensor_157 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2009, 'avg', 64, '0'); convert_element_type_2009 = None + wait_tensor_448 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_157); reduce_scatter_tensor_157 = None + convert_element_type_2010 = torch.ops.prims.convert_element_type.default(add_250, torch.float32); add_250 = None + convert_element_type_2012 = torch.ops.prims.convert_element_type.default(wait_tensor_132, torch.float32); wait_tensor_132 = None + mul_610 = torch.ops.aten.mul.Tensor(convert_element_type_2010, convert_element_type_2012); convert_element_type_2012 = None + mul_612 = torch.ops.aten.mul.Tensor(mul_116, mul_610) + sum_105 = torch.ops.aten.sum.dim_IntList(mul_612, [2], True); mul_612 = None + div_35 = torch.ops.aten.div.Tensor(mul_116, 4096) + mul_613 = torch.ops.aten.mul.Tensor(div_35, sum_105); div_35 = sum_105 = None + sub_53 = torch.ops.aten.sub.Tensor(mul_610, mul_613); mul_610 = mul_613 = None + mul_614 = torch.ops.aten.mul.Tensor(sub_53, rsqrt_29); sub_53 = rsqrt_29 = None + mul_615 = torch.ops.aten.mul.Tensor(convert_element_type_2010, mul_116); convert_element_type_2010 = mul_116 = None + sum_106 = torch.ops.aten.sum.dim_IntList(mul_615, [0, 1]); mul_615 = None + convert_element_type_2013 = torch.ops.prims.convert_element_type.default(mul_614, torch.bfloat16); mul_614 = None + add_251 = torch.ops.aten.add.Tensor(add_247, convert_element_type_2013); add_247 = convert_element_type_2013 = None + convert_element_type_default_30 = torch.ops.prims.convert_element_type.default(sum_106, torch.float32); sum_106 = None + reduce_scatter_tensor_158 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_30, 'avg', 64, '0'); convert_element_type_default_30 = None + wait_tensor_449 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_158); reduce_scatter_tensor_158 = None + view_1509 = torch.ops.aten.view.default(add_251, [16384, 4096]) + permute_913 = torch.ops.aten.permute.default(view_1509, [1, 0]) + mm_471 = torch.ops.aten.mm.default(permute_913, view_499); permute_913 = view_499 = None + permute_915 = torch.ops.aten.permute.default(permute_161, [1, 0]); permute_161 = None + mm_472 = torch.ops.aten.mm.default(view_1509, permute_915); view_1509 = permute_915 = None + view_1510 = torch.ops.aten.view.default(mm_472, [2, 8192, 4096]); mm_472 = None + convert_element_type_2020 = torch.ops.prims.convert_element_type.default(mm_471, torch.float32); mm_471 = None + reduce_scatter_tensor_159 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2020, 'avg', 64, '0'); convert_element_type_2020 = None + wait_tensor_450 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_159); reduce_scatter_tensor_159 = None + view_1511 = torch.ops.aten.view.default(view_1510, [2, 8192, 32, 128]); view_1510 = None + permute_917 = torch.ops.aten.permute.default(view_1511, [0, 2, 1, 3]); view_1511 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16); primals_130 = None + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 64, '0'); convert_element_type_463 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32); add_55 = None + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_127) + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + view_479 = torch.ops.aten.view.default(convert_element_type_465, [16384, 4096]); convert_element_type_465 = None + view_480 = torch.ops.aten.view.default(mm_98, [2, 8192, 4096]); mm_98 = None + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16); primals_132 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 64, '0'); convert_element_type_469 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_99 = torch.ops.aten.mm.default(view_479, permute_155) + view_483 = torch.ops.aten.view.default(mm_99, [2, 8192, 1024]); mm_99 = None + view_486 = torch.ops.aten.view.default(mm_100, [2, 8192, 1024]); mm_100 = None + view_487 = torch.ops.aten.view.default(view_480, [2, 8192, -1, 128]); view_480 = None + view_488 = torch.ops.aten.view.default(view_483, [2, 8192, -1, 128]); view_483 = None + view_489 = torch.ops.aten.view.default(view_486, [2, 8192, -1, 128]); view_486 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_487, torch.float32); view_487 = None + view_490 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 32, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_490); view_490 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_488, torch.float32); view_488 = None + view_491 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 8, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_491); view_491 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_16); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_493 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 32, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_16); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_494 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 8, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_493, torch.bfloat16); view_493 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_494, torch.bfloat16); view_494 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 8, 4, 128]); unsqueeze_28 = None + clone_28 = torch.ops.aten.clone.default(expand_28, memory_format = torch.contiguous_format); expand_28 = None + view_495 = torch.ops.aten.view.default(clone_28, [2, 8192, 32, 128]); clone_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_489, 3); view_489 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 8, 4, 128]); unsqueeze_29 = None + clone_29 = torch.ops.aten.clone.default(expand_29, memory_format = torch.contiguous_format); expand_29 = None + view_496 = torch.ops.aten.view.default(clone_29, [2, 8192, 32, 128]); clone_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_495, [0, 2, 1, 3]); view_495 = None + permute_159 = torch.ops.aten.permute.default(view_496, [0, 2, 1, 3]); view_496 = None + _scaled_dot_product_cudnn_attention_backward_17 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_917, permute_157, permute_158, permute_159, getitem_126, getitem_127, getitem_132, getitem_133, None, None, None, 8192, 8192, 0.0, True); permute_917 = permute_157 = permute_158 = permute_159 = getitem_126 = getitem_127 = getitem_132 = getitem_133 = None + getitem_339 = _scaled_dot_product_cudnn_attention_backward_17[0] + getitem_340 = _scaled_dot_product_cudnn_attention_backward_17[1] + getitem_341 = _scaled_dot_product_cudnn_attention_backward_17[2]; _scaled_dot_product_cudnn_attention_backward_17 = None + permute_918 = torch.ops.aten.permute.default(getitem_341, [0, 2, 1, 3]); getitem_341 = None + permute_919 = torch.ops.aten.permute.default(getitem_340, [0, 2, 1, 3]); getitem_340 = None + permute_920 = torch.ops.aten.permute.default(getitem_339, [0, 2, 1, 3]); getitem_339 = None + view_1512 = torch.ops.aten.view.default(permute_918, [2, 8192, 8, 4, 128]); permute_918 = None + sum_107 = torch.ops.aten.sum.dim_IntList(view_1512, [3], True); view_1512 = None + squeeze_34 = torch.ops.aten.squeeze.dim(sum_107, 3); sum_107 = None + view_1513 = torch.ops.aten.view.default(permute_919, [2, 8192, 8, 4, 128]); permute_919 = None + sum_108 = torch.ops.aten.sum.dim_IntList(view_1513, [3], True); view_1513 = None + squeeze_35 = torch.ops.aten.squeeze.dim(sum_108, 3); sum_108 = None + convert_element_type_2021 = torch.ops.prims.convert_element_type.default(squeeze_35, torch.float32); squeeze_35 = None + convert_element_type_2022 = torch.ops.prims.convert_element_type.default(permute_920, torch.float32); permute_920 = None + view_1514 = torch.ops.aten.view.default(convert_element_type_2021, [2, 8192, 8, 64, 2]); convert_element_type_2021 = None + view_as_complex_98 = torch.ops.aten.view_as_complex.default(view_1514); view_1514 = None + mul_616 = torch.ops.aten.mul.Tensor(view_as_complex_98, _conj); view_as_complex_98 = None + view_1515 = torch.ops.aten.view.default(convert_element_type_2022, [2, 8192, 32, 64, 2]); convert_element_type_2022 = None + view_as_complex_99 = torch.ops.aten.view_as_complex.default(view_1515); view_1515 = None + mul_617 = torch.ops.aten.mul.Tensor(view_as_complex_99, _conj); view_as_complex_99 = None + view_as_real_98 = torch.ops.aten.view_as_real.default(mul_616); mul_616 = None + view_1516 = torch.ops.aten.view.default(view_as_real_98, [2, 8192, 8, 128]); view_as_real_98 = None + convert_element_type_2023 = torch.ops.prims.convert_element_type.default(view_1516, torch.bfloat16); view_1516 = None + view_as_real_99 = torch.ops.aten.view_as_real.default(mul_617); mul_617 = None + view_1517 = torch.ops.aten.view.default(view_as_real_99, [2, 8192, 32, 128]); view_as_real_99 = None + convert_element_type_2024 = torch.ops.prims.convert_element_type.default(view_1517, torch.bfloat16); view_1517 = None + view_1518 = torch.ops.aten.view.default(squeeze_34, [2, 8192, 1024]); squeeze_34 = None + view_1519 = torch.ops.aten.view.default(convert_element_type_2023, [2, 8192, 1024]); convert_element_type_2023 = None + view_1520 = torch.ops.aten.view.default(convert_element_type_2024, [2, 8192, 4096]); convert_element_type_2024 = None + view_1521 = torch.ops.aten.view.default(view_1518, [16384, 1024]); view_1518 = None + permute_921 = torch.ops.aten.permute.default(view_1521, [1, 0]) + mm_473 = torch.ops.aten.mm.default(permute_921, view_479); permute_921 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16); primals_133 = None + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 64, '0'); convert_element_type_472 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + permute_923 = torch.ops.aten.permute.default(permute_156, [1, 0]); permute_156 = None + mm_474 = torch.ops.aten.mm.default(view_1521, permute_923); view_1521 = permute_923 = None + view_1522 = torch.ops.aten.view.default(mm_474, [2, 8192, 4096]); mm_474 = None + convert_element_type_2029 = torch.ops.prims.convert_element_type.default(mm_473, torch.float32); mm_473 = None + reduce_scatter_tensor_160 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2029, 'avg', 64, '0'); convert_element_type_2029 = None + wait_tensor_451 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_160); reduce_scatter_tensor_160 = None + view_1523 = torch.ops.aten.view.default(view_1519, [16384, 1024]); view_1519 = None + permute_925 = torch.ops.aten.permute.default(view_1523, [1, 0]) + mm_475 = torch.ops.aten.mm.default(permute_925, view_479); permute_925 = None + permute_927 = torch.ops.aten.permute.default(permute_155, [1, 0]); permute_155 = None + mm_476 = torch.ops.aten.mm.default(view_1523, permute_927); view_1523 = permute_927 = None + view_1524 = torch.ops.aten.view.default(mm_476, [2, 8192, 4096]); mm_476 = None + add_252 = torch.ops.aten.add.Tensor(view_1522, view_1524); view_1522 = view_1524 = None + convert_element_type_2034 = torch.ops.prims.convert_element_type.default(mm_475, torch.float32); mm_475 = None + reduce_scatter_tensor_161 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2034, 'avg', 64, '0'); convert_element_type_2034 = None + wait_tensor_452 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_161); reduce_scatter_tensor_161 = None + view_1525 = torch.ops.aten.view.default(view_1520, [16384, 4096]); view_1520 = None + permute_929 = torch.ops.aten.permute.default(view_1525, [1, 0]) + mm_477 = torch.ops.aten.mm.default(permute_929, view_479); permute_929 = view_479 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16); primals_131 = None + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 64, '0'); convert_element_type_466 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + permute_931 = torch.ops.aten.permute.default(permute_154, [1, 0]); permute_154 = None + mm_478 = torch.ops.aten.mm.default(view_1525, permute_931); view_1525 = permute_931 = None + view_1526 = torch.ops.aten.view.default(mm_478, [2, 8192, 4096]); mm_478 = None + add_253 = torch.ops.aten.add.Tensor(add_252, view_1526); add_252 = view_1526 = None + convert_element_type_2039 = torch.ops.prims.convert_element_type.default(mm_477, torch.float32); mm_477 = None + reduce_scatter_tensor_162 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2039, 'avg', 64, '0'); convert_element_type_2039 = None + wait_tensor_453 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_162); reduce_scatter_tensor_162 = None + convert_element_type_2040 = torch.ops.prims.convert_element_type.default(add_253, torch.float32); add_253 = None + convert_element_type_2042 = torch.ops.prims.convert_element_type.default(wait_tensor_127, torch.float32); wait_tensor_127 = None + mul_618 = torch.ops.aten.mul.Tensor(convert_element_type_2040, convert_element_type_2042); convert_element_type_2042 = None + mul_620 = torch.ops.aten.mul.Tensor(mul_112, mul_618) + sum_109 = torch.ops.aten.sum.dim_IntList(mul_620, [2], True); mul_620 = None + div_36 = torch.ops.aten.div.Tensor(mul_112, 4096) + mul_621 = torch.ops.aten.mul.Tensor(div_36, sum_109); div_36 = sum_109 = None + sub_54 = torch.ops.aten.sub.Tensor(mul_618, mul_621); mul_618 = mul_621 = None + mul_622 = torch.ops.aten.mul.Tensor(sub_54, rsqrt_28); sub_54 = rsqrt_28 = None + mul_623 = torch.ops.aten.mul.Tensor(convert_element_type_2040, mul_112); convert_element_type_2040 = mul_112 = None + sum_110 = torch.ops.aten.sum.dim_IntList(mul_623, [0, 1]); mul_623 = None + convert_element_type_2043 = torch.ops.prims.convert_element_type.default(mul_622, torch.bfloat16); mul_622 = None + add_254 = torch.ops.aten.add.Tensor(add_251, convert_element_type_2043); add_251 = convert_element_type_2043 = None + convert_element_type_default_29 = torch.ops.prims.convert_element_type.default(sum_110, torch.float32); sum_110 = None + reduce_scatter_tensor_163 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_29, 'avg', 64, '0'); convert_element_type_default_29 = None + wait_tensor_454 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_163); reduce_scatter_tensor_163 = None + view_1527 = torch.ops.aten.view.default(add_254, [16384, 4096]) + permute_933 = torch.ops.aten.permute.default(view_1527, [1, 0]) + permute_149 = torch.ops.aten.permute.default(getitem_117, [0, 2, 1, 3]) + view_463 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16); primals_125 = None + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 64, '0'); convert_element_type_446 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + view_465 = torch.ops.aten.view.default(view_463, [16384, 4096]); view_463 = None + mm_94 = torch.ops.aten.mm.default(view_465, permute_150) + view_466 = torch.ops.aten.view.default(mm_94, [2, 8192, 4096]); mm_94 = None + add_53 = torch.ops.aten.add.Tensor(add_51, view_466); view_466 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16); primals_126 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 64, '0'); convert_element_type_449 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32); add_53 = None + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_123) + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + view_469 = torch.ops.aten.view.default(convert_element_type_451, [16384, 4096]); convert_element_type_451 = None + view_470 = torch.ops.aten.view.default(mm_95, [2, 8192, 14336]); mm_95 = None + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_470, torch.float32); view_470 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16); primals_128 = None + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 64, '0'); convert_element_type_457 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_125, [1, 0]); wait_tensor_125 = None + mm_96 = torch.ops.aten.mm.default(view_469, permute_152) + view_473 = torch.ops.aten.view.default(mm_96, [2, 8192, 14336]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_473) + view_475 = torch.ops.aten.view.default(mul_111, [16384, 14336]); mul_111 = None + mm_479 = torch.ops.aten.mm.default(permute_933, view_475); permute_933 = view_475 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16); primals_129 = None + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 64, '0'); convert_element_type_460 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_126, [1, 0]); wait_tensor_126 = None + permute_935 = torch.ops.aten.permute.default(permute_153, [1, 0]); permute_153 = None + mm_480 = torch.ops.aten.mm.default(view_1527, permute_935); view_1527 = permute_935 = None + view_1528 = torch.ops.aten.view.default(mm_480, [2, 8192, 14336]); mm_480 = None + convert_element_type_2050 = torch.ops.prims.convert_element_type.default(mm_479, torch.float32); mm_479 = None + reduce_scatter_tensor_164 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2050, 'avg', 64, '0'); convert_element_type_2050 = None + wait_tensor_455 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_164); reduce_scatter_tensor_164 = None + mul_624 = torch.ops.aten.mul.Tensor(view_1528, convert_element_type_456); convert_element_type_456 = None + mul_625 = torch.ops.aten.mul.Tensor(view_1528, view_473); view_1528 = view_473 = None + view_1529 = torch.ops.aten.view.default(mul_624, [16384, 14336]); mul_624 = None + permute_937 = torch.ops.aten.permute.default(view_1529, [1, 0]) + mm_481 = torch.ops.aten.mm.default(permute_937, view_469); permute_937 = None + permute_939 = torch.ops.aten.permute.default(permute_152, [1, 0]); permute_152 = None + mm_482 = torch.ops.aten.mm.default(view_1529, permute_939); view_1529 = permute_939 = None + view_1530 = torch.ops.aten.view.default(mm_482, [2, 8192, 4096]); mm_482 = None + convert_element_type_2055 = torch.ops.prims.convert_element_type.default(mm_481, torch.float32); mm_481 = None + reduce_scatter_tensor_165 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2055, 'avg', 64, '0'); convert_element_type_2055 = None + wait_tensor_456 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_165); reduce_scatter_tensor_165 = None + convert_element_type_2056 = torch.ops.prims.convert_element_type.default(mul_625, torch.float32); mul_625 = None + neg_18 = torch.ops.aten.neg.default(convert_element_type_455) + exp_18 = torch.ops.aten.exp.default(neg_18); neg_18 = None + add_255 = torch.ops.aten.add.Tensor(exp_18, 1); exp_18 = None + reciprocal_18 = torch.ops.aten.reciprocal.default(add_255); add_255 = None + mul_626 = torch.ops.aten.mul.Tensor(reciprocal_18, 1); reciprocal_18 = None + mul_627 = torch.ops.aten.mul.Tensor(convert_element_type_2056, mul_626); convert_element_type_2056 = None + sub_55 = torch.ops.aten.sub.Tensor(1, mul_626); mul_626 = None + mul_628 = torch.ops.aten.mul.Tensor(convert_element_type_455, sub_55); convert_element_type_455 = sub_55 = None + add_256 = torch.ops.aten.add.Tensor(mul_628, 1); mul_628 = None + mul_629 = torch.ops.aten.mul.Tensor(mul_627, add_256); mul_627 = add_256 = None + convert_element_type_2058 = torch.ops.prims.convert_element_type.default(mul_629, torch.bfloat16); mul_629 = None + view_1531 = torch.ops.aten.view.default(convert_element_type_2058, [16384, 14336]); convert_element_type_2058 = None + permute_941 = torch.ops.aten.permute.default(view_1531, [1, 0]) + mm_483 = torch.ops.aten.mm.default(permute_941, view_469); permute_941 = view_469 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16); primals_127 = None + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 64, '0'); convert_element_type_452 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + permute_943 = torch.ops.aten.permute.default(permute_151, [1, 0]); permute_151 = None + mm_484 = torch.ops.aten.mm.default(view_1531, permute_943); view_1531 = permute_943 = None + view_1532 = torch.ops.aten.view.default(mm_484, [2, 8192, 4096]); mm_484 = None + add_257 = torch.ops.aten.add.Tensor(view_1530, view_1532); view_1530 = view_1532 = None + convert_element_type_2063 = torch.ops.prims.convert_element_type.default(mm_483, torch.float32); mm_483 = None + reduce_scatter_tensor_166 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2063, 'avg', 64, '0'); convert_element_type_2063 = None + wait_tensor_457 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_166); reduce_scatter_tensor_166 = None + convert_element_type_2064 = torch.ops.prims.convert_element_type.default(add_257, torch.float32); add_257 = None + convert_element_type_2066 = torch.ops.prims.convert_element_type.default(wait_tensor_123, torch.float32); wait_tensor_123 = None + mul_630 = torch.ops.aten.mul.Tensor(convert_element_type_2064, convert_element_type_2066); convert_element_type_2066 = None + mul_632 = torch.ops.aten.mul.Tensor(mul_108, mul_630) + sum_111 = torch.ops.aten.sum.dim_IntList(mul_632, [2], True); mul_632 = None + div_37 = torch.ops.aten.div.Tensor(mul_108, 4096) + mul_633 = torch.ops.aten.mul.Tensor(div_37, sum_111); div_37 = sum_111 = None + sub_56 = torch.ops.aten.sub.Tensor(mul_630, mul_633); mul_630 = mul_633 = None + mul_634 = torch.ops.aten.mul.Tensor(sub_56, rsqrt_27); sub_56 = rsqrt_27 = None + mul_635 = torch.ops.aten.mul.Tensor(convert_element_type_2064, mul_108); convert_element_type_2064 = mul_108 = None + sum_112 = torch.ops.aten.sum.dim_IntList(mul_635, [0, 1]); mul_635 = None + convert_element_type_2067 = torch.ops.prims.convert_element_type.default(mul_634, torch.bfloat16); mul_634 = None + add_258 = torch.ops.aten.add.Tensor(add_254, convert_element_type_2067); add_254 = convert_element_type_2067 = None + convert_element_type_default_28 = torch.ops.prims.convert_element_type.default(sum_112, torch.float32); sum_112 = None + reduce_scatter_tensor_167 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_28, 'avg', 64, '0'); convert_element_type_default_28 = None + wait_tensor_458 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_167); reduce_scatter_tensor_167 = None + view_1533 = torch.ops.aten.view.default(add_258, [16384, 4096]) + permute_945 = torch.ops.aten.permute.default(view_1533, [1, 0]) + mm_485 = torch.ops.aten.mm.default(permute_945, view_465); permute_945 = view_465 = None + permute_947 = torch.ops.aten.permute.default(permute_150, [1, 0]); permute_150 = None + mm_486 = torch.ops.aten.mm.default(view_1533, permute_947); view_1533 = permute_947 = None + view_1534 = torch.ops.aten.view.default(mm_486, [2, 8192, 4096]); mm_486 = None + convert_element_type_2074 = torch.ops.prims.convert_element_type.default(mm_485, torch.float32); mm_485 = None + reduce_scatter_tensor_168 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2074, 'avg', 64, '0'); convert_element_type_2074 = None + wait_tensor_459 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_168); reduce_scatter_tensor_168 = None + view_1535 = torch.ops.aten.view.default(view_1534, [2, 8192, 32, 128]); view_1534 = None + permute_949 = torch.ops.aten.permute.default(view_1535, [0, 2, 1, 3]); view_1535 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16); primals_121 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 64, '0'); convert_element_type_430 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32); add_51 = None + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_118) + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + view_445 = torch.ops.aten.view.default(convert_element_type_432, [16384, 4096]); convert_element_type_432 = None + view_446 = torch.ops.aten.view.default(mm_91, [2, 8192, 4096]); mm_91 = None + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16); primals_123 = None + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 64, '0'); convert_element_type_436 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_120, [1, 0]); wait_tensor_120 = None + mm_92 = torch.ops.aten.mm.default(view_445, permute_144) + view_449 = torch.ops.aten.view.default(mm_92, [2, 8192, 1024]); mm_92 = None + view_452 = torch.ops.aten.view.default(mm_93, [2, 8192, 1024]); mm_93 = None + view_453 = torch.ops.aten.view.default(view_446, [2, 8192, -1, 128]); view_446 = None + view_454 = torch.ops.aten.view.default(view_449, [2, 8192, -1, 128]); view_449 = None + view_455 = torch.ops.aten.view.default(view_452, [2, 8192, -1, 128]); view_452 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_453, torch.float32); view_453 = None + view_456 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 32, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_456); view_456 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_454, torch.float32); view_454 = None + view_457 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 8, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_457); view_457 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_16); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_459 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 32, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_16); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_460 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 8, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_459, torch.bfloat16); view_459 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_460, torch.bfloat16); view_460 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 8, 4, 128]); unsqueeze_26 = None + clone_26 = torch.ops.aten.clone.default(expand_26, memory_format = torch.contiguous_format); expand_26 = None + view_461 = torch.ops.aten.view.default(clone_26, [2, 8192, 32, 128]); clone_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_455, 3); view_455 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 8, 4, 128]); unsqueeze_27 = None + clone_27 = torch.ops.aten.clone.default(expand_27, memory_format = torch.contiguous_format); expand_27 = None + view_462 = torch.ops.aten.view.default(clone_27, [2, 8192, 32, 128]); clone_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_461, [0, 2, 1, 3]); view_461 = None + permute_148 = torch.ops.aten.permute.default(view_462, [0, 2, 1, 3]); view_462 = None + _scaled_dot_product_cudnn_attention_backward_18 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_949, permute_146, permute_147, permute_148, getitem_117, getitem_118, getitem_123, getitem_124, None, None, None, 8192, 8192, 0.0, True); permute_949 = permute_146 = permute_147 = permute_148 = getitem_117 = getitem_118 = getitem_123 = getitem_124 = None + getitem_342 = _scaled_dot_product_cudnn_attention_backward_18[0] + getitem_343 = _scaled_dot_product_cudnn_attention_backward_18[1] + getitem_344 = _scaled_dot_product_cudnn_attention_backward_18[2]; _scaled_dot_product_cudnn_attention_backward_18 = None + permute_950 = torch.ops.aten.permute.default(getitem_344, [0, 2, 1, 3]); getitem_344 = None + permute_951 = torch.ops.aten.permute.default(getitem_343, [0, 2, 1, 3]); getitem_343 = None + permute_952 = torch.ops.aten.permute.default(getitem_342, [0, 2, 1, 3]); getitem_342 = None + view_1536 = torch.ops.aten.view.default(permute_950, [2, 8192, 8, 4, 128]); permute_950 = None + sum_113 = torch.ops.aten.sum.dim_IntList(view_1536, [3], True); view_1536 = None + squeeze_36 = torch.ops.aten.squeeze.dim(sum_113, 3); sum_113 = None + view_1537 = torch.ops.aten.view.default(permute_951, [2, 8192, 8, 4, 128]); permute_951 = None + sum_114 = torch.ops.aten.sum.dim_IntList(view_1537, [3], True); view_1537 = None + squeeze_37 = torch.ops.aten.squeeze.dim(sum_114, 3); sum_114 = None + convert_element_type_2075 = torch.ops.prims.convert_element_type.default(squeeze_37, torch.float32); squeeze_37 = None + convert_element_type_2076 = torch.ops.prims.convert_element_type.default(permute_952, torch.float32); permute_952 = None + view_1538 = torch.ops.aten.view.default(convert_element_type_2075, [2, 8192, 8, 64, 2]); convert_element_type_2075 = None + view_as_complex_100 = torch.ops.aten.view_as_complex.default(view_1538); view_1538 = None + mul_636 = torch.ops.aten.mul.Tensor(view_as_complex_100, _conj); view_as_complex_100 = None + view_1539 = torch.ops.aten.view.default(convert_element_type_2076, [2, 8192, 32, 64, 2]); convert_element_type_2076 = None + view_as_complex_101 = torch.ops.aten.view_as_complex.default(view_1539); view_1539 = None + mul_637 = torch.ops.aten.mul.Tensor(view_as_complex_101, _conj); view_as_complex_101 = None + view_as_real_100 = torch.ops.aten.view_as_real.default(mul_636); mul_636 = None + view_1540 = torch.ops.aten.view.default(view_as_real_100, [2, 8192, 8, 128]); view_as_real_100 = None + convert_element_type_2077 = torch.ops.prims.convert_element_type.default(view_1540, torch.bfloat16); view_1540 = None + view_as_real_101 = torch.ops.aten.view_as_real.default(mul_637); mul_637 = None + view_1541 = torch.ops.aten.view.default(view_as_real_101, [2, 8192, 32, 128]); view_as_real_101 = None + convert_element_type_2078 = torch.ops.prims.convert_element_type.default(view_1541, torch.bfloat16); view_1541 = None + view_1542 = torch.ops.aten.view.default(squeeze_36, [2, 8192, 1024]); squeeze_36 = None + view_1543 = torch.ops.aten.view.default(convert_element_type_2077, [2, 8192, 1024]); convert_element_type_2077 = None + view_1544 = torch.ops.aten.view.default(convert_element_type_2078, [2, 8192, 4096]); convert_element_type_2078 = None + view_1545 = torch.ops.aten.view.default(view_1542, [16384, 1024]); view_1542 = None + permute_953 = torch.ops.aten.permute.default(view_1545, [1, 0]) + mm_487 = torch.ops.aten.mm.default(permute_953, view_445); permute_953 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16); primals_124 = None + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 64, '0'); convert_element_type_439 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + permute_955 = torch.ops.aten.permute.default(permute_145, [1, 0]); permute_145 = None + mm_488 = torch.ops.aten.mm.default(view_1545, permute_955); view_1545 = permute_955 = None + view_1546 = torch.ops.aten.view.default(mm_488, [2, 8192, 4096]); mm_488 = None + convert_element_type_2083 = torch.ops.prims.convert_element_type.default(mm_487, torch.float32); mm_487 = None + reduce_scatter_tensor_169 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2083, 'avg', 64, '0'); convert_element_type_2083 = None + wait_tensor_460 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_169); reduce_scatter_tensor_169 = None + view_1547 = torch.ops.aten.view.default(view_1543, [16384, 1024]); view_1543 = None + permute_957 = torch.ops.aten.permute.default(view_1547, [1, 0]) + mm_489 = torch.ops.aten.mm.default(permute_957, view_445); permute_957 = None + permute_959 = torch.ops.aten.permute.default(permute_144, [1, 0]); permute_144 = None + mm_490 = torch.ops.aten.mm.default(view_1547, permute_959); view_1547 = permute_959 = None + view_1548 = torch.ops.aten.view.default(mm_490, [2, 8192, 4096]); mm_490 = None + add_259 = torch.ops.aten.add.Tensor(view_1546, view_1548); view_1546 = view_1548 = None + convert_element_type_2088 = torch.ops.prims.convert_element_type.default(mm_489, torch.float32); mm_489 = None + reduce_scatter_tensor_170 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2088, 'avg', 64, '0'); convert_element_type_2088 = None + wait_tensor_461 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_170); reduce_scatter_tensor_170 = None + view_1549 = torch.ops.aten.view.default(view_1544, [16384, 4096]); view_1544 = None + permute_961 = torch.ops.aten.permute.default(view_1549, [1, 0]) + mm_491 = torch.ops.aten.mm.default(permute_961, view_445); permute_961 = view_445 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16); primals_122 = None + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 64, '0'); convert_element_type_433 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_119, [1, 0]); wait_tensor_119 = None + permute_963 = torch.ops.aten.permute.default(permute_143, [1, 0]); permute_143 = None + mm_492 = torch.ops.aten.mm.default(view_1549, permute_963); view_1549 = permute_963 = None + view_1550 = torch.ops.aten.view.default(mm_492, [2, 8192, 4096]); mm_492 = None + add_260 = torch.ops.aten.add.Tensor(add_259, view_1550); add_259 = view_1550 = None + convert_element_type_2093 = torch.ops.prims.convert_element_type.default(mm_491, torch.float32); mm_491 = None + reduce_scatter_tensor_171 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2093, 'avg', 64, '0'); convert_element_type_2093 = None + wait_tensor_462 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_171); reduce_scatter_tensor_171 = None + convert_element_type_2094 = torch.ops.prims.convert_element_type.default(add_260, torch.float32); add_260 = None + convert_element_type_2096 = torch.ops.prims.convert_element_type.default(wait_tensor_118, torch.float32); wait_tensor_118 = None + mul_638 = torch.ops.aten.mul.Tensor(convert_element_type_2094, convert_element_type_2096); convert_element_type_2096 = None + mul_640 = torch.ops.aten.mul.Tensor(mul_104, mul_638) + sum_115 = torch.ops.aten.sum.dim_IntList(mul_640, [2], True); mul_640 = None + div_38 = torch.ops.aten.div.Tensor(mul_104, 4096) + mul_641 = torch.ops.aten.mul.Tensor(div_38, sum_115); div_38 = sum_115 = None + sub_57 = torch.ops.aten.sub.Tensor(mul_638, mul_641); mul_638 = mul_641 = None + mul_642 = torch.ops.aten.mul.Tensor(sub_57, rsqrt_26); sub_57 = rsqrt_26 = None + mul_643 = torch.ops.aten.mul.Tensor(convert_element_type_2094, mul_104); convert_element_type_2094 = mul_104 = None + sum_116 = torch.ops.aten.sum.dim_IntList(mul_643, [0, 1]); mul_643 = None + convert_element_type_2097 = torch.ops.prims.convert_element_type.default(mul_642, torch.bfloat16); mul_642 = None + add_261 = torch.ops.aten.add.Tensor(add_258, convert_element_type_2097); add_258 = convert_element_type_2097 = None + convert_element_type_default_27 = torch.ops.prims.convert_element_type.default(sum_116, torch.float32); sum_116 = None + reduce_scatter_tensor_172 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_27, 'avg', 64, '0'); convert_element_type_default_27 = None + wait_tensor_463 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_172); reduce_scatter_tensor_172 = None + view_1551 = torch.ops.aten.view.default(add_261, [16384, 4096]) + permute_965 = torch.ops.aten.permute.default(view_1551, [1, 0]) + permute_138 = torch.ops.aten.permute.default(getitem_108, [0, 2, 1, 3]) + view_429 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16); primals_116 = None + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 64, '0'); convert_element_type_413 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_113, [1, 0]); wait_tensor_113 = None + view_431 = torch.ops.aten.view.default(view_429, [16384, 4096]); view_429 = None + mm_87 = torch.ops.aten.mm.default(view_431, permute_139) + view_432 = torch.ops.aten.view.default(mm_87, [2, 8192, 4096]); mm_87 = None + add_49 = torch.ops.aten.add.Tensor(add_47, view_432); view_432 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16); primals_117 = None + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 64, '0'); convert_element_type_416 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32); add_49 = None + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_114) + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + view_435 = torch.ops.aten.view.default(convert_element_type_418, [16384, 4096]); convert_element_type_418 = None + view_436 = torch.ops.aten.view.default(mm_88, [2, 8192, 14336]); mm_88 = None + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_436, torch.float32); view_436 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16); primals_119 = None + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 64, '0'); convert_element_type_424 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_89 = torch.ops.aten.mm.default(view_435, permute_141) + view_439 = torch.ops.aten.view.default(mm_89, [2, 8192, 14336]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_439) + view_441 = torch.ops.aten.view.default(mul_103, [16384, 14336]); mul_103 = None + mm_493 = torch.ops.aten.mm.default(permute_965, view_441); permute_965 = view_441 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16); primals_120 = None + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 64, '0'); convert_element_type_427 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + permute_967 = torch.ops.aten.permute.default(permute_142, [1, 0]); permute_142 = None + mm_494 = torch.ops.aten.mm.default(view_1551, permute_967); view_1551 = permute_967 = None + view_1552 = torch.ops.aten.view.default(mm_494, [2, 8192, 14336]); mm_494 = None + convert_element_type_2104 = torch.ops.prims.convert_element_type.default(mm_493, torch.float32); mm_493 = None + reduce_scatter_tensor_173 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2104, 'avg', 64, '0'); convert_element_type_2104 = None + wait_tensor_464 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_173); reduce_scatter_tensor_173 = None + mul_644 = torch.ops.aten.mul.Tensor(view_1552, convert_element_type_423); convert_element_type_423 = None + mul_645 = torch.ops.aten.mul.Tensor(view_1552, view_439); view_1552 = view_439 = None + view_1553 = torch.ops.aten.view.default(mul_644, [16384, 14336]); mul_644 = None + permute_969 = torch.ops.aten.permute.default(view_1553, [1, 0]) + mm_495 = torch.ops.aten.mm.default(permute_969, view_435); permute_969 = None + permute_971 = torch.ops.aten.permute.default(permute_141, [1, 0]); permute_141 = None + mm_496 = torch.ops.aten.mm.default(view_1553, permute_971); view_1553 = permute_971 = None + view_1554 = torch.ops.aten.view.default(mm_496, [2, 8192, 4096]); mm_496 = None + convert_element_type_2109 = torch.ops.prims.convert_element_type.default(mm_495, torch.float32); mm_495 = None + reduce_scatter_tensor_174 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2109, 'avg', 64, '0'); convert_element_type_2109 = None + wait_tensor_465 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_174); reduce_scatter_tensor_174 = None + convert_element_type_2110 = torch.ops.prims.convert_element_type.default(mul_645, torch.float32); mul_645 = None + neg_19 = torch.ops.aten.neg.default(convert_element_type_422) + exp_19 = torch.ops.aten.exp.default(neg_19); neg_19 = None + add_262 = torch.ops.aten.add.Tensor(exp_19, 1); exp_19 = None + reciprocal_19 = torch.ops.aten.reciprocal.default(add_262); add_262 = None + mul_646 = torch.ops.aten.mul.Tensor(reciprocal_19, 1); reciprocal_19 = None + mul_647 = torch.ops.aten.mul.Tensor(convert_element_type_2110, mul_646); convert_element_type_2110 = None + sub_58 = torch.ops.aten.sub.Tensor(1, mul_646); mul_646 = None + mul_648 = torch.ops.aten.mul.Tensor(convert_element_type_422, sub_58); convert_element_type_422 = sub_58 = None + add_263 = torch.ops.aten.add.Tensor(mul_648, 1); mul_648 = None + mul_649 = torch.ops.aten.mul.Tensor(mul_647, add_263); mul_647 = add_263 = None + convert_element_type_2112 = torch.ops.prims.convert_element_type.default(mul_649, torch.bfloat16); mul_649 = None + view_1555 = torch.ops.aten.view.default(convert_element_type_2112, [16384, 14336]); convert_element_type_2112 = None + permute_973 = torch.ops.aten.permute.default(view_1555, [1, 0]) + mm_497 = torch.ops.aten.mm.default(permute_973, view_435); permute_973 = view_435 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16); primals_118 = None + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 64, '0'); convert_element_type_419 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + permute_975 = torch.ops.aten.permute.default(permute_140, [1, 0]); permute_140 = None + mm_498 = torch.ops.aten.mm.default(view_1555, permute_975); view_1555 = permute_975 = None + view_1556 = torch.ops.aten.view.default(mm_498, [2, 8192, 4096]); mm_498 = None + add_264 = torch.ops.aten.add.Tensor(view_1554, view_1556); view_1554 = view_1556 = None + convert_element_type_2117 = torch.ops.prims.convert_element_type.default(mm_497, torch.float32); mm_497 = None + reduce_scatter_tensor_175 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2117, 'avg', 64, '0'); convert_element_type_2117 = None + wait_tensor_466 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_175); reduce_scatter_tensor_175 = None + convert_element_type_2118 = torch.ops.prims.convert_element_type.default(add_264, torch.float32); add_264 = None + convert_element_type_2120 = torch.ops.prims.convert_element_type.default(wait_tensor_114, torch.float32); wait_tensor_114 = None + mul_650 = torch.ops.aten.mul.Tensor(convert_element_type_2118, convert_element_type_2120); convert_element_type_2120 = None + mul_652 = torch.ops.aten.mul.Tensor(mul_100, mul_650) + sum_117 = torch.ops.aten.sum.dim_IntList(mul_652, [2], True); mul_652 = None + div_39 = torch.ops.aten.div.Tensor(mul_100, 4096) + mul_653 = torch.ops.aten.mul.Tensor(div_39, sum_117); div_39 = sum_117 = None + sub_59 = torch.ops.aten.sub.Tensor(mul_650, mul_653); mul_650 = mul_653 = None + mul_654 = torch.ops.aten.mul.Tensor(sub_59, rsqrt_25); sub_59 = rsqrt_25 = None + mul_655 = torch.ops.aten.mul.Tensor(convert_element_type_2118, mul_100); convert_element_type_2118 = mul_100 = None + sum_118 = torch.ops.aten.sum.dim_IntList(mul_655, [0, 1]); mul_655 = None + convert_element_type_2121 = torch.ops.prims.convert_element_type.default(mul_654, torch.bfloat16); mul_654 = None + add_265 = torch.ops.aten.add.Tensor(add_261, convert_element_type_2121); add_261 = convert_element_type_2121 = None + convert_element_type_default_26 = torch.ops.prims.convert_element_type.default(sum_118, torch.float32); sum_118 = None + reduce_scatter_tensor_176 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_26, 'avg', 64, '0'); convert_element_type_default_26 = None + wait_tensor_467 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_176); reduce_scatter_tensor_176 = None + view_1557 = torch.ops.aten.view.default(add_265, [16384, 4096]) + permute_977 = torch.ops.aten.permute.default(view_1557, [1, 0]) + mm_499 = torch.ops.aten.mm.default(permute_977, view_431); permute_977 = view_431 = None + permute_979 = torch.ops.aten.permute.default(permute_139, [1, 0]); permute_139 = None + mm_500 = torch.ops.aten.mm.default(view_1557, permute_979); view_1557 = permute_979 = None + view_1558 = torch.ops.aten.view.default(mm_500, [2, 8192, 4096]); mm_500 = None + convert_element_type_2128 = torch.ops.prims.convert_element_type.default(mm_499, torch.float32); mm_499 = None + reduce_scatter_tensor_177 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2128, 'avg', 64, '0'); convert_element_type_2128 = None + wait_tensor_468 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_177); reduce_scatter_tensor_177 = None + view_1559 = torch.ops.aten.view.default(view_1558, [2, 8192, 32, 128]); view_1558 = None + permute_981 = torch.ops.aten.permute.default(view_1559, [0, 2, 1, 3]); view_1559 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16); primals_112 = None + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 64, '0'); convert_element_type_397 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32); add_47 = None + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_109) + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + view_411 = torch.ops.aten.view.default(convert_element_type_399, [16384, 4096]); convert_element_type_399 = None + view_412 = torch.ops.aten.view.default(mm_84, [2, 8192, 4096]); mm_84 = None + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16); primals_114 = None + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 64, '0'); convert_element_type_403 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + mm_85 = torch.ops.aten.mm.default(view_411, permute_133) + view_415 = torch.ops.aten.view.default(mm_85, [2, 8192, 1024]); mm_85 = None + view_418 = torch.ops.aten.view.default(mm_86, [2, 8192, 1024]); mm_86 = None + view_419 = torch.ops.aten.view.default(view_412, [2, 8192, -1, 128]); view_412 = None + view_420 = torch.ops.aten.view.default(view_415, [2, 8192, -1, 128]); view_415 = None + view_421 = torch.ops.aten.view.default(view_418, [2, 8192, -1, 128]); view_418 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_419, torch.float32); view_419 = None + view_422 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 32, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_422); view_422 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_420, torch.float32); view_420 = None + view_423 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 8, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_423); view_423 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_16); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_425 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 32, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_16); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_426 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 8, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_425, torch.bfloat16); view_425 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_426, torch.bfloat16); view_426 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 8, 4, 128]); unsqueeze_24 = None + clone_24 = torch.ops.aten.clone.default(expand_24, memory_format = torch.contiguous_format); expand_24 = None + view_427 = torch.ops.aten.view.default(clone_24, [2, 8192, 32, 128]); clone_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_421, 3); view_421 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 8, 4, 128]); unsqueeze_25 = None + clone_25 = torch.ops.aten.clone.default(expand_25, memory_format = torch.contiguous_format); expand_25 = None + view_428 = torch.ops.aten.view.default(clone_25, [2, 8192, 32, 128]); clone_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_427, [0, 2, 1, 3]); view_427 = None + permute_137 = torch.ops.aten.permute.default(view_428, [0, 2, 1, 3]); view_428 = None + _scaled_dot_product_cudnn_attention_backward_19 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_981, permute_135, permute_136, permute_137, getitem_108, getitem_109, getitem_114, getitem_115, None, None, None, 8192, 8192, 0.0, True); permute_981 = permute_135 = permute_136 = permute_137 = getitem_108 = getitem_109 = getitem_114 = getitem_115 = None + getitem_345 = _scaled_dot_product_cudnn_attention_backward_19[0] + getitem_346 = _scaled_dot_product_cudnn_attention_backward_19[1] + getitem_347 = _scaled_dot_product_cudnn_attention_backward_19[2]; _scaled_dot_product_cudnn_attention_backward_19 = None + permute_982 = torch.ops.aten.permute.default(getitem_347, [0, 2, 1, 3]); getitem_347 = None + permute_983 = torch.ops.aten.permute.default(getitem_346, [0, 2, 1, 3]); getitem_346 = None + permute_984 = torch.ops.aten.permute.default(getitem_345, [0, 2, 1, 3]); getitem_345 = None + view_1560 = torch.ops.aten.view.default(permute_982, [2, 8192, 8, 4, 128]); permute_982 = None + sum_119 = torch.ops.aten.sum.dim_IntList(view_1560, [3], True); view_1560 = None + squeeze_38 = torch.ops.aten.squeeze.dim(sum_119, 3); sum_119 = None + view_1561 = torch.ops.aten.view.default(permute_983, [2, 8192, 8, 4, 128]); permute_983 = None + sum_120 = torch.ops.aten.sum.dim_IntList(view_1561, [3], True); view_1561 = None + squeeze_39 = torch.ops.aten.squeeze.dim(sum_120, 3); sum_120 = None + convert_element_type_2129 = torch.ops.prims.convert_element_type.default(squeeze_39, torch.float32); squeeze_39 = None + convert_element_type_2130 = torch.ops.prims.convert_element_type.default(permute_984, torch.float32); permute_984 = None + view_1562 = torch.ops.aten.view.default(convert_element_type_2129, [2, 8192, 8, 64, 2]); convert_element_type_2129 = None + view_as_complex_102 = torch.ops.aten.view_as_complex.default(view_1562); view_1562 = None + mul_656 = torch.ops.aten.mul.Tensor(view_as_complex_102, _conj); view_as_complex_102 = None + view_1563 = torch.ops.aten.view.default(convert_element_type_2130, [2, 8192, 32, 64, 2]); convert_element_type_2130 = None + view_as_complex_103 = torch.ops.aten.view_as_complex.default(view_1563); view_1563 = None + mul_657 = torch.ops.aten.mul.Tensor(view_as_complex_103, _conj); view_as_complex_103 = None + view_as_real_102 = torch.ops.aten.view_as_real.default(mul_656); mul_656 = None + view_1564 = torch.ops.aten.view.default(view_as_real_102, [2, 8192, 8, 128]); view_as_real_102 = None + convert_element_type_2131 = torch.ops.prims.convert_element_type.default(view_1564, torch.bfloat16); view_1564 = None + view_as_real_103 = torch.ops.aten.view_as_real.default(mul_657); mul_657 = None + view_1565 = torch.ops.aten.view.default(view_as_real_103, [2, 8192, 32, 128]); view_as_real_103 = None + convert_element_type_2132 = torch.ops.prims.convert_element_type.default(view_1565, torch.bfloat16); view_1565 = None + view_1566 = torch.ops.aten.view.default(squeeze_38, [2, 8192, 1024]); squeeze_38 = None + view_1567 = torch.ops.aten.view.default(convert_element_type_2131, [2, 8192, 1024]); convert_element_type_2131 = None + view_1568 = torch.ops.aten.view.default(convert_element_type_2132, [2, 8192, 4096]); convert_element_type_2132 = None + view_1569 = torch.ops.aten.view.default(view_1566, [16384, 1024]); view_1566 = None + permute_985 = torch.ops.aten.permute.default(view_1569, [1, 0]) + mm_501 = torch.ops.aten.mm.default(permute_985, view_411); permute_985 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16); primals_115 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 64, '0'); convert_element_type_406 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_112, [1, 0]); wait_tensor_112 = None + permute_987 = torch.ops.aten.permute.default(permute_134, [1, 0]); permute_134 = None + mm_502 = torch.ops.aten.mm.default(view_1569, permute_987); view_1569 = permute_987 = None + view_1570 = torch.ops.aten.view.default(mm_502, [2, 8192, 4096]); mm_502 = None + convert_element_type_2137 = torch.ops.prims.convert_element_type.default(mm_501, torch.float32); mm_501 = None + reduce_scatter_tensor_178 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2137, 'avg', 64, '0'); convert_element_type_2137 = None + wait_tensor_469 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_178); reduce_scatter_tensor_178 = None + view_1571 = torch.ops.aten.view.default(view_1567, [16384, 1024]); view_1567 = None + permute_989 = torch.ops.aten.permute.default(view_1571, [1, 0]) + mm_503 = torch.ops.aten.mm.default(permute_989, view_411); permute_989 = None + permute_991 = torch.ops.aten.permute.default(permute_133, [1, 0]); permute_133 = None + mm_504 = torch.ops.aten.mm.default(view_1571, permute_991); view_1571 = permute_991 = None + view_1572 = torch.ops.aten.view.default(mm_504, [2, 8192, 4096]); mm_504 = None + add_266 = torch.ops.aten.add.Tensor(view_1570, view_1572); view_1570 = view_1572 = None + convert_element_type_2142 = torch.ops.prims.convert_element_type.default(mm_503, torch.float32); mm_503 = None + reduce_scatter_tensor_179 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2142, 'avg', 64, '0'); convert_element_type_2142 = None + wait_tensor_470 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_179); reduce_scatter_tensor_179 = None + view_1573 = torch.ops.aten.view.default(view_1568, [16384, 4096]); view_1568 = None + permute_993 = torch.ops.aten.permute.default(view_1573, [1, 0]) + mm_505 = torch.ops.aten.mm.default(permute_993, view_411); permute_993 = view_411 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16); primals_113 = None + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 64, '0'); convert_element_type_400 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + permute_995 = torch.ops.aten.permute.default(permute_132, [1, 0]); permute_132 = None + mm_506 = torch.ops.aten.mm.default(view_1573, permute_995); view_1573 = permute_995 = None + view_1574 = torch.ops.aten.view.default(mm_506, [2, 8192, 4096]); mm_506 = None + add_267 = torch.ops.aten.add.Tensor(add_266, view_1574); add_266 = view_1574 = None + convert_element_type_2147 = torch.ops.prims.convert_element_type.default(mm_505, torch.float32); mm_505 = None + reduce_scatter_tensor_180 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2147, 'avg', 64, '0'); convert_element_type_2147 = None + wait_tensor_471 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_180); reduce_scatter_tensor_180 = None + convert_element_type_2148 = torch.ops.prims.convert_element_type.default(add_267, torch.float32); add_267 = None + convert_element_type_2150 = torch.ops.prims.convert_element_type.default(wait_tensor_109, torch.float32); wait_tensor_109 = None + mul_658 = torch.ops.aten.mul.Tensor(convert_element_type_2148, convert_element_type_2150); convert_element_type_2150 = None + mul_660 = torch.ops.aten.mul.Tensor(mul_96, mul_658) + sum_121 = torch.ops.aten.sum.dim_IntList(mul_660, [2], True); mul_660 = None + div_40 = torch.ops.aten.div.Tensor(mul_96, 4096) + mul_661 = torch.ops.aten.mul.Tensor(div_40, sum_121); div_40 = sum_121 = None + sub_60 = torch.ops.aten.sub.Tensor(mul_658, mul_661); mul_658 = mul_661 = None + mul_662 = torch.ops.aten.mul.Tensor(sub_60, rsqrt_24); sub_60 = rsqrt_24 = None + mul_663 = torch.ops.aten.mul.Tensor(convert_element_type_2148, mul_96); convert_element_type_2148 = mul_96 = None + sum_122 = torch.ops.aten.sum.dim_IntList(mul_663, [0, 1]); mul_663 = None + convert_element_type_2151 = torch.ops.prims.convert_element_type.default(mul_662, torch.bfloat16); mul_662 = None + add_268 = torch.ops.aten.add.Tensor(add_265, convert_element_type_2151); add_265 = convert_element_type_2151 = None + convert_element_type_default_25 = torch.ops.prims.convert_element_type.default(sum_122, torch.float32); sum_122 = None + reduce_scatter_tensor_181 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_25, 'avg', 64, '0'); convert_element_type_default_25 = None + wait_tensor_472 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_181); reduce_scatter_tensor_181 = None + view_1575 = torch.ops.aten.view.default(add_268, [16384, 4096]) + permute_997 = torch.ops.aten.permute.default(view_1575, [1, 0]) + permute_127 = torch.ops.aten.permute.default(getitem_99, [0, 2, 1, 3]) + view_395 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16); primals_107 = None + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 64, '0'); convert_element_type_380 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + view_397 = torch.ops.aten.view.default(view_395, [16384, 4096]); view_395 = None + mm_80 = torch.ops.aten.mm.default(view_397, permute_128) + view_398 = torch.ops.aten.view.default(mm_80, [2, 8192, 4096]); mm_80 = None + add_45 = torch.ops.aten.add.Tensor(add_43, view_398); view_398 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16); primals_108 = None + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 64, '0'); convert_element_type_383 = None + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32); add_45 = None + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_105) + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + view_401 = torch.ops.aten.view.default(convert_element_type_385, [16384, 4096]); convert_element_type_385 = None + view_402 = torch.ops.aten.view.default(mm_81, [2, 8192, 14336]); mm_81 = None + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_402, torch.float32); view_402 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16); primals_110 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 64, '0'); convert_element_type_391 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_107, [1, 0]); wait_tensor_107 = None + mm_82 = torch.ops.aten.mm.default(view_401, permute_130) + view_405 = torch.ops.aten.view.default(mm_82, [2, 8192, 14336]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_405) + view_407 = torch.ops.aten.view.default(mul_95, [16384, 14336]); mul_95 = None + mm_507 = torch.ops.aten.mm.default(permute_997, view_407); permute_997 = view_407 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16); primals_111 = None + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 64, '0'); convert_element_type_394 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + permute_999 = torch.ops.aten.permute.default(permute_131, [1, 0]); permute_131 = None + mm_508 = torch.ops.aten.mm.default(view_1575, permute_999); view_1575 = permute_999 = None + view_1576 = torch.ops.aten.view.default(mm_508, [2, 8192, 14336]); mm_508 = None + convert_element_type_2158 = torch.ops.prims.convert_element_type.default(mm_507, torch.float32); mm_507 = None + reduce_scatter_tensor_182 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2158, 'avg', 64, '0'); convert_element_type_2158 = None + wait_tensor_473 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_182); reduce_scatter_tensor_182 = None + mul_664 = torch.ops.aten.mul.Tensor(view_1576, convert_element_type_390); convert_element_type_390 = None + mul_665 = torch.ops.aten.mul.Tensor(view_1576, view_405); view_1576 = view_405 = None + view_1577 = torch.ops.aten.view.default(mul_664, [16384, 14336]); mul_664 = None + permute_1001 = torch.ops.aten.permute.default(view_1577, [1, 0]) + mm_509 = torch.ops.aten.mm.default(permute_1001, view_401); permute_1001 = None + permute_1003 = torch.ops.aten.permute.default(permute_130, [1, 0]); permute_130 = None + mm_510 = torch.ops.aten.mm.default(view_1577, permute_1003); view_1577 = permute_1003 = None + view_1578 = torch.ops.aten.view.default(mm_510, [2, 8192, 4096]); mm_510 = None + convert_element_type_2163 = torch.ops.prims.convert_element_type.default(mm_509, torch.float32); mm_509 = None + reduce_scatter_tensor_183 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2163, 'avg', 64, '0'); convert_element_type_2163 = None + wait_tensor_474 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_183); reduce_scatter_tensor_183 = None + convert_element_type_2164 = torch.ops.prims.convert_element_type.default(mul_665, torch.float32); mul_665 = None + neg_20 = torch.ops.aten.neg.default(convert_element_type_389) + exp_20 = torch.ops.aten.exp.default(neg_20); neg_20 = None + add_269 = torch.ops.aten.add.Tensor(exp_20, 1); exp_20 = None + reciprocal_20 = torch.ops.aten.reciprocal.default(add_269); add_269 = None + mul_666 = torch.ops.aten.mul.Tensor(reciprocal_20, 1); reciprocal_20 = None + mul_667 = torch.ops.aten.mul.Tensor(convert_element_type_2164, mul_666); convert_element_type_2164 = None + sub_61 = torch.ops.aten.sub.Tensor(1, mul_666); mul_666 = None + mul_668 = torch.ops.aten.mul.Tensor(convert_element_type_389, sub_61); convert_element_type_389 = sub_61 = None + add_270 = torch.ops.aten.add.Tensor(mul_668, 1); mul_668 = None + mul_669 = torch.ops.aten.mul.Tensor(mul_667, add_270); mul_667 = add_270 = None + convert_element_type_2166 = torch.ops.prims.convert_element_type.default(mul_669, torch.bfloat16); mul_669 = None + view_1579 = torch.ops.aten.view.default(convert_element_type_2166, [16384, 14336]); convert_element_type_2166 = None + permute_1005 = torch.ops.aten.permute.default(view_1579, [1, 0]) + mm_511 = torch.ops.aten.mm.default(permute_1005, view_401); permute_1005 = view_401 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16); primals_109 = None + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 64, '0'); convert_element_type_386 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_106, [1, 0]); wait_tensor_106 = None + permute_1007 = torch.ops.aten.permute.default(permute_129, [1, 0]); permute_129 = None + mm_512 = torch.ops.aten.mm.default(view_1579, permute_1007); view_1579 = permute_1007 = None + view_1580 = torch.ops.aten.view.default(mm_512, [2, 8192, 4096]); mm_512 = None + add_271 = torch.ops.aten.add.Tensor(view_1578, view_1580); view_1578 = view_1580 = None + convert_element_type_2171 = torch.ops.prims.convert_element_type.default(mm_511, torch.float32); mm_511 = None + reduce_scatter_tensor_184 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2171, 'avg', 64, '0'); convert_element_type_2171 = None + wait_tensor_475 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_184); reduce_scatter_tensor_184 = None + convert_element_type_2172 = torch.ops.prims.convert_element_type.default(add_271, torch.float32); add_271 = None + convert_element_type_2174 = torch.ops.prims.convert_element_type.default(wait_tensor_105, torch.float32); wait_tensor_105 = None + mul_670 = torch.ops.aten.mul.Tensor(convert_element_type_2172, convert_element_type_2174); convert_element_type_2174 = None + mul_672 = torch.ops.aten.mul.Tensor(mul_92, mul_670) + sum_123 = torch.ops.aten.sum.dim_IntList(mul_672, [2], True); mul_672 = None + div_41 = torch.ops.aten.div.Tensor(mul_92, 4096) + mul_673 = torch.ops.aten.mul.Tensor(div_41, sum_123); div_41 = sum_123 = None + sub_62 = torch.ops.aten.sub.Tensor(mul_670, mul_673); mul_670 = mul_673 = None + mul_674 = torch.ops.aten.mul.Tensor(sub_62, rsqrt_23); sub_62 = rsqrt_23 = None + mul_675 = torch.ops.aten.mul.Tensor(convert_element_type_2172, mul_92); convert_element_type_2172 = mul_92 = None + sum_124 = torch.ops.aten.sum.dim_IntList(mul_675, [0, 1]); mul_675 = None + convert_element_type_2175 = torch.ops.prims.convert_element_type.default(mul_674, torch.bfloat16); mul_674 = None + add_272 = torch.ops.aten.add.Tensor(add_268, convert_element_type_2175); add_268 = convert_element_type_2175 = None + convert_element_type_default_24 = torch.ops.prims.convert_element_type.default(sum_124, torch.float32); sum_124 = None + reduce_scatter_tensor_185 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_24, 'avg', 64, '0'); convert_element_type_default_24 = None + wait_tensor_476 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_185); reduce_scatter_tensor_185 = None + view_1581 = torch.ops.aten.view.default(add_272, [16384, 4096]) + permute_1009 = torch.ops.aten.permute.default(view_1581, [1, 0]) + mm_513 = torch.ops.aten.mm.default(permute_1009, view_397); permute_1009 = view_397 = None + permute_1011 = torch.ops.aten.permute.default(permute_128, [1, 0]); permute_128 = None + mm_514 = torch.ops.aten.mm.default(view_1581, permute_1011); view_1581 = permute_1011 = None + view_1582 = torch.ops.aten.view.default(mm_514, [2, 8192, 4096]); mm_514 = None + convert_element_type_2182 = torch.ops.prims.convert_element_type.default(mm_513, torch.float32); mm_513 = None + reduce_scatter_tensor_186 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2182, 'avg', 64, '0'); convert_element_type_2182 = None + wait_tensor_477 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_186); reduce_scatter_tensor_186 = None + view_1583 = torch.ops.aten.view.default(view_1582, [2, 8192, 32, 128]); view_1582 = None + permute_1013 = torch.ops.aten.permute.default(view_1583, [0, 2, 1, 3]); view_1583 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16); primals_103 = None + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 64, '0'); convert_element_type_364 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32); add_43 = None + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_100) + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + view_377 = torch.ops.aten.view.default(convert_element_type_366, [16384, 4096]); convert_element_type_366 = None + view_378 = torch.ops.aten.view.default(mm_77, [2, 8192, 4096]); mm_77 = None + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16); primals_105 = None + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 64, '0'); convert_element_type_370 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + mm_78 = torch.ops.aten.mm.default(view_377, permute_122) + view_381 = torch.ops.aten.view.default(mm_78, [2, 8192, 1024]); mm_78 = None + view_384 = torch.ops.aten.view.default(mm_79, [2, 8192, 1024]); mm_79 = None + view_385 = torch.ops.aten.view.default(view_378, [2, 8192, -1, 128]); view_378 = None + view_386 = torch.ops.aten.view.default(view_381, [2, 8192, -1, 128]); view_381 = None + view_387 = torch.ops.aten.view.default(view_384, [2, 8192, -1, 128]); view_384 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_385, torch.float32); view_385 = None + view_388 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 32, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_388); view_388 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_386, torch.float32); view_386 = None + view_389 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 8, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_389); view_389 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_16); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_391 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 32, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_16); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_392 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 8, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_391, torch.bfloat16); view_391 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_392, torch.bfloat16); view_392 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 8, 4, 128]); unsqueeze_22 = None + clone_22 = torch.ops.aten.clone.default(expand_22, memory_format = torch.contiguous_format); expand_22 = None + view_393 = torch.ops.aten.view.default(clone_22, [2, 8192, 32, 128]); clone_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_387, 3); view_387 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 8, 4, 128]); unsqueeze_23 = None + clone_23 = torch.ops.aten.clone.default(expand_23, memory_format = torch.contiguous_format); expand_23 = None + view_394 = torch.ops.aten.view.default(clone_23, [2, 8192, 32, 128]); clone_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_393, [0, 2, 1, 3]); view_393 = None + permute_126 = torch.ops.aten.permute.default(view_394, [0, 2, 1, 3]); view_394 = None + _scaled_dot_product_cudnn_attention_backward_20 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1013, permute_124, permute_125, permute_126, getitem_99, getitem_100, getitem_105, getitem_106, None, None, None, 8192, 8192, 0.0, True); permute_1013 = permute_124 = permute_125 = permute_126 = getitem_99 = getitem_100 = getitem_105 = getitem_106 = None + getitem_348 = _scaled_dot_product_cudnn_attention_backward_20[0] + getitem_349 = _scaled_dot_product_cudnn_attention_backward_20[1] + getitem_350 = _scaled_dot_product_cudnn_attention_backward_20[2]; _scaled_dot_product_cudnn_attention_backward_20 = None + permute_1014 = torch.ops.aten.permute.default(getitem_350, [0, 2, 1, 3]); getitem_350 = None + permute_1015 = torch.ops.aten.permute.default(getitem_349, [0, 2, 1, 3]); getitem_349 = None + permute_1016 = torch.ops.aten.permute.default(getitem_348, [0, 2, 1, 3]); getitem_348 = None + view_1584 = torch.ops.aten.view.default(permute_1014, [2, 8192, 8, 4, 128]); permute_1014 = None + sum_125 = torch.ops.aten.sum.dim_IntList(view_1584, [3], True); view_1584 = None + squeeze_40 = torch.ops.aten.squeeze.dim(sum_125, 3); sum_125 = None + view_1585 = torch.ops.aten.view.default(permute_1015, [2, 8192, 8, 4, 128]); permute_1015 = None + sum_126 = torch.ops.aten.sum.dim_IntList(view_1585, [3], True); view_1585 = None + squeeze_41 = torch.ops.aten.squeeze.dim(sum_126, 3); sum_126 = None + convert_element_type_2183 = torch.ops.prims.convert_element_type.default(squeeze_41, torch.float32); squeeze_41 = None + convert_element_type_2184 = torch.ops.prims.convert_element_type.default(permute_1016, torch.float32); permute_1016 = None + view_1586 = torch.ops.aten.view.default(convert_element_type_2183, [2, 8192, 8, 64, 2]); convert_element_type_2183 = None + view_as_complex_104 = torch.ops.aten.view_as_complex.default(view_1586); view_1586 = None + mul_676 = torch.ops.aten.mul.Tensor(view_as_complex_104, _conj); view_as_complex_104 = None + view_1587 = torch.ops.aten.view.default(convert_element_type_2184, [2, 8192, 32, 64, 2]); convert_element_type_2184 = None + view_as_complex_105 = torch.ops.aten.view_as_complex.default(view_1587); view_1587 = None + mul_677 = torch.ops.aten.mul.Tensor(view_as_complex_105, _conj); view_as_complex_105 = None + view_as_real_104 = torch.ops.aten.view_as_real.default(mul_676); mul_676 = None + view_1588 = torch.ops.aten.view.default(view_as_real_104, [2, 8192, 8, 128]); view_as_real_104 = None + convert_element_type_2185 = torch.ops.prims.convert_element_type.default(view_1588, torch.bfloat16); view_1588 = None + view_as_real_105 = torch.ops.aten.view_as_real.default(mul_677); mul_677 = None + view_1589 = torch.ops.aten.view.default(view_as_real_105, [2, 8192, 32, 128]); view_as_real_105 = None + convert_element_type_2186 = torch.ops.prims.convert_element_type.default(view_1589, torch.bfloat16); view_1589 = None + view_1590 = torch.ops.aten.view.default(squeeze_40, [2, 8192, 1024]); squeeze_40 = None + view_1591 = torch.ops.aten.view.default(convert_element_type_2185, [2, 8192, 1024]); convert_element_type_2185 = None + view_1592 = torch.ops.aten.view.default(convert_element_type_2186, [2, 8192, 4096]); convert_element_type_2186 = None + view_1593 = torch.ops.aten.view.default(view_1590, [16384, 1024]); view_1590 = None + permute_1017 = torch.ops.aten.permute.default(view_1593, [1, 0]) + mm_515 = torch.ops.aten.mm.default(permute_1017, view_377); permute_1017 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16); primals_106 = None + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 64, '0'); convert_element_type_373 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + permute_1019 = torch.ops.aten.permute.default(permute_123, [1, 0]); permute_123 = None + mm_516 = torch.ops.aten.mm.default(view_1593, permute_1019); view_1593 = permute_1019 = None + view_1594 = torch.ops.aten.view.default(mm_516, [2, 8192, 4096]); mm_516 = None + convert_element_type_2191 = torch.ops.prims.convert_element_type.default(mm_515, torch.float32); mm_515 = None + reduce_scatter_tensor_187 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2191, 'avg', 64, '0'); convert_element_type_2191 = None + wait_tensor_478 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_187); reduce_scatter_tensor_187 = None + view_1595 = torch.ops.aten.view.default(view_1591, [16384, 1024]); view_1591 = None + permute_1021 = torch.ops.aten.permute.default(view_1595, [1, 0]) + mm_517 = torch.ops.aten.mm.default(permute_1021, view_377); permute_1021 = None + permute_1023 = torch.ops.aten.permute.default(permute_122, [1, 0]); permute_122 = None + mm_518 = torch.ops.aten.mm.default(view_1595, permute_1023); view_1595 = permute_1023 = None + view_1596 = torch.ops.aten.view.default(mm_518, [2, 8192, 4096]); mm_518 = None + add_273 = torch.ops.aten.add.Tensor(view_1594, view_1596); view_1594 = view_1596 = None + convert_element_type_2196 = torch.ops.prims.convert_element_type.default(mm_517, torch.float32); mm_517 = None + reduce_scatter_tensor_188 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2196, 'avg', 64, '0'); convert_element_type_2196 = None + wait_tensor_479 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_188); reduce_scatter_tensor_188 = None + view_1597 = torch.ops.aten.view.default(view_1592, [16384, 4096]); view_1592 = None + permute_1025 = torch.ops.aten.permute.default(view_1597, [1, 0]) + mm_519 = torch.ops.aten.mm.default(permute_1025, view_377); permute_1025 = view_377 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16); primals_104 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 64, '0'); convert_element_type_367 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_101, [1, 0]); wait_tensor_101 = None + permute_1027 = torch.ops.aten.permute.default(permute_121, [1, 0]); permute_121 = None + mm_520 = torch.ops.aten.mm.default(view_1597, permute_1027); view_1597 = permute_1027 = None + view_1598 = torch.ops.aten.view.default(mm_520, [2, 8192, 4096]); mm_520 = None + add_274 = torch.ops.aten.add.Tensor(add_273, view_1598); add_273 = view_1598 = None + convert_element_type_2201 = torch.ops.prims.convert_element_type.default(mm_519, torch.float32); mm_519 = None + reduce_scatter_tensor_189 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2201, 'avg', 64, '0'); convert_element_type_2201 = None + wait_tensor_480 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_189); reduce_scatter_tensor_189 = None + convert_element_type_2202 = torch.ops.prims.convert_element_type.default(add_274, torch.float32); add_274 = None + convert_element_type_2204 = torch.ops.prims.convert_element_type.default(wait_tensor_100, torch.float32); wait_tensor_100 = None + mul_678 = torch.ops.aten.mul.Tensor(convert_element_type_2202, convert_element_type_2204); convert_element_type_2204 = None + mul_680 = torch.ops.aten.mul.Tensor(mul_88, mul_678) + sum_127 = torch.ops.aten.sum.dim_IntList(mul_680, [2], True); mul_680 = None + div_42 = torch.ops.aten.div.Tensor(mul_88, 4096) + mul_681 = torch.ops.aten.mul.Tensor(div_42, sum_127); div_42 = sum_127 = None + sub_63 = torch.ops.aten.sub.Tensor(mul_678, mul_681); mul_678 = mul_681 = None + mul_682 = torch.ops.aten.mul.Tensor(sub_63, rsqrt_22); sub_63 = rsqrt_22 = None + mul_683 = torch.ops.aten.mul.Tensor(convert_element_type_2202, mul_88); convert_element_type_2202 = mul_88 = None + sum_128 = torch.ops.aten.sum.dim_IntList(mul_683, [0, 1]); mul_683 = None + convert_element_type_2205 = torch.ops.prims.convert_element_type.default(mul_682, torch.bfloat16); mul_682 = None + add_275 = torch.ops.aten.add.Tensor(add_272, convert_element_type_2205); add_272 = convert_element_type_2205 = None + convert_element_type_default_23 = torch.ops.prims.convert_element_type.default(sum_128, torch.float32); sum_128 = None + reduce_scatter_tensor_190 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_23, 'avg', 64, '0'); convert_element_type_default_23 = None + wait_tensor_481 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_190); reduce_scatter_tensor_190 = None + view_1599 = torch.ops.aten.view.default(add_275, [16384, 4096]) + permute_1029 = torch.ops.aten.permute.default(view_1599, [1, 0]) + permute_116 = torch.ops.aten.permute.default(getitem_90, [0, 2, 1, 3]) + view_361 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16); primals_98 = None + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 64, '0'); convert_element_type_347 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + view_363 = torch.ops.aten.view.default(view_361, [16384, 4096]); view_361 = None + mm_73 = torch.ops.aten.mm.default(view_363, permute_117) + view_364 = torch.ops.aten.view.default(mm_73, [2, 8192, 4096]); mm_73 = None + add_41 = torch.ops.aten.add.Tensor(add_39, view_364); view_364 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16); primals_99 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 64, '0'); convert_element_type_350 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32); add_41 = None + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_96) + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + view_367 = torch.ops.aten.view.default(convert_element_type_352, [16384, 4096]); convert_element_type_352 = None + view_368 = torch.ops.aten.view.default(mm_74, [2, 8192, 14336]); mm_74 = None + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_368, torch.float32); view_368 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16); primals_101 = None + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 64, '0'); convert_element_type_358 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + mm_75 = torch.ops.aten.mm.default(view_367, permute_119) + view_371 = torch.ops.aten.view.default(mm_75, [2, 8192, 14336]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_371) + view_373 = torch.ops.aten.view.default(mul_87, [16384, 14336]); mul_87 = None + mm_521 = torch.ops.aten.mm.default(permute_1029, view_373); permute_1029 = view_373 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16); primals_102 = None + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 64, '0'); convert_element_type_361 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_99, [1, 0]); wait_tensor_99 = None + permute_1031 = torch.ops.aten.permute.default(permute_120, [1, 0]); permute_120 = None + mm_522 = torch.ops.aten.mm.default(view_1599, permute_1031); view_1599 = permute_1031 = None + view_1600 = torch.ops.aten.view.default(mm_522, [2, 8192, 14336]); mm_522 = None + convert_element_type_2212 = torch.ops.prims.convert_element_type.default(mm_521, torch.float32); mm_521 = None + reduce_scatter_tensor_191 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2212, 'avg', 64, '0'); convert_element_type_2212 = None + wait_tensor_482 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_191); reduce_scatter_tensor_191 = None + mul_684 = torch.ops.aten.mul.Tensor(view_1600, convert_element_type_357); convert_element_type_357 = None + mul_685 = torch.ops.aten.mul.Tensor(view_1600, view_371); view_1600 = view_371 = None + view_1601 = torch.ops.aten.view.default(mul_684, [16384, 14336]); mul_684 = None + permute_1033 = torch.ops.aten.permute.default(view_1601, [1, 0]) + mm_523 = torch.ops.aten.mm.default(permute_1033, view_367); permute_1033 = None + permute_1035 = torch.ops.aten.permute.default(permute_119, [1, 0]); permute_119 = None + mm_524 = torch.ops.aten.mm.default(view_1601, permute_1035); view_1601 = permute_1035 = None + view_1602 = torch.ops.aten.view.default(mm_524, [2, 8192, 4096]); mm_524 = None + convert_element_type_2217 = torch.ops.prims.convert_element_type.default(mm_523, torch.float32); mm_523 = None + reduce_scatter_tensor_192 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2217, 'avg', 64, '0'); convert_element_type_2217 = None + wait_tensor_483 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_192); reduce_scatter_tensor_192 = None + convert_element_type_2218 = torch.ops.prims.convert_element_type.default(mul_685, torch.float32); mul_685 = None + neg_21 = torch.ops.aten.neg.default(convert_element_type_356) + exp_21 = torch.ops.aten.exp.default(neg_21); neg_21 = None + add_276 = torch.ops.aten.add.Tensor(exp_21, 1); exp_21 = None + reciprocal_21 = torch.ops.aten.reciprocal.default(add_276); add_276 = None + mul_686 = torch.ops.aten.mul.Tensor(reciprocal_21, 1); reciprocal_21 = None + mul_687 = torch.ops.aten.mul.Tensor(convert_element_type_2218, mul_686); convert_element_type_2218 = None + sub_64 = torch.ops.aten.sub.Tensor(1, mul_686); mul_686 = None + mul_688 = torch.ops.aten.mul.Tensor(convert_element_type_356, sub_64); convert_element_type_356 = sub_64 = None + add_277 = torch.ops.aten.add.Tensor(mul_688, 1); mul_688 = None + mul_689 = torch.ops.aten.mul.Tensor(mul_687, add_277); mul_687 = add_277 = None + convert_element_type_2220 = torch.ops.prims.convert_element_type.default(mul_689, torch.bfloat16); mul_689 = None + view_1603 = torch.ops.aten.view.default(convert_element_type_2220, [16384, 14336]); convert_element_type_2220 = None + permute_1037 = torch.ops.aten.permute.default(view_1603, [1, 0]) + mm_525 = torch.ops.aten.mm.default(permute_1037, view_367); permute_1037 = view_367 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16); primals_100 = None + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 64, '0'); convert_element_type_353 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + permute_1039 = torch.ops.aten.permute.default(permute_118, [1, 0]); permute_118 = None + mm_526 = torch.ops.aten.mm.default(view_1603, permute_1039); view_1603 = permute_1039 = None + view_1604 = torch.ops.aten.view.default(mm_526, [2, 8192, 4096]); mm_526 = None + add_278 = torch.ops.aten.add.Tensor(view_1602, view_1604); view_1602 = view_1604 = None + convert_element_type_2225 = torch.ops.prims.convert_element_type.default(mm_525, torch.float32); mm_525 = None + reduce_scatter_tensor_193 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2225, 'avg', 64, '0'); convert_element_type_2225 = None + wait_tensor_484 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_193); reduce_scatter_tensor_193 = None + convert_element_type_2226 = torch.ops.prims.convert_element_type.default(add_278, torch.float32); add_278 = None + convert_element_type_2228 = torch.ops.prims.convert_element_type.default(wait_tensor_96, torch.float32); wait_tensor_96 = None + mul_690 = torch.ops.aten.mul.Tensor(convert_element_type_2226, convert_element_type_2228); convert_element_type_2228 = None + mul_692 = torch.ops.aten.mul.Tensor(mul_84, mul_690) + sum_129 = torch.ops.aten.sum.dim_IntList(mul_692, [2], True); mul_692 = None + div_43 = torch.ops.aten.div.Tensor(mul_84, 4096) + mul_693 = torch.ops.aten.mul.Tensor(div_43, sum_129); div_43 = sum_129 = None + sub_65 = torch.ops.aten.sub.Tensor(mul_690, mul_693); mul_690 = mul_693 = None + mul_694 = torch.ops.aten.mul.Tensor(sub_65, rsqrt_21); sub_65 = rsqrt_21 = None + mul_695 = torch.ops.aten.mul.Tensor(convert_element_type_2226, mul_84); convert_element_type_2226 = mul_84 = None + sum_130 = torch.ops.aten.sum.dim_IntList(mul_695, [0, 1]); mul_695 = None + convert_element_type_2229 = torch.ops.prims.convert_element_type.default(mul_694, torch.bfloat16); mul_694 = None + add_279 = torch.ops.aten.add.Tensor(add_275, convert_element_type_2229); add_275 = convert_element_type_2229 = None + convert_element_type_default_22 = torch.ops.prims.convert_element_type.default(sum_130, torch.float32); sum_130 = None + reduce_scatter_tensor_194 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_22, 'avg', 64, '0'); convert_element_type_default_22 = None + wait_tensor_485 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_194); reduce_scatter_tensor_194 = None + view_1605 = torch.ops.aten.view.default(add_279, [16384, 4096]) + permute_1041 = torch.ops.aten.permute.default(view_1605, [1, 0]) + mm_527 = torch.ops.aten.mm.default(permute_1041, view_363); permute_1041 = view_363 = None + permute_1043 = torch.ops.aten.permute.default(permute_117, [1, 0]); permute_117 = None + mm_528 = torch.ops.aten.mm.default(view_1605, permute_1043); view_1605 = permute_1043 = None + view_1606 = torch.ops.aten.view.default(mm_528, [2, 8192, 4096]); mm_528 = None + convert_element_type_2236 = torch.ops.prims.convert_element_type.default(mm_527, torch.float32); mm_527 = None + reduce_scatter_tensor_195 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2236, 'avg', 64, '0'); convert_element_type_2236 = None + wait_tensor_486 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_195); reduce_scatter_tensor_195 = None + view_1607 = torch.ops.aten.view.default(view_1606, [2, 8192, 32, 128]); view_1606 = None + permute_1045 = torch.ops.aten.permute.default(view_1607, [0, 2, 1, 3]); view_1607 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16); primals_94 = None + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 64, '0'); convert_element_type_331 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32); add_39 = None + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_91) + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + view_343 = torch.ops.aten.view.default(convert_element_type_333, [16384, 4096]); convert_element_type_333 = None + view_344 = torch.ops.aten.view.default(mm_70, [2, 8192, 4096]); mm_70 = None + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16); primals_96 = None + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 64, '0'); convert_element_type_337 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_93, [1, 0]); wait_tensor_93 = None + mm_71 = torch.ops.aten.mm.default(view_343, permute_111) + view_347 = torch.ops.aten.view.default(mm_71, [2, 8192, 1024]); mm_71 = None + view_350 = torch.ops.aten.view.default(mm_72, [2, 8192, 1024]); mm_72 = None + view_351 = torch.ops.aten.view.default(view_344, [2, 8192, -1, 128]); view_344 = None + view_352 = torch.ops.aten.view.default(view_347, [2, 8192, -1, 128]); view_347 = None + view_353 = torch.ops.aten.view.default(view_350, [2, 8192, -1, 128]); view_350 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_351, torch.float32); view_351 = None + view_354 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 32, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_354); view_354 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_352, torch.float32); view_352 = None + view_355 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 8, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_355); view_355 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_16); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_357 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 32, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_16); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_358 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 8, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_357, torch.bfloat16); view_357 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_358, torch.bfloat16); view_358 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 8, 4, 128]); unsqueeze_20 = None + clone_20 = torch.ops.aten.clone.default(expand_20, memory_format = torch.contiguous_format); expand_20 = None + view_359 = torch.ops.aten.view.default(clone_20, [2, 8192, 32, 128]); clone_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_353, 3); view_353 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 8, 4, 128]); unsqueeze_21 = None + clone_21 = torch.ops.aten.clone.default(expand_21, memory_format = torch.contiguous_format); expand_21 = None + view_360 = torch.ops.aten.view.default(clone_21, [2, 8192, 32, 128]); clone_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_359, [0, 2, 1, 3]); view_359 = None + permute_115 = torch.ops.aten.permute.default(view_360, [0, 2, 1, 3]); view_360 = None + _scaled_dot_product_cudnn_attention_backward_21 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1045, permute_113, permute_114, permute_115, getitem_90, getitem_91, getitem_96, getitem_97, None, None, None, 8192, 8192, 0.0, True); permute_1045 = permute_113 = permute_114 = permute_115 = getitem_90 = getitem_91 = getitem_96 = getitem_97 = None + getitem_351 = _scaled_dot_product_cudnn_attention_backward_21[0] + getitem_352 = _scaled_dot_product_cudnn_attention_backward_21[1] + getitem_353 = _scaled_dot_product_cudnn_attention_backward_21[2]; _scaled_dot_product_cudnn_attention_backward_21 = None + permute_1046 = torch.ops.aten.permute.default(getitem_353, [0, 2, 1, 3]); getitem_353 = None + permute_1047 = torch.ops.aten.permute.default(getitem_352, [0, 2, 1, 3]); getitem_352 = None + permute_1048 = torch.ops.aten.permute.default(getitem_351, [0, 2, 1, 3]); getitem_351 = None + view_1608 = torch.ops.aten.view.default(permute_1046, [2, 8192, 8, 4, 128]); permute_1046 = None + sum_131 = torch.ops.aten.sum.dim_IntList(view_1608, [3], True); view_1608 = None + squeeze_42 = torch.ops.aten.squeeze.dim(sum_131, 3); sum_131 = None + view_1609 = torch.ops.aten.view.default(permute_1047, [2, 8192, 8, 4, 128]); permute_1047 = None + sum_132 = torch.ops.aten.sum.dim_IntList(view_1609, [3], True); view_1609 = None + squeeze_43 = torch.ops.aten.squeeze.dim(sum_132, 3); sum_132 = None + convert_element_type_2237 = torch.ops.prims.convert_element_type.default(squeeze_43, torch.float32); squeeze_43 = None + convert_element_type_2238 = torch.ops.prims.convert_element_type.default(permute_1048, torch.float32); permute_1048 = None + view_1610 = torch.ops.aten.view.default(convert_element_type_2237, [2, 8192, 8, 64, 2]); convert_element_type_2237 = None + view_as_complex_106 = torch.ops.aten.view_as_complex.default(view_1610); view_1610 = None + mul_696 = torch.ops.aten.mul.Tensor(view_as_complex_106, _conj); view_as_complex_106 = None + view_1611 = torch.ops.aten.view.default(convert_element_type_2238, [2, 8192, 32, 64, 2]); convert_element_type_2238 = None + view_as_complex_107 = torch.ops.aten.view_as_complex.default(view_1611); view_1611 = None + mul_697 = torch.ops.aten.mul.Tensor(view_as_complex_107, _conj); view_as_complex_107 = None + view_as_real_106 = torch.ops.aten.view_as_real.default(mul_696); mul_696 = None + view_1612 = torch.ops.aten.view.default(view_as_real_106, [2, 8192, 8, 128]); view_as_real_106 = None + convert_element_type_2239 = torch.ops.prims.convert_element_type.default(view_1612, torch.bfloat16); view_1612 = None + view_as_real_107 = torch.ops.aten.view_as_real.default(mul_697); mul_697 = None + view_1613 = torch.ops.aten.view.default(view_as_real_107, [2, 8192, 32, 128]); view_as_real_107 = None + convert_element_type_2240 = torch.ops.prims.convert_element_type.default(view_1613, torch.bfloat16); view_1613 = None + view_1614 = torch.ops.aten.view.default(squeeze_42, [2, 8192, 1024]); squeeze_42 = None + view_1615 = torch.ops.aten.view.default(convert_element_type_2239, [2, 8192, 1024]); convert_element_type_2239 = None + view_1616 = torch.ops.aten.view.default(convert_element_type_2240, [2, 8192, 4096]); convert_element_type_2240 = None + view_1617 = torch.ops.aten.view.default(view_1614, [16384, 1024]); view_1614 = None + permute_1049 = torch.ops.aten.permute.default(view_1617, [1, 0]) + mm_529 = torch.ops.aten.mm.default(permute_1049, view_343); permute_1049 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16); primals_97 = None + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 64, '0'); convert_element_type_340 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_94, [1, 0]); wait_tensor_94 = None + permute_1051 = torch.ops.aten.permute.default(permute_112, [1, 0]); permute_112 = None + mm_530 = torch.ops.aten.mm.default(view_1617, permute_1051); view_1617 = permute_1051 = None + view_1618 = torch.ops.aten.view.default(mm_530, [2, 8192, 4096]); mm_530 = None + convert_element_type_2245 = torch.ops.prims.convert_element_type.default(mm_529, torch.float32); mm_529 = None + reduce_scatter_tensor_196 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2245, 'avg', 64, '0'); convert_element_type_2245 = None + wait_tensor_487 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_196); reduce_scatter_tensor_196 = None + view_1619 = torch.ops.aten.view.default(view_1615, [16384, 1024]); view_1615 = None + permute_1053 = torch.ops.aten.permute.default(view_1619, [1, 0]) + mm_531 = torch.ops.aten.mm.default(permute_1053, view_343); permute_1053 = None + permute_1055 = torch.ops.aten.permute.default(permute_111, [1, 0]); permute_111 = None + mm_532 = torch.ops.aten.mm.default(view_1619, permute_1055); view_1619 = permute_1055 = None + view_1620 = torch.ops.aten.view.default(mm_532, [2, 8192, 4096]); mm_532 = None + add_280 = torch.ops.aten.add.Tensor(view_1618, view_1620); view_1618 = view_1620 = None + convert_element_type_2250 = torch.ops.prims.convert_element_type.default(mm_531, torch.float32); mm_531 = None + reduce_scatter_tensor_197 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2250, 'avg', 64, '0'); convert_element_type_2250 = None + wait_tensor_488 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_197); reduce_scatter_tensor_197 = None + view_1621 = torch.ops.aten.view.default(view_1616, [16384, 4096]); view_1616 = None + permute_1057 = torch.ops.aten.permute.default(view_1621, [1, 0]) + mm_533 = torch.ops.aten.mm.default(permute_1057, view_343); permute_1057 = view_343 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16); primals_95 = None + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 64, '0'); convert_element_type_334 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_92, [1, 0]); wait_tensor_92 = None + permute_1059 = torch.ops.aten.permute.default(permute_110, [1, 0]); permute_110 = None + mm_534 = torch.ops.aten.mm.default(view_1621, permute_1059); view_1621 = permute_1059 = None + view_1622 = torch.ops.aten.view.default(mm_534, [2, 8192, 4096]); mm_534 = None + add_281 = torch.ops.aten.add.Tensor(add_280, view_1622); add_280 = view_1622 = None + convert_element_type_2255 = torch.ops.prims.convert_element_type.default(mm_533, torch.float32); mm_533 = None + reduce_scatter_tensor_198 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2255, 'avg', 64, '0'); convert_element_type_2255 = None + wait_tensor_489 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_198); reduce_scatter_tensor_198 = None + convert_element_type_2256 = torch.ops.prims.convert_element_type.default(add_281, torch.float32); add_281 = None + convert_element_type_2258 = torch.ops.prims.convert_element_type.default(wait_tensor_91, torch.float32); wait_tensor_91 = None + mul_698 = torch.ops.aten.mul.Tensor(convert_element_type_2256, convert_element_type_2258); convert_element_type_2258 = None + mul_700 = torch.ops.aten.mul.Tensor(mul_80, mul_698) + sum_133 = torch.ops.aten.sum.dim_IntList(mul_700, [2], True); mul_700 = None + div_44 = torch.ops.aten.div.Tensor(mul_80, 4096) + mul_701 = torch.ops.aten.mul.Tensor(div_44, sum_133); div_44 = sum_133 = None + sub_66 = torch.ops.aten.sub.Tensor(mul_698, mul_701); mul_698 = mul_701 = None + mul_702 = torch.ops.aten.mul.Tensor(sub_66, rsqrt_20); sub_66 = rsqrt_20 = None + mul_703 = torch.ops.aten.mul.Tensor(convert_element_type_2256, mul_80); convert_element_type_2256 = mul_80 = None + sum_134 = torch.ops.aten.sum.dim_IntList(mul_703, [0, 1]); mul_703 = None + convert_element_type_2259 = torch.ops.prims.convert_element_type.default(mul_702, torch.bfloat16); mul_702 = None + add_282 = torch.ops.aten.add.Tensor(add_279, convert_element_type_2259); add_279 = convert_element_type_2259 = None + convert_element_type_default_21 = torch.ops.prims.convert_element_type.default(sum_134, torch.float32); sum_134 = None + reduce_scatter_tensor_199 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_21, 'avg', 64, '0'); convert_element_type_default_21 = None + wait_tensor_490 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_199); reduce_scatter_tensor_199 = None + view_1623 = torch.ops.aten.view.default(add_282, [16384, 4096]) + permute_1061 = torch.ops.aten.permute.default(view_1623, [1, 0]) + permute_105 = torch.ops.aten.permute.default(getitem_81, [0, 2, 1, 3]) + view_327 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16); primals_89 = None + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 64, '0'); convert_element_type_314 = None + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_86, [1, 0]); wait_tensor_86 = None + view_329 = torch.ops.aten.view.default(view_327, [16384, 4096]); view_327 = None + mm_66 = torch.ops.aten.mm.default(view_329, permute_106) + view_330 = torch.ops.aten.view.default(mm_66, [2, 8192, 4096]); mm_66 = None + add_37 = torch.ops.aten.add.Tensor(add_35, view_330); view_330 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16); primals_90 = None + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 64, '0'); convert_element_type_317 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32); add_37 = None + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_87) + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + view_333 = torch.ops.aten.view.default(convert_element_type_319, [16384, 4096]); convert_element_type_319 = None + view_334 = torch.ops.aten.view.default(mm_67, [2, 8192, 14336]); mm_67 = None + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_334, torch.float32); view_334 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16); primals_92 = None + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 64, '0'); convert_element_type_325 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + mm_68 = torch.ops.aten.mm.default(view_333, permute_108) + view_337 = torch.ops.aten.view.default(mm_68, [2, 8192, 14336]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_337) + view_339 = torch.ops.aten.view.default(mul_79, [16384, 14336]); mul_79 = None + mm_535 = torch.ops.aten.mm.default(permute_1061, view_339); permute_1061 = view_339 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16); primals_93 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 64, '0'); convert_element_type_328 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + permute_1063 = torch.ops.aten.permute.default(permute_109, [1, 0]); permute_109 = None + mm_536 = torch.ops.aten.mm.default(view_1623, permute_1063); view_1623 = permute_1063 = None + view_1624 = torch.ops.aten.view.default(mm_536, [2, 8192, 14336]); mm_536 = None + convert_element_type_2266 = torch.ops.prims.convert_element_type.default(mm_535, torch.float32); mm_535 = None + reduce_scatter_tensor_200 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2266, 'avg', 64, '0'); convert_element_type_2266 = None + wait_tensor_491 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_200); reduce_scatter_tensor_200 = None + mul_704 = torch.ops.aten.mul.Tensor(view_1624, convert_element_type_324); convert_element_type_324 = None + mul_705 = torch.ops.aten.mul.Tensor(view_1624, view_337); view_1624 = view_337 = None + view_1625 = torch.ops.aten.view.default(mul_704, [16384, 14336]); mul_704 = None + permute_1065 = torch.ops.aten.permute.default(view_1625, [1, 0]) + mm_537 = torch.ops.aten.mm.default(permute_1065, view_333); permute_1065 = None + permute_1067 = torch.ops.aten.permute.default(permute_108, [1, 0]); permute_108 = None + mm_538 = torch.ops.aten.mm.default(view_1625, permute_1067); view_1625 = permute_1067 = None + view_1626 = torch.ops.aten.view.default(mm_538, [2, 8192, 4096]); mm_538 = None + convert_element_type_2271 = torch.ops.prims.convert_element_type.default(mm_537, torch.float32); mm_537 = None + reduce_scatter_tensor_201 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2271, 'avg', 64, '0'); convert_element_type_2271 = None + wait_tensor_492 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_201); reduce_scatter_tensor_201 = None + convert_element_type_2272 = torch.ops.prims.convert_element_type.default(mul_705, torch.float32); mul_705 = None + neg_22 = torch.ops.aten.neg.default(convert_element_type_323) + exp_22 = torch.ops.aten.exp.default(neg_22); neg_22 = None + add_283 = torch.ops.aten.add.Tensor(exp_22, 1); exp_22 = None + reciprocal_22 = torch.ops.aten.reciprocal.default(add_283); add_283 = None + mul_706 = torch.ops.aten.mul.Tensor(reciprocal_22, 1); reciprocal_22 = None + mul_707 = torch.ops.aten.mul.Tensor(convert_element_type_2272, mul_706); convert_element_type_2272 = None + sub_67 = torch.ops.aten.sub.Tensor(1, mul_706); mul_706 = None + mul_708 = torch.ops.aten.mul.Tensor(convert_element_type_323, sub_67); convert_element_type_323 = sub_67 = None + add_284 = torch.ops.aten.add.Tensor(mul_708, 1); mul_708 = None + mul_709 = torch.ops.aten.mul.Tensor(mul_707, add_284); mul_707 = add_284 = None + convert_element_type_2274 = torch.ops.prims.convert_element_type.default(mul_709, torch.bfloat16); mul_709 = None + view_1627 = torch.ops.aten.view.default(convert_element_type_2274, [16384, 14336]); convert_element_type_2274 = None + permute_1069 = torch.ops.aten.permute.default(view_1627, [1, 0]) + mm_539 = torch.ops.aten.mm.default(permute_1069, view_333); permute_1069 = view_333 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16); primals_91 = None + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 64, '0'); convert_element_type_320 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_88, [1, 0]); wait_tensor_88 = None + permute_1071 = torch.ops.aten.permute.default(permute_107, [1, 0]); permute_107 = None + mm_540 = torch.ops.aten.mm.default(view_1627, permute_1071); view_1627 = permute_1071 = None + view_1628 = torch.ops.aten.view.default(mm_540, [2, 8192, 4096]); mm_540 = None + add_285 = torch.ops.aten.add.Tensor(view_1626, view_1628); view_1626 = view_1628 = None + convert_element_type_2279 = torch.ops.prims.convert_element_type.default(mm_539, torch.float32); mm_539 = None + reduce_scatter_tensor_202 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2279, 'avg', 64, '0'); convert_element_type_2279 = None + wait_tensor_493 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_202); reduce_scatter_tensor_202 = None + convert_element_type_2280 = torch.ops.prims.convert_element_type.default(add_285, torch.float32); add_285 = None + convert_element_type_2282 = torch.ops.prims.convert_element_type.default(wait_tensor_87, torch.float32); wait_tensor_87 = None + mul_710 = torch.ops.aten.mul.Tensor(convert_element_type_2280, convert_element_type_2282); convert_element_type_2282 = None + mul_712 = torch.ops.aten.mul.Tensor(mul_76, mul_710) + sum_135 = torch.ops.aten.sum.dim_IntList(mul_712, [2], True); mul_712 = None + div_45 = torch.ops.aten.div.Tensor(mul_76, 4096) + mul_713 = torch.ops.aten.mul.Tensor(div_45, sum_135); div_45 = sum_135 = None + sub_68 = torch.ops.aten.sub.Tensor(mul_710, mul_713); mul_710 = mul_713 = None + mul_714 = torch.ops.aten.mul.Tensor(sub_68, rsqrt_19); sub_68 = rsqrt_19 = None + mul_715 = torch.ops.aten.mul.Tensor(convert_element_type_2280, mul_76); convert_element_type_2280 = mul_76 = None + sum_136 = torch.ops.aten.sum.dim_IntList(mul_715, [0, 1]); mul_715 = None + convert_element_type_2283 = torch.ops.prims.convert_element_type.default(mul_714, torch.bfloat16); mul_714 = None + add_286 = torch.ops.aten.add.Tensor(add_282, convert_element_type_2283); add_282 = convert_element_type_2283 = None + convert_element_type_default_20 = torch.ops.prims.convert_element_type.default(sum_136, torch.float32); sum_136 = None + reduce_scatter_tensor_203 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_20, 'avg', 64, '0'); convert_element_type_default_20 = None + wait_tensor_494 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_203); reduce_scatter_tensor_203 = None + view_1629 = torch.ops.aten.view.default(add_286, [16384, 4096]) + permute_1073 = torch.ops.aten.permute.default(view_1629, [1, 0]) + mm_541 = torch.ops.aten.mm.default(permute_1073, view_329); permute_1073 = view_329 = None + permute_1075 = torch.ops.aten.permute.default(permute_106, [1, 0]); permute_106 = None + mm_542 = torch.ops.aten.mm.default(view_1629, permute_1075); view_1629 = permute_1075 = None + view_1630 = torch.ops.aten.view.default(mm_542, [2, 8192, 4096]); mm_542 = None + convert_element_type_2290 = torch.ops.prims.convert_element_type.default(mm_541, torch.float32); mm_541 = None + reduce_scatter_tensor_204 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2290, 'avg', 64, '0'); convert_element_type_2290 = None + wait_tensor_495 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_204); reduce_scatter_tensor_204 = None + view_1631 = torch.ops.aten.view.default(view_1630, [2, 8192, 32, 128]); view_1630 = None + permute_1077 = torch.ops.aten.permute.default(view_1631, [0, 2, 1, 3]); view_1631 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16); primals_85 = None + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 64, '0'); convert_element_type_298 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32); add_35 = None + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_82) + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + view_309 = torch.ops.aten.view.default(convert_element_type_300, [16384, 4096]); convert_element_type_300 = None + view_310 = torch.ops.aten.view.default(mm_63, [2, 8192, 4096]); mm_63 = None + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16); primals_87 = None + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 64, '0'); convert_element_type_304 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + mm_64 = torch.ops.aten.mm.default(view_309, permute_100) + view_313 = torch.ops.aten.view.default(mm_64, [2, 8192, 1024]); mm_64 = None + view_316 = torch.ops.aten.view.default(mm_65, [2, 8192, 1024]); mm_65 = None + view_317 = torch.ops.aten.view.default(view_310, [2, 8192, -1, 128]); view_310 = None + view_318 = torch.ops.aten.view.default(view_313, [2, 8192, -1, 128]); view_313 = None + view_319 = torch.ops.aten.view.default(view_316, [2, 8192, -1, 128]); view_316 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_317, torch.float32); view_317 = None + view_320 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 32, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_320); view_320 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_318, torch.float32); view_318 = None + view_321 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 8, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_321); view_321 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_16); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_323 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 32, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_16); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_324 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 8, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_323, torch.bfloat16); view_323 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_324, torch.bfloat16); view_324 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 8, 4, 128]); unsqueeze_18 = None + clone_18 = torch.ops.aten.clone.default(expand_18, memory_format = torch.contiguous_format); expand_18 = None + view_325 = torch.ops.aten.view.default(clone_18, [2, 8192, 32, 128]); clone_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_319, 3); view_319 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 8, 4, 128]); unsqueeze_19 = None + clone_19 = torch.ops.aten.clone.default(expand_19, memory_format = torch.contiguous_format); expand_19 = None + view_326 = torch.ops.aten.view.default(clone_19, [2, 8192, 32, 128]); clone_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_325, [0, 2, 1, 3]); view_325 = None + permute_104 = torch.ops.aten.permute.default(view_326, [0, 2, 1, 3]); view_326 = None + _scaled_dot_product_cudnn_attention_backward_22 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1077, permute_102, permute_103, permute_104, getitem_81, getitem_82, getitem_87, getitem_88, None, None, None, 8192, 8192, 0.0, True); permute_1077 = permute_102 = permute_103 = permute_104 = getitem_81 = getitem_82 = getitem_87 = getitem_88 = None + getitem_354 = _scaled_dot_product_cudnn_attention_backward_22[0] + getitem_355 = _scaled_dot_product_cudnn_attention_backward_22[1] + getitem_356 = _scaled_dot_product_cudnn_attention_backward_22[2]; _scaled_dot_product_cudnn_attention_backward_22 = None + permute_1078 = torch.ops.aten.permute.default(getitem_356, [0, 2, 1, 3]); getitem_356 = None + permute_1079 = torch.ops.aten.permute.default(getitem_355, [0, 2, 1, 3]); getitem_355 = None + permute_1080 = torch.ops.aten.permute.default(getitem_354, [0, 2, 1, 3]); getitem_354 = None + view_1632 = torch.ops.aten.view.default(permute_1078, [2, 8192, 8, 4, 128]); permute_1078 = None + sum_137 = torch.ops.aten.sum.dim_IntList(view_1632, [3], True); view_1632 = None + squeeze_44 = torch.ops.aten.squeeze.dim(sum_137, 3); sum_137 = None + view_1633 = torch.ops.aten.view.default(permute_1079, [2, 8192, 8, 4, 128]); permute_1079 = None + sum_138 = torch.ops.aten.sum.dim_IntList(view_1633, [3], True); view_1633 = None + squeeze_45 = torch.ops.aten.squeeze.dim(sum_138, 3); sum_138 = None + convert_element_type_2291 = torch.ops.prims.convert_element_type.default(squeeze_45, torch.float32); squeeze_45 = None + convert_element_type_2292 = torch.ops.prims.convert_element_type.default(permute_1080, torch.float32); permute_1080 = None + view_1634 = torch.ops.aten.view.default(convert_element_type_2291, [2, 8192, 8, 64, 2]); convert_element_type_2291 = None + view_as_complex_108 = torch.ops.aten.view_as_complex.default(view_1634); view_1634 = None + mul_716 = torch.ops.aten.mul.Tensor(view_as_complex_108, _conj); view_as_complex_108 = None + view_1635 = torch.ops.aten.view.default(convert_element_type_2292, [2, 8192, 32, 64, 2]); convert_element_type_2292 = None + view_as_complex_109 = torch.ops.aten.view_as_complex.default(view_1635); view_1635 = None + mul_717 = torch.ops.aten.mul.Tensor(view_as_complex_109, _conj); view_as_complex_109 = None + view_as_real_108 = torch.ops.aten.view_as_real.default(mul_716); mul_716 = None + view_1636 = torch.ops.aten.view.default(view_as_real_108, [2, 8192, 8, 128]); view_as_real_108 = None + convert_element_type_2293 = torch.ops.prims.convert_element_type.default(view_1636, torch.bfloat16); view_1636 = None + view_as_real_109 = torch.ops.aten.view_as_real.default(mul_717); mul_717 = None + view_1637 = torch.ops.aten.view.default(view_as_real_109, [2, 8192, 32, 128]); view_as_real_109 = None + convert_element_type_2294 = torch.ops.prims.convert_element_type.default(view_1637, torch.bfloat16); view_1637 = None + view_1638 = torch.ops.aten.view.default(squeeze_44, [2, 8192, 1024]); squeeze_44 = None + view_1639 = torch.ops.aten.view.default(convert_element_type_2293, [2, 8192, 1024]); convert_element_type_2293 = None + view_1640 = torch.ops.aten.view.default(convert_element_type_2294, [2, 8192, 4096]); convert_element_type_2294 = None + view_1641 = torch.ops.aten.view.default(view_1638, [16384, 1024]); view_1638 = None + permute_1081 = torch.ops.aten.permute.default(view_1641, [1, 0]) + mm_543 = torch.ops.aten.mm.default(permute_1081, view_309); permute_1081 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16); primals_88 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 64, '0'); convert_element_type_307 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + permute_1083 = torch.ops.aten.permute.default(permute_101, [1, 0]); permute_101 = None + mm_544 = torch.ops.aten.mm.default(view_1641, permute_1083); view_1641 = permute_1083 = None + view_1642 = torch.ops.aten.view.default(mm_544, [2, 8192, 4096]); mm_544 = None + convert_element_type_2299 = torch.ops.prims.convert_element_type.default(mm_543, torch.float32); mm_543 = None + reduce_scatter_tensor_205 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2299, 'avg', 64, '0'); convert_element_type_2299 = None + wait_tensor_496 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_205); reduce_scatter_tensor_205 = None + view_1643 = torch.ops.aten.view.default(view_1639, [16384, 1024]); view_1639 = None + permute_1085 = torch.ops.aten.permute.default(view_1643, [1, 0]) + mm_545 = torch.ops.aten.mm.default(permute_1085, view_309); permute_1085 = None + permute_1087 = torch.ops.aten.permute.default(permute_100, [1, 0]); permute_100 = None + mm_546 = torch.ops.aten.mm.default(view_1643, permute_1087); view_1643 = permute_1087 = None + view_1644 = torch.ops.aten.view.default(mm_546, [2, 8192, 4096]); mm_546 = None + add_287 = torch.ops.aten.add.Tensor(view_1642, view_1644); view_1642 = view_1644 = None + convert_element_type_2304 = torch.ops.prims.convert_element_type.default(mm_545, torch.float32); mm_545 = None + reduce_scatter_tensor_206 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2304, 'avg', 64, '0'); convert_element_type_2304 = None + wait_tensor_497 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_206); reduce_scatter_tensor_206 = None + view_1645 = torch.ops.aten.view.default(view_1640, [16384, 4096]); view_1640 = None + permute_1089 = torch.ops.aten.permute.default(view_1645, [1, 0]) + mm_547 = torch.ops.aten.mm.default(permute_1089, view_309); permute_1089 = view_309 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16); primals_86 = None + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 64, '0'); convert_element_type_301 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + permute_1091 = torch.ops.aten.permute.default(permute_99, [1, 0]); permute_99 = None + mm_548 = torch.ops.aten.mm.default(view_1645, permute_1091); view_1645 = permute_1091 = None + view_1646 = torch.ops.aten.view.default(mm_548, [2, 8192, 4096]); mm_548 = None + add_288 = torch.ops.aten.add.Tensor(add_287, view_1646); add_287 = view_1646 = None + convert_element_type_2309 = torch.ops.prims.convert_element_type.default(mm_547, torch.float32); mm_547 = None + reduce_scatter_tensor_207 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2309, 'avg', 64, '0'); convert_element_type_2309 = None + wait_tensor_498 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_207); reduce_scatter_tensor_207 = None + convert_element_type_2310 = torch.ops.prims.convert_element_type.default(add_288, torch.float32); add_288 = None + convert_element_type_2312 = torch.ops.prims.convert_element_type.default(wait_tensor_82, torch.float32); wait_tensor_82 = None + mul_718 = torch.ops.aten.mul.Tensor(convert_element_type_2310, convert_element_type_2312); convert_element_type_2312 = None + mul_720 = torch.ops.aten.mul.Tensor(mul_72, mul_718) + sum_139 = torch.ops.aten.sum.dim_IntList(mul_720, [2], True); mul_720 = None + div_46 = torch.ops.aten.div.Tensor(mul_72, 4096) + mul_721 = torch.ops.aten.mul.Tensor(div_46, sum_139); div_46 = sum_139 = None + sub_69 = torch.ops.aten.sub.Tensor(mul_718, mul_721); mul_718 = mul_721 = None + mul_722 = torch.ops.aten.mul.Tensor(sub_69, rsqrt_18); sub_69 = rsqrt_18 = None + mul_723 = torch.ops.aten.mul.Tensor(convert_element_type_2310, mul_72); convert_element_type_2310 = mul_72 = None + sum_140 = torch.ops.aten.sum.dim_IntList(mul_723, [0, 1]); mul_723 = None + convert_element_type_2313 = torch.ops.prims.convert_element_type.default(mul_722, torch.bfloat16); mul_722 = None + add_289 = torch.ops.aten.add.Tensor(add_286, convert_element_type_2313); add_286 = convert_element_type_2313 = None + convert_element_type_default_19 = torch.ops.prims.convert_element_type.default(sum_140, torch.float32); sum_140 = None + reduce_scatter_tensor_208 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_19, 'avg', 64, '0'); convert_element_type_default_19 = None + wait_tensor_499 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_208); reduce_scatter_tensor_208 = None + view_1647 = torch.ops.aten.view.default(add_289, [16384, 4096]) + permute_1093 = torch.ops.aten.permute.default(view_1647, [1, 0]) + permute_94 = torch.ops.aten.permute.default(getitem_72, [0, 2, 1, 3]) + view_293 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16); primals_80 = None + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 64, '0'); convert_element_type_281 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + view_295 = torch.ops.aten.view.default(view_293, [16384, 4096]); view_293 = None + mm_59 = torch.ops.aten.mm.default(view_295, permute_95) + view_296 = torch.ops.aten.view.default(mm_59, [2, 8192, 4096]); mm_59 = None + add_33 = torch.ops.aten.add.Tensor(add_31, view_296); view_296 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16); primals_81 = None + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 64, '0'); convert_element_type_284 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32); add_33 = None + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_78) + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + view_299 = torch.ops.aten.view.default(convert_element_type_286, [16384, 4096]); convert_element_type_286 = None + view_300 = torch.ops.aten.view.default(mm_60, [2, 8192, 14336]); mm_60 = None + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16); primals_83 = None + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 64, '0'); convert_element_type_292 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_80, [1, 0]); wait_tensor_80 = None + mm_61 = torch.ops.aten.mm.default(view_299, permute_97) + view_303 = torch.ops.aten.view.default(mm_61, [2, 8192, 14336]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_303) + view_305 = torch.ops.aten.view.default(mul_71, [16384, 14336]); mul_71 = None + mm_549 = torch.ops.aten.mm.default(permute_1093, view_305); permute_1093 = view_305 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16); primals_84 = None + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 64, '0'); convert_element_type_295 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_81, [1, 0]); wait_tensor_81 = None + permute_1095 = torch.ops.aten.permute.default(permute_98, [1, 0]); permute_98 = None + mm_550 = torch.ops.aten.mm.default(view_1647, permute_1095); view_1647 = permute_1095 = None + view_1648 = torch.ops.aten.view.default(mm_550, [2, 8192, 14336]); mm_550 = None + convert_element_type_2320 = torch.ops.prims.convert_element_type.default(mm_549, torch.float32); mm_549 = None + reduce_scatter_tensor_209 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2320, 'avg', 64, '0'); convert_element_type_2320 = None + wait_tensor_500 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_209); reduce_scatter_tensor_209 = None + mul_724 = torch.ops.aten.mul.Tensor(view_1648, convert_element_type_291); convert_element_type_291 = None + mul_725 = torch.ops.aten.mul.Tensor(view_1648, view_303); view_1648 = view_303 = None + view_1649 = torch.ops.aten.view.default(mul_724, [16384, 14336]); mul_724 = None + permute_1097 = torch.ops.aten.permute.default(view_1649, [1, 0]) + mm_551 = torch.ops.aten.mm.default(permute_1097, view_299); permute_1097 = None + permute_1099 = torch.ops.aten.permute.default(permute_97, [1, 0]); permute_97 = None + mm_552 = torch.ops.aten.mm.default(view_1649, permute_1099); view_1649 = permute_1099 = None + view_1650 = torch.ops.aten.view.default(mm_552, [2, 8192, 4096]); mm_552 = None + convert_element_type_2325 = torch.ops.prims.convert_element_type.default(mm_551, torch.float32); mm_551 = None + reduce_scatter_tensor_210 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2325, 'avg', 64, '0'); convert_element_type_2325 = None + wait_tensor_501 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_210); reduce_scatter_tensor_210 = None + convert_element_type_2326 = torch.ops.prims.convert_element_type.default(mul_725, torch.float32); mul_725 = None + neg_23 = torch.ops.aten.neg.default(convert_element_type_290) + exp_23 = torch.ops.aten.exp.default(neg_23); neg_23 = None + add_290 = torch.ops.aten.add.Tensor(exp_23, 1); exp_23 = None + reciprocal_23 = torch.ops.aten.reciprocal.default(add_290); add_290 = None + mul_726 = torch.ops.aten.mul.Tensor(reciprocal_23, 1); reciprocal_23 = None + mul_727 = torch.ops.aten.mul.Tensor(convert_element_type_2326, mul_726); convert_element_type_2326 = None + sub_70 = torch.ops.aten.sub.Tensor(1, mul_726); mul_726 = None + mul_728 = torch.ops.aten.mul.Tensor(convert_element_type_290, sub_70); convert_element_type_290 = sub_70 = None + add_291 = torch.ops.aten.add.Tensor(mul_728, 1); mul_728 = None + mul_729 = torch.ops.aten.mul.Tensor(mul_727, add_291); mul_727 = add_291 = None + convert_element_type_2328 = torch.ops.prims.convert_element_type.default(mul_729, torch.bfloat16); mul_729 = None + view_1651 = torch.ops.aten.view.default(convert_element_type_2328, [16384, 14336]); convert_element_type_2328 = None + permute_1101 = torch.ops.aten.permute.default(view_1651, [1, 0]) + mm_553 = torch.ops.aten.mm.default(permute_1101, view_299); permute_1101 = view_299 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16); primals_82 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 64, '0'); convert_element_type_287 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_79, [1, 0]); wait_tensor_79 = None + permute_1103 = torch.ops.aten.permute.default(permute_96, [1, 0]); permute_96 = None + mm_554 = torch.ops.aten.mm.default(view_1651, permute_1103); view_1651 = permute_1103 = None + view_1652 = torch.ops.aten.view.default(mm_554, [2, 8192, 4096]); mm_554 = None + add_292 = torch.ops.aten.add.Tensor(view_1650, view_1652); view_1650 = view_1652 = None + convert_element_type_2333 = torch.ops.prims.convert_element_type.default(mm_553, torch.float32); mm_553 = None + reduce_scatter_tensor_211 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2333, 'avg', 64, '0'); convert_element_type_2333 = None + wait_tensor_502 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_211); reduce_scatter_tensor_211 = None + convert_element_type_2334 = torch.ops.prims.convert_element_type.default(add_292, torch.float32); add_292 = None + convert_element_type_2336 = torch.ops.prims.convert_element_type.default(wait_tensor_78, torch.float32); wait_tensor_78 = None + mul_730 = torch.ops.aten.mul.Tensor(convert_element_type_2334, convert_element_type_2336); convert_element_type_2336 = None + mul_732 = torch.ops.aten.mul.Tensor(mul_68, mul_730) + sum_141 = torch.ops.aten.sum.dim_IntList(mul_732, [2], True); mul_732 = None + div_47 = torch.ops.aten.div.Tensor(mul_68, 4096) + mul_733 = torch.ops.aten.mul.Tensor(div_47, sum_141); div_47 = sum_141 = None + sub_71 = torch.ops.aten.sub.Tensor(mul_730, mul_733); mul_730 = mul_733 = None + mul_734 = torch.ops.aten.mul.Tensor(sub_71, rsqrt_17); sub_71 = rsqrt_17 = None + mul_735 = torch.ops.aten.mul.Tensor(convert_element_type_2334, mul_68); convert_element_type_2334 = mul_68 = None + sum_142 = torch.ops.aten.sum.dim_IntList(mul_735, [0, 1]); mul_735 = None + convert_element_type_2337 = torch.ops.prims.convert_element_type.default(mul_734, torch.bfloat16); mul_734 = None + add_293 = torch.ops.aten.add.Tensor(add_289, convert_element_type_2337); add_289 = convert_element_type_2337 = None + convert_element_type_default_18 = torch.ops.prims.convert_element_type.default(sum_142, torch.float32); sum_142 = None + reduce_scatter_tensor_212 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_18, 'avg', 64, '0'); convert_element_type_default_18 = None + wait_tensor_503 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_212); reduce_scatter_tensor_212 = None + view_1653 = torch.ops.aten.view.default(add_293, [16384, 4096]) + permute_1105 = torch.ops.aten.permute.default(view_1653, [1, 0]) + mm_555 = torch.ops.aten.mm.default(permute_1105, view_295); permute_1105 = view_295 = None + permute_1107 = torch.ops.aten.permute.default(permute_95, [1, 0]); permute_95 = None + mm_556 = torch.ops.aten.mm.default(view_1653, permute_1107); view_1653 = permute_1107 = None + view_1654 = torch.ops.aten.view.default(mm_556, [2, 8192, 4096]); mm_556 = None + convert_element_type_2344 = torch.ops.prims.convert_element_type.default(mm_555, torch.float32); mm_555 = None + reduce_scatter_tensor_213 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2344, 'avg', 64, '0'); convert_element_type_2344 = None + wait_tensor_504 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_213); reduce_scatter_tensor_213 = None + view_1655 = torch.ops.aten.view.default(view_1654, [2, 8192, 32, 128]); view_1654 = None + permute_1109 = torch.ops.aten.permute.default(view_1655, [0, 2, 1, 3]); view_1655 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16); primals_76 = None + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 64, '0'); convert_element_type_265 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32); add_31 = None + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_73) + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + view_275 = torch.ops.aten.view.default(convert_element_type_267, [16384, 4096]); convert_element_type_267 = None + view_276 = torch.ops.aten.view.default(mm_56, [2, 8192, 4096]); mm_56 = None + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16); primals_78 = None + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 64, '0'); convert_element_type_271 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_75, [1, 0]); wait_tensor_75 = None + mm_57 = torch.ops.aten.mm.default(view_275, permute_89) + view_279 = torch.ops.aten.view.default(mm_57, [2, 8192, 1024]); mm_57 = None + view_282 = torch.ops.aten.view.default(mm_58, [2, 8192, 1024]); mm_58 = None + view_283 = torch.ops.aten.view.default(view_276, [2, 8192, -1, 128]); view_276 = None + view_284 = torch.ops.aten.view.default(view_279, [2, 8192, -1, 128]); view_279 = None + view_285 = torch.ops.aten.view.default(view_282, [2, 8192, -1, 128]); view_282 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_283, torch.float32); view_283 = None + view_286 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 32, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_286); view_286 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_284, torch.float32); view_284 = None + view_287 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 8, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_287); view_287 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_16); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_289 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 32, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_16); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_290 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 8, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_289, torch.bfloat16); view_289 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_290, torch.bfloat16); view_290 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 8, 4, 128]); unsqueeze_16 = None + clone_16 = torch.ops.aten.clone.default(expand_16, memory_format = torch.contiguous_format); expand_16 = None + view_291 = torch.ops.aten.view.default(clone_16, [2, 8192, 32, 128]); clone_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_285, 3); view_285 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 8, 4, 128]); unsqueeze_17 = None + clone_17 = torch.ops.aten.clone.default(expand_17, memory_format = torch.contiguous_format); expand_17 = None + view_292 = torch.ops.aten.view.default(clone_17, [2, 8192, 32, 128]); clone_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_291, [0, 2, 1, 3]); view_291 = None + permute_93 = torch.ops.aten.permute.default(view_292, [0, 2, 1, 3]); view_292 = None + _scaled_dot_product_cudnn_attention_backward_23 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1109, permute_91, permute_92, permute_93, getitem_72, getitem_73, getitem_78, getitem_79, None, None, None, 8192, 8192, 0.0, True); permute_1109 = permute_91 = permute_92 = permute_93 = getitem_72 = getitem_73 = getitem_78 = getitem_79 = None + getitem_357 = _scaled_dot_product_cudnn_attention_backward_23[0] + getitem_358 = _scaled_dot_product_cudnn_attention_backward_23[1] + getitem_359 = _scaled_dot_product_cudnn_attention_backward_23[2]; _scaled_dot_product_cudnn_attention_backward_23 = None + permute_1110 = torch.ops.aten.permute.default(getitem_359, [0, 2, 1, 3]); getitem_359 = None + permute_1111 = torch.ops.aten.permute.default(getitem_358, [0, 2, 1, 3]); getitem_358 = None + permute_1112 = torch.ops.aten.permute.default(getitem_357, [0, 2, 1, 3]); getitem_357 = None + view_1656 = torch.ops.aten.view.default(permute_1110, [2, 8192, 8, 4, 128]); permute_1110 = None + sum_143 = torch.ops.aten.sum.dim_IntList(view_1656, [3], True); view_1656 = None + squeeze_46 = torch.ops.aten.squeeze.dim(sum_143, 3); sum_143 = None + view_1657 = torch.ops.aten.view.default(permute_1111, [2, 8192, 8, 4, 128]); permute_1111 = None + sum_144 = torch.ops.aten.sum.dim_IntList(view_1657, [3], True); view_1657 = None + squeeze_47 = torch.ops.aten.squeeze.dim(sum_144, 3); sum_144 = None + convert_element_type_2345 = torch.ops.prims.convert_element_type.default(squeeze_47, torch.float32); squeeze_47 = None + convert_element_type_2346 = torch.ops.prims.convert_element_type.default(permute_1112, torch.float32); permute_1112 = None + view_1658 = torch.ops.aten.view.default(convert_element_type_2345, [2, 8192, 8, 64, 2]); convert_element_type_2345 = None + view_as_complex_110 = torch.ops.aten.view_as_complex.default(view_1658); view_1658 = None + mul_736 = torch.ops.aten.mul.Tensor(view_as_complex_110, _conj); view_as_complex_110 = None + view_1659 = torch.ops.aten.view.default(convert_element_type_2346, [2, 8192, 32, 64, 2]); convert_element_type_2346 = None + view_as_complex_111 = torch.ops.aten.view_as_complex.default(view_1659); view_1659 = None + mul_737 = torch.ops.aten.mul.Tensor(view_as_complex_111, _conj); view_as_complex_111 = None + view_as_real_110 = torch.ops.aten.view_as_real.default(mul_736); mul_736 = None + view_1660 = torch.ops.aten.view.default(view_as_real_110, [2, 8192, 8, 128]); view_as_real_110 = None + convert_element_type_2347 = torch.ops.prims.convert_element_type.default(view_1660, torch.bfloat16); view_1660 = None + view_as_real_111 = torch.ops.aten.view_as_real.default(mul_737); mul_737 = None + view_1661 = torch.ops.aten.view.default(view_as_real_111, [2, 8192, 32, 128]); view_as_real_111 = None + convert_element_type_2348 = torch.ops.prims.convert_element_type.default(view_1661, torch.bfloat16); view_1661 = None + view_1662 = torch.ops.aten.view.default(squeeze_46, [2, 8192, 1024]); squeeze_46 = None + view_1663 = torch.ops.aten.view.default(convert_element_type_2347, [2, 8192, 1024]); convert_element_type_2347 = None + view_1664 = torch.ops.aten.view.default(convert_element_type_2348, [2, 8192, 4096]); convert_element_type_2348 = None + view_1665 = torch.ops.aten.view.default(view_1662, [16384, 1024]); view_1662 = None + permute_1113 = torch.ops.aten.permute.default(view_1665, [1, 0]) + mm_557 = torch.ops.aten.mm.default(permute_1113, view_275); permute_1113 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16); primals_79 = None + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 64, '0'); convert_element_type_274 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + permute_1115 = torch.ops.aten.permute.default(permute_90, [1, 0]); permute_90 = None + mm_558 = torch.ops.aten.mm.default(view_1665, permute_1115); view_1665 = permute_1115 = None + view_1666 = torch.ops.aten.view.default(mm_558, [2, 8192, 4096]); mm_558 = None + convert_element_type_2353 = torch.ops.prims.convert_element_type.default(mm_557, torch.float32); mm_557 = None + reduce_scatter_tensor_214 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2353, 'avg', 64, '0'); convert_element_type_2353 = None + wait_tensor_505 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_214); reduce_scatter_tensor_214 = None + view_1667 = torch.ops.aten.view.default(view_1663, [16384, 1024]); view_1663 = None + permute_1117 = torch.ops.aten.permute.default(view_1667, [1, 0]) + mm_559 = torch.ops.aten.mm.default(permute_1117, view_275); permute_1117 = None + permute_1119 = torch.ops.aten.permute.default(permute_89, [1, 0]); permute_89 = None + mm_560 = torch.ops.aten.mm.default(view_1667, permute_1119); view_1667 = permute_1119 = None + view_1668 = torch.ops.aten.view.default(mm_560, [2, 8192, 4096]); mm_560 = None + add_294 = torch.ops.aten.add.Tensor(view_1666, view_1668); view_1666 = view_1668 = None + convert_element_type_2358 = torch.ops.prims.convert_element_type.default(mm_559, torch.float32); mm_559 = None + reduce_scatter_tensor_215 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2358, 'avg', 64, '0'); convert_element_type_2358 = None + wait_tensor_506 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_215); reduce_scatter_tensor_215 = None + view_1669 = torch.ops.aten.view.default(view_1664, [16384, 4096]); view_1664 = None + permute_1121 = torch.ops.aten.permute.default(view_1669, [1, 0]) + mm_561 = torch.ops.aten.mm.default(permute_1121, view_275); permute_1121 = view_275 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16); primals_77 = None + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 64, '0'); convert_element_type_268 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_74, [1, 0]); wait_tensor_74 = None + permute_1123 = torch.ops.aten.permute.default(permute_88, [1, 0]); permute_88 = None + mm_562 = torch.ops.aten.mm.default(view_1669, permute_1123); view_1669 = permute_1123 = None + view_1670 = torch.ops.aten.view.default(mm_562, [2, 8192, 4096]); mm_562 = None + add_295 = torch.ops.aten.add.Tensor(add_294, view_1670); add_294 = view_1670 = None + convert_element_type_2363 = torch.ops.prims.convert_element_type.default(mm_561, torch.float32); mm_561 = None + reduce_scatter_tensor_216 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2363, 'avg', 64, '0'); convert_element_type_2363 = None + wait_tensor_507 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_216); reduce_scatter_tensor_216 = None + convert_element_type_2364 = torch.ops.prims.convert_element_type.default(add_295, torch.float32); add_295 = None + convert_element_type_2366 = torch.ops.prims.convert_element_type.default(wait_tensor_73, torch.float32); wait_tensor_73 = None + mul_738 = torch.ops.aten.mul.Tensor(convert_element_type_2364, convert_element_type_2366); convert_element_type_2366 = None + mul_740 = torch.ops.aten.mul.Tensor(mul_64, mul_738) + sum_145 = torch.ops.aten.sum.dim_IntList(mul_740, [2], True); mul_740 = None + div_48 = torch.ops.aten.div.Tensor(mul_64, 4096) + mul_741 = torch.ops.aten.mul.Tensor(div_48, sum_145); div_48 = sum_145 = None + sub_72 = torch.ops.aten.sub.Tensor(mul_738, mul_741); mul_738 = mul_741 = None + mul_742 = torch.ops.aten.mul.Tensor(sub_72, rsqrt_16); sub_72 = rsqrt_16 = None + mul_743 = torch.ops.aten.mul.Tensor(convert_element_type_2364, mul_64); convert_element_type_2364 = mul_64 = None + sum_146 = torch.ops.aten.sum.dim_IntList(mul_743, [0, 1]); mul_743 = None + convert_element_type_2367 = torch.ops.prims.convert_element_type.default(mul_742, torch.bfloat16); mul_742 = None + add_296 = torch.ops.aten.add.Tensor(add_293, convert_element_type_2367); add_293 = convert_element_type_2367 = None + convert_element_type_default_17 = torch.ops.prims.convert_element_type.default(sum_146, torch.float32); sum_146 = None + reduce_scatter_tensor_217 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_17, 'avg', 64, '0'); convert_element_type_default_17 = None + wait_tensor_508 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_217); reduce_scatter_tensor_217 = None + view_1671 = torch.ops.aten.view.default(add_296, [16384, 4096]) + permute_1125 = torch.ops.aten.permute.default(view_1671, [1, 0]) + permute_83 = torch.ops.aten.permute.default(getitem_63, [0, 2, 1, 3]) + view_259 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16); primals_71 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 64, '0'); convert_element_type_248 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_68, [1, 0]); wait_tensor_68 = None + view_261 = torch.ops.aten.view.default(view_259, [16384, 4096]); view_259 = None + mm_52 = torch.ops.aten.mm.default(view_261, permute_84) + view_262 = torch.ops.aten.view.default(mm_52, [2, 8192, 4096]); mm_52 = None + add_29 = torch.ops.aten.add.Tensor(add_27, view_262); view_262 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16); primals_72 = None + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 64, '0'); convert_element_type_251 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32); add_29 = None + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_69) + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + view_265 = torch.ops.aten.view.default(convert_element_type_253, [16384, 4096]); convert_element_type_253 = None + view_266 = torch.ops.aten.view.default(mm_53, [2, 8192, 14336]); mm_53 = None + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_266, torch.float32); view_266 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16); primals_74 = None + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 64, '0'); convert_element_type_259 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_54 = torch.ops.aten.mm.default(view_265, permute_86) + view_269 = torch.ops.aten.view.default(mm_54, [2, 8192, 14336]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_269) + view_271 = torch.ops.aten.view.default(mul_63, [16384, 14336]); mul_63 = None + mm_563 = torch.ops.aten.mm.default(permute_1125, view_271); permute_1125 = view_271 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16); primals_75 = None + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 64, '0'); convert_element_type_262 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + permute_1127 = torch.ops.aten.permute.default(permute_87, [1, 0]); permute_87 = None + mm_564 = torch.ops.aten.mm.default(view_1671, permute_1127); view_1671 = permute_1127 = None + view_1672 = torch.ops.aten.view.default(mm_564, [2, 8192, 14336]); mm_564 = None + convert_element_type_2374 = torch.ops.prims.convert_element_type.default(mm_563, torch.float32); mm_563 = None + reduce_scatter_tensor_218 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2374, 'avg', 64, '0'); convert_element_type_2374 = None + wait_tensor_509 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_218); reduce_scatter_tensor_218 = None + mul_744 = torch.ops.aten.mul.Tensor(view_1672, convert_element_type_258); convert_element_type_258 = None + mul_745 = torch.ops.aten.mul.Tensor(view_1672, view_269); view_1672 = view_269 = None + view_1673 = torch.ops.aten.view.default(mul_744, [16384, 14336]); mul_744 = None + permute_1129 = torch.ops.aten.permute.default(view_1673, [1, 0]) + mm_565 = torch.ops.aten.mm.default(permute_1129, view_265); permute_1129 = None + permute_1131 = torch.ops.aten.permute.default(permute_86, [1, 0]); permute_86 = None + mm_566 = torch.ops.aten.mm.default(view_1673, permute_1131); view_1673 = permute_1131 = None + view_1674 = torch.ops.aten.view.default(mm_566, [2, 8192, 4096]); mm_566 = None + convert_element_type_2379 = torch.ops.prims.convert_element_type.default(mm_565, torch.float32); mm_565 = None + reduce_scatter_tensor_219 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2379, 'avg', 64, '0'); convert_element_type_2379 = None + wait_tensor_510 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_219); reduce_scatter_tensor_219 = None + convert_element_type_2380 = torch.ops.prims.convert_element_type.default(mul_745, torch.float32); mul_745 = None + neg_24 = torch.ops.aten.neg.default(convert_element_type_257) + exp_24 = torch.ops.aten.exp.default(neg_24); neg_24 = None + add_297 = torch.ops.aten.add.Tensor(exp_24, 1); exp_24 = None + reciprocal_24 = torch.ops.aten.reciprocal.default(add_297); add_297 = None + mul_746 = torch.ops.aten.mul.Tensor(reciprocal_24, 1); reciprocal_24 = None + mul_747 = torch.ops.aten.mul.Tensor(convert_element_type_2380, mul_746); convert_element_type_2380 = None + sub_73 = torch.ops.aten.sub.Tensor(1, mul_746); mul_746 = None + mul_748 = torch.ops.aten.mul.Tensor(convert_element_type_257, sub_73); convert_element_type_257 = sub_73 = None + add_298 = torch.ops.aten.add.Tensor(mul_748, 1); mul_748 = None + mul_749 = torch.ops.aten.mul.Tensor(mul_747, add_298); mul_747 = add_298 = None + convert_element_type_2382 = torch.ops.prims.convert_element_type.default(mul_749, torch.bfloat16); mul_749 = None + view_1675 = torch.ops.aten.view.default(convert_element_type_2382, [16384, 14336]); convert_element_type_2382 = None + permute_1133 = torch.ops.aten.permute.default(view_1675, [1, 0]) + mm_567 = torch.ops.aten.mm.default(permute_1133, view_265); permute_1133 = view_265 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16); primals_73 = None + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 64, '0'); convert_element_type_254 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + permute_1135 = torch.ops.aten.permute.default(permute_85, [1, 0]); permute_85 = None + mm_568 = torch.ops.aten.mm.default(view_1675, permute_1135); view_1675 = permute_1135 = None + view_1676 = torch.ops.aten.view.default(mm_568, [2, 8192, 4096]); mm_568 = None + add_299 = torch.ops.aten.add.Tensor(view_1674, view_1676); view_1674 = view_1676 = None + convert_element_type_2387 = torch.ops.prims.convert_element_type.default(mm_567, torch.float32); mm_567 = None + reduce_scatter_tensor_220 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2387, 'avg', 64, '0'); convert_element_type_2387 = None + wait_tensor_511 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_220); reduce_scatter_tensor_220 = None + convert_element_type_2388 = torch.ops.prims.convert_element_type.default(add_299, torch.float32); add_299 = None + convert_element_type_2390 = torch.ops.prims.convert_element_type.default(wait_tensor_69, torch.float32); wait_tensor_69 = None + mul_750 = torch.ops.aten.mul.Tensor(convert_element_type_2388, convert_element_type_2390); convert_element_type_2390 = None + mul_752 = torch.ops.aten.mul.Tensor(mul_60, mul_750) + sum_147 = torch.ops.aten.sum.dim_IntList(mul_752, [2], True); mul_752 = None + div_49 = torch.ops.aten.div.Tensor(mul_60, 4096) + mul_753 = torch.ops.aten.mul.Tensor(div_49, sum_147); div_49 = sum_147 = None + sub_74 = torch.ops.aten.sub.Tensor(mul_750, mul_753); mul_750 = mul_753 = None + mul_754 = torch.ops.aten.mul.Tensor(sub_74, rsqrt_15); sub_74 = rsqrt_15 = None + mul_755 = torch.ops.aten.mul.Tensor(convert_element_type_2388, mul_60); convert_element_type_2388 = mul_60 = None + sum_148 = torch.ops.aten.sum.dim_IntList(mul_755, [0, 1]); mul_755 = None + convert_element_type_2391 = torch.ops.prims.convert_element_type.default(mul_754, torch.bfloat16); mul_754 = None + add_300 = torch.ops.aten.add.Tensor(add_296, convert_element_type_2391); add_296 = convert_element_type_2391 = None + convert_element_type_default_16 = torch.ops.prims.convert_element_type.default(sum_148, torch.float32); sum_148 = None + reduce_scatter_tensor_221 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_16, 'avg', 64, '0'); convert_element_type_default_16 = None + wait_tensor_512 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_221); reduce_scatter_tensor_221 = None + view_1677 = torch.ops.aten.view.default(add_300, [16384, 4096]) + permute_1137 = torch.ops.aten.permute.default(view_1677, [1, 0]) + mm_569 = torch.ops.aten.mm.default(permute_1137, view_261); permute_1137 = view_261 = None + permute_1139 = torch.ops.aten.permute.default(permute_84, [1, 0]); permute_84 = None + mm_570 = torch.ops.aten.mm.default(view_1677, permute_1139); view_1677 = permute_1139 = None + view_1678 = torch.ops.aten.view.default(mm_570, [2, 8192, 4096]); mm_570 = None + convert_element_type_2398 = torch.ops.prims.convert_element_type.default(mm_569, torch.float32); mm_569 = None + reduce_scatter_tensor_222 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2398, 'avg', 64, '0'); convert_element_type_2398 = None + wait_tensor_513 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_222); reduce_scatter_tensor_222 = None + view_1679 = torch.ops.aten.view.default(view_1678, [2, 8192, 32, 128]); view_1678 = None + permute_1141 = torch.ops.aten.permute.default(view_1679, [0, 2, 1, 3]); view_1679 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16); primals_67 = None + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 64, '0'); convert_element_type_232 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32); add_27 = None + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_64) + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + view_241 = torch.ops.aten.view.default(convert_element_type_234, [16384, 4096]); convert_element_type_234 = None + view_242 = torch.ops.aten.view.default(mm_49, [2, 8192, 4096]); mm_49 = None + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16); primals_69 = None + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 64, '0'); convert_element_type_238 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_66, [1, 0]); wait_tensor_66 = None + mm_50 = torch.ops.aten.mm.default(view_241, permute_78) + view_245 = torch.ops.aten.view.default(mm_50, [2, 8192, 1024]); mm_50 = None + view_248 = torch.ops.aten.view.default(mm_51, [2, 8192, 1024]); mm_51 = None + view_249 = torch.ops.aten.view.default(view_242, [2, 8192, -1, 128]); view_242 = None + view_250 = torch.ops.aten.view.default(view_245, [2, 8192, -1, 128]); view_245 = None + view_251 = torch.ops.aten.view.default(view_248, [2, 8192, -1, 128]); view_248 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 32, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_250, torch.float32); view_250 = None + view_253 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 8, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_253); view_253 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_16); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_255 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 32, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_16); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_256 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 8, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_256, torch.bfloat16); view_256 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 8, 4, 128]); unsqueeze_14 = None + clone_14 = torch.ops.aten.clone.default(expand_14, memory_format = torch.contiguous_format); expand_14 = None + view_257 = torch.ops.aten.view.default(clone_14, [2, 8192, 32, 128]); clone_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_251, 3); view_251 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 8, 4, 128]); unsqueeze_15 = None + clone_15 = torch.ops.aten.clone.default(expand_15, memory_format = torch.contiguous_format); expand_15 = None + view_258 = torch.ops.aten.view.default(clone_15, [2, 8192, 32, 128]); clone_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + permute_82 = torch.ops.aten.permute.default(view_258, [0, 2, 1, 3]); view_258 = None + _scaled_dot_product_cudnn_attention_backward_24 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1141, permute_80, permute_81, permute_82, getitem_63, getitem_64, getitem_69, getitem_70, None, None, None, 8192, 8192, 0.0, True); permute_1141 = permute_80 = permute_81 = permute_82 = getitem_63 = getitem_64 = getitem_69 = getitem_70 = None + getitem_360 = _scaled_dot_product_cudnn_attention_backward_24[0] + getitem_361 = _scaled_dot_product_cudnn_attention_backward_24[1] + getitem_362 = _scaled_dot_product_cudnn_attention_backward_24[2]; _scaled_dot_product_cudnn_attention_backward_24 = None + permute_1142 = torch.ops.aten.permute.default(getitem_362, [0, 2, 1, 3]); getitem_362 = None + permute_1143 = torch.ops.aten.permute.default(getitem_361, [0, 2, 1, 3]); getitem_361 = None + permute_1144 = torch.ops.aten.permute.default(getitem_360, [0, 2, 1, 3]); getitem_360 = None + view_1680 = torch.ops.aten.view.default(permute_1142, [2, 8192, 8, 4, 128]); permute_1142 = None + sum_149 = torch.ops.aten.sum.dim_IntList(view_1680, [3], True); view_1680 = None + squeeze_48 = torch.ops.aten.squeeze.dim(sum_149, 3); sum_149 = None + view_1681 = torch.ops.aten.view.default(permute_1143, [2, 8192, 8, 4, 128]); permute_1143 = None + sum_150 = torch.ops.aten.sum.dim_IntList(view_1681, [3], True); view_1681 = None + squeeze_49 = torch.ops.aten.squeeze.dim(sum_150, 3); sum_150 = None + convert_element_type_2399 = torch.ops.prims.convert_element_type.default(squeeze_49, torch.float32); squeeze_49 = None + convert_element_type_2400 = torch.ops.prims.convert_element_type.default(permute_1144, torch.float32); permute_1144 = None + view_1682 = torch.ops.aten.view.default(convert_element_type_2399, [2, 8192, 8, 64, 2]); convert_element_type_2399 = None + view_as_complex_112 = torch.ops.aten.view_as_complex.default(view_1682); view_1682 = None + mul_756 = torch.ops.aten.mul.Tensor(view_as_complex_112, _conj); view_as_complex_112 = None + view_1683 = torch.ops.aten.view.default(convert_element_type_2400, [2, 8192, 32, 64, 2]); convert_element_type_2400 = None + view_as_complex_113 = torch.ops.aten.view_as_complex.default(view_1683); view_1683 = None + mul_757 = torch.ops.aten.mul.Tensor(view_as_complex_113, _conj); view_as_complex_113 = None + view_as_real_112 = torch.ops.aten.view_as_real.default(mul_756); mul_756 = None + view_1684 = torch.ops.aten.view.default(view_as_real_112, [2, 8192, 8, 128]); view_as_real_112 = None + convert_element_type_2401 = torch.ops.prims.convert_element_type.default(view_1684, torch.bfloat16); view_1684 = None + view_as_real_113 = torch.ops.aten.view_as_real.default(mul_757); mul_757 = None + view_1685 = torch.ops.aten.view.default(view_as_real_113, [2, 8192, 32, 128]); view_as_real_113 = None + convert_element_type_2402 = torch.ops.prims.convert_element_type.default(view_1685, torch.bfloat16); view_1685 = None + view_1686 = torch.ops.aten.view.default(squeeze_48, [2, 8192, 1024]); squeeze_48 = None + view_1687 = torch.ops.aten.view.default(convert_element_type_2401, [2, 8192, 1024]); convert_element_type_2401 = None + view_1688 = torch.ops.aten.view.default(convert_element_type_2402, [2, 8192, 4096]); convert_element_type_2402 = None + view_1689 = torch.ops.aten.view.default(view_1686, [16384, 1024]); view_1686 = None + permute_1145 = torch.ops.aten.permute.default(view_1689, [1, 0]) + mm_571 = torch.ops.aten.mm.default(permute_1145, view_241); permute_1145 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16); primals_70 = None + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 64, '0'); convert_element_type_241 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_67, [1, 0]); wait_tensor_67 = None + permute_1147 = torch.ops.aten.permute.default(permute_79, [1, 0]); permute_79 = None + mm_572 = torch.ops.aten.mm.default(view_1689, permute_1147); view_1689 = permute_1147 = None + view_1690 = torch.ops.aten.view.default(mm_572, [2, 8192, 4096]); mm_572 = None + convert_element_type_2407 = torch.ops.prims.convert_element_type.default(mm_571, torch.float32); mm_571 = None + reduce_scatter_tensor_223 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2407, 'avg', 64, '0'); convert_element_type_2407 = None + wait_tensor_514 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_223); reduce_scatter_tensor_223 = None + view_1691 = torch.ops.aten.view.default(view_1687, [16384, 1024]); view_1687 = None + permute_1149 = torch.ops.aten.permute.default(view_1691, [1, 0]) + mm_573 = torch.ops.aten.mm.default(permute_1149, view_241); permute_1149 = None + permute_1151 = torch.ops.aten.permute.default(permute_78, [1, 0]); permute_78 = None + mm_574 = torch.ops.aten.mm.default(view_1691, permute_1151); view_1691 = permute_1151 = None + view_1692 = torch.ops.aten.view.default(mm_574, [2, 8192, 4096]); mm_574 = None + add_301 = torch.ops.aten.add.Tensor(view_1690, view_1692); view_1690 = view_1692 = None + convert_element_type_2412 = torch.ops.prims.convert_element_type.default(mm_573, torch.float32); mm_573 = None + reduce_scatter_tensor_224 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2412, 'avg', 64, '0'); convert_element_type_2412 = None + wait_tensor_515 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_224); reduce_scatter_tensor_224 = None + view_1693 = torch.ops.aten.view.default(view_1688, [16384, 4096]); view_1688 = None + permute_1153 = torch.ops.aten.permute.default(view_1693, [1, 0]) + mm_575 = torch.ops.aten.mm.default(permute_1153, view_241); permute_1153 = view_241 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16); primals_68 = None + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 64, '0'); convert_element_type_235 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + permute_1155 = torch.ops.aten.permute.default(permute_77, [1, 0]); permute_77 = None + mm_576 = torch.ops.aten.mm.default(view_1693, permute_1155); view_1693 = permute_1155 = None + view_1694 = torch.ops.aten.view.default(mm_576, [2, 8192, 4096]); mm_576 = None + add_302 = torch.ops.aten.add.Tensor(add_301, view_1694); add_301 = view_1694 = None + convert_element_type_2417 = torch.ops.prims.convert_element_type.default(mm_575, torch.float32); mm_575 = None + reduce_scatter_tensor_225 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2417, 'avg', 64, '0'); convert_element_type_2417 = None + wait_tensor_516 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_225); reduce_scatter_tensor_225 = None + convert_element_type_2418 = torch.ops.prims.convert_element_type.default(add_302, torch.float32); add_302 = None + convert_element_type_2420 = torch.ops.prims.convert_element_type.default(wait_tensor_64, torch.float32); wait_tensor_64 = None + mul_758 = torch.ops.aten.mul.Tensor(convert_element_type_2418, convert_element_type_2420); convert_element_type_2420 = None + mul_760 = torch.ops.aten.mul.Tensor(mul_56, mul_758) + sum_151 = torch.ops.aten.sum.dim_IntList(mul_760, [2], True); mul_760 = None + div_50 = torch.ops.aten.div.Tensor(mul_56, 4096) + mul_761 = torch.ops.aten.mul.Tensor(div_50, sum_151); div_50 = sum_151 = None + sub_75 = torch.ops.aten.sub.Tensor(mul_758, mul_761); mul_758 = mul_761 = None + mul_762 = torch.ops.aten.mul.Tensor(sub_75, rsqrt_14); sub_75 = rsqrt_14 = None + mul_763 = torch.ops.aten.mul.Tensor(convert_element_type_2418, mul_56); convert_element_type_2418 = mul_56 = None + sum_152 = torch.ops.aten.sum.dim_IntList(mul_763, [0, 1]); mul_763 = None + convert_element_type_2421 = torch.ops.prims.convert_element_type.default(mul_762, torch.bfloat16); mul_762 = None + add_303 = torch.ops.aten.add.Tensor(add_300, convert_element_type_2421); add_300 = convert_element_type_2421 = None + convert_element_type_default_15 = torch.ops.prims.convert_element_type.default(sum_152, torch.float32); sum_152 = None + reduce_scatter_tensor_226 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_15, 'avg', 64, '0'); convert_element_type_default_15 = None + wait_tensor_517 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_226); reduce_scatter_tensor_226 = None + view_1695 = torch.ops.aten.view.default(add_303, [16384, 4096]) + permute_1157 = torch.ops.aten.permute.default(view_1695, [1, 0]) + permute_72 = torch.ops.aten.permute.default(getitem_54, [0, 2, 1, 3]) + view_225 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16); primals_62 = None + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 64, '0'); convert_element_type_215 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + view_227 = torch.ops.aten.view.default(view_225, [16384, 4096]); view_225 = None + mm_45 = torch.ops.aten.mm.default(view_227, permute_73) + view_228 = torch.ops.aten.view.default(mm_45, [2, 8192, 4096]); mm_45 = None + add_25 = torch.ops.aten.add.Tensor(add_23, view_228); view_228 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16); primals_63 = None + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 64, '0'); convert_element_type_218 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32); add_25 = None + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_60) + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + view_231 = torch.ops.aten.view.default(convert_element_type_220, [16384, 4096]); convert_element_type_220 = None + view_232 = torch.ops.aten.view.default(mm_46, [2, 8192, 14336]); mm_46 = None + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_232, torch.float32); view_232 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16); primals_65 = None + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 64, '0'); convert_element_type_226 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_62, [1, 0]); wait_tensor_62 = None + mm_47 = torch.ops.aten.mm.default(view_231, permute_75) + view_235 = torch.ops.aten.view.default(mm_47, [2, 8192, 14336]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_235) + view_237 = torch.ops.aten.view.default(mul_55, [16384, 14336]); mul_55 = None + mm_577 = torch.ops.aten.mm.default(permute_1157, view_237); permute_1157 = view_237 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16); primals_66 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 64, '0'); convert_element_type_229 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + permute_1159 = torch.ops.aten.permute.default(permute_76, [1, 0]); permute_76 = None + mm_578 = torch.ops.aten.mm.default(view_1695, permute_1159); view_1695 = permute_1159 = None + view_1696 = torch.ops.aten.view.default(mm_578, [2, 8192, 14336]); mm_578 = None + convert_element_type_2428 = torch.ops.prims.convert_element_type.default(mm_577, torch.float32); mm_577 = None + reduce_scatter_tensor_227 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2428, 'avg', 64, '0'); convert_element_type_2428 = None + wait_tensor_518 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_227); reduce_scatter_tensor_227 = None + mul_764 = torch.ops.aten.mul.Tensor(view_1696, convert_element_type_225); convert_element_type_225 = None + mul_765 = torch.ops.aten.mul.Tensor(view_1696, view_235); view_1696 = view_235 = None + view_1697 = torch.ops.aten.view.default(mul_764, [16384, 14336]); mul_764 = None + permute_1161 = torch.ops.aten.permute.default(view_1697, [1, 0]) + mm_579 = torch.ops.aten.mm.default(permute_1161, view_231); permute_1161 = None + permute_1163 = torch.ops.aten.permute.default(permute_75, [1, 0]); permute_75 = None + mm_580 = torch.ops.aten.mm.default(view_1697, permute_1163); view_1697 = permute_1163 = None + view_1698 = torch.ops.aten.view.default(mm_580, [2, 8192, 4096]); mm_580 = None + convert_element_type_2433 = torch.ops.prims.convert_element_type.default(mm_579, torch.float32); mm_579 = None + reduce_scatter_tensor_228 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2433, 'avg', 64, '0'); convert_element_type_2433 = None + wait_tensor_519 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_228); reduce_scatter_tensor_228 = None + convert_element_type_2434 = torch.ops.prims.convert_element_type.default(mul_765, torch.float32); mul_765 = None + neg_25 = torch.ops.aten.neg.default(convert_element_type_224) + exp_25 = torch.ops.aten.exp.default(neg_25); neg_25 = None + add_304 = torch.ops.aten.add.Tensor(exp_25, 1); exp_25 = None + reciprocal_25 = torch.ops.aten.reciprocal.default(add_304); add_304 = None + mul_766 = torch.ops.aten.mul.Tensor(reciprocal_25, 1); reciprocal_25 = None + mul_767 = torch.ops.aten.mul.Tensor(convert_element_type_2434, mul_766); convert_element_type_2434 = None + sub_76 = torch.ops.aten.sub.Tensor(1, mul_766); mul_766 = None + mul_768 = torch.ops.aten.mul.Tensor(convert_element_type_224, sub_76); convert_element_type_224 = sub_76 = None + add_305 = torch.ops.aten.add.Tensor(mul_768, 1); mul_768 = None + mul_769 = torch.ops.aten.mul.Tensor(mul_767, add_305); mul_767 = add_305 = None + convert_element_type_2436 = torch.ops.prims.convert_element_type.default(mul_769, torch.bfloat16); mul_769 = None + view_1699 = torch.ops.aten.view.default(convert_element_type_2436, [16384, 14336]); convert_element_type_2436 = None + permute_1165 = torch.ops.aten.permute.default(view_1699, [1, 0]) + mm_581 = torch.ops.aten.mm.default(permute_1165, view_231); permute_1165 = view_231 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16); primals_64 = None + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 64, '0'); convert_element_type_221 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_61, [1, 0]); wait_tensor_61 = None + permute_1167 = torch.ops.aten.permute.default(permute_74, [1, 0]); permute_74 = None + mm_582 = torch.ops.aten.mm.default(view_1699, permute_1167); view_1699 = permute_1167 = None + view_1700 = torch.ops.aten.view.default(mm_582, [2, 8192, 4096]); mm_582 = None + add_306 = torch.ops.aten.add.Tensor(view_1698, view_1700); view_1698 = view_1700 = None + convert_element_type_2441 = torch.ops.prims.convert_element_type.default(mm_581, torch.float32); mm_581 = None + reduce_scatter_tensor_229 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2441, 'avg', 64, '0'); convert_element_type_2441 = None + wait_tensor_520 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_229); reduce_scatter_tensor_229 = None + convert_element_type_2442 = torch.ops.prims.convert_element_type.default(add_306, torch.float32); add_306 = None + convert_element_type_2444 = torch.ops.prims.convert_element_type.default(wait_tensor_60, torch.float32); wait_tensor_60 = None + mul_770 = torch.ops.aten.mul.Tensor(convert_element_type_2442, convert_element_type_2444); convert_element_type_2444 = None + mul_772 = torch.ops.aten.mul.Tensor(mul_52, mul_770) + sum_153 = torch.ops.aten.sum.dim_IntList(mul_772, [2], True); mul_772 = None + div_51 = torch.ops.aten.div.Tensor(mul_52, 4096) + mul_773 = torch.ops.aten.mul.Tensor(div_51, sum_153); div_51 = sum_153 = None + sub_77 = torch.ops.aten.sub.Tensor(mul_770, mul_773); mul_770 = mul_773 = None + mul_774 = torch.ops.aten.mul.Tensor(sub_77, rsqrt_13); sub_77 = rsqrt_13 = None + mul_775 = torch.ops.aten.mul.Tensor(convert_element_type_2442, mul_52); convert_element_type_2442 = mul_52 = None + sum_154 = torch.ops.aten.sum.dim_IntList(mul_775, [0, 1]); mul_775 = None + convert_element_type_2445 = torch.ops.prims.convert_element_type.default(mul_774, torch.bfloat16); mul_774 = None + add_307 = torch.ops.aten.add.Tensor(add_303, convert_element_type_2445); add_303 = convert_element_type_2445 = None + convert_element_type_default_14 = torch.ops.prims.convert_element_type.default(sum_154, torch.float32); sum_154 = None + reduce_scatter_tensor_230 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_14, 'avg', 64, '0'); convert_element_type_default_14 = None + wait_tensor_521 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_230); reduce_scatter_tensor_230 = None + view_1701 = torch.ops.aten.view.default(add_307, [16384, 4096]) + permute_1169 = torch.ops.aten.permute.default(view_1701, [1, 0]) + mm_583 = torch.ops.aten.mm.default(permute_1169, view_227); permute_1169 = view_227 = None + permute_1171 = torch.ops.aten.permute.default(permute_73, [1, 0]); permute_73 = None + mm_584 = torch.ops.aten.mm.default(view_1701, permute_1171); view_1701 = permute_1171 = None + view_1702 = torch.ops.aten.view.default(mm_584, [2, 8192, 4096]); mm_584 = None + convert_element_type_2452 = torch.ops.prims.convert_element_type.default(mm_583, torch.float32); mm_583 = None + reduce_scatter_tensor_231 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2452, 'avg', 64, '0'); convert_element_type_2452 = None + wait_tensor_522 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_231); reduce_scatter_tensor_231 = None + view_1703 = torch.ops.aten.view.default(view_1702, [2, 8192, 32, 128]); view_1702 = None + permute_1173 = torch.ops.aten.permute.default(view_1703, [0, 2, 1, 3]); view_1703 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16); primals_58 = None + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 64, '0'); convert_element_type_199 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32); add_23 = None + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_55) + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + view_207 = torch.ops.aten.view.default(convert_element_type_201, [16384, 4096]); convert_element_type_201 = None + view_208 = torch.ops.aten.view.default(mm_42, [2, 8192, 4096]); mm_42 = None + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16); primals_60 = None + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 64, '0'); convert_element_type_205 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_43 = torch.ops.aten.mm.default(view_207, permute_67) + view_211 = torch.ops.aten.view.default(mm_43, [2, 8192, 1024]); mm_43 = None + view_214 = torch.ops.aten.view.default(mm_44, [2, 8192, 1024]); mm_44 = None + view_215 = torch.ops.aten.view.default(view_208, [2, 8192, -1, 128]); view_208 = None + view_216 = torch.ops.aten.view.default(view_211, [2, 8192, -1, 128]); view_211 = None + view_217 = torch.ops.aten.view.default(view_214, [2, 8192, -1, 128]); view_214 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_215, torch.float32); view_215 = None + view_218 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 32, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_218); view_218 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_216, torch.float32); view_216 = None + view_219 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 8, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_219); view_219 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_16); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_221 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 32, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_16); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_222 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 8, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_221, torch.bfloat16); view_221 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_222, torch.bfloat16); view_222 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 8, 4, 128]); unsqueeze_12 = None + clone_12 = torch.ops.aten.clone.default(expand_12, memory_format = torch.contiguous_format); expand_12 = None + view_223 = torch.ops.aten.view.default(clone_12, [2, 8192, 32, 128]); clone_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_217, 3); view_217 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 8, 4, 128]); unsqueeze_13 = None + clone_13 = torch.ops.aten.clone.default(expand_13, memory_format = torch.contiguous_format); expand_13 = None + view_224 = torch.ops.aten.view.default(clone_13, [2, 8192, 32, 128]); clone_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_223, [0, 2, 1, 3]); view_223 = None + permute_71 = torch.ops.aten.permute.default(view_224, [0, 2, 1, 3]); view_224 = None + _scaled_dot_product_cudnn_attention_backward_25 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1173, permute_69, permute_70, permute_71, getitem_54, getitem_55, getitem_60, getitem_61, None, None, None, 8192, 8192, 0.0, True); permute_1173 = permute_69 = permute_70 = permute_71 = getitem_54 = getitem_55 = getitem_60 = getitem_61 = None + getitem_363 = _scaled_dot_product_cudnn_attention_backward_25[0] + getitem_364 = _scaled_dot_product_cudnn_attention_backward_25[1] + getitem_365 = _scaled_dot_product_cudnn_attention_backward_25[2]; _scaled_dot_product_cudnn_attention_backward_25 = None + permute_1174 = torch.ops.aten.permute.default(getitem_365, [0, 2, 1, 3]); getitem_365 = None + permute_1175 = torch.ops.aten.permute.default(getitem_364, [0, 2, 1, 3]); getitem_364 = None + permute_1176 = torch.ops.aten.permute.default(getitem_363, [0, 2, 1, 3]); getitem_363 = None + view_1704 = torch.ops.aten.view.default(permute_1174, [2, 8192, 8, 4, 128]); permute_1174 = None + sum_155 = torch.ops.aten.sum.dim_IntList(view_1704, [3], True); view_1704 = None + squeeze_50 = torch.ops.aten.squeeze.dim(sum_155, 3); sum_155 = None + view_1705 = torch.ops.aten.view.default(permute_1175, [2, 8192, 8, 4, 128]); permute_1175 = None + sum_156 = torch.ops.aten.sum.dim_IntList(view_1705, [3], True); view_1705 = None + squeeze_51 = torch.ops.aten.squeeze.dim(sum_156, 3); sum_156 = None + convert_element_type_2453 = torch.ops.prims.convert_element_type.default(squeeze_51, torch.float32); squeeze_51 = None + convert_element_type_2454 = torch.ops.prims.convert_element_type.default(permute_1176, torch.float32); permute_1176 = None + view_1706 = torch.ops.aten.view.default(convert_element_type_2453, [2, 8192, 8, 64, 2]); convert_element_type_2453 = None + view_as_complex_114 = torch.ops.aten.view_as_complex.default(view_1706); view_1706 = None + mul_776 = torch.ops.aten.mul.Tensor(view_as_complex_114, _conj); view_as_complex_114 = None + view_1707 = torch.ops.aten.view.default(convert_element_type_2454, [2, 8192, 32, 64, 2]); convert_element_type_2454 = None + view_as_complex_115 = torch.ops.aten.view_as_complex.default(view_1707); view_1707 = None + mul_777 = torch.ops.aten.mul.Tensor(view_as_complex_115, _conj); view_as_complex_115 = None + view_as_real_114 = torch.ops.aten.view_as_real.default(mul_776); mul_776 = None + view_1708 = torch.ops.aten.view.default(view_as_real_114, [2, 8192, 8, 128]); view_as_real_114 = None + convert_element_type_2455 = torch.ops.prims.convert_element_type.default(view_1708, torch.bfloat16); view_1708 = None + view_as_real_115 = torch.ops.aten.view_as_real.default(mul_777); mul_777 = None + view_1709 = torch.ops.aten.view.default(view_as_real_115, [2, 8192, 32, 128]); view_as_real_115 = None + convert_element_type_2456 = torch.ops.prims.convert_element_type.default(view_1709, torch.bfloat16); view_1709 = None + view_1710 = torch.ops.aten.view.default(squeeze_50, [2, 8192, 1024]); squeeze_50 = None + view_1711 = torch.ops.aten.view.default(convert_element_type_2455, [2, 8192, 1024]); convert_element_type_2455 = None + view_1712 = torch.ops.aten.view.default(convert_element_type_2456, [2, 8192, 4096]); convert_element_type_2456 = None + view_1713 = torch.ops.aten.view.default(view_1710, [16384, 1024]); view_1710 = None + permute_1177 = torch.ops.aten.permute.default(view_1713, [1, 0]) + mm_585 = torch.ops.aten.mm.default(permute_1177, view_207); permute_1177 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16); primals_61 = None + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 64, '0'); convert_element_type_208 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + permute_1179 = torch.ops.aten.permute.default(permute_68, [1, 0]); permute_68 = None + mm_586 = torch.ops.aten.mm.default(view_1713, permute_1179); view_1713 = permute_1179 = None + view_1714 = torch.ops.aten.view.default(mm_586, [2, 8192, 4096]); mm_586 = None + convert_element_type_2461 = torch.ops.prims.convert_element_type.default(mm_585, torch.float32); mm_585 = None + reduce_scatter_tensor_232 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2461, 'avg', 64, '0'); convert_element_type_2461 = None + wait_tensor_523 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_232); reduce_scatter_tensor_232 = None + view_1715 = torch.ops.aten.view.default(view_1711, [16384, 1024]); view_1711 = None + permute_1181 = torch.ops.aten.permute.default(view_1715, [1, 0]) + mm_587 = torch.ops.aten.mm.default(permute_1181, view_207); permute_1181 = None + permute_1183 = torch.ops.aten.permute.default(permute_67, [1, 0]); permute_67 = None + mm_588 = torch.ops.aten.mm.default(view_1715, permute_1183); view_1715 = permute_1183 = None + view_1716 = torch.ops.aten.view.default(mm_588, [2, 8192, 4096]); mm_588 = None + add_308 = torch.ops.aten.add.Tensor(view_1714, view_1716); view_1714 = view_1716 = None + convert_element_type_2466 = torch.ops.prims.convert_element_type.default(mm_587, torch.float32); mm_587 = None + reduce_scatter_tensor_233 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2466, 'avg', 64, '0'); convert_element_type_2466 = None + wait_tensor_524 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_233); reduce_scatter_tensor_233 = None + view_1717 = torch.ops.aten.view.default(view_1712, [16384, 4096]); view_1712 = None + permute_1185 = torch.ops.aten.permute.default(view_1717, [1, 0]) + mm_589 = torch.ops.aten.mm.default(permute_1185, view_207); permute_1185 = view_207 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16); primals_59 = None + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 64, '0'); convert_element_type_202 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + permute_1187 = torch.ops.aten.permute.default(permute_66, [1, 0]); permute_66 = None + mm_590 = torch.ops.aten.mm.default(view_1717, permute_1187); view_1717 = permute_1187 = None + view_1718 = torch.ops.aten.view.default(mm_590, [2, 8192, 4096]); mm_590 = None + add_309 = torch.ops.aten.add.Tensor(add_308, view_1718); add_308 = view_1718 = None + convert_element_type_2471 = torch.ops.prims.convert_element_type.default(mm_589, torch.float32); mm_589 = None + reduce_scatter_tensor_234 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2471, 'avg', 64, '0'); convert_element_type_2471 = None + wait_tensor_525 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_234); reduce_scatter_tensor_234 = None + convert_element_type_2472 = torch.ops.prims.convert_element_type.default(add_309, torch.float32); add_309 = None + convert_element_type_2474 = torch.ops.prims.convert_element_type.default(wait_tensor_55, torch.float32); wait_tensor_55 = None + mul_778 = torch.ops.aten.mul.Tensor(convert_element_type_2472, convert_element_type_2474); convert_element_type_2474 = None + mul_780 = torch.ops.aten.mul.Tensor(mul_48, mul_778) + sum_157 = torch.ops.aten.sum.dim_IntList(mul_780, [2], True); mul_780 = None + div_52 = torch.ops.aten.div.Tensor(mul_48, 4096) + mul_781 = torch.ops.aten.mul.Tensor(div_52, sum_157); div_52 = sum_157 = None + sub_78 = torch.ops.aten.sub.Tensor(mul_778, mul_781); mul_778 = mul_781 = None + mul_782 = torch.ops.aten.mul.Tensor(sub_78, rsqrt_12); sub_78 = rsqrt_12 = None + mul_783 = torch.ops.aten.mul.Tensor(convert_element_type_2472, mul_48); convert_element_type_2472 = mul_48 = None + sum_158 = torch.ops.aten.sum.dim_IntList(mul_783, [0, 1]); mul_783 = None + convert_element_type_2475 = torch.ops.prims.convert_element_type.default(mul_782, torch.bfloat16); mul_782 = None + add_310 = torch.ops.aten.add.Tensor(add_307, convert_element_type_2475); add_307 = convert_element_type_2475 = None + convert_element_type_default_13 = torch.ops.prims.convert_element_type.default(sum_158, torch.float32); sum_158 = None + reduce_scatter_tensor_235 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_13, 'avg', 64, '0'); convert_element_type_default_13 = None + wait_tensor_526 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_235); reduce_scatter_tensor_235 = None + view_1719 = torch.ops.aten.view.default(add_310, [16384, 4096]) + permute_1189 = torch.ops.aten.permute.default(view_1719, [1, 0]) + permute_61 = torch.ops.aten.permute.default(getitem_45, [0, 2, 1, 3]) + view_191 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16); primals_53 = None + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 64, '0'); convert_element_type_182 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + view_193 = torch.ops.aten.view.default(view_191, [16384, 4096]); view_191 = None + mm_38 = torch.ops.aten.mm.default(view_193, permute_62) + view_194 = torch.ops.aten.view.default(mm_38, [2, 8192, 4096]); mm_38 = None + add_21 = torch.ops.aten.add.Tensor(add_19, view_194); view_194 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16); primals_54 = None + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 64, '0'); convert_element_type_185 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32); add_21 = None + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_51) + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + view_197 = torch.ops.aten.view.default(convert_element_type_187, [16384, 4096]); convert_element_type_187 = None + view_198 = torch.ops.aten.view.default(mm_39, [2, 8192, 14336]); mm_39 = None + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_198, torch.float32); view_198 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16); primals_56 = None + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 64, '0'); convert_element_type_193 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_53, [1, 0]); wait_tensor_53 = None + mm_40 = torch.ops.aten.mm.default(view_197, permute_64) + view_201 = torch.ops.aten.view.default(mm_40, [2, 8192, 14336]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_201) + view_203 = torch.ops.aten.view.default(mul_47, [16384, 14336]); mul_47 = None + mm_591 = torch.ops.aten.mm.default(permute_1189, view_203); permute_1189 = view_203 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16); primals_57 = None + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 64, '0'); convert_element_type_196 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_54, [1, 0]); wait_tensor_54 = None + permute_1191 = torch.ops.aten.permute.default(permute_65, [1, 0]); permute_65 = None + mm_592 = torch.ops.aten.mm.default(view_1719, permute_1191); view_1719 = permute_1191 = None + view_1720 = torch.ops.aten.view.default(mm_592, [2, 8192, 14336]); mm_592 = None + convert_element_type_2482 = torch.ops.prims.convert_element_type.default(mm_591, torch.float32); mm_591 = None + reduce_scatter_tensor_236 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2482, 'avg', 64, '0'); convert_element_type_2482 = None + wait_tensor_527 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_236); reduce_scatter_tensor_236 = None + mul_784 = torch.ops.aten.mul.Tensor(view_1720, convert_element_type_192); convert_element_type_192 = None + mul_785 = torch.ops.aten.mul.Tensor(view_1720, view_201); view_1720 = view_201 = None + view_1721 = torch.ops.aten.view.default(mul_784, [16384, 14336]); mul_784 = None + permute_1193 = torch.ops.aten.permute.default(view_1721, [1, 0]) + mm_593 = torch.ops.aten.mm.default(permute_1193, view_197); permute_1193 = None + permute_1195 = torch.ops.aten.permute.default(permute_64, [1, 0]); permute_64 = None + mm_594 = torch.ops.aten.mm.default(view_1721, permute_1195); view_1721 = permute_1195 = None + view_1722 = torch.ops.aten.view.default(mm_594, [2, 8192, 4096]); mm_594 = None + convert_element_type_2487 = torch.ops.prims.convert_element_type.default(mm_593, torch.float32); mm_593 = None + reduce_scatter_tensor_237 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2487, 'avg', 64, '0'); convert_element_type_2487 = None + wait_tensor_528 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_237); reduce_scatter_tensor_237 = None + convert_element_type_2488 = torch.ops.prims.convert_element_type.default(mul_785, torch.float32); mul_785 = None + neg_26 = torch.ops.aten.neg.default(convert_element_type_191) + exp_26 = torch.ops.aten.exp.default(neg_26); neg_26 = None + add_311 = torch.ops.aten.add.Tensor(exp_26, 1); exp_26 = None + reciprocal_26 = torch.ops.aten.reciprocal.default(add_311); add_311 = None + mul_786 = torch.ops.aten.mul.Tensor(reciprocal_26, 1); reciprocal_26 = None + mul_787 = torch.ops.aten.mul.Tensor(convert_element_type_2488, mul_786); convert_element_type_2488 = None + sub_79 = torch.ops.aten.sub.Tensor(1, mul_786); mul_786 = None + mul_788 = torch.ops.aten.mul.Tensor(convert_element_type_191, sub_79); convert_element_type_191 = sub_79 = None + add_312 = torch.ops.aten.add.Tensor(mul_788, 1); mul_788 = None + mul_789 = torch.ops.aten.mul.Tensor(mul_787, add_312); mul_787 = add_312 = None + convert_element_type_2490 = torch.ops.prims.convert_element_type.default(mul_789, torch.bfloat16); mul_789 = None + view_1723 = torch.ops.aten.view.default(convert_element_type_2490, [16384, 14336]); convert_element_type_2490 = None + permute_1197 = torch.ops.aten.permute.default(view_1723, [1, 0]) + mm_595 = torch.ops.aten.mm.default(permute_1197, view_197); permute_1197 = view_197 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16); primals_55 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 64, '0'); convert_element_type_188 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + permute_1199 = torch.ops.aten.permute.default(permute_63, [1, 0]); permute_63 = None + mm_596 = torch.ops.aten.mm.default(view_1723, permute_1199); view_1723 = permute_1199 = None + view_1724 = torch.ops.aten.view.default(mm_596, [2, 8192, 4096]); mm_596 = None + add_313 = torch.ops.aten.add.Tensor(view_1722, view_1724); view_1722 = view_1724 = None + convert_element_type_2495 = torch.ops.prims.convert_element_type.default(mm_595, torch.float32); mm_595 = None + reduce_scatter_tensor_238 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2495, 'avg', 64, '0'); convert_element_type_2495 = None + wait_tensor_529 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_238); reduce_scatter_tensor_238 = None + convert_element_type_2496 = torch.ops.prims.convert_element_type.default(add_313, torch.float32); add_313 = None + convert_element_type_2498 = torch.ops.prims.convert_element_type.default(wait_tensor_51, torch.float32); wait_tensor_51 = None + mul_790 = torch.ops.aten.mul.Tensor(convert_element_type_2496, convert_element_type_2498); convert_element_type_2498 = None + mul_792 = torch.ops.aten.mul.Tensor(mul_44, mul_790) + sum_159 = torch.ops.aten.sum.dim_IntList(mul_792, [2], True); mul_792 = None + div_53 = torch.ops.aten.div.Tensor(mul_44, 4096) + mul_793 = torch.ops.aten.mul.Tensor(div_53, sum_159); div_53 = sum_159 = None + sub_80 = torch.ops.aten.sub.Tensor(mul_790, mul_793); mul_790 = mul_793 = None + mul_794 = torch.ops.aten.mul.Tensor(sub_80, rsqrt_11); sub_80 = rsqrt_11 = None + mul_795 = torch.ops.aten.mul.Tensor(convert_element_type_2496, mul_44); convert_element_type_2496 = mul_44 = None + sum_160 = torch.ops.aten.sum.dim_IntList(mul_795, [0, 1]); mul_795 = None + convert_element_type_2499 = torch.ops.prims.convert_element_type.default(mul_794, torch.bfloat16); mul_794 = None + add_314 = torch.ops.aten.add.Tensor(add_310, convert_element_type_2499); add_310 = convert_element_type_2499 = None + convert_element_type_default_12 = torch.ops.prims.convert_element_type.default(sum_160, torch.float32); sum_160 = None + reduce_scatter_tensor_239 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_12, 'avg', 64, '0'); convert_element_type_default_12 = None + wait_tensor_530 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_239); reduce_scatter_tensor_239 = None + view_1725 = torch.ops.aten.view.default(add_314, [16384, 4096]) + permute_1201 = torch.ops.aten.permute.default(view_1725, [1, 0]) + mm_597 = torch.ops.aten.mm.default(permute_1201, view_193); permute_1201 = view_193 = None + permute_1203 = torch.ops.aten.permute.default(permute_62, [1, 0]); permute_62 = None + mm_598 = torch.ops.aten.mm.default(view_1725, permute_1203); view_1725 = permute_1203 = None + view_1726 = torch.ops.aten.view.default(mm_598, [2, 8192, 4096]); mm_598 = None + convert_element_type_2506 = torch.ops.prims.convert_element_type.default(mm_597, torch.float32); mm_597 = None + reduce_scatter_tensor_240 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2506, 'avg', 64, '0'); convert_element_type_2506 = None + wait_tensor_531 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_240); reduce_scatter_tensor_240 = None + view_1727 = torch.ops.aten.view.default(view_1726, [2, 8192, 32, 128]); view_1726 = None + permute_1205 = torch.ops.aten.permute.default(view_1727, [0, 2, 1, 3]); view_1727 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16); primals_49 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 64, '0'); convert_element_type_166 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32); add_19 = None + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_46) + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + view_173 = torch.ops.aten.view.default(convert_element_type_168, [16384, 4096]); convert_element_type_168 = None + view_174 = torch.ops.aten.view.default(mm_35, [2, 8192, 4096]); mm_35 = None + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16); primals_51 = None + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 64, '0'); convert_element_type_172 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_48, [1, 0]); wait_tensor_48 = None + mm_36 = torch.ops.aten.mm.default(view_173, permute_56) + view_177 = torch.ops.aten.view.default(mm_36, [2, 8192, 1024]); mm_36 = None + view_180 = torch.ops.aten.view.default(mm_37, [2, 8192, 1024]); mm_37 = None + view_181 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + view_182 = torch.ops.aten.view.default(view_177, [2, 8192, -1, 128]); view_177 = None + view_183 = torch.ops.aten.view.default(view_180, [2, 8192, -1, 128]); view_180 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_181, torch.float32); view_181 = None + view_184 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 32, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_184); view_184 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_182, torch.float32); view_182 = None + view_185 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 8, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_185); view_185 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_16); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_187 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 32, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_16); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_188 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 8, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_187, torch.bfloat16); view_187 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_188, torch.bfloat16); view_188 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 8, 4, 128]); unsqueeze_10 = None + clone_10 = torch.ops.aten.clone.default(expand_10, memory_format = torch.contiguous_format); expand_10 = None + view_189 = torch.ops.aten.view.default(clone_10, [2, 8192, 32, 128]); clone_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_183, 3); view_183 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 8, 4, 128]); unsqueeze_11 = None + clone_11 = torch.ops.aten.clone.default(expand_11, memory_format = torch.contiguous_format); expand_11 = None + view_190 = torch.ops.aten.view.default(clone_11, [2, 8192, 32, 128]); clone_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_189, [0, 2, 1, 3]); view_189 = None + permute_60 = torch.ops.aten.permute.default(view_190, [0, 2, 1, 3]); view_190 = None + _scaled_dot_product_cudnn_attention_backward_26 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1205, permute_58, permute_59, permute_60, getitem_45, getitem_46, getitem_51, getitem_52, None, None, None, 8192, 8192, 0.0, True); permute_1205 = permute_58 = permute_59 = permute_60 = getitem_45 = getitem_46 = getitem_51 = getitem_52 = None + getitem_366 = _scaled_dot_product_cudnn_attention_backward_26[0] + getitem_367 = _scaled_dot_product_cudnn_attention_backward_26[1] + getitem_368 = _scaled_dot_product_cudnn_attention_backward_26[2]; _scaled_dot_product_cudnn_attention_backward_26 = None + permute_1206 = torch.ops.aten.permute.default(getitem_368, [0, 2, 1, 3]); getitem_368 = None + permute_1207 = torch.ops.aten.permute.default(getitem_367, [0, 2, 1, 3]); getitem_367 = None + permute_1208 = torch.ops.aten.permute.default(getitem_366, [0, 2, 1, 3]); getitem_366 = None + view_1728 = torch.ops.aten.view.default(permute_1206, [2, 8192, 8, 4, 128]); permute_1206 = None + sum_161 = torch.ops.aten.sum.dim_IntList(view_1728, [3], True); view_1728 = None + squeeze_52 = torch.ops.aten.squeeze.dim(sum_161, 3); sum_161 = None + view_1729 = torch.ops.aten.view.default(permute_1207, [2, 8192, 8, 4, 128]); permute_1207 = None + sum_162 = torch.ops.aten.sum.dim_IntList(view_1729, [3], True); view_1729 = None + squeeze_53 = torch.ops.aten.squeeze.dim(sum_162, 3); sum_162 = None + convert_element_type_2507 = torch.ops.prims.convert_element_type.default(squeeze_53, torch.float32); squeeze_53 = None + convert_element_type_2508 = torch.ops.prims.convert_element_type.default(permute_1208, torch.float32); permute_1208 = None + view_1730 = torch.ops.aten.view.default(convert_element_type_2507, [2, 8192, 8, 64, 2]); convert_element_type_2507 = None + view_as_complex_116 = torch.ops.aten.view_as_complex.default(view_1730); view_1730 = None + mul_796 = torch.ops.aten.mul.Tensor(view_as_complex_116, _conj); view_as_complex_116 = None + view_1731 = torch.ops.aten.view.default(convert_element_type_2508, [2, 8192, 32, 64, 2]); convert_element_type_2508 = None + view_as_complex_117 = torch.ops.aten.view_as_complex.default(view_1731); view_1731 = None + mul_797 = torch.ops.aten.mul.Tensor(view_as_complex_117, _conj); view_as_complex_117 = None + view_as_real_116 = torch.ops.aten.view_as_real.default(mul_796); mul_796 = None + view_1732 = torch.ops.aten.view.default(view_as_real_116, [2, 8192, 8, 128]); view_as_real_116 = None + convert_element_type_2509 = torch.ops.prims.convert_element_type.default(view_1732, torch.bfloat16); view_1732 = None + view_as_real_117 = torch.ops.aten.view_as_real.default(mul_797); mul_797 = None + view_1733 = torch.ops.aten.view.default(view_as_real_117, [2, 8192, 32, 128]); view_as_real_117 = None + convert_element_type_2510 = torch.ops.prims.convert_element_type.default(view_1733, torch.bfloat16); view_1733 = None + view_1734 = torch.ops.aten.view.default(squeeze_52, [2, 8192, 1024]); squeeze_52 = None + view_1735 = torch.ops.aten.view.default(convert_element_type_2509, [2, 8192, 1024]); convert_element_type_2509 = None + view_1736 = torch.ops.aten.view.default(convert_element_type_2510, [2, 8192, 4096]); convert_element_type_2510 = None + view_1737 = torch.ops.aten.view.default(view_1734, [16384, 1024]); view_1734 = None + permute_1209 = torch.ops.aten.permute.default(view_1737, [1, 0]) + mm_599 = torch.ops.aten.mm.default(permute_1209, view_173); permute_1209 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16); primals_52 = None + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 64, '0'); convert_element_type_175 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_49, [1, 0]); wait_tensor_49 = None + permute_1211 = torch.ops.aten.permute.default(permute_57, [1, 0]); permute_57 = None + mm_600 = torch.ops.aten.mm.default(view_1737, permute_1211); view_1737 = permute_1211 = None + view_1738 = torch.ops.aten.view.default(mm_600, [2, 8192, 4096]); mm_600 = None + convert_element_type_2515 = torch.ops.prims.convert_element_type.default(mm_599, torch.float32); mm_599 = None + reduce_scatter_tensor_241 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2515, 'avg', 64, '0'); convert_element_type_2515 = None + wait_tensor_532 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_241); reduce_scatter_tensor_241 = None + view_1739 = torch.ops.aten.view.default(view_1735, [16384, 1024]); view_1735 = None + permute_1213 = torch.ops.aten.permute.default(view_1739, [1, 0]) + mm_601 = torch.ops.aten.mm.default(permute_1213, view_173); permute_1213 = None + permute_1215 = torch.ops.aten.permute.default(permute_56, [1, 0]); permute_56 = None + mm_602 = torch.ops.aten.mm.default(view_1739, permute_1215); view_1739 = permute_1215 = None + view_1740 = torch.ops.aten.view.default(mm_602, [2, 8192, 4096]); mm_602 = None + add_315 = torch.ops.aten.add.Tensor(view_1738, view_1740); view_1738 = view_1740 = None + convert_element_type_2520 = torch.ops.prims.convert_element_type.default(mm_601, torch.float32); mm_601 = None + reduce_scatter_tensor_242 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2520, 'avg', 64, '0'); convert_element_type_2520 = None + wait_tensor_533 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_242); reduce_scatter_tensor_242 = None + view_1741 = torch.ops.aten.view.default(view_1736, [16384, 4096]); view_1736 = None + permute_1217 = torch.ops.aten.permute.default(view_1741, [1, 0]) + mm_603 = torch.ops.aten.mm.default(permute_1217, view_173); permute_1217 = view_173 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16); primals_50 = None + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 64, '0'); convert_element_type_169 = None + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_47, [1, 0]); wait_tensor_47 = None + permute_1219 = torch.ops.aten.permute.default(permute_55, [1, 0]); permute_55 = None + mm_604 = torch.ops.aten.mm.default(view_1741, permute_1219); view_1741 = permute_1219 = None + view_1742 = torch.ops.aten.view.default(mm_604, [2, 8192, 4096]); mm_604 = None + add_316 = torch.ops.aten.add.Tensor(add_315, view_1742); add_315 = view_1742 = None + convert_element_type_2525 = torch.ops.prims.convert_element_type.default(mm_603, torch.float32); mm_603 = None + reduce_scatter_tensor_243 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2525, 'avg', 64, '0'); convert_element_type_2525 = None + wait_tensor_534 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_243); reduce_scatter_tensor_243 = None + convert_element_type_2526 = torch.ops.prims.convert_element_type.default(add_316, torch.float32); add_316 = None + convert_element_type_2528 = torch.ops.prims.convert_element_type.default(wait_tensor_46, torch.float32); wait_tensor_46 = None + mul_798 = torch.ops.aten.mul.Tensor(convert_element_type_2526, convert_element_type_2528); convert_element_type_2528 = None + mul_800 = torch.ops.aten.mul.Tensor(mul_40, mul_798) + sum_163 = torch.ops.aten.sum.dim_IntList(mul_800, [2], True); mul_800 = None + div_54 = torch.ops.aten.div.Tensor(mul_40, 4096) + mul_801 = torch.ops.aten.mul.Tensor(div_54, sum_163); div_54 = sum_163 = None + sub_81 = torch.ops.aten.sub.Tensor(mul_798, mul_801); mul_798 = mul_801 = None + mul_802 = torch.ops.aten.mul.Tensor(sub_81, rsqrt_10); sub_81 = rsqrt_10 = None + mul_803 = torch.ops.aten.mul.Tensor(convert_element_type_2526, mul_40); convert_element_type_2526 = mul_40 = None + sum_164 = torch.ops.aten.sum.dim_IntList(mul_803, [0, 1]); mul_803 = None + convert_element_type_2529 = torch.ops.prims.convert_element_type.default(mul_802, torch.bfloat16); mul_802 = None + add_317 = torch.ops.aten.add.Tensor(add_314, convert_element_type_2529); add_314 = convert_element_type_2529 = None + convert_element_type_default_11 = torch.ops.prims.convert_element_type.default(sum_164, torch.float32); sum_164 = None + reduce_scatter_tensor_244 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_11, 'avg', 64, '0'); convert_element_type_default_11 = None + wait_tensor_535 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_244); reduce_scatter_tensor_244 = None + view_1743 = torch.ops.aten.view.default(add_317, [16384, 4096]) + permute_1221 = torch.ops.aten.permute.default(view_1743, [1, 0]) + permute_50 = torch.ops.aten.permute.default(getitem_36, [0, 2, 1, 3]) + view_157 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16); primals_44 = None + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 64, '0'); convert_element_type_149 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_41, [1, 0]); wait_tensor_41 = None + view_159 = torch.ops.aten.view.default(view_157, [16384, 4096]); view_157 = None + mm_31 = torch.ops.aten.mm.default(view_159, permute_51) + view_160 = torch.ops.aten.view.default(mm_31, [2, 8192, 4096]); mm_31 = None + add_17 = torch.ops.aten.add.Tensor(add_15, view_160); view_160 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16); primals_45 = None + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 64, '0'); convert_element_type_152 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32); add_17 = None + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_42) + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + view_163 = torch.ops.aten.view.default(convert_element_type_154, [16384, 4096]); convert_element_type_154 = None + view_164 = torch.ops.aten.view.default(mm_32, [2, 8192, 14336]); mm_32 = None + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_164, torch.float32); view_164 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16); primals_47 = None + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 64, '0'); convert_element_type_160 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_33 = torch.ops.aten.mm.default(view_163, permute_53) + view_167 = torch.ops.aten.view.default(mm_33, [2, 8192, 14336]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_167) + view_169 = torch.ops.aten.view.default(mul_39, [16384, 14336]); mul_39 = None + mm_605 = torch.ops.aten.mm.default(permute_1221, view_169); permute_1221 = view_169 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16); primals_48 = None + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 64, '0'); convert_element_type_163 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + permute_1223 = torch.ops.aten.permute.default(permute_54, [1, 0]); permute_54 = None + mm_606 = torch.ops.aten.mm.default(view_1743, permute_1223); view_1743 = permute_1223 = None + view_1744 = torch.ops.aten.view.default(mm_606, [2, 8192, 14336]); mm_606 = None + convert_element_type_2536 = torch.ops.prims.convert_element_type.default(mm_605, torch.float32); mm_605 = None + reduce_scatter_tensor_245 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2536, 'avg', 64, '0'); convert_element_type_2536 = None + wait_tensor_536 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_245); reduce_scatter_tensor_245 = None + mul_804 = torch.ops.aten.mul.Tensor(view_1744, convert_element_type_159); convert_element_type_159 = None + mul_805 = torch.ops.aten.mul.Tensor(view_1744, view_167); view_1744 = view_167 = None + view_1745 = torch.ops.aten.view.default(mul_804, [16384, 14336]); mul_804 = None + permute_1225 = torch.ops.aten.permute.default(view_1745, [1, 0]) + mm_607 = torch.ops.aten.mm.default(permute_1225, view_163); permute_1225 = None + permute_1227 = torch.ops.aten.permute.default(permute_53, [1, 0]); permute_53 = None + mm_608 = torch.ops.aten.mm.default(view_1745, permute_1227); view_1745 = permute_1227 = None + view_1746 = torch.ops.aten.view.default(mm_608, [2, 8192, 4096]); mm_608 = None + convert_element_type_2541 = torch.ops.prims.convert_element_type.default(mm_607, torch.float32); mm_607 = None + reduce_scatter_tensor_246 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2541, 'avg', 64, '0'); convert_element_type_2541 = None + wait_tensor_537 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_246); reduce_scatter_tensor_246 = None + convert_element_type_2542 = torch.ops.prims.convert_element_type.default(mul_805, torch.float32); mul_805 = None + neg_27 = torch.ops.aten.neg.default(convert_element_type_158) + exp_27 = torch.ops.aten.exp.default(neg_27); neg_27 = None + add_318 = torch.ops.aten.add.Tensor(exp_27, 1); exp_27 = None + reciprocal_27 = torch.ops.aten.reciprocal.default(add_318); add_318 = None + mul_806 = torch.ops.aten.mul.Tensor(reciprocal_27, 1); reciprocal_27 = None + mul_807 = torch.ops.aten.mul.Tensor(convert_element_type_2542, mul_806); convert_element_type_2542 = None + sub_82 = torch.ops.aten.sub.Tensor(1, mul_806); mul_806 = None + mul_808 = torch.ops.aten.mul.Tensor(convert_element_type_158, sub_82); convert_element_type_158 = sub_82 = None + add_319 = torch.ops.aten.add.Tensor(mul_808, 1); mul_808 = None + mul_809 = torch.ops.aten.mul.Tensor(mul_807, add_319); mul_807 = add_319 = None + convert_element_type_2544 = torch.ops.prims.convert_element_type.default(mul_809, torch.bfloat16); mul_809 = None + view_1747 = torch.ops.aten.view.default(convert_element_type_2544, [16384, 14336]); convert_element_type_2544 = None + permute_1229 = torch.ops.aten.permute.default(view_1747, [1, 0]) + mm_609 = torch.ops.aten.mm.default(permute_1229, view_163); permute_1229 = view_163 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16); primals_46 = None + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 64, '0'); convert_element_type_155 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + permute_1231 = torch.ops.aten.permute.default(permute_52, [1, 0]); permute_52 = None + mm_610 = torch.ops.aten.mm.default(view_1747, permute_1231); view_1747 = permute_1231 = None + view_1748 = torch.ops.aten.view.default(mm_610, [2, 8192, 4096]); mm_610 = None + add_320 = torch.ops.aten.add.Tensor(view_1746, view_1748); view_1746 = view_1748 = None + convert_element_type_2549 = torch.ops.prims.convert_element_type.default(mm_609, torch.float32); mm_609 = None + reduce_scatter_tensor_247 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2549, 'avg', 64, '0'); convert_element_type_2549 = None + wait_tensor_538 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_247); reduce_scatter_tensor_247 = None + convert_element_type_2550 = torch.ops.prims.convert_element_type.default(add_320, torch.float32); add_320 = None + convert_element_type_2552 = torch.ops.prims.convert_element_type.default(wait_tensor_42, torch.float32); wait_tensor_42 = None + mul_810 = torch.ops.aten.mul.Tensor(convert_element_type_2550, convert_element_type_2552); convert_element_type_2552 = None + mul_812 = torch.ops.aten.mul.Tensor(mul_36, mul_810) + sum_165 = torch.ops.aten.sum.dim_IntList(mul_812, [2], True); mul_812 = None + div_55 = torch.ops.aten.div.Tensor(mul_36, 4096) + mul_813 = torch.ops.aten.mul.Tensor(div_55, sum_165); div_55 = sum_165 = None + sub_83 = torch.ops.aten.sub.Tensor(mul_810, mul_813); mul_810 = mul_813 = None + mul_814 = torch.ops.aten.mul.Tensor(sub_83, rsqrt_9); sub_83 = rsqrt_9 = None + mul_815 = torch.ops.aten.mul.Tensor(convert_element_type_2550, mul_36); convert_element_type_2550 = mul_36 = None + sum_166 = torch.ops.aten.sum.dim_IntList(mul_815, [0, 1]); mul_815 = None + convert_element_type_2553 = torch.ops.prims.convert_element_type.default(mul_814, torch.bfloat16); mul_814 = None + add_321 = torch.ops.aten.add.Tensor(add_317, convert_element_type_2553); add_317 = convert_element_type_2553 = None + convert_element_type_default_10 = torch.ops.prims.convert_element_type.default(sum_166, torch.float32); sum_166 = None + reduce_scatter_tensor_248 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_10, 'avg', 64, '0'); convert_element_type_default_10 = None + wait_tensor_539 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_248); reduce_scatter_tensor_248 = None + view_1749 = torch.ops.aten.view.default(add_321, [16384, 4096]) + permute_1233 = torch.ops.aten.permute.default(view_1749, [1, 0]) + mm_611 = torch.ops.aten.mm.default(permute_1233, view_159); permute_1233 = view_159 = None + permute_1235 = torch.ops.aten.permute.default(permute_51, [1, 0]); permute_51 = None + mm_612 = torch.ops.aten.mm.default(view_1749, permute_1235); view_1749 = permute_1235 = None + view_1750 = torch.ops.aten.view.default(mm_612, [2, 8192, 4096]); mm_612 = None + convert_element_type_2560 = torch.ops.prims.convert_element_type.default(mm_611, torch.float32); mm_611 = None + reduce_scatter_tensor_249 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2560, 'avg', 64, '0'); convert_element_type_2560 = None + wait_tensor_540 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_249); reduce_scatter_tensor_249 = None + view_1751 = torch.ops.aten.view.default(view_1750, [2, 8192, 32, 128]); view_1750 = None + permute_1237 = torch.ops.aten.permute.default(view_1751, [0, 2, 1, 3]); view_1751 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16); primals_40 = None + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 64, '0'); convert_element_type_133 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32); add_15 = None + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_37) + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + view_139 = torch.ops.aten.view.default(convert_element_type_135, [16384, 4096]); convert_element_type_135 = None + view_140 = torch.ops.aten.view.default(mm_28, [2, 8192, 4096]); mm_28 = None + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16); primals_42 = None + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 64, '0'); convert_element_type_139 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + mm_29 = torch.ops.aten.mm.default(view_139, permute_45) + view_143 = torch.ops.aten.view.default(mm_29, [2, 8192, 1024]); mm_29 = None + view_146 = torch.ops.aten.view.default(mm_30, [2, 8192, 1024]); mm_30 = None + view_147 = torch.ops.aten.view.default(view_140, [2, 8192, -1, 128]); view_140 = None + view_148 = torch.ops.aten.view.default(view_143, [2, 8192, -1, 128]); view_143 = None + view_149 = torch.ops.aten.view.default(view_146, [2, 8192, -1, 128]); view_146 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_147, torch.float32); view_147 = None + view_150 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 32, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_150); view_150 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_148, torch.float32); view_148 = None + view_151 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 8, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_151); view_151 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_16); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_153 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 32, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_16); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_154 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 8, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_153, torch.bfloat16); view_153 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_154, torch.bfloat16); view_154 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 8, 4, 128]); unsqueeze_8 = None + clone_8 = torch.ops.aten.clone.default(expand_8, memory_format = torch.contiguous_format); expand_8 = None + view_155 = torch.ops.aten.view.default(clone_8, [2, 8192, 32, 128]); clone_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_149, 3); view_149 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 8, 4, 128]); unsqueeze_9 = None + clone_9 = torch.ops.aten.clone.default(expand_9, memory_format = torch.contiguous_format); expand_9 = None + view_156 = torch.ops.aten.view.default(clone_9, [2, 8192, 32, 128]); clone_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_155, [0, 2, 1, 3]); view_155 = None + permute_49 = torch.ops.aten.permute.default(view_156, [0, 2, 1, 3]); view_156 = None + _scaled_dot_product_cudnn_attention_backward_27 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1237, permute_47, permute_48, permute_49, getitem_36, getitem_37, getitem_42, getitem_43, None, None, None, 8192, 8192, 0.0, True); permute_1237 = permute_47 = permute_48 = permute_49 = getitem_36 = getitem_37 = getitem_42 = getitem_43 = None + getitem_369 = _scaled_dot_product_cudnn_attention_backward_27[0] + getitem_370 = _scaled_dot_product_cudnn_attention_backward_27[1] + getitem_371 = _scaled_dot_product_cudnn_attention_backward_27[2]; _scaled_dot_product_cudnn_attention_backward_27 = None + permute_1238 = torch.ops.aten.permute.default(getitem_371, [0, 2, 1, 3]); getitem_371 = None + permute_1239 = torch.ops.aten.permute.default(getitem_370, [0, 2, 1, 3]); getitem_370 = None + permute_1240 = torch.ops.aten.permute.default(getitem_369, [0, 2, 1, 3]); getitem_369 = None + view_1752 = torch.ops.aten.view.default(permute_1238, [2, 8192, 8, 4, 128]); permute_1238 = None + sum_167 = torch.ops.aten.sum.dim_IntList(view_1752, [3], True); view_1752 = None + squeeze_54 = torch.ops.aten.squeeze.dim(sum_167, 3); sum_167 = None + view_1753 = torch.ops.aten.view.default(permute_1239, [2, 8192, 8, 4, 128]); permute_1239 = None + sum_168 = torch.ops.aten.sum.dim_IntList(view_1753, [3], True); view_1753 = None + squeeze_55 = torch.ops.aten.squeeze.dim(sum_168, 3); sum_168 = None + convert_element_type_2561 = torch.ops.prims.convert_element_type.default(squeeze_55, torch.float32); squeeze_55 = None + convert_element_type_2562 = torch.ops.prims.convert_element_type.default(permute_1240, torch.float32); permute_1240 = None + view_1754 = torch.ops.aten.view.default(convert_element_type_2561, [2, 8192, 8, 64, 2]); convert_element_type_2561 = None + view_as_complex_118 = torch.ops.aten.view_as_complex.default(view_1754); view_1754 = None + mul_816 = torch.ops.aten.mul.Tensor(view_as_complex_118, _conj); view_as_complex_118 = None + view_1755 = torch.ops.aten.view.default(convert_element_type_2562, [2, 8192, 32, 64, 2]); convert_element_type_2562 = None + view_as_complex_119 = torch.ops.aten.view_as_complex.default(view_1755); view_1755 = None + mul_817 = torch.ops.aten.mul.Tensor(view_as_complex_119, _conj); view_as_complex_119 = None + view_as_real_118 = torch.ops.aten.view_as_real.default(mul_816); mul_816 = None + view_1756 = torch.ops.aten.view.default(view_as_real_118, [2, 8192, 8, 128]); view_as_real_118 = None + convert_element_type_2563 = torch.ops.prims.convert_element_type.default(view_1756, torch.bfloat16); view_1756 = None + view_as_real_119 = torch.ops.aten.view_as_real.default(mul_817); mul_817 = None + view_1757 = torch.ops.aten.view.default(view_as_real_119, [2, 8192, 32, 128]); view_as_real_119 = None + convert_element_type_2564 = torch.ops.prims.convert_element_type.default(view_1757, torch.bfloat16); view_1757 = None + view_1758 = torch.ops.aten.view.default(squeeze_54, [2, 8192, 1024]); squeeze_54 = None + view_1759 = torch.ops.aten.view.default(convert_element_type_2563, [2, 8192, 1024]); convert_element_type_2563 = None + view_1760 = torch.ops.aten.view.default(convert_element_type_2564, [2, 8192, 4096]); convert_element_type_2564 = None + view_1761 = torch.ops.aten.view.default(view_1758, [16384, 1024]); view_1758 = None + permute_1241 = torch.ops.aten.permute.default(view_1761, [1, 0]) + mm_613 = torch.ops.aten.mm.default(permute_1241, view_139); permute_1241 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16); primals_43 = None + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 64, '0'); convert_element_type_142 = None + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_40, [1, 0]); wait_tensor_40 = None + permute_1243 = torch.ops.aten.permute.default(permute_46, [1, 0]); permute_46 = None + mm_614 = torch.ops.aten.mm.default(view_1761, permute_1243); view_1761 = permute_1243 = None + view_1762 = torch.ops.aten.view.default(mm_614, [2, 8192, 4096]); mm_614 = None + convert_element_type_2569 = torch.ops.prims.convert_element_type.default(mm_613, torch.float32); mm_613 = None + reduce_scatter_tensor_250 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2569, 'avg', 64, '0'); convert_element_type_2569 = None + wait_tensor_541 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_250); reduce_scatter_tensor_250 = None + view_1763 = torch.ops.aten.view.default(view_1759, [16384, 1024]); view_1759 = None + permute_1245 = torch.ops.aten.permute.default(view_1763, [1, 0]) + mm_615 = torch.ops.aten.mm.default(permute_1245, view_139); permute_1245 = None + permute_1247 = torch.ops.aten.permute.default(permute_45, [1, 0]); permute_45 = None + mm_616 = torch.ops.aten.mm.default(view_1763, permute_1247); view_1763 = permute_1247 = None + view_1764 = torch.ops.aten.view.default(mm_616, [2, 8192, 4096]); mm_616 = None + add_322 = torch.ops.aten.add.Tensor(view_1762, view_1764); view_1762 = view_1764 = None + convert_element_type_2574 = torch.ops.prims.convert_element_type.default(mm_615, torch.float32); mm_615 = None + reduce_scatter_tensor_251 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2574, 'avg', 64, '0'); convert_element_type_2574 = None + wait_tensor_542 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_251); reduce_scatter_tensor_251 = None + view_1765 = torch.ops.aten.view.default(view_1760, [16384, 4096]); view_1760 = None + permute_1249 = torch.ops.aten.permute.default(view_1765, [1, 0]) + mm_617 = torch.ops.aten.mm.default(permute_1249, view_139); permute_1249 = view_139 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16); primals_41 = None + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 64, '0'); convert_element_type_136 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + permute_1251 = torch.ops.aten.permute.default(permute_44, [1, 0]); permute_44 = None + mm_618 = torch.ops.aten.mm.default(view_1765, permute_1251); view_1765 = permute_1251 = None + view_1766 = torch.ops.aten.view.default(mm_618, [2, 8192, 4096]); mm_618 = None + add_323 = torch.ops.aten.add.Tensor(add_322, view_1766); add_322 = view_1766 = None + convert_element_type_2579 = torch.ops.prims.convert_element_type.default(mm_617, torch.float32); mm_617 = None + reduce_scatter_tensor_252 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2579, 'avg', 64, '0'); convert_element_type_2579 = None + wait_tensor_543 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_252); reduce_scatter_tensor_252 = None + convert_element_type_2580 = torch.ops.prims.convert_element_type.default(add_323, torch.float32); add_323 = None + convert_element_type_2582 = torch.ops.prims.convert_element_type.default(wait_tensor_37, torch.float32); wait_tensor_37 = None + mul_818 = torch.ops.aten.mul.Tensor(convert_element_type_2580, convert_element_type_2582); convert_element_type_2582 = None + mul_820 = torch.ops.aten.mul.Tensor(mul_32, mul_818) + sum_169 = torch.ops.aten.sum.dim_IntList(mul_820, [2], True); mul_820 = None + div_56 = torch.ops.aten.div.Tensor(mul_32, 4096) + mul_821 = torch.ops.aten.mul.Tensor(div_56, sum_169); div_56 = sum_169 = None + sub_84 = torch.ops.aten.sub.Tensor(mul_818, mul_821); mul_818 = mul_821 = None + mul_822 = torch.ops.aten.mul.Tensor(sub_84, rsqrt_8); sub_84 = rsqrt_8 = None + mul_823 = torch.ops.aten.mul.Tensor(convert_element_type_2580, mul_32); convert_element_type_2580 = mul_32 = None + sum_170 = torch.ops.aten.sum.dim_IntList(mul_823, [0, 1]); mul_823 = None + convert_element_type_2583 = torch.ops.prims.convert_element_type.default(mul_822, torch.bfloat16); mul_822 = None + add_324 = torch.ops.aten.add.Tensor(add_321, convert_element_type_2583); add_321 = convert_element_type_2583 = None + convert_element_type_default_9 = torch.ops.prims.convert_element_type.default(sum_170, torch.float32); sum_170 = None + reduce_scatter_tensor_253 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_9, 'avg', 64, '0'); convert_element_type_default_9 = None + wait_tensor_544 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_253); reduce_scatter_tensor_253 = None + view_1767 = torch.ops.aten.view.default(add_324, [16384, 4096]) + permute_1253 = torch.ops.aten.permute.default(view_1767, [1, 0]) + permute_39 = torch.ops.aten.permute.default(getitem_27, [0, 2, 1, 3]) + view_123 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16); primals_35 = None + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 64, '0'); convert_element_type_116 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + view_125 = torch.ops.aten.view.default(view_123, [16384, 4096]); view_123 = None + mm_24 = torch.ops.aten.mm.default(view_125, permute_40) + view_126 = torch.ops.aten.view.default(mm_24, [2, 8192, 4096]); mm_24 = None + add_13 = torch.ops.aten.add.Tensor(add_11, view_126); view_126 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16); primals_36 = None + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 64, '0'); convert_element_type_119 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32); add_13 = None + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_33) + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + view_129 = torch.ops.aten.view.default(convert_element_type_121, [16384, 4096]); convert_element_type_121 = None + view_130 = torch.ops.aten.view.default(mm_25, [2, 8192, 14336]); mm_25 = None + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_130, torch.float32); view_130 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16); primals_38 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 64, '0'); convert_element_type_127 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_35, [1, 0]); wait_tensor_35 = None + mm_26 = torch.ops.aten.mm.default(view_129, permute_42) + view_133 = torch.ops.aten.view.default(mm_26, [2, 8192, 14336]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_133) + view_135 = torch.ops.aten.view.default(mul_31, [16384, 14336]); mul_31 = None + mm_619 = torch.ops.aten.mm.default(permute_1253, view_135); permute_1253 = view_135 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16); primals_39 = None + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 64, '0'); convert_element_type_130 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_36, [1, 0]); wait_tensor_36 = None + permute_1255 = torch.ops.aten.permute.default(permute_43, [1, 0]); permute_43 = None + mm_620 = torch.ops.aten.mm.default(view_1767, permute_1255); view_1767 = permute_1255 = None + view_1768 = torch.ops.aten.view.default(mm_620, [2, 8192, 14336]); mm_620 = None + convert_element_type_2590 = torch.ops.prims.convert_element_type.default(mm_619, torch.float32); mm_619 = None + reduce_scatter_tensor_254 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2590, 'avg', 64, '0'); convert_element_type_2590 = None + wait_tensor_545 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_254); reduce_scatter_tensor_254 = None + mul_824 = torch.ops.aten.mul.Tensor(view_1768, convert_element_type_126); convert_element_type_126 = None + mul_825 = torch.ops.aten.mul.Tensor(view_1768, view_133); view_1768 = view_133 = None + view_1769 = torch.ops.aten.view.default(mul_824, [16384, 14336]); mul_824 = None + permute_1257 = torch.ops.aten.permute.default(view_1769, [1, 0]) + mm_621 = torch.ops.aten.mm.default(permute_1257, view_129); permute_1257 = None + permute_1259 = torch.ops.aten.permute.default(permute_42, [1, 0]); permute_42 = None + mm_622 = torch.ops.aten.mm.default(view_1769, permute_1259); view_1769 = permute_1259 = None + view_1770 = torch.ops.aten.view.default(mm_622, [2, 8192, 4096]); mm_622 = None + convert_element_type_2595 = torch.ops.prims.convert_element_type.default(mm_621, torch.float32); mm_621 = None + reduce_scatter_tensor_255 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2595, 'avg', 64, '0'); convert_element_type_2595 = None + wait_tensor_546 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_255); reduce_scatter_tensor_255 = None + convert_element_type_2596 = torch.ops.prims.convert_element_type.default(mul_825, torch.float32); mul_825 = None + neg_28 = torch.ops.aten.neg.default(convert_element_type_125) + exp_28 = torch.ops.aten.exp.default(neg_28); neg_28 = None + add_325 = torch.ops.aten.add.Tensor(exp_28, 1); exp_28 = None + reciprocal_28 = torch.ops.aten.reciprocal.default(add_325); add_325 = None + mul_826 = torch.ops.aten.mul.Tensor(reciprocal_28, 1); reciprocal_28 = None + mul_827 = torch.ops.aten.mul.Tensor(convert_element_type_2596, mul_826); convert_element_type_2596 = None + sub_85 = torch.ops.aten.sub.Tensor(1, mul_826); mul_826 = None + mul_828 = torch.ops.aten.mul.Tensor(convert_element_type_125, sub_85); convert_element_type_125 = sub_85 = None + add_326 = torch.ops.aten.add.Tensor(mul_828, 1); mul_828 = None + mul_829 = torch.ops.aten.mul.Tensor(mul_827, add_326); mul_827 = add_326 = None + convert_element_type_2598 = torch.ops.prims.convert_element_type.default(mul_829, torch.bfloat16); mul_829 = None + view_1771 = torch.ops.aten.view.default(convert_element_type_2598, [16384, 14336]); convert_element_type_2598 = None + permute_1261 = torch.ops.aten.permute.default(view_1771, [1, 0]) + mm_623 = torch.ops.aten.mm.default(permute_1261, view_129); permute_1261 = view_129 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16); primals_37 = None + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 64, '0'); convert_element_type_122 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_34, [1, 0]); wait_tensor_34 = None + permute_1263 = torch.ops.aten.permute.default(permute_41, [1, 0]); permute_41 = None + mm_624 = torch.ops.aten.mm.default(view_1771, permute_1263); view_1771 = permute_1263 = None + view_1772 = torch.ops.aten.view.default(mm_624, [2, 8192, 4096]); mm_624 = None + add_327 = torch.ops.aten.add.Tensor(view_1770, view_1772); view_1770 = view_1772 = None + convert_element_type_2603 = torch.ops.prims.convert_element_type.default(mm_623, torch.float32); mm_623 = None + reduce_scatter_tensor_256 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2603, 'avg', 64, '0'); convert_element_type_2603 = None + wait_tensor_547 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_256); reduce_scatter_tensor_256 = None + convert_element_type_2604 = torch.ops.prims.convert_element_type.default(add_327, torch.float32); add_327 = None + convert_element_type_2606 = torch.ops.prims.convert_element_type.default(wait_tensor_33, torch.float32); wait_tensor_33 = None + mul_830 = torch.ops.aten.mul.Tensor(convert_element_type_2604, convert_element_type_2606); convert_element_type_2606 = None + mul_832 = torch.ops.aten.mul.Tensor(mul_28, mul_830) + sum_171 = torch.ops.aten.sum.dim_IntList(mul_832, [2], True); mul_832 = None + div_57 = torch.ops.aten.div.Tensor(mul_28, 4096) + mul_833 = torch.ops.aten.mul.Tensor(div_57, sum_171); div_57 = sum_171 = None + sub_86 = torch.ops.aten.sub.Tensor(mul_830, mul_833); mul_830 = mul_833 = None + mul_834 = torch.ops.aten.mul.Tensor(sub_86, rsqrt_7); sub_86 = rsqrt_7 = None + mul_835 = torch.ops.aten.mul.Tensor(convert_element_type_2604, mul_28); convert_element_type_2604 = mul_28 = None + sum_172 = torch.ops.aten.sum.dim_IntList(mul_835, [0, 1]); mul_835 = None + convert_element_type_2607 = torch.ops.prims.convert_element_type.default(mul_834, torch.bfloat16); mul_834 = None + add_328 = torch.ops.aten.add.Tensor(add_324, convert_element_type_2607); add_324 = convert_element_type_2607 = None + convert_element_type_default_8 = torch.ops.prims.convert_element_type.default(sum_172, torch.float32); sum_172 = None + reduce_scatter_tensor_257 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_8, 'avg', 64, '0'); convert_element_type_default_8 = None + wait_tensor_548 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_257); reduce_scatter_tensor_257 = None + view_1773 = torch.ops.aten.view.default(add_328, [16384, 4096]) + permute_1265 = torch.ops.aten.permute.default(view_1773, [1, 0]) + mm_625 = torch.ops.aten.mm.default(permute_1265, view_125); permute_1265 = view_125 = None + permute_1267 = torch.ops.aten.permute.default(permute_40, [1, 0]); permute_40 = None + mm_626 = torch.ops.aten.mm.default(view_1773, permute_1267); view_1773 = permute_1267 = None + view_1774 = torch.ops.aten.view.default(mm_626, [2, 8192, 4096]); mm_626 = None + convert_element_type_2614 = torch.ops.prims.convert_element_type.default(mm_625, torch.float32); mm_625 = None + reduce_scatter_tensor_258 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2614, 'avg', 64, '0'); convert_element_type_2614 = None + wait_tensor_549 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_258); reduce_scatter_tensor_258 = None + view_1775 = torch.ops.aten.view.default(view_1774, [2, 8192, 32, 128]); view_1774 = None + permute_1269 = torch.ops.aten.permute.default(view_1775, [0, 2, 1, 3]); view_1775 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16); primals_31 = None + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 64, '0'); convert_element_type_100 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32); add_11 = None + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_28) + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + view_105 = torch.ops.aten.view.default(convert_element_type_102, [16384, 4096]); convert_element_type_102 = None + view_106 = torch.ops.aten.view.default(mm_21, [2, 8192, 4096]); mm_21 = None + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16); primals_33 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 64, '0'); convert_element_type_106 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + mm_22 = torch.ops.aten.mm.default(view_105, permute_34) + view_109 = torch.ops.aten.view.default(mm_22, [2, 8192, 1024]); mm_22 = None + view_112 = torch.ops.aten.view.default(mm_23, [2, 8192, 1024]); mm_23 = None + view_113 = torch.ops.aten.view.default(view_106, [2, 8192, -1, 128]); view_106 = None + view_114 = torch.ops.aten.view.default(view_109, [2, 8192, -1, 128]); view_109 = None + view_115 = torch.ops.aten.view.default(view_112, [2, 8192, -1, 128]); view_112 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_113, torch.float32); view_113 = None + view_116 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 32, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_116); view_116 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_114, torch.float32); view_114 = None + view_117 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 8, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_117); view_117 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_16); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_119 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 32, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_16); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_120 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 8, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_119, torch.bfloat16); view_119 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_120, torch.bfloat16); view_120 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 8, 4, 128]); unsqueeze_6 = None + clone_6 = torch.ops.aten.clone.default(expand_6, memory_format = torch.contiguous_format); expand_6 = None + view_121 = torch.ops.aten.view.default(clone_6, [2, 8192, 32, 128]); clone_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_115, 3); view_115 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 8, 4, 128]); unsqueeze_7 = None + clone_7 = torch.ops.aten.clone.default(expand_7, memory_format = torch.contiguous_format); expand_7 = None + view_122 = torch.ops.aten.view.default(clone_7, [2, 8192, 32, 128]); clone_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_121, [0, 2, 1, 3]); view_121 = None + permute_38 = torch.ops.aten.permute.default(view_122, [0, 2, 1, 3]); view_122 = None + _scaled_dot_product_cudnn_attention_backward_28 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1269, permute_36, permute_37, permute_38, getitem_27, getitem_28, getitem_33, getitem_34, None, None, None, 8192, 8192, 0.0, True); permute_1269 = permute_36 = permute_37 = permute_38 = getitem_27 = getitem_28 = getitem_33 = getitem_34 = None + getitem_372 = _scaled_dot_product_cudnn_attention_backward_28[0] + getitem_373 = _scaled_dot_product_cudnn_attention_backward_28[1] + getitem_374 = _scaled_dot_product_cudnn_attention_backward_28[2]; _scaled_dot_product_cudnn_attention_backward_28 = None + permute_1270 = torch.ops.aten.permute.default(getitem_374, [0, 2, 1, 3]); getitem_374 = None + permute_1271 = torch.ops.aten.permute.default(getitem_373, [0, 2, 1, 3]); getitem_373 = None + permute_1272 = torch.ops.aten.permute.default(getitem_372, [0, 2, 1, 3]); getitem_372 = None + view_1776 = torch.ops.aten.view.default(permute_1270, [2, 8192, 8, 4, 128]); permute_1270 = None + sum_173 = torch.ops.aten.sum.dim_IntList(view_1776, [3], True); view_1776 = None + squeeze_56 = torch.ops.aten.squeeze.dim(sum_173, 3); sum_173 = None + view_1777 = torch.ops.aten.view.default(permute_1271, [2, 8192, 8, 4, 128]); permute_1271 = None + sum_174 = torch.ops.aten.sum.dim_IntList(view_1777, [3], True); view_1777 = None + squeeze_57 = torch.ops.aten.squeeze.dim(sum_174, 3); sum_174 = None + convert_element_type_2615 = torch.ops.prims.convert_element_type.default(squeeze_57, torch.float32); squeeze_57 = None + convert_element_type_2616 = torch.ops.prims.convert_element_type.default(permute_1272, torch.float32); permute_1272 = None + view_1778 = torch.ops.aten.view.default(convert_element_type_2615, [2, 8192, 8, 64, 2]); convert_element_type_2615 = None + view_as_complex_120 = torch.ops.aten.view_as_complex.default(view_1778); view_1778 = None + mul_836 = torch.ops.aten.mul.Tensor(view_as_complex_120, _conj); view_as_complex_120 = None + view_1779 = torch.ops.aten.view.default(convert_element_type_2616, [2, 8192, 32, 64, 2]); convert_element_type_2616 = None + view_as_complex_121 = torch.ops.aten.view_as_complex.default(view_1779); view_1779 = None + mul_837 = torch.ops.aten.mul.Tensor(view_as_complex_121, _conj); view_as_complex_121 = None + view_as_real_120 = torch.ops.aten.view_as_real.default(mul_836); mul_836 = None + view_1780 = torch.ops.aten.view.default(view_as_real_120, [2, 8192, 8, 128]); view_as_real_120 = None + convert_element_type_2617 = torch.ops.prims.convert_element_type.default(view_1780, torch.bfloat16); view_1780 = None + view_as_real_121 = torch.ops.aten.view_as_real.default(mul_837); mul_837 = None + view_1781 = torch.ops.aten.view.default(view_as_real_121, [2, 8192, 32, 128]); view_as_real_121 = None + convert_element_type_2618 = torch.ops.prims.convert_element_type.default(view_1781, torch.bfloat16); view_1781 = None + view_1782 = torch.ops.aten.view.default(squeeze_56, [2, 8192, 1024]); squeeze_56 = None + view_1783 = torch.ops.aten.view.default(convert_element_type_2617, [2, 8192, 1024]); convert_element_type_2617 = None + view_1784 = torch.ops.aten.view.default(convert_element_type_2618, [2, 8192, 4096]); convert_element_type_2618 = None + view_1785 = torch.ops.aten.view.default(view_1782, [16384, 1024]); view_1782 = None + permute_1273 = torch.ops.aten.permute.default(view_1785, [1, 0]) + mm_627 = torch.ops.aten.mm.default(permute_1273, view_105); permute_1273 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16); primals_34 = None + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 64, '0'); convert_element_type_109 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + permute_1275 = torch.ops.aten.permute.default(permute_35, [1, 0]); permute_35 = None + mm_628 = torch.ops.aten.mm.default(view_1785, permute_1275); view_1785 = permute_1275 = None + view_1786 = torch.ops.aten.view.default(mm_628, [2, 8192, 4096]); mm_628 = None + convert_element_type_2623 = torch.ops.prims.convert_element_type.default(mm_627, torch.float32); mm_627 = None + reduce_scatter_tensor_259 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2623, 'avg', 64, '0'); convert_element_type_2623 = None + wait_tensor_550 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_259); reduce_scatter_tensor_259 = None + view_1787 = torch.ops.aten.view.default(view_1783, [16384, 1024]); view_1783 = None + permute_1277 = torch.ops.aten.permute.default(view_1787, [1, 0]) + mm_629 = torch.ops.aten.mm.default(permute_1277, view_105); permute_1277 = None + permute_1279 = torch.ops.aten.permute.default(permute_34, [1, 0]); permute_34 = None + mm_630 = torch.ops.aten.mm.default(view_1787, permute_1279); view_1787 = permute_1279 = None + view_1788 = torch.ops.aten.view.default(mm_630, [2, 8192, 4096]); mm_630 = None + add_329 = torch.ops.aten.add.Tensor(view_1786, view_1788); view_1786 = view_1788 = None + convert_element_type_2628 = torch.ops.prims.convert_element_type.default(mm_629, torch.float32); mm_629 = None + reduce_scatter_tensor_260 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2628, 'avg', 64, '0'); convert_element_type_2628 = None + wait_tensor_551 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_260); reduce_scatter_tensor_260 = None + view_1789 = torch.ops.aten.view.default(view_1784, [16384, 4096]); view_1784 = None + permute_1281 = torch.ops.aten.permute.default(view_1789, [1, 0]) + mm_631 = torch.ops.aten.mm.default(permute_1281, view_105); permute_1281 = view_105 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16); primals_32 = None + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 64, '0'); convert_element_type_103 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_29, [1, 0]); wait_tensor_29 = None + permute_1283 = torch.ops.aten.permute.default(permute_33, [1, 0]); permute_33 = None + mm_632 = torch.ops.aten.mm.default(view_1789, permute_1283); view_1789 = permute_1283 = None + view_1790 = torch.ops.aten.view.default(mm_632, [2, 8192, 4096]); mm_632 = None + add_330 = torch.ops.aten.add.Tensor(add_329, view_1790); add_329 = view_1790 = None + convert_element_type_2633 = torch.ops.prims.convert_element_type.default(mm_631, torch.float32); mm_631 = None + reduce_scatter_tensor_261 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2633, 'avg', 64, '0'); convert_element_type_2633 = None + wait_tensor_552 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_261); reduce_scatter_tensor_261 = None + convert_element_type_2634 = torch.ops.prims.convert_element_type.default(add_330, torch.float32); add_330 = None + convert_element_type_2636 = torch.ops.prims.convert_element_type.default(wait_tensor_28, torch.float32); wait_tensor_28 = None + mul_838 = torch.ops.aten.mul.Tensor(convert_element_type_2634, convert_element_type_2636); convert_element_type_2636 = None + mul_840 = torch.ops.aten.mul.Tensor(mul_24, mul_838) + sum_175 = torch.ops.aten.sum.dim_IntList(mul_840, [2], True); mul_840 = None + div_58 = torch.ops.aten.div.Tensor(mul_24, 4096) + mul_841 = torch.ops.aten.mul.Tensor(div_58, sum_175); div_58 = sum_175 = None + sub_87 = torch.ops.aten.sub.Tensor(mul_838, mul_841); mul_838 = mul_841 = None + mul_842 = torch.ops.aten.mul.Tensor(sub_87, rsqrt_6); sub_87 = rsqrt_6 = None + mul_843 = torch.ops.aten.mul.Tensor(convert_element_type_2634, mul_24); convert_element_type_2634 = mul_24 = None + sum_176 = torch.ops.aten.sum.dim_IntList(mul_843, [0, 1]); mul_843 = None + convert_element_type_2637 = torch.ops.prims.convert_element_type.default(mul_842, torch.bfloat16); mul_842 = None + add_331 = torch.ops.aten.add.Tensor(add_328, convert_element_type_2637); add_328 = convert_element_type_2637 = None + convert_element_type_default_7 = torch.ops.prims.convert_element_type.default(sum_176, torch.float32); sum_176 = None + reduce_scatter_tensor_262 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_7, 'avg', 64, '0'); convert_element_type_default_7 = None + wait_tensor_553 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_262); reduce_scatter_tensor_262 = None + view_1791 = torch.ops.aten.view.default(add_331, [16384, 4096]) + permute_1285 = torch.ops.aten.permute.default(view_1791, [1, 0]) + permute_28 = torch.ops.aten.permute.default(getitem_18, [0, 2, 1, 3]) + view_89 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16); primals_26 = None + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 64, '0'); convert_element_type_83 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_23, [1, 0]); wait_tensor_23 = None + view_91 = torch.ops.aten.view.default(view_89, [16384, 4096]); view_89 = None + mm_17 = torch.ops.aten.mm.default(view_91, permute_29) + view_92 = torch.ops.aten.view.default(mm_17, [2, 8192, 4096]); mm_17 = None + add_9 = torch.ops.aten.add.Tensor(add_7, view_92); view_92 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16); primals_27 = None + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 64, '0'); convert_element_type_86 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32); add_9 = None + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_24) + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + view_95 = torch.ops.aten.view.default(convert_element_type_88, [16384, 4096]); convert_element_type_88 = None + view_96 = torch.ops.aten.view.default(mm_18, [2, 8192, 14336]); mm_18 = None + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_96, torch.float32); view_96 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16); primals_29 = None + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 64, '0'); convert_element_type_94 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + mm_19 = torch.ops.aten.mm.default(view_95, permute_31) + view_99 = torch.ops.aten.view.default(mm_19, [2, 8192, 14336]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_99) + view_101 = torch.ops.aten.view.default(mul_23, [16384, 14336]); mul_23 = None + mm_633 = torch.ops.aten.mm.default(permute_1285, view_101); permute_1285 = view_101 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16); primals_30 = None + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 64, '0'); convert_element_type_97 = None + wait_tensor_27 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_27, [1, 0]); wait_tensor_27 = None + permute_1287 = torch.ops.aten.permute.default(permute_32, [1, 0]); permute_32 = None + mm_634 = torch.ops.aten.mm.default(view_1791, permute_1287); view_1791 = permute_1287 = None + view_1792 = torch.ops.aten.view.default(mm_634, [2, 8192, 14336]); mm_634 = None + convert_element_type_2644 = torch.ops.prims.convert_element_type.default(mm_633, torch.float32); mm_633 = None + reduce_scatter_tensor_263 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2644, 'avg', 64, '0'); convert_element_type_2644 = None + wait_tensor_554 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_263); reduce_scatter_tensor_263 = None + mul_844 = torch.ops.aten.mul.Tensor(view_1792, convert_element_type_93); convert_element_type_93 = None + mul_845 = torch.ops.aten.mul.Tensor(view_1792, view_99); view_1792 = view_99 = None + view_1793 = torch.ops.aten.view.default(mul_844, [16384, 14336]); mul_844 = None + permute_1289 = torch.ops.aten.permute.default(view_1793, [1, 0]) + mm_635 = torch.ops.aten.mm.default(permute_1289, view_95); permute_1289 = None + permute_1291 = torch.ops.aten.permute.default(permute_31, [1, 0]); permute_31 = None + mm_636 = torch.ops.aten.mm.default(view_1793, permute_1291); view_1793 = permute_1291 = None + view_1794 = torch.ops.aten.view.default(mm_636, [2, 8192, 4096]); mm_636 = None + convert_element_type_2649 = torch.ops.prims.convert_element_type.default(mm_635, torch.float32); mm_635 = None + reduce_scatter_tensor_264 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2649, 'avg', 64, '0'); convert_element_type_2649 = None + wait_tensor_555 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_264); reduce_scatter_tensor_264 = None + convert_element_type_2650 = torch.ops.prims.convert_element_type.default(mul_845, torch.float32); mul_845 = None + neg_29 = torch.ops.aten.neg.default(convert_element_type_92) + exp_29 = torch.ops.aten.exp.default(neg_29); neg_29 = None + add_332 = torch.ops.aten.add.Tensor(exp_29, 1); exp_29 = None + reciprocal_29 = torch.ops.aten.reciprocal.default(add_332); add_332 = None + mul_846 = torch.ops.aten.mul.Tensor(reciprocal_29, 1); reciprocal_29 = None + mul_847 = torch.ops.aten.mul.Tensor(convert_element_type_2650, mul_846); convert_element_type_2650 = None + sub_88 = torch.ops.aten.sub.Tensor(1, mul_846); mul_846 = None + mul_848 = torch.ops.aten.mul.Tensor(convert_element_type_92, sub_88); convert_element_type_92 = sub_88 = None + add_333 = torch.ops.aten.add.Tensor(mul_848, 1); mul_848 = None + mul_849 = torch.ops.aten.mul.Tensor(mul_847, add_333); mul_847 = add_333 = None + convert_element_type_2652 = torch.ops.prims.convert_element_type.default(mul_849, torch.bfloat16); mul_849 = None + view_1795 = torch.ops.aten.view.default(convert_element_type_2652, [16384, 14336]); convert_element_type_2652 = None + permute_1293 = torch.ops.aten.permute.default(view_1795, [1, 0]) + mm_637 = torch.ops.aten.mm.default(permute_1293, view_95); permute_1293 = view_95 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16); primals_28 = None + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 64, '0'); convert_element_type_89 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + permute_1295 = torch.ops.aten.permute.default(permute_30, [1, 0]); permute_30 = None + mm_638 = torch.ops.aten.mm.default(view_1795, permute_1295); view_1795 = permute_1295 = None + view_1796 = torch.ops.aten.view.default(mm_638, [2, 8192, 4096]); mm_638 = None + add_334 = torch.ops.aten.add.Tensor(view_1794, view_1796); view_1794 = view_1796 = None + convert_element_type_2657 = torch.ops.prims.convert_element_type.default(mm_637, torch.float32); mm_637 = None + reduce_scatter_tensor_265 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2657, 'avg', 64, '0'); convert_element_type_2657 = None + wait_tensor_556 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_265); reduce_scatter_tensor_265 = None + convert_element_type_2658 = torch.ops.prims.convert_element_type.default(add_334, torch.float32); add_334 = None + convert_element_type_2660 = torch.ops.prims.convert_element_type.default(wait_tensor_24, torch.float32); wait_tensor_24 = None + mul_850 = torch.ops.aten.mul.Tensor(convert_element_type_2658, convert_element_type_2660); convert_element_type_2660 = None + mul_852 = torch.ops.aten.mul.Tensor(mul_20, mul_850) + sum_177 = torch.ops.aten.sum.dim_IntList(mul_852, [2], True); mul_852 = None + div_59 = torch.ops.aten.div.Tensor(mul_20, 4096) + mul_853 = torch.ops.aten.mul.Tensor(div_59, sum_177); div_59 = sum_177 = None + sub_89 = torch.ops.aten.sub.Tensor(mul_850, mul_853); mul_850 = mul_853 = None + mul_854 = torch.ops.aten.mul.Tensor(sub_89, rsqrt_5); sub_89 = rsqrt_5 = None + mul_855 = torch.ops.aten.mul.Tensor(convert_element_type_2658, mul_20); convert_element_type_2658 = mul_20 = None + sum_178 = torch.ops.aten.sum.dim_IntList(mul_855, [0, 1]); mul_855 = None + convert_element_type_2661 = torch.ops.prims.convert_element_type.default(mul_854, torch.bfloat16); mul_854 = None + add_335 = torch.ops.aten.add.Tensor(add_331, convert_element_type_2661); add_331 = convert_element_type_2661 = None + convert_element_type_default_6 = torch.ops.prims.convert_element_type.default(sum_178, torch.float32); sum_178 = None + reduce_scatter_tensor_266 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_6, 'avg', 64, '0'); convert_element_type_default_6 = None + wait_tensor_557 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_266); reduce_scatter_tensor_266 = None + view_1797 = torch.ops.aten.view.default(add_335, [16384, 4096]) + permute_1297 = torch.ops.aten.permute.default(view_1797, [1, 0]) + mm_639 = torch.ops.aten.mm.default(permute_1297, view_91); permute_1297 = view_91 = None + permute_1299 = torch.ops.aten.permute.default(permute_29, [1, 0]); permute_29 = None + mm_640 = torch.ops.aten.mm.default(view_1797, permute_1299); view_1797 = permute_1299 = None + view_1798 = torch.ops.aten.view.default(mm_640, [2, 8192, 4096]); mm_640 = None + convert_element_type_2668 = torch.ops.prims.convert_element_type.default(mm_639, torch.float32); mm_639 = None + reduce_scatter_tensor_267 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2668, 'avg', 64, '0'); convert_element_type_2668 = None + wait_tensor_558 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_267); reduce_scatter_tensor_267 = None + view_1799 = torch.ops.aten.view.default(view_1798, [2, 8192, 32, 128]); view_1798 = None + permute_1301 = torch.ops.aten.permute.default(view_1799, [0, 2, 1, 3]); view_1799 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16); primals_22 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 64, '0'); convert_element_type_67 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32); add_7 = None + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_19) + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + view_71 = torch.ops.aten.view.default(convert_element_type_69, [16384, 4096]); convert_element_type_69 = None + view_72 = torch.ops.aten.view.default(mm_14, [2, 8192, 4096]); mm_14 = None + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16); primals_24 = None + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 64, '0'); convert_element_type_73 = None + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_21, [1, 0]); wait_tensor_21 = None + mm_15 = torch.ops.aten.mm.default(view_71, permute_23) + view_75 = torch.ops.aten.view.default(mm_15, [2, 8192, 1024]); mm_15 = None + view_78 = torch.ops.aten.view.default(mm_16, [2, 8192, 1024]); mm_16 = None + view_79 = torch.ops.aten.view.default(view_72, [2, 8192, -1, 128]); view_72 = None + view_80 = torch.ops.aten.view.default(view_75, [2, 8192, -1, 128]); view_75 = None + view_81 = torch.ops.aten.view.default(view_78, [2, 8192, -1, 128]); view_78 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_79, torch.float32); view_79 = None + view_82 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 32, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_82); view_82 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_80, torch.float32); view_80 = None + view_83 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 8, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_83); view_83 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_16); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_85 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 32, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_16); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_86 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 8, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_85, torch.bfloat16); view_85 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_86, torch.bfloat16); view_86 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 8, 4, 128]); unsqueeze_4 = None + clone_4 = torch.ops.aten.clone.default(expand_4, memory_format = torch.contiguous_format); expand_4 = None + view_87 = torch.ops.aten.view.default(clone_4, [2, 8192, 32, 128]); clone_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_81, 3); view_81 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 8, 4, 128]); unsqueeze_5 = None + clone_5 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format); expand_5 = None + view_88 = torch.ops.aten.view.default(clone_5, [2, 8192, 32, 128]); clone_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_87, [0, 2, 1, 3]); view_87 = None + permute_27 = torch.ops.aten.permute.default(view_88, [0, 2, 1, 3]); view_88 = None + _scaled_dot_product_cudnn_attention_backward_29 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1301, permute_25, permute_26, permute_27, getitem_18, getitem_19, getitem_24, getitem_25, None, None, None, 8192, 8192, 0.0, True); permute_1301 = permute_25 = permute_26 = permute_27 = getitem_18 = getitem_19 = getitem_24 = getitem_25 = None + getitem_375 = _scaled_dot_product_cudnn_attention_backward_29[0] + getitem_376 = _scaled_dot_product_cudnn_attention_backward_29[1] + getitem_377 = _scaled_dot_product_cudnn_attention_backward_29[2]; _scaled_dot_product_cudnn_attention_backward_29 = None + permute_1302 = torch.ops.aten.permute.default(getitem_377, [0, 2, 1, 3]); getitem_377 = None + permute_1303 = torch.ops.aten.permute.default(getitem_376, [0, 2, 1, 3]); getitem_376 = None + permute_1304 = torch.ops.aten.permute.default(getitem_375, [0, 2, 1, 3]); getitem_375 = None + view_1800 = torch.ops.aten.view.default(permute_1302, [2, 8192, 8, 4, 128]); permute_1302 = None + sum_179 = torch.ops.aten.sum.dim_IntList(view_1800, [3], True); view_1800 = None + squeeze_58 = torch.ops.aten.squeeze.dim(sum_179, 3); sum_179 = None + view_1801 = torch.ops.aten.view.default(permute_1303, [2, 8192, 8, 4, 128]); permute_1303 = None + sum_180 = torch.ops.aten.sum.dim_IntList(view_1801, [3], True); view_1801 = None + squeeze_59 = torch.ops.aten.squeeze.dim(sum_180, 3); sum_180 = None + convert_element_type_2669 = torch.ops.prims.convert_element_type.default(squeeze_59, torch.float32); squeeze_59 = None + convert_element_type_2670 = torch.ops.prims.convert_element_type.default(permute_1304, torch.float32); permute_1304 = None + view_1802 = torch.ops.aten.view.default(convert_element_type_2669, [2, 8192, 8, 64, 2]); convert_element_type_2669 = None + view_as_complex_122 = torch.ops.aten.view_as_complex.default(view_1802); view_1802 = None + mul_856 = torch.ops.aten.mul.Tensor(view_as_complex_122, _conj); view_as_complex_122 = None + view_1803 = torch.ops.aten.view.default(convert_element_type_2670, [2, 8192, 32, 64, 2]); convert_element_type_2670 = None + view_as_complex_123 = torch.ops.aten.view_as_complex.default(view_1803); view_1803 = None + mul_857 = torch.ops.aten.mul.Tensor(view_as_complex_123, _conj); view_as_complex_123 = None + view_as_real_122 = torch.ops.aten.view_as_real.default(mul_856); mul_856 = None + view_1804 = torch.ops.aten.view.default(view_as_real_122, [2, 8192, 8, 128]); view_as_real_122 = None + convert_element_type_2671 = torch.ops.prims.convert_element_type.default(view_1804, torch.bfloat16); view_1804 = None + view_as_real_123 = torch.ops.aten.view_as_real.default(mul_857); mul_857 = None + view_1805 = torch.ops.aten.view.default(view_as_real_123, [2, 8192, 32, 128]); view_as_real_123 = None + convert_element_type_2672 = torch.ops.prims.convert_element_type.default(view_1805, torch.bfloat16); view_1805 = None + view_1806 = torch.ops.aten.view.default(squeeze_58, [2, 8192, 1024]); squeeze_58 = None + view_1807 = torch.ops.aten.view.default(convert_element_type_2671, [2, 8192, 1024]); convert_element_type_2671 = None + view_1808 = torch.ops.aten.view.default(convert_element_type_2672, [2, 8192, 4096]); convert_element_type_2672 = None + view_1809 = torch.ops.aten.view.default(view_1806, [16384, 1024]); view_1806 = None + permute_1305 = torch.ops.aten.permute.default(view_1809, [1, 0]) + mm_641 = torch.ops.aten.mm.default(permute_1305, view_71); permute_1305 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16); primals_25 = None + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 64, '0'); convert_element_type_76 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_22, [1, 0]); wait_tensor_22 = None + permute_1307 = torch.ops.aten.permute.default(permute_24, [1, 0]); permute_24 = None + mm_642 = torch.ops.aten.mm.default(view_1809, permute_1307); view_1809 = permute_1307 = None + view_1810 = torch.ops.aten.view.default(mm_642, [2, 8192, 4096]); mm_642 = None + convert_element_type_2677 = torch.ops.prims.convert_element_type.default(mm_641, torch.float32); mm_641 = None + reduce_scatter_tensor_268 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2677, 'avg', 64, '0'); convert_element_type_2677 = None + wait_tensor_559 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_268); reduce_scatter_tensor_268 = None + view_1811 = torch.ops.aten.view.default(view_1807, [16384, 1024]); view_1807 = None + permute_1309 = torch.ops.aten.permute.default(view_1811, [1, 0]) + mm_643 = torch.ops.aten.mm.default(permute_1309, view_71); permute_1309 = None + permute_1311 = torch.ops.aten.permute.default(permute_23, [1, 0]); permute_23 = None + mm_644 = torch.ops.aten.mm.default(view_1811, permute_1311); view_1811 = permute_1311 = None + view_1812 = torch.ops.aten.view.default(mm_644, [2, 8192, 4096]); mm_644 = None + add_336 = torch.ops.aten.add.Tensor(view_1810, view_1812); view_1810 = view_1812 = None + convert_element_type_2682 = torch.ops.prims.convert_element_type.default(mm_643, torch.float32); mm_643 = None + reduce_scatter_tensor_269 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2682, 'avg', 64, '0'); convert_element_type_2682 = None + wait_tensor_560 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_269); reduce_scatter_tensor_269 = None + view_1813 = torch.ops.aten.view.default(view_1808, [16384, 4096]); view_1808 = None + permute_1313 = torch.ops.aten.permute.default(view_1813, [1, 0]) + mm_645 = torch.ops.aten.mm.default(permute_1313, view_71); permute_1313 = view_71 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16); primals_23 = None + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 64, '0'); convert_element_type_70 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + permute_1315 = torch.ops.aten.permute.default(permute_22, [1, 0]); permute_22 = None + mm_646 = torch.ops.aten.mm.default(view_1813, permute_1315); view_1813 = permute_1315 = None + view_1814 = torch.ops.aten.view.default(mm_646, [2, 8192, 4096]); mm_646 = None + add_337 = torch.ops.aten.add.Tensor(add_336, view_1814); add_336 = view_1814 = None + convert_element_type_2687 = torch.ops.prims.convert_element_type.default(mm_645, torch.float32); mm_645 = None + reduce_scatter_tensor_270 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2687, 'avg', 64, '0'); convert_element_type_2687 = None + wait_tensor_561 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_270); reduce_scatter_tensor_270 = None + convert_element_type_2688 = torch.ops.prims.convert_element_type.default(add_337, torch.float32); add_337 = None + convert_element_type_2690 = torch.ops.prims.convert_element_type.default(wait_tensor_19, torch.float32); wait_tensor_19 = None + mul_858 = torch.ops.aten.mul.Tensor(convert_element_type_2688, convert_element_type_2690); convert_element_type_2690 = None + mul_860 = torch.ops.aten.mul.Tensor(mul_16, mul_858) + sum_181 = torch.ops.aten.sum.dim_IntList(mul_860, [2], True); mul_860 = None + div_60 = torch.ops.aten.div.Tensor(mul_16, 4096) + mul_861 = torch.ops.aten.mul.Tensor(div_60, sum_181); div_60 = sum_181 = None + sub_90 = torch.ops.aten.sub.Tensor(mul_858, mul_861); mul_858 = mul_861 = None + mul_862 = torch.ops.aten.mul.Tensor(sub_90, rsqrt_4); sub_90 = rsqrt_4 = None + mul_863 = torch.ops.aten.mul.Tensor(convert_element_type_2688, mul_16); convert_element_type_2688 = mul_16 = None + sum_182 = torch.ops.aten.sum.dim_IntList(mul_863, [0, 1]); mul_863 = None + convert_element_type_2691 = torch.ops.prims.convert_element_type.default(mul_862, torch.bfloat16); mul_862 = None + add_338 = torch.ops.aten.add.Tensor(add_335, convert_element_type_2691); add_335 = convert_element_type_2691 = None + convert_element_type_default_5 = torch.ops.prims.convert_element_type.default(sum_182, torch.float32); sum_182 = None + reduce_scatter_tensor_271 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_5, 'avg', 64, '0'); convert_element_type_default_5 = None + wait_tensor_562 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_271); reduce_scatter_tensor_271 = None + view_1815 = torch.ops.aten.view.default(add_338, [16384, 4096]) + permute_1317 = torch.ops.aten.permute.default(view_1815, [1, 0]) + permute_17 = torch.ops.aten.permute.default(getitem_9, [0, 2, 1, 3]) + view_55 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16); primals_17 = None + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 64, '0'); convert_element_type_50 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_14, [1, 0]); wait_tensor_14 = None + view_57 = torch.ops.aten.view.default(view_55, [16384, 4096]); view_55 = None + mm_10 = torch.ops.aten.mm.default(view_57, permute_18) + view_58 = torch.ops.aten.view.default(mm_10, [2, 8192, 4096]); mm_10 = None + add_5 = torch.ops.aten.add.Tensor(add_3, view_58); view_58 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16); primals_18 = None + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 64, '0'); convert_element_type_53 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32); add_5 = None + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_15) + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + view_61 = torch.ops.aten.view.default(convert_element_type_55, [16384, 4096]); convert_element_type_55 = None + view_62 = torch.ops.aten.view.default(mm_11, [2, 8192, 14336]); mm_11 = None + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_62, torch.float32); view_62 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16); primals_20 = None + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 64, '0'); convert_element_type_61 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + mm_12 = torch.ops.aten.mm.default(view_61, permute_20) + view_65 = torch.ops.aten.view.default(mm_12, [2, 8192, 14336]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_65) + view_67 = torch.ops.aten.view.default(mul_15, [16384, 14336]); mul_15 = None + mm_647 = torch.ops.aten.mm.default(permute_1317, view_67); permute_1317 = view_67 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16); primals_21 = None + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 64, '0'); convert_element_type_64 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + permute_1319 = torch.ops.aten.permute.default(permute_21, [1, 0]); permute_21 = None + mm_648 = torch.ops.aten.mm.default(view_1815, permute_1319); view_1815 = permute_1319 = None + view_1816 = torch.ops.aten.view.default(mm_648, [2, 8192, 14336]); mm_648 = None + convert_element_type_2698 = torch.ops.prims.convert_element_type.default(mm_647, torch.float32); mm_647 = None + reduce_scatter_tensor_272 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2698, 'avg', 64, '0'); convert_element_type_2698 = None + wait_tensor_563 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_272); reduce_scatter_tensor_272 = None + mul_864 = torch.ops.aten.mul.Tensor(view_1816, convert_element_type_60); convert_element_type_60 = None + mul_865 = torch.ops.aten.mul.Tensor(view_1816, view_65); view_1816 = view_65 = None + view_1817 = torch.ops.aten.view.default(mul_864, [16384, 14336]); mul_864 = None + permute_1321 = torch.ops.aten.permute.default(view_1817, [1, 0]) + mm_649 = torch.ops.aten.mm.default(permute_1321, view_61); permute_1321 = None + permute_1323 = torch.ops.aten.permute.default(permute_20, [1, 0]); permute_20 = None + mm_650 = torch.ops.aten.mm.default(view_1817, permute_1323); view_1817 = permute_1323 = None + view_1818 = torch.ops.aten.view.default(mm_650, [2, 8192, 4096]); mm_650 = None + convert_element_type_2703 = torch.ops.prims.convert_element_type.default(mm_649, torch.float32); mm_649 = None + reduce_scatter_tensor_273 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2703, 'avg', 64, '0'); convert_element_type_2703 = None + wait_tensor_564 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_273); reduce_scatter_tensor_273 = None + convert_element_type_2704 = torch.ops.prims.convert_element_type.default(mul_865, torch.float32); mul_865 = None + neg_30 = torch.ops.aten.neg.default(convert_element_type_59) + exp_30 = torch.ops.aten.exp.default(neg_30); neg_30 = None + add_339 = torch.ops.aten.add.Tensor(exp_30, 1); exp_30 = None + reciprocal_30 = torch.ops.aten.reciprocal.default(add_339); add_339 = None + mul_866 = torch.ops.aten.mul.Tensor(reciprocal_30, 1); reciprocal_30 = None + mul_867 = torch.ops.aten.mul.Tensor(convert_element_type_2704, mul_866); convert_element_type_2704 = None + sub_91 = torch.ops.aten.sub.Tensor(1, mul_866); mul_866 = None + mul_868 = torch.ops.aten.mul.Tensor(convert_element_type_59, sub_91); convert_element_type_59 = sub_91 = None + add_340 = torch.ops.aten.add.Tensor(mul_868, 1); mul_868 = None + mul_869 = torch.ops.aten.mul.Tensor(mul_867, add_340); mul_867 = add_340 = None + convert_element_type_2706 = torch.ops.prims.convert_element_type.default(mul_869, torch.bfloat16); mul_869 = None + view_1819 = torch.ops.aten.view.default(convert_element_type_2706, [16384, 14336]); convert_element_type_2706 = None + permute_1325 = torch.ops.aten.permute.default(view_1819, [1, 0]) + mm_651 = torch.ops.aten.mm.default(permute_1325, view_61); permute_1325 = view_61 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16); primals_19 = None + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 64, '0'); convert_element_type_56 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_16, [1, 0]); wait_tensor_16 = None + permute_1327 = torch.ops.aten.permute.default(permute_19, [1, 0]); permute_19 = None + mm_652 = torch.ops.aten.mm.default(view_1819, permute_1327); view_1819 = permute_1327 = None + view_1820 = torch.ops.aten.view.default(mm_652, [2, 8192, 4096]); mm_652 = None + add_341 = torch.ops.aten.add.Tensor(view_1818, view_1820); view_1818 = view_1820 = None + convert_element_type_2711 = torch.ops.prims.convert_element_type.default(mm_651, torch.float32); mm_651 = None + reduce_scatter_tensor_274 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2711, 'avg', 64, '0'); convert_element_type_2711 = None + wait_tensor_565 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_274); reduce_scatter_tensor_274 = None + convert_element_type_2712 = torch.ops.prims.convert_element_type.default(add_341, torch.float32); add_341 = None + convert_element_type_2714 = torch.ops.prims.convert_element_type.default(wait_tensor_15, torch.float32); wait_tensor_15 = None + mul_870 = torch.ops.aten.mul.Tensor(convert_element_type_2712, convert_element_type_2714); convert_element_type_2714 = None + mul_872 = torch.ops.aten.mul.Tensor(mul_12, mul_870) + sum_183 = torch.ops.aten.sum.dim_IntList(mul_872, [2], True); mul_872 = None + div_61 = torch.ops.aten.div.Tensor(mul_12, 4096) + mul_873 = torch.ops.aten.mul.Tensor(div_61, sum_183); div_61 = sum_183 = None + sub_92 = torch.ops.aten.sub.Tensor(mul_870, mul_873); mul_870 = mul_873 = None + mul_874 = torch.ops.aten.mul.Tensor(sub_92, rsqrt_3); sub_92 = rsqrt_3 = None + mul_875 = torch.ops.aten.mul.Tensor(convert_element_type_2712, mul_12); convert_element_type_2712 = mul_12 = None + sum_184 = torch.ops.aten.sum.dim_IntList(mul_875, [0, 1]); mul_875 = None + convert_element_type_2715 = torch.ops.prims.convert_element_type.default(mul_874, torch.bfloat16); mul_874 = None + add_342 = torch.ops.aten.add.Tensor(add_338, convert_element_type_2715); add_338 = convert_element_type_2715 = None + convert_element_type_default_4 = torch.ops.prims.convert_element_type.default(sum_184, torch.float32); sum_184 = None + reduce_scatter_tensor_275 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_4, 'avg', 64, '0'); convert_element_type_default_4 = None + wait_tensor_566 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_275); reduce_scatter_tensor_275 = None + view_1821 = torch.ops.aten.view.default(add_342, [16384, 4096]) + permute_1329 = torch.ops.aten.permute.default(view_1821, [1, 0]) + mm_653 = torch.ops.aten.mm.default(permute_1329, view_57); permute_1329 = view_57 = None + permute_1331 = torch.ops.aten.permute.default(permute_18, [1, 0]); permute_18 = None + mm_654 = torch.ops.aten.mm.default(view_1821, permute_1331); view_1821 = permute_1331 = None + view_1822 = torch.ops.aten.view.default(mm_654, [2, 8192, 4096]); mm_654 = None + convert_element_type_2722 = torch.ops.prims.convert_element_type.default(mm_653, torch.float32); mm_653 = None + reduce_scatter_tensor_276 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2722, 'avg', 64, '0'); convert_element_type_2722 = None + wait_tensor_567 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_276); reduce_scatter_tensor_276 = None + view_1823 = torch.ops.aten.view.default(view_1822, [2, 8192, 32, 128]); view_1822 = None + permute_1333 = torch.ops.aten.permute.default(view_1823, [0, 2, 1, 3]); view_1823 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16); primals_13 = None + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 64, '0'); convert_element_type_34 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32); add_3 = None + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_10) + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + view_37 = torch.ops.aten.view.default(convert_element_type_36, [16384, 4096]); convert_element_type_36 = None + view_38 = torch.ops.aten.view.default(mm_7, [2, 8192, 4096]); mm_7 = None + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16); primals_15 = None + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 64, '0'); convert_element_type_40 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_8 = torch.ops.aten.mm.default(view_37, permute_12) + view_41 = torch.ops.aten.view.default(mm_8, [2, 8192, 1024]); mm_8 = None + view_44 = torch.ops.aten.view.default(mm_9, [2, 8192, 1024]); mm_9 = None + view_45 = torch.ops.aten.view.default(view_38, [2, 8192, -1, 128]); view_38 = None + view_46 = torch.ops.aten.view.default(view_41, [2, 8192, -1, 128]); view_41 = None + view_47 = torch.ops.aten.view.default(view_44, [2, 8192, -1, 128]); view_44 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_45, torch.float32); view_45 = None + view_48 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 32, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_48); view_48 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_46, torch.float32); view_46 = None + view_49 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 8, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_49); view_49 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_16); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_51 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 32, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_16); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_52 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 8, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_51, torch.bfloat16); view_51 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_52, torch.bfloat16); view_52 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 8, 4, 128]); unsqueeze_2 = None + clone_2 = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format); expand_2 = None + view_53 = torch.ops.aten.view.default(clone_2, [2, 8192, 32, 128]); clone_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_47, 3); view_47 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 8, 4, 128]); unsqueeze_3 = None + clone_3 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format); expand_3 = None + view_54 = torch.ops.aten.view.default(clone_3, [2, 8192, 32, 128]); clone_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_53, [0, 2, 1, 3]); view_53 = None + permute_16 = torch.ops.aten.permute.default(view_54, [0, 2, 1, 3]); view_54 = None + _scaled_dot_product_cudnn_attention_backward_30 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1333, permute_14, permute_15, permute_16, getitem_9, getitem_10, getitem_15, getitem_16, None, None, None, 8192, 8192, 0.0, True); permute_1333 = permute_14 = permute_15 = permute_16 = getitem_9 = getitem_10 = getitem_15 = getitem_16 = None + getitem_378 = _scaled_dot_product_cudnn_attention_backward_30[0] + getitem_379 = _scaled_dot_product_cudnn_attention_backward_30[1] + getitem_380 = _scaled_dot_product_cudnn_attention_backward_30[2]; _scaled_dot_product_cudnn_attention_backward_30 = None + permute_1334 = torch.ops.aten.permute.default(getitem_380, [0, 2, 1, 3]); getitem_380 = None + permute_1335 = torch.ops.aten.permute.default(getitem_379, [0, 2, 1, 3]); getitem_379 = None + permute_1336 = torch.ops.aten.permute.default(getitem_378, [0, 2, 1, 3]); getitem_378 = None + view_1824 = torch.ops.aten.view.default(permute_1334, [2, 8192, 8, 4, 128]); permute_1334 = None + sum_185 = torch.ops.aten.sum.dim_IntList(view_1824, [3], True); view_1824 = None + squeeze_60 = torch.ops.aten.squeeze.dim(sum_185, 3); sum_185 = None + view_1825 = torch.ops.aten.view.default(permute_1335, [2, 8192, 8, 4, 128]); permute_1335 = None + sum_186 = torch.ops.aten.sum.dim_IntList(view_1825, [3], True); view_1825 = None + squeeze_61 = torch.ops.aten.squeeze.dim(sum_186, 3); sum_186 = None + convert_element_type_2723 = torch.ops.prims.convert_element_type.default(squeeze_61, torch.float32); squeeze_61 = None + convert_element_type_2724 = torch.ops.prims.convert_element_type.default(permute_1336, torch.float32); permute_1336 = None + view_1826 = torch.ops.aten.view.default(convert_element_type_2723, [2, 8192, 8, 64, 2]); convert_element_type_2723 = None + view_as_complex_124 = torch.ops.aten.view_as_complex.default(view_1826); view_1826 = None + mul_876 = torch.ops.aten.mul.Tensor(view_as_complex_124, _conj); view_as_complex_124 = None + view_1827 = torch.ops.aten.view.default(convert_element_type_2724, [2, 8192, 32, 64, 2]); convert_element_type_2724 = None + view_as_complex_125 = torch.ops.aten.view_as_complex.default(view_1827); view_1827 = None + mul_877 = torch.ops.aten.mul.Tensor(view_as_complex_125, _conj); view_as_complex_125 = None + view_as_real_124 = torch.ops.aten.view_as_real.default(mul_876); mul_876 = None + view_1828 = torch.ops.aten.view.default(view_as_real_124, [2, 8192, 8, 128]); view_as_real_124 = None + convert_element_type_2725 = torch.ops.prims.convert_element_type.default(view_1828, torch.bfloat16); view_1828 = None + view_as_real_125 = torch.ops.aten.view_as_real.default(mul_877); mul_877 = None + view_1829 = torch.ops.aten.view.default(view_as_real_125, [2, 8192, 32, 128]); view_as_real_125 = None + convert_element_type_2726 = torch.ops.prims.convert_element_type.default(view_1829, torch.bfloat16); view_1829 = None + view_1830 = torch.ops.aten.view.default(squeeze_60, [2, 8192, 1024]); squeeze_60 = None + view_1831 = torch.ops.aten.view.default(convert_element_type_2725, [2, 8192, 1024]); convert_element_type_2725 = None + view_1832 = torch.ops.aten.view.default(convert_element_type_2726, [2, 8192, 4096]); convert_element_type_2726 = None + view_1833 = torch.ops.aten.view.default(view_1830, [16384, 1024]); view_1830 = None + permute_1337 = torch.ops.aten.permute.default(view_1833, [1, 0]) + mm_655 = torch.ops.aten.mm.default(permute_1337, view_37); permute_1337 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16); primals_16 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 64, '0'); convert_element_type_43 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + permute_1339 = torch.ops.aten.permute.default(permute_13, [1, 0]); permute_13 = None + mm_656 = torch.ops.aten.mm.default(view_1833, permute_1339); view_1833 = permute_1339 = None + view_1834 = torch.ops.aten.view.default(mm_656, [2, 8192, 4096]); mm_656 = None + convert_element_type_2731 = torch.ops.prims.convert_element_type.default(mm_655, torch.float32); mm_655 = None + reduce_scatter_tensor_277 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2731, 'avg', 64, '0'); convert_element_type_2731 = None + wait_tensor_568 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_277); reduce_scatter_tensor_277 = None + view_1835 = torch.ops.aten.view.default(view_1831, [16384, 1024]); view_1831 = None + permute_1341 = torch.ops.aten.permute.default(view_1835, [1, 0]) + mm_657 = torch.ops.aten.mm.default(permute_1341, view_37); permute_1341 = None + permute_1343 = torch.ops.aten.permute.default(permute_12, [1, 0]); permute_12 = None + mm_658 = torch.ops.aten.mm.default(view_1835, permute_1343); view_1835 = permute_1343 = None + view_1836 = torch.ops.aten.view.default(mm_658, [2, 8192, 4096]); mm_658 = None + add_343 = torch.ops.aten.add.Tensor(view_1834, view_1836); view_1834 = view_1836 = None + convert_element_type_2736 = torch.ops.prims.convert_element_type.default(mm_657, torch.float32); mm_657 = None + reduce_scatter_tensor_278 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2736, 'avg', 64, '0'); convert_element_type_2736 = None + wait_tensor_569 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_278); reduce_scatter_tensor_278 = None + view_1837 = torch.ops.aten.view.default(view_1832, [16384, 4096]); view_1832 = None + permute_1345 = torch.ops.aten.permute.default(view_1837, [1, 0]) + mm_659 = torch.ops.aten.mm.default(permute_1345, view_37); permute_1345 = view_37 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16); primals_14 = None + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 64, '0'); convert_element_type_37 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + permute_1347 = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None + mm_660 = torch.ops.aten.mm.default(view_1837, permute_1347); view_1837 = permute_1347 = None + view_1838 = torch.ops.aten.view.default(mm_660, [2, 8192, 4096]); mm_660 = None + add_344 = torch.ops.aten.add.Tensor(add_343, view_1838); add_343 = view_1838 = None + convert_element_type_2741 = torch.ops.prims.convert_element_type.default(mm_659, torch.float32); mm_659 = None + reduce_scatter_tensor_279 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2741, 'avg', 64, '0'); convert_element_type_2741 = None + wait_tensor_570 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_279); reduce_scatter_tensor_279 = None + convert_element_type_2742 = torch.ops.prims.convert_element_type.default(add_344, torch.float32); add_344 = None + convert_element_type_2744 = torch.ops.prims.convert_element_type.default(wait_tensor_10, torch.float32); wait_tensor_10 = None + mul_878 = torch.ops.aten.mul.Tensor(convert_element_type_2742, convert_element_type_2744); convert_element_type_2744 = None + mul_880 = torch.ops.aten.mul.Tensor(mul_8, mul_878) + sum_187 = torch.ops.aten.sum.dim_IntList(mul_880, [2], True); mul_880 = None + div_62 = torch.ops.aten.div.Tensor(mul_8, 4096) + mul_881 = torch.ops.aten.mul.Tensor(div_62, sum_187); div_62 = sum_187 = None + sub_93 = torch.ops.aten.sub.Tensor(mul_878, mul_881); mul_878 = mul_881 = None + mul_882 = torch.ops.aten.mul.Tensor(sub_93, rsqrt_2); sub_93 = rsqrt_2 = None + mul_883 = torch.ops.aten.mul.Tensor(convert_element_type_2742, mul_8); convert_element_type_2742 = mul_8 = None + sum_188 = torch.ops.aten.sum.dim_IntList(mul_883, [0, 1]); mul_883 = None + convert_element_type_2745 = torch.ops.prims.convert_element_type.default(mul_882, torch.bfloat16); mul_882 = None + add_345 = torch.ops.aten.add.Tensor(add_342, convert_element_type_2745); add_342 = convert_element_type_2745 = None + convert_element_type_default_3 = torch.ops.prims.convert_element_type.default(sum_188, torch.float32); sum_188 = None + reduce_scatter_tensor_280 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_3, 'avg', 64, '0'); convert_element_type_default_3 = None + wait_tensor_571 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_280); reduce_scatter_tensor_280 = None + view_1839 = torch.ops.aten.view.default(add_345, [16384, 4096]) + permute_1349 = torch.ops.aten.permute.default(view_1839, [1, 0]) + permute_6 = torch.ops.aten.permute.default(getitem, [0, 2, 1, 3]) + view_21 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16); primals_8 = None + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 64, '0'); convert_element_type_17 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + view_23 = torch.ops.aten.view.default(view_21, [16384, 4096]); view_21 = None + mm_3 = torch.ops.aten.mm.default(view_23, permute_7) + view_24 = torch.ops.aten.view.default(mm_3, [2, 8192, 4096]); mm_3 = None + add_1 = torch.ops.aten.add.Tensor(embedding, view_24); view_24 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16); primals_9 = None + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 64, '0'); convert_element_type_20 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32); add_1 = None + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_6) + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + view_27 = torch.ops.aten.view.default(convert_element_type_22, [16384, 4096]); convert_element_type_22 = None + view_28 = torch.ops.aten.view.default(mm_4, [2, 8192, 14336]); mm_4 = None + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_28, torch.float32); view_28 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16); primals_11 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 64, '0'); convert_element_type_28 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_8, [1, 0]); wait_tensor_8 = None + mm_5 = torch.ops.aten.mm.default(view_27, permute_9) + view_31 = torch.ops.aten.view.default(mm_5, [2, 8192, 14336]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_31) + view_33 = torch.ops.aten.view.default(mul_7, [16384, 14336]); mul_7 = None + mm_661 = torch.ops.aten.mm.default(permute_1349, view_33); permute_1349 = view_33 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16); primals_12 = None + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 64, '0'); convert_element_type_31 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_9, [1, 0]); wait_tensor_9 = None + permute_1351 = torch.ops.aten.permute.default(permute_10, [1, 0]); permute_10 = None + mm_662 = torch.ops.aten.mm.default(view_1839, permute_1351); view_1839 = permute_1351 = None + view_1840 = torch.ops.aten.view.default(mm_662, [2, 8192, 14336]); mm_662 = None + convert_element_type_2752 = torch.ops.prims.convert_element_type.default(mm_661, torch.float32); mm_661 = None + reduce_scatter_tensor_281 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2752, 'avg', 64, '0'); convert_element_type_2752 = None + wait_tensor_572 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_281); reduce_scatter_tensor_281 = None + mul_884 = torch.ops.aten.mul.Tensor(view_1840, convert_element_type_27); convert_element_type_27 = None + mul_885 = torch.ops.aten.mul.Tensor(view_1840, view_31); view_1840 = view_31 = None + view_1841 = torch.ops.aten.view.default(mul_884, [16384, 14336]); mul_884 = None + permute_1353 = torch.ops.aten.permute.default(view_1841, [1, 0]) + mm_663 = torch.ops.aten.mm.default(permute_1353, view_27); permute_1353 = None + permute_1355 = torch.ops.aten.permute.default(permute_9, [1, 0]); permute_9 = None + mm_664 = torch.ops.aten.mm.default(view_1841, permute_1355); view_1841 = permute_1355 = None + view_1842 = torch.ops.aten.view.default(mm_664, [2, 8192, 4096]); mm_664 = None + convert_element_type_2757 = torch.ops.prims.convert_element_type.default(mm_663, torch.float32); mm_663 = None + reduce_scatter_tensor_282 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2757, 'avg', 64, '0'); convert_element_type_2757 = None + wait_tensor_573 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_282); reduce_scatter_tensor_282 = None + convert_element_type_2758 = torch.ops.prims.convert_element_type.default(mul_885, torch.float32); mul_885 = None + neg_31 = torch.ops.aten.neg.default(convert_element_type_26) + exp_31 = torch.ops.aten.exp.default(neg_31); neg_31 = None + add_346 = torch.ops.aten.add.Tensor(exp_31, 1); exp_31 = None + reciprocal_31 = torch.ops.aten.reciprocal.default(add_346); add_346 = None + mul_886 = torch.ops.aten.mul.Tensor(reciprocal_31, 1); reciprocal_31 = None + mul_887 = torch.ops.aten.mul.Tensor(convert_element_type_2758, mul_886); convert_element_type_2758 = None + sub_94 = torch.ops.aten.sub.Tensor(1, mul_886); mul_886 = None + mul_888 = torch.ops.aten.mul.Tensor(convert_element_type_26, sub_94); convert_element_type_26 = sub_94 = None + add_347 = torch.ops.aten.add.Tensor(mul_888, 1); mul_888 = None + mul_889 = torch.ops.aten.mul.Tensor(mul_887, add_347); mul_887 = add_347 = None + convert_element_type_2760 = torch.ops.prims.convert_element_type.default(mul_889, torch.bfloat16); mul_889 = None + view_1843 = torch.ops.aten.view.default(convert_element_type_2760, [16384, 14336]); convert_element_type_2760 = None + permute_1357 = torch.ops.aten.permute.default(view_1843, [1, 0]) + mm_665 = torch.ops.aten.mm.default(permute_1357, view_27); permute_1357 = view_27 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16); primals_10 = None + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 64, '0'); convert_element_type_23 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + permute_1359 = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None + mm_666 = torch.ops.aten.mm.default(view_1843, permute_1359); view_1843 = permute_1359 = None + view_1844 = torch.ops.aten.view.default(mm_666, [2, 8192, 4096]); mm_666 = None + add_348 = torch.ops.aten.add.Tensor(view_1842, view_1844); view_1842 = view_1844 = None + convert_element_type_2765 = torch.ops.prims.convert_element_type.default(mm_665, torch.float32); mm_665 = None + reduce_scatter_tensor_283 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2765, 'avg', 64, '0'); convert_element_type_2765 = None + wait_tensor_574 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_283); reduce_scatter_tensor_283 = None + convert_element_type_2766 = torch.ops.prims.convert_element_type.default(add_348, torch.float32); add_348 = None + convert_element_type_2768 = torch.ops.prims.convert_element_type.default(wait_tensor_6, torch.float32); wait_tensor_6 = None + mul_890 = torch.ops.aten.mul.Tensor(convert_element_type_2766, convert_element_type_2768); convert_element_type_2768 = None + mul_892 = torch.ops.aten.mul.Tensor(mul_4, mul_890) + sum_189 = torch.ops.aten.sum.dim_IntList(mul_892, [2], True); mul_892 = None + div_63 = torch.ops.aten.div.Tensor(mul_4, 4096) + mul_893 = torch.ops.aten.mul.Tensor(div_63, sum_189); div_63 = sum_189 = None + sub_95 = torch.ops.aten.sub.Tensor(mul_890, mul_893); mul_890 = mul_893 = None + mul_894 = torch.ops.aten.mul.Tensor(sub_95, rsqrt_1); sub_95 = rsqrt_1 = None + mul_895 = torch.ops.aten.mul.Tensor(convert_element_type_2766, mul_4); convert_element_type_2766 = mul_4 = None + sum_190 = torch.ops.aten.sum.dim_IntList(mul_895, [0, 1]); mul_895 = None + convert_element_type_2769 = torch.ops.prims.convert_element_type.default(mul_894, torch.bfloat16); mul_894 = None + add_349 = torch.ops.aten.add.Tensor(add_345, convert_element_type_2769); add_345 = convert_element_type_2769 = None + convert_element_type_default_2 = torch.ops.prims.convert_element_type.default(sum_190, torch.float32); sum_190 = None + reduce_scatter_tensor_284 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_2, 'avg', 64, '0'); convert_element_type_default_2 = None + wait_tensor_575 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_284); reduce_scatter_tensor_284 = None + view_1845 = torch.ops.aten.view.default(add_349, [16384, 4096]) + permute_1361 = torch.ops.aten.permute.default(view_1845, [1, 0]) + mm_667 = torch.ops.aten.mm.default(permute_1361, view_23); permute_1361 = view_23 = None + permute_1363 = torch.ops.aten.permute.default(permute_7, [1, 0]); permute_7 = None + mm_668 = torch.ops.aten.mm.default(view_1845, permute_1363); view_1845 = permute_1363 = None + view_1846 = torch.ops.aten.view.default(mm_668, [2, 8192, 4096]); mm_668 = None + convert_element_type_2776 = torch.ops.prims.convert_element_type.default(mm_667, torch.float32); mm_667 = None + reduce_scatter_tensor_285 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2776, 'avg', 64, '0'); convert_element_type_2776 = None + wait_tensor_576 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_285); reduce_scatter_tensor_285 = None + view_1847 = torch.ops.aten.view.default(view_1846, [2, 8192, 32, 128]); view_1846 = None + permute_1365 = torch.ops.aten.permute.default(view_1847, [0, 2, 1, 3]); view_1847 = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16); primals_4 = None + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 64, '0'); convert_element_type_1 = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(embedding, torch.float32); embedding = None + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_1) + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + view_3 = torch.ops.aten.view.default(convert_element_type_3, [16384, 4096]); convert_element_type_3 = None + view_4 = torch.ops.aten.view.default(mm, [2, 8192, 4096]); mm = None + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16); primals_6 = None + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 64, '0'); convert_element_type_7 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_3, [1, 0]); wait_tensor_3 = None + mm_1 = torch.ops.aten.mm.default(view_3, permute_1) + view_7 = torch.ops.aten.view.default(mm_1, [2, 8192, 1024]); mm_1 = None + view_10 = torch.ops.aten.view.default(mm_2, [2, 8192, 1024]); mm_2 = None + view_11 = torch.ops.aten.view.default(view_4, [2, 8192, -1, 128]); view_4 = None + view_12 = torch.ops.aten.view.default(view_7, [2, 8192, -1, 128]); view_7 = None + view_13 = torch.ops.aten.view.default(view_10, [2, 8192, -1, 128]); view_10 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_11, torch.float32); view_11 = None + view_14 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 32, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_14); view_14 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_12, torch.float32); view_12 = None + view_15 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 8, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_15); view_15 = None + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_16); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_17 = torch.ops.aten.view.default(view_as_real, [2, 8192, 32, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_16); view_as_complex_1 = view_16 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_18 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 8, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_17, torch.bfloat16); view_17 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_18, torch.bfloat16); view_18 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 8, 4, 128]); unsqueeze = None + clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None + view_19 = torch.ops.aten.view.default(clone, [2, 8192, 32, 128]); clone = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_13, 3); view_13 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 8, 4, 128]); unsqueeze_1 = None + clone_1 = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None + view_20 = torch.ops.aten.view.default(clone_1, [2, 8192, 32, 128]); clone_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3]); view_19 = None + permute_5 = torch.ops.aten.permute.default(view_20, [0, 2, 1, 3]); view_20 = None + _scaled_dot_product_cudnn_attention_backward_31 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1365, permute_3, permute_4, permute_5, getitem, getitem_1, getitem_6, getitem_7, None, None, None, 8192, 8192, 0.0, True); permute_1365 = permute_3 = permute_4 = permute_5 = getitem = getitem_1 = getitem_6 = getitem_7 = None + getitem_381 = _scaled_dot_product_cudnn_attention_backward_31[0] + getitem_382 = _scaled_dot_product_cudnn_attention_backward_31[1] + getitem_383 = _scaled_dot_product_cudnn_attention_backward_31[2]; _scaled_dot_product_cudnn_attention_backward_31 = None + permute_1366 = torch.ops.aten.permute.default(getitem_383, [0, 2, 1, 3]); getitem_383 = None + permute_1367 = torch.ops.aten.permute.default(getitem_382, [0, 2, 1, 3]); getitem_382 = None + permute_1368 = torch.ops.aten.permute.default(getitem_381, [0, 2, 1, 3]); getitem_381 = None + view_1848 = torch.ops.aten.view.default(permute_1366, [2, 8192, 8, 4, 128]); permute_1366 = None + sum_191 = torch.ops.aten.sum.dim_IntList(view_1848, [3], True); view_1848 = None + squeeze_62 = torch.ops.aten.squeeze.dim(sum_191, 3); sum_191 = None + view_1849 = torch.ops.aten.view.default(permute_1367, [2, 8192, 8, 4, 128]); permute_1367 = None + sum_192 = torch.ops.aten.sum.dim_IntList(view_1849, [3], True); view_1849 = None + squeeze_63 = torch.ops.aten.squeeze.dim(sum_192, 3); sum_192 = None + convert_element_type_2777 = torch.ops.prims.convert_element_type.default(squeeze_63, torch.float32); squeeze_63 = None + convert_element_type_2778 = torch.ops.prims.convert_element_type.default(permute_1368, torch.float32); permute_1368 = None + view_1850 = torch.ops.aten.view.default(convert_element_type_2777, [2, 8192, 8, 64, 2]); convert_element_type_2777 = None + view_as_complex_126 = torch.ops.aten.view_as_complex.default(view_1850); view_1850 = None + mul_896 = torch.ops.aten.mul.Tensor(view_as_complex_126, _conj); view_as_complex_126 = None + view_1851 = torch.ops.aten.view.default(convert_element_type_2778, [2, 8192, 32, 64, 2]); convert_element_type_2778 = None + view_as_complex_127 = torch.ops.aten.view_as_complex.default(view_1851); view_1851 = None + mul_897 = torch.ops.aten.mul.Tensor(view_as_complex_127, _conj); view_as_complex_127 = _conj = None + view_as_real_126 = torch.ops.aten.view_as_real.default(mul_896); mul_896 = None + view_1852 = torch.ops.aten.view.default(view_as_real_126, [2, 8192, 8, 128]); view_as_real_126 = None + convert_element_type_2779 = torch.ops.prims.convert_element_type.default(view_1852, torch.bfloat16); view_1852 = None + view_as_real_127 = torch.ops.aten.view_as_real.default(mul_897); mul_897 = None + view_1853 = torch.ops.aten.view.default(view_as_real_127, [2, 8192, 32, 128]); view_as_real_127 = None + convert_element_type_2780 = torch.ops.prims.convert_element_type.default(view_1853, torch.bfloat16); view_1853 = None + view_1854 = torch.ops.aten.view.default(squeeze_62, [2, 8192, 1024]); squeeze_62 = None + view_1855 = torch.ops.aten.view.default(convert_element_type_2779, [2, 8192, 1024]); convert_element_type_2779 = None + view_1856 = torch.ops.aten.view.default(convert_element_type_2780, [2, 8192, 4096]); convert_element_type_2780 = None + view_1857 = torch.ops.aten.view.default(view_1854, [16384, 1024]); view_1854 = None + permute_1369 = torch.ops.aten.permute.default(view_1857, [1, 0]) + mm_669 = torch.ops.aten.mm.default(permute_1369, view_3); permute_1369 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16); primals_7 = None + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 64, '0'); convert_element_type_10 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + permute_1371 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + mm_670 = torch.ops.aten.mm.default(view_1857, permute_1371); view_1857 = permute_1371 = None + view_1858 = torch.ops.aten.view.default(mm_670, [2, 8192, 4096]); mm_670 = None + convert_element_type_2785 = torch.ops.prims.convert_element_type.default(mm_669, torch.float32); mm_669 = None + reduce_scatter_tensor_286 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2785, 'avg', 64, '0'); convert_element_type_2785 = None + wait_tensor_577 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_286); reduce_scatter_tensor_286 = None + view_1859 = torch.ops.aten.view.default(view_1855, [16384, 1024]); view_1855 = None + permute_1373 = torch.ops.aten.permute.default(view_1859, [1, 0]) + mm_671 = torch.ops.aten.mm.default(permute_1373, view_3); permute_1373 = None + permute_1375 = torch.ops.aten.permute.default(permute_1, [1, 0]); permute_1 = None + mm_672 = torch.ops.aten.mm.default(view_1859, permute_1375); view_1859 = permute_1375 = None + view_1860 = torch.ops.aten.view.default(mm_672, [2, 8192, 4096]); mm_672 = None + add_350 = torch.ops.aten.add.Tensor(view_1858, view_1860); view_1858 = view_1860 = None + convert_element_type_2790 = torch.ops.prims.convert_element_type.default(mm_671, torch.float32); mm_671 = None + reduce_scatter_tensor_287 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2790, 'avg', 64, '0'); convert_element_type_2790 = None + wait_tensor_578 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_287); reduce_scatter_tensor_287 = None + view_1861 = torch.ops.aten.view.default(view_1856, [16384, 4096]); view_1856 = None + permute_1377 = torch.ops.aten.permute.default(view_1861, [1, 0]) + mm_673 = torch.ops.aten.mm.default(permute_1377, view_3); permute_1377 = view_3 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16); primals_5 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 64, '0'); convert_element_type_4 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + permute = torch.ops.aten.permute.default(wait_tensor_2, [1, 0]); wait_tensor_2 = None + permute_1379 = torch.ops.aten.permute.default(permute, [1, 0]); permute = None + mm_674 = torch.ops.aten.mm.default(view_1861, permute_1379); view_1861 = permute_1379 = None + view_1862 = torch.ops.aten.view.default(mm_674, [2, 8192, 4096]); mm_674 = None + add_351 = torch.ops.aten.add.Tensor(add_350, view_1862); add_350 = view_1862 = None + convert_element_type_2795 = torch.ops.prims.convert_element_type.default(mm_673, torch.float32); mm_673 = None + reduce_scatter_tensor_288 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2795, 'avg', 64, '0'); convert_element_type_2795 = None + wait_tensor_579 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_288); reduce_scatter_tensor_288 = None + convert_element_type_2796 = torch.ops.prims.convert_element_type.default(add_351, torch.float32); add_351 = None + convert_element_type_2798 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32); wait_tensor_1 = None + mul_898 = torch.ops.aten.mul.Tensor(convert_element_type_2796, convert_element_type_2798); convert_element_type_2798 = None + mul_900 = torch.ops.aten.mul.Tensor(mul, mul_898) + sum_193 = torch.ops.aten.sum.dim_IntList(mul_900, [2], True); mul_900 = None + div_64 = torch.ops.aten.div.Tensor(mul, 4096) + mul_901 = torch.ops.aten.mul.Tensor(div_64, sum_193); div_64 = sum_193 = None + sub_96 = torch.ops.aten.sub.Tensor(mul_898, mul_901); mul_898 = mul_901 = None + mul_902 = torch.ops.aten.mul.Tensor(sub_96, rsqrt); sub_96 = rsqrt = None + mul_903 = torch.ops.aten.mul.Tensor(convert_element_type_2796, mul); convert_element_type_2796 = mul = None + sum_194 = torch.ops.aten.sum.dim_IntList(mul_903, [0, 1]); mul_903 = None + convert_element_type_2799 = torch.ops.prims.convert_element_type.default(mul_902, torch.bfloat16); mul_902 = None + add_352 = torch.ops.aten.add.Tensor(add_349, convert_element_type_2799); add_349 = convert_element_type_2799 = None + convert_element_type_default_1 = torch.ops.prims.convert_element_type.default(sum_194, torch.float32); sum_194 = None + reduce_scatter_tensor_289 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_1, 'avg', 64, '0'); convert_element_type_default_1 = None + wait_tensor_580 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_289); reduce_scatter_tensor_289 = None + convert_element_type_2802 = torch.ops.prims.convert_element_type.default(add_352, torch.float32); add_352 = None + eq = torch.ops.aten.eq.Scalar(primals_2, -1) + unsqueeze_64 = torch.ops.aten.unsqueeze.default(eq, -1); eq = None + full_default = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + where = torch.ops.aten.where.self(unsqueeze_64, full_default, convert_element_type_2802); unsqueeze_64 = full_default = convert_element_type_2802 = None + full_default_1 = torch.ops.aten.full.default([128256, 4096], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put = torch.ops.aten.index_put.default(full_default_1, [primals_2], where, True); full_default_1 = primals_2 = where = None + convert_element_type_default = torch.ops.prims.convert_element_type.default(index_put, torch.float32); index_put = None + reduce_scatter_tensor_290 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default, 'avg', 64, '0'); convert_element_type_default = None + wait_tensor_581 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_290); reduce_scatter_tensor_290 = None + return (wait_tensor_581, None, None, wait_tensor_580, wait_tensor_579, wait_tensor_578, wait_tensor_577, wait_tensor_576, wait_tensor_575, wait_tensor_574, wait_tensor_573, wait_tensor_572, wait_tensor_571, wait_tensor_570, wait_tensor_569, wait_tensor_568, wait_tensor_567, wait_tensor_566, wait_tensor_565, wait_tensor_564, wait_tensor_563, wait_tensor_562, wait_tensor_561, wait_tensor_560, wait_tensor_559, wait_tensor_558, wait_tensor_557, wait_tensor_556, wait_tensor_555, wait_tensor_554, wait_tensor_553, wait_tensor_552, wait_tensor_551, wait_tensor_550, wait_tensor_549, wait_tensor_548, wait_tensor_547, wait_tensor_546, wait_tensor_545, wait_tensor_544, wait_tensor_543, wait_tensor_542, wait_tensor_541, wait_tensor_540, wait_tensor_539, wait_tensor_538, wait_tensor_537, wait_tensor_536, wait_tensor_535, wait_tensor_534, wait_tensor_533, wait_tensor_532, wait_tensor_531, wait_tensor_530, wait_tensor_529, wait_tensor_528, wait_tensor_527, wait_tensor_526, wait_tensor_525, wait_tensor_524, wait_tensor_523, wait_tensor_522, wait_tensor_521, wait_tensor_520, wait_tensor_519, wait_tensor_518, wait_tensor_517, wait_tensor_516, wait_tensor_515, wait_tensor_514, wait_tensor_513, wait_tensor_512, wait_tensor_511, wait_tensor_510, wait_tensor_509, wait_tensor_508, wait_tensor_507, wait_tensor_506, wait_tensor_505, wait_tensor_504, wait_tensor_503, wait_tensor_502, wait_tensor_501, wait_tensor_500, wait_tensor_499, wait_tensor_498, wait_tensor_497, wait_tensor_496, wait_tensor_495, wait_tensor_494, wait_tensor_493, wait_tensor_492, wait_tensor_491, wait_tensor_490, wait_tensor_489, wait_tensor_488, wait_tensor_487, wait_tensor_486, wait_tensor_485, wait_tensor_484, wait_tensor_483, wait_tensor_482, wait_tensor_481, wait_tensor_480, wait_tensor_479, wait_tensor_478, wait_tensor_477, wait_tensor_476, wait_tensor_475, wait_tensor_474, wait_tensor_473, wait_tensor_472, wait_tensor_471, wait_tensor_470, wait_tensor_469, wait_tensor_468, wait_tensor_467, wait_tensor_466, wait_tensor_465, wait_tensor_464, wait_tensor_463, wait_tensor_462, wait_tensor_461, wait_tensor_460, wait_tensor_459, wait_tensor_458, wait_tensor_457, wait_tensor_456, wait_tensor_455, wait_tensor_454, wait_tensor_453, wait_tensor_452, wait_tensor_451, wait_tensor_450, wait_tensor_449, wait_tensor_448, wait_tensor_447, wait_tensor_446, wait_tensor_445, wait_tensor_444, wait_tensor_443, wait_tensor_442, wait_tensor_441, wait_tensor_440, wait_tensor_439, wait_tensor_438, wait_tensor_437, wait_tensor_436, wait_tensor_435, wait_tensor_434, wait_tensor_433, wait_tensor_432, wait_tensor_431, wait_tensor_430, wait_tensor_429, wait_tensor_428, wait_tensor_427, wait_tensor_426, wait_tensor_425, wait_tensor_424, wait_tensor_423, wait_tensor_422, wait_tensor_421, wait_tensor_420, wait_tensor_419, wait_tensor_418, wait_tensor_417, wait_tensor_416, wait_tensor_415, wait_tensor_414, wait_tensor_413, wait_tensor_412, wait_tensor_411, wait_tensor_410, wait_tensor_409, wait_tensor_408, wait_tensor_407, wait_tensor_406, wait_tensor_405, wait_tensor_404, wait_tensor_403, wait_tensor_402, wait_tensor_401, wait_tensor_400, wait_tensor_399, wait_tensor_398, wait_tensor_397, wait_tensor_396, wait_tensor_395, wait_tensor_394, wait_tensor_393, wait_tensor_392, wait_tensor_391, wait_tensor_390, wait_tensor_389, wait_tensor_388, wait_tensor_387, wait_tensor_386, wait_tensor_385, wait_tensor_384, wait_tensor_383, wait_tensor_382, wait_tensor_381, wait_tensor_380, wait_tensor_379, wait_tensor_378, wait_tensor_377, wait_tensor_376, wait_tensor_375, wait_tensor_374, wait_tensor_373, wait_tensor_372, wait_tensor_371, wait_tensor_370, wait_tensor_369, wait_tensor_368, wait_tensor_367, wait_tensor_366, wait_tensor_365, wait_tensor_364, wait_tensor_363, wait_tensor_362, wait_tensor_361, wait_tensor_360, wait_tensor_359, wait_tensor_358, wait_tensor_357, wait_tensor_356, wait_tensor_355, wait_tensor_354, wait_tensor_353, wait_tensor_352, wait_tensor_351, wait_tensor_350, wait_tensor_349, wait_tensor_348, wait_tensor_347, wait_tensor_346, wait_tensor_345, wait_tensor_344, wait_tensor_343, wait_tensor_342, wait_tensor_341, wait_tensor_340, wait_tensor_339, wait_tensor_338, wait_tensor_337, wait_tensor_336, wait_tensor_335, wait_tensor_334, wait_tensor_333, wait_tensor_332, wait_tensor_331, wait_tensor_330, wait_tensor_329, wait_tensor_328, wait_tensor_327, wait_tensor_326, wait_tensor_325, wait_tensor_324, wait_tensor_323, wait_tensor_322, wait_tensor_321, wait_tensor_320, wait_tensor_319, wait_tensor_318, wait_tensor_317, wait_tensor_316, wait_tensor_315, wait_tensor_314, wait_tensor_313, wait_tensor_312, wait_tensor_311, wait_tensor_310, wait_tensor_309, wait_tensor_308, wait_tensor_307, wait_tensor_306, wait_tensor_305, wait_tensor_304, wait_tensor_303, wait_tensor_302, wait_tensor_301, wait_tensor_300, wait_tensor_299, wait_tensor_298, wait_tensor_297, wait_tensor_296, wait_tensor_295, wait_tensor_294, wait_tensor_293, wait_tensor_292, wait_tensor_291) + +def load_args(reader): + buf0 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf0, (2004, 4096), is_leaf=True) # primals_1 + buf1 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf3, (64,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf4, (64, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf5, (16, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf6, (16, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf7, (64, 4096), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf8, (64,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf9, (224, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf10, (224, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf11, (64, 14336), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf12, (64,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf13, (64, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf14, (16, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf15, (16, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf16, (64, 4096), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf17, (64,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf18, (224, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf19, (224, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf20, (64, 14336), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf21, (64,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf22, (64, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf23, (16, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf24, (16, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf25, (64, 4096), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf26, (64,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf27, (224, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf28, (224, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf29, (64, 14336), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf30, (64,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf31, (64, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf32, (16, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf33, (16, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf34, (64, 4096), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf35, (64,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf36, (224, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf37, (224, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf38, (64, 14336), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf39, (64,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf40, (64, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf41, (16, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf42, (16, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf43, (64, 4096), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf44, (64,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf45, (224, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf46, (224, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf47, (64, 14336), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf48, (64,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf49, (64, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf50, (16, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf51, (16, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf52, (64, 4096), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf53, (64,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf54, (224, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf55, (224, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf56, (64, 14336), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf57, (64,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf58, (64, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf59, (16, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf60, (16, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf61, (64, 4096), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf62, (64,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf63, (224, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf64, (224, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf65, (64, 14336), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf66, (64,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf67, (64, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf68, (16, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf69, (16, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf70, (64, 4096), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf71, (64,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf72, (224, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf73, (224, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf74, (64, 14336), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf75, (64,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf76, (64, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf77, (16, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf78, (16, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf79, (64, 4096), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf80, (64,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf81, (224, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf82, (224, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf83, (64, 14336), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf84, (64,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf85, (64, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf86, (16, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf87, (16, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf88, (64, 4096), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf89, (64,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf90, (224, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf91, (224, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf92, (64, 14336), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf93, (64,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf94, (64, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf95, (16, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf96, (16, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf97, (64, 4096), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf98, (64,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf99, (224, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf100, (224, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf101, (64, 14336), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf102, (64,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf103, (64, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf104, (16, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf105, (16, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf106, (64, 4096), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf107, (64,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf108, (224, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf109, (224, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf110, (64, 14336), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf111, (64,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf112, (64, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf113, (16, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf114, (16, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf115, (64, 4096), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf116, (64,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf117, (224, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf118, (224, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf119, (64, 14336), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf120, (64,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf121, (64, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf122, (16, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf123, (16, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf124, (64, 4096), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf125, (64,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf126, (224, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf127, (224, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf128, (64, 14336), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf129, (64,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf130, (64, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf131, (16, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf132, (16, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf133, (64, 4096), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf134, (64,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf135, (224, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf136, (224, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf137, (64, 14336), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf138, (64,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf139, (64, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf140, (16, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf141, (16, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf142, (64, 4096), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf143, (64,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf144, (224, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf145, (224, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf146, (64, 14336), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf147, (64,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf148, (64, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf149, (16, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf150, (16, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf151, (64, 4096), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf152, (64,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf153, (224, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf154, (224, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf155, (64, 14336), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf156, (64,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf157, (64, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf158, (16, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf159, (16, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf160, (64, 4096), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf161, (64,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf162, (224, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf163, (224, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf164, (64, 14336), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf165, (64,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf166, (64, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf167, (16, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf168, (16, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf169, (64, 4096), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf170, (64,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf171, (224, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf172, (224, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf173, (64, 14336), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf174, (64,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf175, (64, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf176, (16, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf177, (16, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf178, (64, 4096), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf179, (64,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf180, (224, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf181, (224, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf182, (64, 14336), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf183, (64,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf184, (64, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf185, (16, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf186, (16, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf187, (64, 4096), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf188, (64,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf189, (224, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf190, (224, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf191, (64, 14336), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf192, (64,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf193, (64, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf194, (16, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf195, (16, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf196, (64, 4096), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf197, (64,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf198, (224, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf199, (224, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf200, (64, 14336), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf201, (64,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf202, (64, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf203, (16, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf204, (16, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf205, (64, 4096), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf206, (64,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf207, (224, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf208, (224, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf209, (64, 14336), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf210, (64,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf211, (64, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf212, (16, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf213, (16, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf214, (64, 4096), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf215, (64,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf216, (224, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf217, (224, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf218, (64, 14336), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf219, (64,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf220, (64, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf221, (16, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf222, (16, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf223, (64, 4096), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf224, (64,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf225, (224, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf226, (224, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf227, (64, 14336), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf228, (64,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf229, (64, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf230, (16, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf231, (16, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf232, (64, 4096), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf233, (64,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf234, (224, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf235, (224, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf236, (64, 14336), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf237, (64,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf238, (64, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf239, (16, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf240, (16, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf241, (64, 4096), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf242, (64,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf243, (224, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf244, (224, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf245, (64, 14336), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf246, (64,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf247, (64, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf248, (16, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf249, (16, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf250, (64, 4096), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf251, (64,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf252, (224, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf253, (224, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf254, (64, 14336), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf255, (64,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf256, (64, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf257, (16, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf258, (16, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf259, (64, 4096), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf260, (64,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf261, (224, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf262, (224, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf263, (64, 14336), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf264, (64,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf265, (64, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf266, (16, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf267, (16, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf268, (64, 4096), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf269, (64,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf270, (224, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf271, (224, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf272, (64, 14336), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf273, (64,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf274, (64, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf275, (16, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf276, (16, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf277, (64, 4096), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf278, (64,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf279, (224, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf280, (224, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf281, (64, 14336), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf282, (64,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf283, (64, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf284, (16, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf285, (16, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf286, (64, 4096), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf287, (64,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf288, (224, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf289, (224, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf290, (64, 14336), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf291, (64,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf292, (2004, 4096), is_leaf=True) # primals_293 + buf293 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf293, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # embedding + buf294 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf294, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm + buf295 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf295, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_2 + buf296 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf296, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem + buf297 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf297, (2, 32, 8192, 1), is_leaf=True) # getitem_1 + buf298 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf298, (), dtype=torch.int64, is_leaf=True) # getitem_6 + buf299 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf299, (), dtype=torch.int64, is_leaf=True) # getitem_7 + buf300 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf300, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_4 + buf301 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf301, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_3 + buf302 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf302, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_7 + buf303 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf303, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_9 + buf304 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf304, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_9 + buf305 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf305, (2, 32, 8192, 1), is_leaf=True) # getitem_10 + buf306 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf306, (), dtype=torch.int64, is_leaf=True) # getitem_15 + buf307 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf307, (), dtype=torch.int64, is_leaf=True) # getitem_16 + buf308 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf308, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_11 + buf309 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf309, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_7 + buf310 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf310, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_14 + buf311 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf311, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_16 + buf312 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf312, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_18 + buf313 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf313, (2, 32, 8192, 1), is_leaf=True) # getitem_19 + buf314 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf314, (), dtype=torch.int64, is_leaf=True) # getitem_24 + buf315 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf315, (), dtype=torch.int64, is_leaf=True) # getitem_25 + buf316 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf316, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_18 + buf317 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf317, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_11 + buf318 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf318, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_21 + buf319 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf319, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_23 + buf320 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf320, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_27 + buf321 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf321, (2, 32, 8192, 1), is_leaf=True) # getitem_28 + buf322 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf322, (), dtype=torch.int64, is_leaf=True) # getitem_33 + buf323 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf323, (), dtype=torch.int64, is_leaf=True) # getitem_34 + buf324 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf324, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_25 + buf325 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf325, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_15 + buf326 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf326, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_28 + buf327 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf327, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_30 + buf328 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf328, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_36 + buf329 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf329, (2, 32, 8192, 1), is_leaf=True) # getitem_37 + buf330 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf330, (), dtype=torch.int64, is_leaf=True) # getitem_42 + buf331 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf331, (), dtype=torch.int64, is_leaf=True) # getitem_43 + buf332 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf332, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_32 + buf333 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf333, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_19 + buf334 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf334, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_35 + buf335 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf335, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_37 + buf336 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf336, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_45 + buf337 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf337, (2, 32, 8192, 1), is_leaf=True) # getitem_46 + buf338 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf338, (), dtype=torch.int64, is_leaf=True) # getitem_51 + buf339 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf339, (), dtype=torch.int64, is_leaf=True) # getitem_52 + buf340 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf340, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_39 + buf341 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf341, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_23 + buf342 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf342, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_42 + buf343 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf343, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_44 + buf344 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf344, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_54 + buf345 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf345, (2, 32, 8192, 1), is_leaf=True) # getitem_55 + buf346 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf346, (), dtype=torch.int64, is_leaf=True) # getitem_60 + buf347 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf347, (), dtype=torch.int64, is_leaf=True) # getitem_61 + buf348 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf348, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_46 + buf349 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf349, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_27 + buf350 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf350, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_49 + buf351 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf351, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_51 + buf352 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf352, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_63 + buf353 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf353, (2, 32, 8192, 1), is_leaf=True) # getitem_64 + buf354 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf354, (), dtype=torch.int64, is_leaf=True) # getitem_69 + buf355 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf355, (), dtype=torch.int64, is_leaf=True) # getitem_70 + buf356 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf356, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_53 + buf357 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf357, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_31 + buf358 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf358, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_56 + buf359 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf359, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_58 + buf360 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf360, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_72 + buf361 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf361, (2, 32, 8192, 1), is_leaf=True) # getitem_73 + buf362 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf362, (), dtype=torch.int64, is_leaf=True) # getitem_78 + buf363 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf363, (), dtype=torch.int64, is_leaf=True) # getitem_79 + buf364 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf364, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_60 + buf365 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf365, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_35 + buf366 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf366, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_63 + buf367 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf367, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_65 + buf368 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf368, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_81 + buf369 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf369, (2, 32, 8192, 1), is_leaf=True) # getitem_82 + buf370 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf370, (), dtype=torch.int64, is_leaf=True) # getitem_87 + buf371 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf371, (), dtype=torch.int64, is_leaf=True) # getitem_88 + buf372 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf372, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_67 + buf373 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf373, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_39 + buf374 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf374, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_70 + buf375 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf375, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_72 + buf376 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf376, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_90 + buf377 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf377, (2, 32, 8192, 1), is_leaf=True) # getitem_91 + buf378 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf378, (), dtype=torch.int64, is_leaf=True) # getitem_96 + buf379 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf379, (), dtype=torch.int64, is_leaf=True) # getitem_97 + buf380 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf380, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_74 + buf381 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf381, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_43 + buf382 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf382, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_77 + buf383 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf383, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_79 + buf384 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf384, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_99 + buf385 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf385, (2, 32, 8192, 1), is_leaf=True) # getitem_100 + buf386 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf386, (), dtype=torch.int64, is_leaf=True) # getitem_105 + buf387 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf387, (), dtype=torch.int64, is_leaf=True) # getitem_106 + buf388 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf388, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_81 + buf389 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf389, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_47 + buf390 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf390, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_84 + buf391 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf391, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_86 + buf392 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf392, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_108 + buf393 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf393, (2, 32, 8192, 1), is_leaf=True) # getitem_109 + buf394 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf394, (), dtype=torch.int64, is_leaf=True) # getitem_114 + buf395 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf395, (), dtype=torch.int64, is_leaf=True) # getitem_115 + buf396 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf396, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_88 + buf397 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf397, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_51 + buf398 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf398, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_91 + buf399 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf399, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_93 + buf400 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf400, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_117 + buf401 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf401, (2, 32, 8192, 1), is_leaf=True) # getitem_118 + buf402 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf402, (), dtype=torch.int64, is_leaf=True) # getitem_123 + buf403 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf403, (), dtype=torch.int64, is_leaf=True) # getitem_124 + buf404 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf404, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_95 + buf405 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf405, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_55 + buf406 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf406, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_98 + buf407 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf407, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_100 + buf408 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf408, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_126 + buf409 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf409, (2, 32, 8192, 1), is_leaf=True) # getitem_127 + buf410 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf410, (), dtype=torch.int64, is_leaf=True) # getitem_132 + buf411 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf411, (), dtype=torch.int64, is_leaf=True) # getitem_133 + buf412 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf412, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_102 + buf413 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf413, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_59 + buf414 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf414, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_105 + buf415 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf415, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_107 + buf416 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf416, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_135 + buf417 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf417, (2, 32, 8192, 1), is_leaf=True) # getitem_136 + buf418 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf418, (), dtype=torch.int64, is_leaf=True) # getitem_141 + buf419 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf419, (), dtype=torch.int64, is_leaf=True) # getitem_142 + buf420 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf420, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_109 + buf421 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf421, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_63 + buf422 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf422, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_112 + buf423 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf423, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_114 + buf424 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf424, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_144 + buf425 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf425, (2, 32, 8192, 1), is_leaf=True) # getitem_145 + buf426 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf426, (), dtype=torch.int64, is_leaf=True) # getitem_150 + buf427 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf427, (), dtype=torch.int64, is_leaf=True) # getitem_151 + buf428 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf428, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_116 + buf429 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf429, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_67 + buf430 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf430, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_119 + buf431 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf431, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_121 + buf432 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf432, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_153 + buf433 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf433, (2, 32, 8192, 1), is_leaf=True) # getitem_154 + buf434 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf434, (), dtype=torch.int64, is_leaf=True) # getitem_159 + buf435 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf435, (), dtype=torch.int64, is_leaf=True) # getitem_160 + buf436 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf436, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_123 + buf437 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf437, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_71 + buf438 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf438, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_126 + buf439 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf439, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_128 + buf440 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf440, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_162 + buf441 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf441, (2, 32, 8192, 1), is_leaf=True) # getitem_163 + buf442 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf442, (), dtype=torch.int64, is_leaf=True) # getitem_168 + buf443 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf443, (), dtype=torch.int64, is_leaf=True) # getitem_169 + buf444 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf444, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_130 + buf445 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf445, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_75 + buf446 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf446, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_133 + buf447 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf447, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_135 + buf448 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf448, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_171 + buf449 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf449, (2, 32, 8192, 1), is_leaf=True) # getitem_172 + buf450 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf450, (), dtype=torch.int64, is_leaf=True) # getitem_177 + buf451 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf451, (), dtype=torch.int64, is_leaf=True) # getitem_178 + buf452 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf452, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_137 + buf453 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf453, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_79 + buf454 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf454, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_140 + buf455 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf455, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_142 + buf456 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf456, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_180 + buf457 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf457, (2, 32, 8192, 1), is_leaf=True) # getitem_181 + buf458 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf458, (), dtype=torch.int64, is_leaf=True) # getitem_186 + buf459 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf459, (), dtype=torch.int64, is_leaf=True) # getitem_187 + buf460 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf460, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_144 + buf461 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf461, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_83 + buf462 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf462, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_147 + buf463 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf463, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_149 + buf464 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf464, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_189 + buf465 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf465, (2, 32, 8192, 1), is_leaf=True) # getitem_190 + buf466 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf466, (), dtype=torch.int64, is_leaf=True) # getitem_195 + buf467 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf467, (), dtype=torch.int64, is_leaf=True) # getitem_196 + buf468 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf468, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_151 + buf469 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf469, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_87 + buf470 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf470, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_154 + buf471 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf471, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_156 + buf472 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf472, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_198 + buf473 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf473, (2, 32, 8192, 1), is_leaf=True) # getitem_199 + buf474 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf474, (), dtype=torch.int64, is_leaf=True) # getitem_204 + buf475 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf475, (), dtype=torch.int64, is_leaf=True) # getitem_205 + buf476 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf476, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_158 + buf477 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf477, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_91 + buf478 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf478, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_161 + buf479 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf479, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_163 + buf480 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf480, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_207 + buf481 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf481, (2, 32, 8192, 1), is_leaf=True) # getitem_208 + buf482 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf482, (), dtype=torch.int64, is_leaf=True) # getitem_213 + buf483 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf483, (), dtype=torch.int64, is_leaf=True) # getitem_214 + buf484 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf484, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_165 + buf485 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf485, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_95 + buf486 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf486, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_168 + buf487 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf487, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_170 + buf488 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf488, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_216 + buf489 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf489, (2, 32, 8192, 1), is_leaf=True) # getitem_217 + buf490 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf490, (), dtype=torch.int64, is_leaf=True) # getitem_222 + buf491 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf491, (), dtype=torch.int64, is_leaf=True) # getitem_223 + buf492 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf492, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_172 + buf493 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf493, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_99 + buf494 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf494, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_175 + buf495 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf495, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_177 + buf496 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf496, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_225 + buf497 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf497, (2, 32, 8192, 1), is_leaf=True) # getitem_226 + buf498 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf498, (), dtype=torch.int64, is_leaf=True) # getitem_231 + buf499 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf499, (), dtype=torch.int64, is_leaf=True) # getitem_232 + buf500 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf500, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_179 + buf501 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf501, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_103 + buf502 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf502, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_182 + buf503 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf503, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_184 + buf504 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf504, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_234 + buf505 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf505, (2, 32, 8192, 1), is_leaf=True) # getitem_235 + buf506 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf506, (), dtype=torch.int64, is_leaf=True) # getitem_240 + buf507 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf507, (), dtype=torch.int64, is_leaf=True) # getitem_241 + buf508 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf508, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_186 + buf509 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf509, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_107 + buf510 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf510, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_189 + buf511 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf511, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_191 + buf512 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf512, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_243 + buf513 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf513, (2, 32, 8192, 1), is_leaf=True) # getitem_244 + buf514 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf514, (), dtype=torch.int64, is_leaf=True) # getitem_249 + buf515 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf515, (), dtype=torch.int64, is_leaf=True) # getitem_250 + buf516 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf516, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_193 + buf517 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf517, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_111 + buf518 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf518, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_196 + buf519 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf519, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_198 + buf520 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf520, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_252 + buf521 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf521, (2, 32, 8192, 1), is_leaf=True) # getitem_253 + buf522 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf522, (), dtype=torch.int64, is_leaf=True) # getitem_258 + buf523 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf523, (), dtype=torch.int64, is_leaf=True) # getitem_259 + buf524 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf524, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_200 + buf525 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf525, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_115 + buf526 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf526, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_203 + buf527 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf527, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_205 + buf528 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf528, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_261 + buf529 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf529, (2, 32, 8192, 1), is_leaf=True) # getitem_262 + buf530 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf530, (), dtype=torch.int64, is_leaf=True) # getitem_267 + buf531 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf531, (), dtype=torch.int64, is_leaf=True) # getitem_268 + buf532 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf532, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_207 + buf533 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf533, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_119 + buf534 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf534, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_210 + buf535 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf535, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_212 + buf536 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf536, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_270 + buf537 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf537, (2, 32, 8192, 1), is_leaf=True) # getitem_271 + buf538 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf538, (), dtype=torch.int64, is_leaf=True) # getitem_276 + buf539 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf539, (), dtype=torch.int64, is_leaf=True) # getitem_277 + buf540 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf540, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_214 + buf541 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf541, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_123 + buf542 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf542, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_217 + buf543 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf543, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_219 + buf544 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf544, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_279 + buf545 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf545, (2, 32, 8192, 1), is_leaf=True) # getitem_280 + buf546 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf546, (), dtype=torch.int64, is_leaf=True) # getitem_285 + buf547 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf547, (), dtype=torch.int64, is_leaf=True) # getitem_286 + buf548 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf548, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_221 + buf549 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf549, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_223 + buf550 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf550, (2, 8192, 1), is_leaf=True) # rsqrt_64 + buf551 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf551, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # view_1091 + buf552 = reader.storage(None, 4202692608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf552, (2, 8192, 128256), dtype=torch.bfloat16, is_leaf=True) # tangents_1 + +load_args._version = 0 + +def get_mesh_sizes(): + return 64, + +def get_colls_estimations_file(): + return "colls8_8.table" + +def get_pg_names(): + return "0", diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_2d_32layers.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_2d_32layers.py new file mode 100644 index 00000000..ba7a2a32 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_2d_32layers.py @@ -0,0 +1,5783 @@ +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, wait_tensor_1, mm, mm_2, getitem_80, getitem_81, getitem_86, getitem_87, reduce_scatter_tensor_1, mm_4, add_3, mm_7, mm_9, getitem_121, getitem_122, getitem_127, getitem_128, reduce_scatter_tensor_3, mm_11, add_7, mm_14, mm_16, getitem_162, getitem_163, getitem_168, getitem_169, reduce_scatter_tensor_5, mm_18, add_11, mm_21, mm_23, getitem_203, getitem_204, getitem_209, getitem_210, reduce_scatter_tensor_7, mm_25, add_15, mm_28, mm_30, getitem_244, getitem_245, getitem_250, getitem_251, reduce_scatter_tensor_9, mm_32, add_19, mm_35, mm_37, getitem_285, getitem_286, getitem_291, getitem_292, reduce_scatter_tensor_11, mm_39, add_23, mm_42, mm_44, getitem_326, getitem_327, getitem_332, getitem_333, reduce_scatter_tensor_13, mm_46, add_27, mm_49, mm_51, getitem_367, getitem_368, getitem_373, getitem_374, reduce_scatter_tensor_15, mm_53, add_31, mm_56, mm_58, getitem_408, getitem_409, getitem_414, getitem_415, reduce_scatter_tensor_17, mm_60, add_35, mm_63, mm_65, getitem_449, getitem_450, getitem_455, getitem_456, reduce_scatter_tensor_19, mm_67, add_39, mm_70, mm_72, getitem_490, getitem_491, getitem_496, getitem_497, reduce_scatter_tensor_21, mm_74, add_43, mm_77, mm_79, getitem_531, getitem_532, getitem_537, getitem_538, reduce_scatter_tensor_23, mm_81, add_47, mm_84, mm_86, getitem_572, getitem_573, getitem_578, getitem_579, reduce_scatter_tensor_25, mm_88, add_51, mm_91, mm_93, getitem_613, getitem_614, getitem_619, getitem_620, reduce_scatter_tensor_27, mm_95, add_55, mm_98, mm_100, getitem_654, getitem_655, getitem_660, getitem_661, reduce_scatter_tensor_29, mm_102, add_59, mm_105, mm_107, getitem_695, getitem_696, getitem_701, getitem_702, reduce_scatter_tensor_31, mm_109, reduce_scatter_tensor_32, rsqrt_32, view_1167, tangents_1): + view_1169 = torch.ops.aten.view.default(tangents_1, [16384, 16032]); tangents_1 = None + permute_177 = torch.ops.aten.permute.default(view_1169, [1, 0]) + mm_113 = torch.ops.aten.mm.default(permute_177, view_1167); permute_177 = view_1167 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16); primals_149 = None + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 8, '0'); convert_element_type_532 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + permute_179 = torch.ops.aten.permute.default(permute_176, [1, 0]); permute_176 = None + mm_114 = torch.ops.aten.mm.default(view_1169, permute_179); view_1169 = permute_179 = None + view_1170 = torch.ops.aten.view.default(mm_114, [2, 8192, 4096]); mm_114 = None + convert_element_type_539 = torch.ops.prims.convert_element_type.default(mm_113, torch.float32); mm_113 = None + reduce_scatter_tensor_33 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_539, 'avg', 8, '0'); convert_element_type_539 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33); reduce_scatter_tensor_33 = None + split_74 = torch.ops.aten.split.Tensor(view_1170, 1024, 1); view_1170 = None + getitem_736 = split_74[0] + getitem_737 = split_74[1] + getitem_738 = split_74[2] + getitem_739 = split_74[3] + getitem_740 = split_74[4] + getitem_741 = split_74[5] + getitem_742 = split_74[6] + getitem_743 = split_74[7]; split_74 = None + cat_66 = torch.ops.aten.cat.default([getitem_736, getitem_737, getitem_738, getitem_739, getitem_740, getitem_741, getitem_742, getitem_743]); getitem_736 = getitem_737 = getitem_738 = getitem_739 = getitem_740 = getitem_741 = getitem_742 = getitem_743 = None + reduce_scatter_tensor_34 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_66, 'sum', 8, '1'); cat_66 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_34); reduce_scatter_tensor_34 = None + convert_element_type_540 = torch.ops.prims.convert_element_type.default(wait_tensor_214, torch.float32); wait_tensor_214 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16); primals_148 = None + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 8, '0'); convert_element_type_529 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(wait_tensor_210, torch.float32); wait_tensor_210 = None + mul_130 = torch.ops.aten.mul.Tensor(convert_element_type_540, convert_element_type_542); convert_element_type_542 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31); reduce_scatter_tensor_31 = None + add_61 = torch.ops.aten.add.Tensor(add_59, wait_tensor_203); wait_tensor_203 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_32); reduce_scatter_tensor_32 = None + add_63 = torch.ops.aten.add.Tensor(add_61, wait_tensor_209); wait_tensor_209 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32); add_63 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = None + mul_132 = torch.ops.aten.mul.Tensor(mul_128, mul_130) + sum_1 = torch.ops.aten.sum.dim_IntList(mul_132, [2], True); mul_132 = None + div = torch.ops.aten.div.Tensor(mul_128, 4096) + mul_133 = torch.ops.aten.mul.Tensor(div, sum_1); div = sum_1 = None + sub_1 = torch.ops.aten.sub.Tensor(mul_130, mul_133); mul_130 = mul_133 = None + mul_134 = torch.ops.aten.mul.Tensor(sub_1, rsqrt_32); sub_1 = rsqrt_32 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_540, mul_128); convert_element_type_540 = mul_128 = None + sum_2 = torch.ops.aten.sum.dim_IntList(mul_135, [0, 1]); mul_135 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(sum_2, torch.bfloat16); sum_2 = None + all_reduce = torch.ops._c10d_functional.all_reduce.default(convert_element_type_544, 'sum', '1'); convert_element_type_544 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_reduce); all_reduce = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(wait_tensor_215, torch.float32); wait_tensor_215 = None + reduce_scatter_tensor_35 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_545, 'avg', 8, '0'); convert_element_type_545 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35); reduce_scatter_tensor_35 = None + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_543, 8, '1') + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + split_75 = torch.ops.aten.split.Tensor(wait_tensor_217, 2); wait_tensor_217 = None + getitem_744 = split_75[0] + getitem_745 = split_75[1] + getitem_746 = split_75[2] + getitem_747 = split_75[3] + getitem_748 = split_75[4] + getitem_749 = split_75[5] + getitem_750 = split_75[6] + getitem_751 = split_75[7]; split_75 = None + cat_67 = torch.ops.aten.cat.default([getitem_744, getitem_745, getitem_746, getitem_747, getitem_748, getitem_749, getitem_750, getitem_751], 1); getitem_744 = getitem_745 = getitem_746 = getitem_747 = getitem_748 = getitem_749 = getitem_750 = getitem_751 = None + view_1171 = torch.ops.aten.view.default(cat_67, [16384, 4096]); cat_67 = None + permute_181 = torch.ops.aten.permute.default(view_1171, [1, 0]) + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16); primals_144 = None + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 8, '0'); convert_element_type_515 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32); add_61 = None + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_204) + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_517, 8, '1'); convert_element_type_517 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + split_71 = torch.ops.aten.split.Tensor(wait_tensor_205, 2); wait_tensor_205 = None + getitem_712 = split_71[0] + getitem_713 = split_71[1] + getitem_714 = split_71[2] + getitem_715 = split_71[3] + getitem_716 = split_71[4] + getitem_717 = split_71[5] + getitem_718 = split_71[6] + getitem_719 = split_71[7]; split_71 = None + cat_63 = torch.ops.aten.cat.default([getitem_712, getitem_713, getitem_714, getitem_715, getitem_716, getitem_717, getitem_718, getitem_719], 1); getitem_712 = getitem_713 = getitem_714 = getitem_715 = getitem_716 = getitem_717 = getitem_718 = getitem_719 = None + view_1140 = torch.ops.aten.view.default(cat_63, [16384, 4096]); cat_63 = None + view_1141 = torch.ops.aten.view.default(mm_109, [2, 8192, 1792]); mm_109 = None + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_1141, torch.float32); view_1141 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16); primals_146 = None + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 8, '0'); convert_element_type_523 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + mm_110 = torch.ops.aten.mm.default(view_1140, permute_174) + view_1148 = torch.ops.aten.view.default(mm_110, [2, 8192, 1792]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_1148) + view_1155 = torch.ops.aten.view.default(mul_127, [16384, 1792]); mul_127 = None + mm_115 = torch.ops.aten.mm.default(permute_181, view_1155); permute_181 = view_1155 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16); primals_147 = None + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 8, '0'); convert_element_type_526 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_208, [1, 0]); wait_tensor_208 = None + permute_183 = torch.ops.aten.permute.default(permute_175, [1, 0]); permute_175 = None + mm_116 = torch.ops.aten.mm.default(view_1171, permute_183); view_1171 = permute_183 = None + view_1172 = torch.ops.aten.view.default(mm_116, [2, 8192, 1792]); mm_116 = None + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mm_115, torch.float32); mm_115 = None + reduce_scatter_tensor_36 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_550, 'avg', 8, '0'); convert_element_type_550 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_36); reduce_scatter_tensor_36 = None + mul_136 = torch.ops.aten.mul.Tensor(view_1172, convert_element_type_522); convert_element_type_522 = None + mul_137 = torch.ops.aten.mul.Tensor(view_1172, view_1148); view_1172 = view_1148 = None + view_1173 = torch.ops.aten.view.default(mul_136, [16384, 1792]); mul_136 = None + permute_185 = torch.ops.aten.permute.default(view_1173, [1, 0]) + mm_117 = torch.ops.aten.mm.default(permute_185, view_1140); permute_185 = None + permute_187 = torch.ops.aten.permute.default(permute_174, [1, 0]); permute_174 = None + mm_118 = torch.ops.aten.mm.default(view_1173, permute_187); view_1173 = permute_187 = None + view_1174 = torch.ops.aten.view.default(mm_118, [2, 8192, 4096]); mm_118 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mm_117, torch.float32); mm_117 = None + reduce_scatter_tensor_37 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_555, 'avg', 8, '0'); convert_element_type_555 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37); reduce_scatter_tensor_37 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(mul_137, torch.float32); mul_137 = None + neg = torch.ops.aten.neg.default(convert_element_type_521) + exp = torch.ops.aten.exp.default(neg); neg = None + add_65 = torch.ops.aten.add.Tensor(exp, 1); exp = None + reciprocal = torch.ops.aten.reciprocal.default(add_65); add_65 = None + mul_138 = torch.ops.aten.mul.Tensor(reciprocal, 1); reciprocal = None + mul_139 = torch.ops.aten.mul.Tensor(convert_element_type_556, mul_138); convert_element_type_556 = None + sub_2 = torch.ops.aten.sub.Tensor(1, mul_138); mul_138 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_521, sub_2); convert_element_type_521 = sub_2 = None + add_66 = torch.ops.aten.add.Tensor(mul_140, 1); mul_140 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_139, add_66); mul_139 = add_66 = None + convert_element_type_558 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + view_1175 = torch.ops.aten.view.default(convert_element_type_558, [16384, 1792]); convert_element_type_558 = None + permute_189 = torch.ops.aten.permute.default(view_1175, [1, 0]) + mm_119 = torch.ops.aten.mm.default(permute_189, view_1140); permute_189 = view_1140 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16); primals_145 = None + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 8, '0'); convert_element_type_518 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + permute_191 = torch.ops.aten.permute.default(permute_173, [1, 0]); permute_173 = None + mm_120 = torch.ops.aten.mm.default(view_1175, permute_191); view_1175 = permute_191 = None + view_1176 = torch.ops.aten.view.default(mm_120, [2, 8192, 4096]); mm_120 = None + add_67 = torch.ops.aten.add.Tensor(view_1174, view_1176); view_1174 = view_1176 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(mm_119, torch.float32); mm_119 = None + reduce_scatter_tensor_38 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_563, 'avg', 8, '0'); convert_element_type_563 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_38); reduce_scatter_tensor_38 = None + split_76 = torch.ops.aten.split.Tensor(add_67, 1024, 1); add_67 = None + getitem_752 = split_76[0] + getitem_753 = split_76[1] + getitem_754 = split_76[2] + getitem_755 = split_76[3] + getitem_756 = split_76[4] + getitem_757 = split_76[5] + getitem_758 = split_76[6] + getitem_759 = split_76[7]; split_76 = None + cat_68 = torch.ops.aten.cat.default([getitem_752, getitem_753, getitem_754, getitem_755, getitem_756, getitem_757, getitem_758, getitem_759]); getitem_752 = getitem_753 = getitem_754 = getitem_755 = getitem_756 = getitem_757 = getitem_758 = getitem_759 = None + reduce_scatter_tensor_39 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_68, 'sum', 8, '1'); cat_68 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39); reduce_scatter_tensor_39 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(wait_tensor_221, torch.float32); wait_tensor_221 = None + convert_element_type_566 = torch.ops.prims.convert_element_type.default(wait_tensor_204, torch.float32); wait_tensor_204 = None + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_564, convert_element_type_566); convert_element_type_566 = None + mul_144 = torch.ops.aten.mul.Tensor(mul_124, mul_142) + sum_3 = torch.ops.aten.sum.dim_IntList(mul_144, [2], True); mul_144 = None + div_1 = torch.ops.aten.div.Tensor(mul_124, 4096) + mul_145 = torch.ops.aten.mul.Tensor(div_1, sum_3); div_1 = sum_3 = None + sub_3 = torch.ops.aten.sub.Tensor(mul_142, mul_145); mul_142 = mul_145 = None + mul_146 = torch.ops.aten.mul.Tensor(sub_3, rsqrt_31); sub_3 = rsqrt_31 = None + mul_147 = torch.ops.aten.mul.Tensor(convert_element_type_564, mul_124); convert_element_type_564 = mul_124 = None + sum_4 = torch.ops.aten.sum.dim_IntList(mul_147, [0, 1]); mul_147 = None + convert_element_type_567 = torch.ops.prims.convert_element_type.default(mul_146, torch.bfloat16); mul_146 = None + convert_element_type_568 = torch.ops.prims.convert_element_type.default(sum_4, torch.bfloat16); sum_4 = None + all_reduce_1 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_568, 'sum', '1'); convert_element_type_568 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_1); all_reduce_1 = None + convert_element_type_569 = torch.ops.prims.convert_element_type.default(wait_tensor_222, torch.float32); wait_tensor_222 = None + reduce_scatter_tensor_40 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_569, 'avg', 8, '0'); convert_element_type_569 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_40); reduce_scatter_tensor_40 = None + add_68 = torch.ops.aten.add.Tensor(convert_element_type_543, convert_element_type_567); convert_element_type_543 = convert_element_type_567 = None + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_68, 8, '1') + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + split_77 = torch.ops.aten.split.Tensor(wait_tensor_224, 2); wait_tensor_224 = None + getitem_760 = split_77[0] + getitem_761 = split_77[1] + getitem_762 = split_77[2] + getitem_763 = split_77[3] + getitem_764 = split_77[4] + getitem_765 = split_77[5] + getitem_766 = split_77[6] + getitem_767 = split_77[7]; split_77 = None + cat_69 = torch.ops.aten.cat.default([getitem_760, getitem_761, getitem_762, getitem_763, getitem_764, getitem_765, getitem_766, getitem_767], 1); getitem_760 = getitem_761 = getitem_762 = getitem_763 = getitem_764 = getitem_765 = getitem_766 = getitem_767 = None + view_1177 = torch.ops.aten.view.default(cat_69, [16384, 4096]); cat_69 = None + permute_193 = torch.ops.aten.permute.default(view_1177, [1, 0]) + permute_171 = torch.ops.aten.permute.default(getitem_695, [0, 2, 1, 3]) + view_1122 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + view_1128 = torch.ops.aten.view.default(view_1122, [16384, 512]); view_1122 = None + mm_121 = torch.ops.aten.mm.default(permute_193, view_1128); permute_193 = view_1128 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16); primals_143 = None + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 8, '0'); convert_element_type_512 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + permute_195 = torch.ops.aten.permute.default(permute_172, [1, 0]); permute_172 = None + mm_122 = torch.ops.aten.mm.default(view_1177, permute_195); view_1177 = permute_195 = None + view_1178 = torch.ops.aten.view.default(mm_122, [2, 8192, 512]); mm_122 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(mm_121, torch.float32); mm_121 = None + reduce_scatter_tensor_41 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_574, 'avg', 8, '0'); convert_element_type_574 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41); reduce_scatter_tensor_41 = None + view_1179 = torch.ops.aten.view.default(view_1178, [2, 8192, 4, 128]); view_1178 = None + permute_197 = torch.ops.aten.permute.default(view_1179, [0, 2, 1, 3]); view_1179 = None + view_37 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]); primals_3 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16); primals_139 = None + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 8, '0'); convert_element_type_496 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32); add_59 = None + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_197) + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_498, 8, '1'); convert_element_type_498 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + split_69 = torch.ops.aten.split.Tensor(wait_tensor_198, 2); wait_tensor_198 = None + getitem_687 = split_69[0] + getitem_688 = split_69[1] + getitem_689 = split_69[2] + getitem_690 = split_69[3] + getitem_691 = split_69[4] + getitem_692 = split_69[5] + getitem_693 = split_69[6] + getitem_694 = split_69[7]; split_69 = None + cat_61 = torch.ops.aten.cat.default([getitem_687, getitem_688, getitem_689, getitem_690, getitem_691, getitem_692, getitem_693, getitem_694], 1); getitem_687 = getitem_688 = getitem_689 = getitem_690 = getitem_691 = getitem_692 = getitem_693 = getitem_694 = None + view_1095 = torch.ops.aten.view.default(cat_61, [16384, 4096]); cat_61 = None + view_1096 = torch.ops.aten.view.default(mm_105, [2, 8192, 512]); mm_105 = None + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16); primals_141 = None + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 8, '0'); convert_element_type_502 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + mm_106 = torch.ops.aten.mm.default(view_1095, permute_166) + view_1103 = torch.ops.aten.view.default(mm_106, [2, 8192, 128]); mm_106 = None + view_1110 = torch.ops.aten.view.default(mm_107, [2, 8192, 128]); mm_107 = None + view_1112 = torch.ops.aten.view.default(view_1096, [2, 8192, -1, 128]); view_1096 = None + view_1113 = torch.ops.aten.view.default(view_1103, [2, 8192, -1, 128]); view_1103 = None + view_1114 = torch.ops.aten.view.default(view_1110, [2, 8192, -1, 128]); view_1110 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_1112, torch.float32); view_1112 = None + view_1115 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 4, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_1115); view_1115 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_1113, torch.float32); view_1113 = None + view_1116 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 1, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_1116); view_1116 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_37); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_1118 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 4, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_37); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_1119 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 1, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_1118, torch.bfloat16); view_1118 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_1119, torch.bfloat16); view_1119 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 1, 4, 128]); unsqueeze_30 = None + view_1120 = torch.ops.aten.view.default(expand_30, [2, 8192, 4, 128]); expand_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_1114, 3); view_1114 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 1, 4, 128]); unsqueeze_31 = None + view_1121 = torch.ops.aten.view.default(expand_31, [2, 8192, 4, 128]); expand_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_1120, [0, 2, 1, 3]); view_1120 = None + permute_170 = torch.ops.aten.permute.default(view_1121, [0, 2, 1, 3]); view_1121 = None + _scaled_dot_product_cudnn_attention_backward = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_197, permute_168, permute_169, permute_170, getitem_695, getitem_696, getitem_701, getitem_702, None, None, None, 8192, 8192, 0.0, True); permute_197 = permute_168 = permute_169 = permute_170 = getitem_695 = getitem_696 = getitem_701 = getitem_702 = None + getitem_768 = _scaled_dot_product_cudnn_attention_backward[0] + getitem_769 = _scaled_dot_product_cudnn_attention_backward[1] + getitem_770 = _scaled_dot_product_cudnn_attention_backward[2]; _scaled_dot_product_cudnn_attention_backward = None + permute_198 = torch.ops.aten.permute.default(getitem_770, [0, 2, 1, 3]); getitem_770 = None + permute_199 = torch.ops.aten.permute.default(getitem_769, [0, 2, 1, 3]); getitem_769 = None + permute_200 = torch.ops.aten.permute.default(getitem_768, [0, 2, 1, 3]); getitem_768 = None + view_1180 = torch.ops.aten.view.default(permute_198, [2, 8192, 1, 4, 128]); permute_198 = None + sum_5 = torch.ops.aten.sum.dim_IntList(view_1180, [3], True); view_1180 = None + squeeze = torch.ops.aten.squeeze.dim(sum_5, 3); sum_5 = None + view_1181 = torch.ops.aten.view.default(permute_199, [2, 8192, 1, 4, 128]); permute_199 = None + sum_6 = torch.ops.aten.sum.dim_IntList(view_1181, [3], True); view_1181 = None + squeeze_1 = torch.ops.aten.squeeze.dim(sum_6, 3); sum_6 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(squeeze_1, torch.float32); squeeze_1 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(permute_200, torch.float32); permute_200 = None + view_1182 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 1, 64, 2]); convert_element_type_575 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_1182); view_1182 = None + _conj = torch.ops.aten._conj.default(view_37) + mul_148 = torch.ops.aten.mul.Tensor(view_as_complex_32, _conj); view_as_complex_32 = None + view_1183 = torch.ops.aten.view.default(convert_element_type_576, [2, 8192, 4, 64, 2]); convert_element_type_576 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_1183); view_1183 = None + mul_149 = torch.ops.aten.mul.Tensor(view_as_complex_33, _conj); view_as_complex_33 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_148); mul_148 = None + view_1184 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 1, 128]); view_as_real_32 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_1184, torch.bfloat16); view_1184 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_149); mul_149 = None + view_1185 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 4, 128]); view_as_real_33 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(view_1185, torch.bfloat16); view_1185 = None + view_1186 = torch.ops.aten.view.default(squeeze, [2, 8192, 128]); squeeze = None + view_1187 = torch.ops.aten.view.default(convert_element_type_577, [2, 8192, 128]); convert_element_type_577 = None + view_1188 = torch.ops.aten.view.default(convert_element_type_578, [2, 8192, 512]); convert_element_type_578 = None + view_1189 = torch.ops.aten.view.default(view_1186, [16384, 128]); view_1186 = None + permute_201 = torch.ops.aten.permute.default(view_1189, [1, 0]) + mm_123 = torch.ops.aten.mm.default(permute_201, view_1095); permute_201 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16); primals_142 = None + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 8, '0'); convert_element_type_505 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + permute_203 = torch.ops.aten.permute.default(permute_167, [1, 0]); permute_167 = None + mm_124 = torch.ops.aten.mm.default(view_1189, permute_203); view_1189 = permute_203 = None + view_1190 = torch.ops.aten.view.default(mm_124, [2, 8192, 4096]); mm_124 = None + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mm_123, torch.float32); mm_123 = None + reduce_scatter_tensor_42 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_583, 'avg', 8, '0'); convert_element_type_583 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_42); reduce_scatter_tensor_42 = None + view_1191 = torch.ops.aten.view.default(view_1187, [16384, 128]); view_1187 = None + permute_205 = torch.ops.aten.permute.default(view_1191, [1, 0]) + mm_125 = torch.ops.aten.mm.default(permute_205, view_1095); permute_205 = None + permute_207 = torch.ops.aten.permute.default(permute_166, [1, 0]); permute_166 = None + mm_126 = torch.ops.aten.mm.default(view_1191, permute_207); view_1191 = permute_207 = None + view_1192 = torch.ops.aten.view.default(mm_126, [2, 8192, 4096]); mm_126 = None + add_69 = torch.ops.aten.add.Tensor(view_1190, view_1192); view_1190 = view_1192 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mm_125, torch.float32); mm_125 = None + reduce_scatter_tensor_43 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_588, 'avg', 8, '0'); convert_element_type_588 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43); reduce_scatter_tensor_43 = None + view_1193 = torch.ops.aten.view.default(view_1188, [16384, 512]); view_1188 = None + permute_209 = torch.ops.aten.permute.default(view_1193, [1, 0]) + mm_127 = torch.ops.aten.mm.default(permute_209, view_1095); permute_209 = view_1095 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16); primals_140 = None + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 8, '0'); convert_element_type_499 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_199, [1, 0]); wait_tensor_199 = None + permute_211 = torch.ops.aten.permute.default(permute_165, [1, 0]); permute_165 = None + mm_128 = torch.ops.aten.mm.default(view_1193, permute_211); view_1193 = permute_211 = None + view_1194 = torch.ops.aten.view.default(mm_128, [2, 8192, 4096]); mm_128 = None + add_70 = torch.ops.aten.add.Tensor(add_69, view_1194); add_69 = view_1194 = None + convert_element_type_593 = torch.ops.prims.convert_element_type.default(mm_127, torch.float32); mm_127 = None + reduce_scatter_tensor_44 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_593, 'avg', 8, '0'); convert_element_type_593 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_44); reduce_scatter_tensor_44 = None + split_78 = torch.ops.aten.split.Tensor(add_70, 1024, 1); add_70 = None + getitem_771 = split_78[0] + getitem_772 = split_78[1] + getitem_773 = split_78[2] + getitem_774 = split_78[3] + getitem_775 = split_78[4] + getitem_776 = split_78[5] + getitem_777 = split_78[6] + getitem_778 = split_78[7]; split_78 = None + cat_70 = torch.ops.aten.cat.default([getitem_771, getitem_772, getitem_773, getitem_774, getitem_775, getitem_776, getitem_777, getitem_778]); getitem_771 = getitem_772 = getitem_773 = getitem_774 = getitem_775 = getitem_776 = getitem_777 = getitem_778 = None + reduce_scatter_tensor_45 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_70, 'sum', 8, '1'); cat_70 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45); reduce_scatter_tensor_45 = None + convert_element_type_594 = torch.ops.prims.convert_element_type.default(wait_tensor_229, torch.float32); wait_tensor_229 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(wait_tensor_197, torch.float32); wait_tensor_197 = None + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_594, convert_element_type_596); convert_element_type_596 = None + mul_152 = torch.ops.aten.mul.Tensor(mul_120, mul_150) + sum_7 = torch.ops.aten.sum.dim_IntList(mul_152, [2], True); mul_152 = None + div_2 = torch.ops.aten.div.Tensor(mul_120, 4096) + mul_153 = torch.ops.aten.mul.Tensor(div_2, sum_7); div_2 = sum_7 = None + sub_4 = torch.ops.aten.sub.Tensor(mul_150, mul_153); mul_150 = mul_153 = None + mul_154 = torch.ops.aten.mul.Tensor(sub_4, rsqrt_30); sub_4 = rsqrt_30 = None + mul_155 = torch.ops.aten.mul.Tensor(convert_element_type_594, mul_120); convert_element_type_594 = mul_120 = None + sum_8 = torch.ops.aten.sum.dim_IntList(mul_155, [0, 1]); mul_155 = None + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_154, torch.bfloat16); mul_154 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(sum_8, torch.bfloat16); sum_8 = None + all_reduce_2 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_598, 'sum', '1'); convert_element_type_598 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_2); all_reduce_2 = None + convert_element_type_599 = torch.ops.prims.convert_element_type.default(wait_tensor_230, torch.float32); wait_tensor_230 = None + reduce_scatter_tensor_46 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_599, 'avg', 8, '0'); convert_element_type_599 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_46); reduce_scatter_tensor_46 = None + add_71 = torch.ops.aten.add.Tensor(add_68, convert_element_type_597); add_68 = convert_element_type_597 = None + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_71, 8, '1') + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + split_79 = torch.ops.aten.split.Tensor(wait_tensor_232, 2); wait_tensor_232 = None + getitem_779 = split_79[0] + getitem_780 = split_79[1] + getitem_781 = split_79[2] + getitem_782 = split_79[3] + getitem_783 = split_79[4] + getitem_784 = split_79[5] + getitem_785 = split_79[6] + getitem_786 = split_79[7]; split_79 = None + cat_71 = torch.ops.aten.cat.default([getitem_779, getitem_780, getitem_781, getitem_782, getitem_783, getitem_784, getitem_785, getitem_786], 1); getitem_779 = getitem_780 = getitem_781 = getitem_782 = getitem_783 = getitem_784 = getitem_785 = getitem_786 = None + view_1195 = torch.ops.aten.view.default(cat_71, [16384, 4096]); cat_71 = None + permute_213 = torch.ops.aten.permute.default(view_1195, [1, 0]) + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29); reduce_scatter_tensor_29 = None + add_57 = torch.ops.aten.add.Tensor(add_55, wait_tensor_190); wait_tensor_190 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16); primals_135 = None + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 8, '0'); convert_element_type_482 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32); add_57 = None + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_191) + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_484, 8, '1'); convert_element_type_484 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + split_67 = torch.ops.aten.split.Tensor(wait_tensor_192, 2); wait_tensor_192 = None + getitem_671 = split_67[0] + getitem_672 = split_67[1] + getitem_673 = split_67[2] + getitem_674 = split_67[3] + getitem_675 = split_67[4] + getitem_676 = split_67[5] + getitem_677 = split_67[6] + getitem_678 = split_67[7]; split_67 = None + cat_59 = torch.ops.aten.cat.default([getitem_671, getitem_672, getitem_673, getitem_674, getitem_675, getitem_676, getitem_677, getitem_678], 1); getitem_671 = getitem_672 = getitem_673 = getitem_674 = getitem_675 = getitem_676 = getitem_677 = getitem_678 = None + view_1068 = torch.ops.aten.view.default(cat_59, [16384, 4096]); cat_59 = None + view_1069 = torch.ops.aten.view.default(mm_102, [2, 8192, 1792]); mm_102 = None + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_1069, torch.float32); view_1069 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16); primals_137 = None + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 8, '0'); convert_element_type_490 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + mm_103 = torch.ops.aten.mm.default(view_1068, permute_163) + view_1076 = torch.ops.aten.view.default(mm_103, [2, 8192, 1792]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_1076) + view_1083 = torch.ops.aten.view.default(mul_119, [16384, 1792]); mul_119 = None + mm_129 = torch.ops.aten.mm.default(permute_213, view_1083); permute_213 = view_1083 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16); primals_138 = None + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 8, '0'); convert_element_type_493 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_195, [1, 0]); wait_tensor_195 = None + permute_215 = torch.ops.aten.permute.default(permute_164, [1, 0]); permute_164 = None + mm_130 = torch.ops.aten.mm.default(view_1195, permute_215); view_1195 = permute_215 = None + view_1196 = torch.ops.aten.view.default(mm_130, [2, 8192, 1792]); mm_130 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(mm_129, torch.float32); mm_129 = None + reduce_scatter_tensor_47 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_604, 'avg', 8, '0'); convert_element_type_604 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47); reduce_scatter_tensor_47 = None + mul_156 = torch.ops.aten.mul.Tensor(view_1196, convert_element_type_489); convert_element_type_489 = None + mul_157 = torch.ops.aten.mul.Tensor(view_1196, view_1076); view_1196 = view_1076 = None + view_1197 = torch.ops.aten.view.default(mul_156, [16384, 1792]); mul_156 = None + permute_217 = torch.ops.aten.permute.default(view_1197, [1, 0]) + mm_131 = torch.ops.aten.mm.default(permute_217, view_1068); permute_217 = None + permute_219 = torch.ops.aten.permute.default(permute_163, [1, 0]); permute_163 = None + mm_132 = torch.ops.aten.mm.default(view_1197, permute_219); view_1197 = permute_219 = None + view_1198 = torch.ops.aten.view.default(mm_132, [2, 8192, 4096]); mm_132 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(mm_131, torch.float32); mm_131 = None + reduce_scatter_tensor_48 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_609, 'avg', 8, '0'); convert_element_type_609 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_48); reduce_scatter_tensor_48 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(mul_157, torch.float32); mul_157 = None + neg_1 = torch.ops.aten.neg.default(convert_element_type_488) + exp_1 = torch.ops.aten.exp.default(neg_1); neg_1 = None + add_72 = torch.ops.aten.add.Tensor(exp_1, 1); exp_1 = None + reciprocal_1 = torch.ops.aten.reciprocal.default(add_72); add_72 = None + mul_158 = torch.ops.aten.mul.Tensor(reciprocal_1, 1); reciprocal_1 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_610, mul_158); convert_element_type_610 = None + sub_5 = torch.ops.aten.sub.Tensor(1, mul_158); mul_158 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_488, sub_5); convert_element_type_488 = sub_5 = None + add_73 = torch.ops.aten.add.Tensor(mul_160, 1); mul_160 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_159, add_73); mul_159 = add_73 = None + convert_element_type_612 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + view_1199 = torch.ops.aten.view.default(convert_element_type_612, [16384, 1792]); convert_element_type_612 = None + permute_221 = torch.ops.aten.permute.default(view_1199, [1, 0]) + mm_133 = torch.ops.aten.mm.default(permute_221, view_1068); permute_221 = view_1068 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16); primals_136 = None + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 8, '0'); convert_element_type_485 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + permute_223 = torch.ops.aten.permute.default(permute_162, [1, 0]); permute_162 = None + mm_134 = torch.ops.aten.mm.default(view_1199, permute_223); view_1199 = permute_223 = None + view_1200 = torch.ops.aten.view.default(mm_134, [2, 8192, 4096]); mm_134 = None + add_74 = torch.ops.aten.add.Tensor(view_1198, view_1200); view_1198 = view_1200 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(mm_133, torch.float32); mm_133 = None + reduce_scatter_tensor_49 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_617, 'avg', 8, '0'); convert_element_type_617 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49); reduce_scatter_tensor_49 = None + split_80 = torch.ops.aten.split.Tensor(add_74, 1024, 1); add_74 = None + getitem_787 = split_80[0] + getitem_788 = split_80[1] + getitem_789 = split_80[2] + getitem_790 = split_80[3] + getitem_791 = split_80[4] + getitem_792 = split_80[5] + getitem_793 = split_80[6] + getitem_794 = split_80[7]; split_80 = None + cat_72 = torch.ops.aten.cat.default([getitem_787, getitem_788, getitem_789, getitem_790, getitem_791, getitem_792, getitem_793, getitem_794]); getitem_787 = getitem_788 = getitem_789 = getitem_790 = getitem_791 = getitem_792 = getitem_793 = getitem_794 = None + reduce_scatter_tensor_50 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_72, 'sum', 8, '1'); cat_72 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_50); reduce_scatter_tensor_50 = None + convert_element_type_618 = torch.ops.prims.convert_element_type.default(wait_tensor_236, torch.float32); wait_tensor_236 = None + convert_element_type_620 = torch.ops.prims.convert_element_type.default(wait_tensor_191, torch.float32); wait_tensor_191 = None + mul_162 = torch.ops.aten.mul.Tensor(convert_element_type_618, convert_element_type_620); convert_element_type_620 = None + mul_164 = torch.ops.aten.mul.Tensor(mul_116, mul_162) + sum_9 = torch.ops.aten.sum.dim_IntList(mul_164, [2], True); mul_164 = None + div_3 = torch.ops.aten.div.Tensor(mul_116, 4096) + mul_165 = torch.ops.aten.mul.Tensor(div_3, sum_9); div_3 = sum_9 = None + sub_6 = torch.ops.aten.sub.Tensor(mul_162, mul_165); mul_162 = mul_165 = None + mul_166 = torch.ops.aten.mul.Tensor(sub_6, rsqrt_29); sub_6 = rsqrt_29 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_618, mul_116); convert_element_type_618 = mul_116 = None + sum_10 = torch.ops.aten.sum.dim_IntList(mul_167, [0, 1]); mul_167 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(sum_10, torch.bfloat16); sum_10 = None + all_reduce_3 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_622, 'sum', '1'); convert_element_type_622 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_3); all_reduce_3 = None + convert_element_type_623 = torch.ops.prims.convert_element_type.default(wait_tensor_237, torch.float32); wait_tensor_237 = None + reduce_scatter_tensor_51 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_623, 'avg', 8, '0'); convert_element_type_623 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51); reduce_scatter_tensor_51 = None + add_75 = torch.ops.aten.add.Tensor(add_71, convert_element_type_621); add_71 = convert_element_type_621 = None + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_75, 8, '1') + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + split_81 = torch.ops.aten.split.Tensor(wait_tensor_239, 2); wait_tensor_239 = None + getitem_795 = split_81[0] + getitem_796 = split_81[1] + getitem_797 = split_81[2] + getitem_798 = split_81[3] + getitem_799 = split_81[4] + getitem_800 = split_81[5] + getitem_801 = split_81[6] + getitem_802 = split_81[7]; split_81 = None + cat_73 = torch.ops.aten.cat.default([getitem_795, getitem_796, getitem_797, getitem_798, getitem_799, getitem_800, getitem_801, getitem_802], 1); getitem_795 = getitem_796 = getitem_797 = getitem_798 = getitem_799 = getitem_800 = getitem_801 = getitem_802 = None + view_1201 = torch.ops.aten.view.default(cat_73, [16384, 4096]); cat_73 = None + permute_225 = torch.ops.aten.permute.default(view_1201, [1, 0]) + permute_160 = torch.ops.aten.permute.default(getitem_654, [0, 2, 1, 3]) + view_1050 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + view_1056 = torch.ops.aten.view.default(view_1050, [16384, 512]); view_1050 = None + mm_135 = torch.ops.aten.mm.default(permute_225, view_1056); permute_225 = view_1056 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16); primals_134 = None + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 8, '0'); convert_element_type_479 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + permute_227 = torch.ops.aten.permute.default(permute_161, [1, 0]); permute_161 = None + mm_136 = torch.ops.aten.mm.default(view_1201, permute_227); view_1201 = permute_227 = None + view_1202 = torch.ops.aten.view.default(mm_136, [2, 8192, 512]); mm_136 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(mm_135, torch.float32); mm_135 = None + reduce_scatter_tensor_52 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_628, 'avg', 8, '0'); convert_element_type_628 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_52); reduce_scatter_tensor_52 = None + view_1203 = torch.ops.aten.view.default(view_1202, [2, 8192, 4, 128]); view_1202 = None + permute_229 = torch.ops.aten.permute.default(view_1203, [0, 2, 1, 3]); view_1203 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16); primals_130 = None + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 8, '0'); convert_element_type_463 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32); add_55 = None + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_184) + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_465, 8, '1'); convert_element_type_465 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + split_65 = torch.ops.aten.split.Tensor(wait_tensor_185, 2); wait_tensor_185 = None + getitem_646 = split_65[0] + getitem_647 = split_65[1] + getitem_648 = split_65[2] + getitem_649 = split_65[3] + getitem_650 = split_65[4] + getitem_651 = split_65[5] + getitem_652 = split_65[6] + getitem_653 = split_65[7]; split_65 = None + cat_57 = torch.ops.aten.cat.default([getitem_646, getitem_647, getitem_648, getitem_649, getitem_650, getitem_651, getitem_652, getitem_653], 1); getitem_646 = getitem_647 = getitem_648 = getitem_649 = getitem_650 = getitem_651 = getitem_652 = getitem_653 = None + view_1023 = torch.ops.aten.view.default(cat_57, [16384, 4096]); cat_57 = None + view_1024 = torch.ops.aten.view.default(mm_98, [2, 8192, 512]); mm_98 = None + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16); primals_132 = None + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 8, '0'); convert_element_type_469 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + mm_99 = torch.ops.aten.mm.default(view_1023, permute_155) + view_1031 = torch.ops.aten.view.default(mm_99, [2, 8192, 128]); mm_99 = None + view_1038 = torch.ops.aten.view.default(mm_100, [2, 8192, 128]); mm_100 = None + view_1040 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1041 = torch.ops.aten.view.default(view_1031, [2, 8192, -1, 128]); view_1031 = None + view_1042 = torch.ops.aten.view.default(view_1038, [2, 8192, -1, 128]); view_1038 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_1040, torch.float32); view_1040 = None + view_1043 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 4, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_1043); view_1043 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_1041, torch.float32); view_1041 = None + view_1044 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 1, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_1044); view_1044 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_37); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_1046 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 4, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_37); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_1047 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 1, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_1046, torch.bfloat16); view_1046 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_1047, torch.bfloat16); view_1047 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 1, 4, 128]); unsqueeze_28 = None + view_1048 = torch.ops.aten.view.default(expand_28, [2, 8192, 4, 128]); expand_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_1042, 3); view_1042 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 1, 4, 128]); unsqueeze_29 = None + view_1049 = torch.ops.aten.view.default(expand_29, [2, 8192, 4, 128]); expand_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_1048, [0, 2, 1, 3]); view_1048 = None + permute_159 = torch.ops.aten.permute.default(view_1049, [0, 2, 1, 3]); view_1049 = None + _scaled_dot_product_cudnn_attention_backward_1 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_229, permute_157, permute_158, permute_159, getitem_654, getitem_655, getitem_660, getitem_661, None, None, None, 8192, 8192, 0.0, True); permute_229 = permute_157 = permute_158 = permute_159 = getitem_654 = getitem_655 = getitem_660 = getitem_661 = None + getitem_803 = _scaled_dot_product_cudnn_attention_backward_1[0] + getitem_804 = _scaled_dot_product_cudnn_attention_backward_1[1] + getitem_805 = _scaled_dot_product_cudnn_attention_backward_1[2]; _scaled_dot_product_cudnn_attention_backward_1 = None + permute_230 = torch.ops.aten.permute.default(getitem_805, [0, 2, 1, 3]); getitem_805 = None + permute_231 = torch.ops.aten.permute.default(getitem_804, [0, 2, 1, 3]); getitem_804 = None + permute_232 = torch.ops.aten.permute.default(getitem_803, [0, 2, 1, 3]); getitem_803 = None + view_1204 = torch.ops.aten.view.default(permute_230, [2, 8192, 1, 4, 128]); permute_230 = None + sum_11 = torch.ops.aten.sum.dim_IntList(view_1204, [3], True); view_1204 = None + squeeze_2 = torch.ops.aten.squeeze.dim(sum_11, 3); sum_11 = None + view_1205 = torch.ops.aten.view.default(permute_231, [2, 8192, 1, 4, 128]); permute_231 = None + sum_12 = torch.ops.aten.sum.dim_IntList(view_1205, [3], True); view_1205 = None + squeeze_3 = torch.ops.aten.squeeze.dim(sum_12, 3); sum_12 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(squeeze_3, torch.float32); squeeze_3 = None + convert_element_type_630 = torch.ops.prims.convert_element_type.default(permute_232, torch.float32); permute_232 = None + view_1206 = torch.ops.aten.view.default(convert_element_type_629, [2, 8192, 1, 64, 2]); convert_element_type_629 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_1206); view_1206 = None + mul_168 = torch.ops.aten.mul.Tensor(view_as_complex_34, _conj); view_as_complex_34 = None + view_1207 = torch.ops.aten.view.default(convert_element_type_630, [2, 8192, 4, 64, 2]); convert_element_type_630 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_1207); view_1207 = None + mul_169 = torch.ops.aten.mul.Tensor(view_as_complex_35, _conj); view_as_complex_35 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_168); mul_168 = None + view_1208 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 1, 128]); view_as_real_34 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(view_1208, torch.bfloat16); view_1208 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_169); mul_169 = None + view_1209 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 4, 128]); view_as_real_35 = None + convert_element_type_632 = torch.ops.prims.convert_element_type.default(view_1209, torch.bfloat16); view_1209 = None + view_1210 = torch.ops.aten.view.default(squeeze_2, [2, 8192, 128]); squeeze_2 = None + view_1211 = torch.ops.aten.view.default(convert_element_type_631, [2, 8192, 128]); convert_element_type_631 = None + view_1212 = torch.ops.aten.view.default(convert_element_type_632, [2, 8192, 512]); convert_element_type_632 = None + view_1213 = torch.ops.aten.view.default(view_1210, [16384, 128]); view_1210 = None + permute_233 = torch.ops.aten.permute.default(view_1213, [1, 0]) + mm_137 = torch.ops.aten.mm.default(permute_233, view_1023); permute_233 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16); primals_133 = None + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 8, '0'); convert_element_type_472 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + permute_235 = torch.ops.aten.permute.default(permute_156, [1, 0]); permute_156 = None + mm_138 = torch.ops.aten.mm.default(view_1213, permute_235); view_1213 = permute_235 = None + view_1214 = torch.ops.aten.view.default(mm_138, [2, 8192, 4096]); mm_138 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(mm_137, torch.float32); mm_137 = None + reduce_scatter_tensor_53 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_637, 'avg', 8, '0'); convert_element_type_637 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53); reduce_scatter_tensor_53 = None + view_1215 = torch.ops.aten.view.default(view_1211, [16384, 128]); view_1211 = None + permute_237 = torch.ops.aten.permute.default(view_1215, [1, 0]) + mm_139 = torch.ops.aten.mm.default(permute_237, view_1023); permute_237 = None + permute_239 = torch.ops.aten.permute.default(permute_155, [1, 0]); permute_155 = None + mm_140 = torch.ops.aten.mm.default(view_1215, permute_239); view_1215 = permute_239 = None + view_1216 = torch.ops.aten.view.default(mm_140, [2, 8192, 4096]); mm_140 = None + add_76 = torch.ops.aten.add.Tensor(view_1214, view_1216); view_1214 = view_1216 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(mm_139, torch.float32); mm_139 = None + reduce_scatter_tensor_54 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_642, 'avg', 8, '0'); convert_element_type_642 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_54); reduce_scatter_tensor_54 = None + view_1217 = torch.ops.aten.view.default(view_1212, [16384, 512]); view_1212 = None + permute_241 = torch.ops.aten.permute.default(view_1217, [1, 0]) + mm_141 = torch.ops.aten.mm.default(permute_241, view_1023); permute_241 = view_1023 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16); primals_131 = None + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 8, '0'); convert_element_type_466 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_186, [1, 0]); wait_tensor_186 = None + permute_243 = torch.ops.aten.permute.default(permute_154, [1, 0]); permute_154 = None + mm_142 = torch.ops.aten.mm.default(view_1217, permute_243); view_1217 = permute_243 = None + view_1218 = torch.ops.aten.view.default(mm_142, [2, 8192, 4096]); mm_142 = None + add_77 = torch.ops.aten.add.Tensor(add_76, view_1218); add_76 = view_1218 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(mm_141, torch.float32); mm_141 = None + reduce_scatter_tensor_55 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_647, 'avg', 8, '0'); convert_element_type_647 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55); reduce_scatter_tensor_55 = None + split_82 = torch.ops.aten.split.Tensor(add_77, 1024, 1); add_77 = None + getitem_806 = split_82[0] + getitem_807 = split_82[1] + getitem_808 = split_82[2] + getitem_809 = split_82[3] + getitem_810 = split_82[4] + getitem_811 = split_82[5] + getitem_812 = split_82[6] + getitem_813 = split_82[7]; split_82 = None + cat_74 = torch.ops.aten.cat.default([getitem_806, getitem_807, getitem_808, getitem_809, getitem_810, getitem_811, getitem_812, getitem_813]); getitem_806 = getitem_807 = getitem_808 = getitem_809 = getitem_810 = getitem_811 = getitem_812 = getitem_813 = None + reduce_scatter_tensor_56 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_74, 'sum', 8, '1'); cat_74 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_56); reduce_scatter_tensor_56 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(wait_tensor_244, torch.float32); wait_tensor_244 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(wait_tensor_184, torch.float32); wait_tensor_184 = None + mul_170 = torch.ops.aten.mul.Tensor(convert_element_type_648, convert_element_type_650); convert_element_type_650 = None + mul_172 = torch.ops.aten.mul.Tensor(mul_112, mul_170) + sum_13 = torch.ops.aten.sum.dim_IntList(mul_172, [2], True); mul_172 = None + div_4 = torch.ops.aten.div.Tensor(mul_112, 4096) + mul_173 = torch.ops.aten.mul.Tensor(div_4, sum_13); div_4 = sum_13 = None + sub_7 = torch.ops.aten.sub.Tensor(mul_170, mul_173); mul_170 = mul_173 = None + mul_174 = torch.ops.aten.mul.Tensor(sub_7, rsqrt_28); sub_7 = rsqrt_28 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_648, mul_112); convert_element_type_648 = mul_112 = None + sum_14 = torch.ops.aten.sum.dim_IntList(mul_175, [0, 1]); mul_175 = None + convert_element_type_651 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_652 = torch.ops.prims.convert_element_type.default(sum_14, torch.bfloat16); sum_14 = None + all_reduce_4 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_652, 'sum', '1'); convert_element_type_652 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_4); all_reduce_4 = None + convert_element_type_653 = torch.ops.prims.convert_element_type.default(wait_tensor_245, torch.float32); wait_tensor_245 = None + reduce_scatter_tensor_57 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_653, 'avg', 8, '0'); convert_element_type_653 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57); reduce_scatter_tensor_57 = None + add_78 = torch.ops.aten.add.Tensor(add_75, convert_element_type_651); add_75 = convert_element_type_651 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_78, 8, '1') + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + split_83 = torch.ops.aten.split.Tensor(wait_tensor_247, 2); wait_tensor_247 = None + getitem_814 = split_83[0] + getitem_815 = split_83[1] + getitem_816 = split_83[2] + getitem_817 = split_83[3] + getitem_818 = split_83[4] + getitem_819 = split_83[5] + getitem_820 = split_83[6] + getitem_821 = split_83[7]; split_83 = None + cat_75 = torch.ops.aten.cat.default([getitem_814, getitem_815, getitem_816, getitem_817, getitem_818, getitem_819, getitem_820, getitem_821], 1); getitem_814 = getitem_815 = getitem_816 = getitem_817 = getitem_818 = getitem_819 = getitem_820 = getitem_821 = None + view_1219 = torch.ops.aten.view.default(cat_75, [16384, 4096]); cat_75 = None + permute_245 = torch.ops.aten.permute.default(view_1219, [1, 0]) + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27); reduce_scatter_tensor_27 = None + add_53 = torch.ops.aten.add.Tensor(add_51, wait_tensor_177); wait_tensor_177 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16); primals_126 = None + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 8, '0'); convert_element_type_449 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32); add_53 = None + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_178) + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_451, 8, '1'); convert_element_type_451 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + split_63 = torch.ops.aten.split.Tensor(wait_tensor_179, 2); wait_tensor_179 = None + getitem_630 = split_63[0] + getitem_631 = split_63[1] + getitem_632 = split_63[2] + getitem_633 = split_63[3] + getitem_634 = split_63[4] + getitem_635 = split_63[5] + getitem_636 = split_63[6] + getitem_637 = split_63[7]; split_63 = None + cat_55 = torch.ops.aten.cat.default([getitem_630, getitem_631, getitem_632, getitem_633, getitem_634, getitem_635, getitem_636, getitem_637], 1); getitem_630 = getitem_631 = getitem_632 = getitem_633 = getitem_634 = getitem_635 = getitem_636 = getitem_637 = None + view_996 = torch.ops.aten.view.default(cat_55, [16384, 4096]); cat_55 = None + view_997 = torch.ops.aten.view.default(mm_95, [2, 8192, 1792]); mm_95 = None + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16); primals_128 = None + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 8, '0'); convert_element_type_457 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_181, [1, 0]); wait_tensor_181 = None + mm_96 = torch.ops.aten.mm.default(view_996, permute_152) + view_1004 = torch.ops.aten.view.default(mm_96, [2, 8192, 1792]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_1004) + view_1011 = torch.ops.aten.view.default(mul_111, [16384, 1792]); mul_111 = None + mm_143 = torch.ops.aten.mm.default(permute_245, view_1011); permute_245 = view_1011 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16); primals_129 = None + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 8, '0'); convert_element_type_460 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + permute_247 = torch.ops.aten.permute.default(permute_153, [1, 0]); permute_153 = None + mm_144 = torch.ops.aten.mm.default(view_1219, permute_247); view_1219 = permute_247 = None + view_1220 = torch.ops.aten.view.default(mm_144, [2, 8192, 1792]); mm_144 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(mm_143, torch.float32); mm_143 = None + reduce_scatter_tensor_58 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_658, 'avg', 8, '0'); convert_element_type_658 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_58); reduce_scatter_tensor_58 = None + mul_176 = torch.ops.aten.mul.Tensor(view_1220, convert_element_type_456); convert_element_type_456 = None + mul_177 = torch.ops.aten.mul.Tensor(view_1220, view_1004); view_1220 = view_1004 = None + view_1221 = torch.ops.aten.view.default(mul_176, [16384, 1792]); mul_176 = None + permute_249 = torch.ops.aten.permute.default(view_1221, [1, 0]) + mm_145 = torch.ops.aten.mm.default(permute_249, view_996); permute_249 = None + permute_251 = torch.ops.aten.permute.default(permute_152, [1, 0]); permute_152 = None + mm_146 = torch.ops.aten.mm.default(view_1221, permute_251); view_1221 = permute_251 = None + view_1222 = torch.ops.aten.view.default(mm_146, [2, 8192, 4096]); mm_146 = None + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mm_145, torch.float32); mm_145 = None + reduce_scatter_tensor_59 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_663, 'avg', 8, '0'); convert_element_type_663 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59); reduce_scatter_tensor_59 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(mul_177, torch.float32); mul_177 = None + neg_2 = torch.ops.aten.neg.default(convert_element_type_455) + exp_2 = torch.ops.aten.exp.default(neg_2); neg_2 = None + add_79 = torch.ops.aten.add.Tensor(exp_2, 1); exp_2 = None + reciprocal_2 = torch.ops.aten.reciprocal.default(add_79); add_79 = None + mul_178 = torch.ops.aten.mul.Tensor(reciprocal_2, 1); reciprocal_2 = None + mul_179 = torch.ops.aten.mul.Tensor(convert_element_type_664, mul_178); convert_element_type_664 = None + sub_8 = torch.ops.aten.sub.Tensor(1, mul_178); mul_178 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_455, sub_8); convert_element_type_455 = sub_8 = None + add_80 = torch.ops.aten.add.Tensor(mul_180, 1); mul_180 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_179, add_80); mul_179 = add_80 = None + convert_element_type_666 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + view_1223 = torch.ops.aten.view.default(convert_element_type_666, [16384, 1792]); convert_element_type_666 = None + permute_253 = torch.ops.aten.permute.default(view_1223, [1, 0]) + mm_147 = torch.ops.aten.mm.default(permute_253, view_996); permute_253 = view_996 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16); primals_127 = None + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 8, '0'); convert_element_type_452 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + permute_255 = torch.ops.aten.permute.default(permute_151, [1, 0]); permute_151 = None + mm_148 = torch.ops.aten.mm.default(view_1223, permute_255); view_1223 = permute_255 = None + view_1224 = torch.ops.aten.view.default(mm_148, [2, 8192, 4096]); mm_148 = None + add_81 = torch.ops.aten.add.Tensor(view_1222, view_1224); view_1222 = view_1224 = None + convert_element_type_671 = torch.ops.prims.convert_element_type.default(mm_147, torch.float32); mm_147 = None + reduce_scatter_tensor_60 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_671, 'avg', 8, '0'); convert_element_type_671 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_60); reduce_scatter_tensor_60 = None + split_84 = torch.ops.aten.split.Tensor(add_81, 1024, 1); add_81 = None + getitem_822 = split_84[0] + getitem_823 = split_84[1] + getitem_824 = split_84[2] + getitem_825 = split_84[3] + getitem_826 = split_84[4] + getitem_827 = split_84[5] + getitem_828 = split_84[6] + getitem_829 = split_84[7]; split_84 = None + cat_76 = torch.ops.aten.cat.default([getitem_822, getitem_823, getitem_824, getitem_825, getitem_826, getitem_827, getitem_828, getitem_829]); getitem_822 = getitem_823 = getitem_824 = getitem_825 = getitem_826 = getitem_827 = getitem_828 = getitem_829 = None + reduce_scatter_tensor_61 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_76, 'sum', 8, '1'); cat_76 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61); reduce_scatter_tensor_61 = None + convert_element_type_672 = torch.ops.prims.convert_element_type.default(wait_tensor_251, torch.float32); wait_tensor_251 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(wait_tensor_178, torch.float32); wait_tensor_178 = None + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_672, convert_element_type_674); convert_element_type_674 = None + mul_184 = torch.ops.aten.mul.Tensor(mul_108, mul_182) + sum_15 = torch.ops.aten.sum.dim_IntList(mul_184, [2], True); mul_184 = None + div_5 = torch.ops.aten.div.Tensor(mul_108, 4096) + mul_185 = torch.ops.aten.mul.Tensor(div_5, sum_15); div_5 = sum_15 = None + sub_9 = torch.ops.aten.sub.Tensor(mul_182, mul_185); mul_182 = mul_185 = None + mul_186 = torch.ops.aten.mul.Tensor(sub_9, rsqrt_27); sub_9 = rsqrt_27 = None + mul_187 = torch.ops.aten.mul.Tensor(convert_element_type_672, mul_108); convert_element_type_672 = mul_108 = None + sum_16 = torch.ops.aten.sum.dim_IntList(mul_187, [0, 1]); mul_187 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(mul_186, torch.bfloat16); mul_186 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(sum_16, torch.bfloat16); sum_16 = None + all_reduce_5 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_676, 'sum', '1'); convert_element_type_676 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_5); all_reduce_5 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(wait_tensor_252, torch.float32); wait_tensor_252 = None + reduce_scatter_tensor_62 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_677, 'avg', 8, '0'); convert_element_type_677 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_62); reduce_scatter_tensor_62 = None + add_82 = torch.ops.aten.add.Tensor(add_78, convert_element_type_675); add_78 = convert_element_type_675 = None + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_82, 8, '1') + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + split_85 = torch.ops.aten.split.Tensor(wait_tensor_254, 2); wait_tensor_254 = None + getitem_830 = split_85[0] + getitem_831 = split_85[1] + getitem_832 = split_85[2] + getitem_833 = split_85[3] + getitem_834 = split_85[4] + getitem_835 = split_85[5] + getitem_836 = split_85[6] + getitem_837 = split_85[7]; split_85 = None + cat_77 = torch.ops.aten.cat.default([getitem_830, getitem_831, getitem_832, getitem_833, getitem_834, getitem_835, getitem_836, getitem_837], 1); getitem_830 = getitem_831 = getitem_832 = getitem_833 = getitem_834 = getitem_835 = getitem_836 = getitem_837 = None + view_1225 = torch.ops.aten.view.default(cat_77, [16384, 4096]); cat_77 = None + permute_257 = torch.ops.aten.permute.default(view_1225, [1, 0]) + permute_149 = torch.ops.aten.permute.default(getitem_613, [0, 2, 1, 3]) + view_978 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + view_984 = torch.ops.aten.view.default(view_978, [16384, 512]); view_978 = None + mm_149 = torch.ops.aten.mm.default(permute_257, view_984); permute_257 = view_984 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16); primals_125 = None + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 8, '0'); convert_element_type_446 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + permute_259 = torch.ops.aten.permute.default(permute_150, [1, 0]); permute_150 = None + mm_150 = torch.ops.aten.mm.default(view_1225, permute_259); view_1225 = permute_259 = None + view_1226 = torch.ops.aten.view.default(mm_150, [2, 8192, 512]); mm_150 = None + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mm_149, torch.float32); mm_149 = None + reduce_scatter_tensor_63 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_682, 'avg', 8, '0'); convert_element_type_682 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63); reduce_scatter_tensor_63 = None + view_1227 = torch.ops.aten.view.default(view_1226, [2, 8192, 4, 128]); view_1226 = None + permute_261 = torch.ops.aten.permute.default(view_1227, [0, 2, 1, 3]); view_1227 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16); primals_121 = None + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 8, '0'); convert_element_type_430 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32); add_51 = None + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_171) + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_432, 8, '1'); convert_element_type_432 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + split_61 = torch.ops.aten.split.Tensor(wait_tensor_172, 2); wait_tensor_172 = None + getitem_605 = split_61[0] + getitem_606 = split_61[1] + getitem_607 = split_61[2] + getitem_608 = split_61[3] + getitem_609 = split_61[4] + getitem_610 = split_61[5] + getitem_611 = split_61[6] + getitem_612 = split_61[7]; split_61 = None + cat_53 = torch.ops.aten.cat.default([getitem_605, getitem_606, getitem_607, getitem_608, getitem_609, getitem_610, getitem_611, getitem_612], 1); getitem_605 = getitem_606 = getitem_607 = getitem_608 = getitem_609 = getitem_610 = getitem_611 = getitem_612 = None + view_951 = torch.ops.aten.view.default(cat_53, [16384, 4096]); cat_53 = None + view_952 = torch.ops.aten.view.default(mm_91, [2, 8192, 512]); mm_91 = None + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16); primals_123 = None + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 8, '0'); convert_element_type_436 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_92 = torch.ops.aten.mm.default(view_951, permute_144) + view_959 = torch.ops.aten.view.default(mm_92, [2, 8192, 128]); mm_92 = None + view_966 = torch.ops.aten.view.default(mm_93, [2, 8192, 128]); mm_93 = None + view_968 = torch.ops.aten.view.default(view_952, [2, 8192, -1, 128]); view_952 = None + view_969 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_970 = torch.ops.aten.view.default(view_966, [2, 8192, -1, 128]); view_966 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_968, torch.float32); view_968 = None + view_971 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 4, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_971); view_971 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_969, torch.float32); view_969 = None + view_972 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 1, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_972); view_972 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_37); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_974 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 4, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_37); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_975 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 1, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_974, torch.bfloat16); view_974 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_975, torch.bfloat16); view_975 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 1, 4, 128]); unsqueeze_26 = None + view_976 = torch.ops.aten.view.default(expand_26, [2, 8192, 4, 128]); expand_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_970, 3); view_970 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 1, 4, 128]); unsqueeze_27 = None + view_977 = torch.ops.aten.view.default(expand_27, [2, 8192, 4, 128]); expand_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_976, [0, 2, 1, 3]); view_976 = None + permute_148 = torch.ops.aten.permute.default(view_977, [0, 2, 1, 3]); view_977 = None + _scaled_dot_product_cudnn_attention_backward_2 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_261, permute_146, permute_147, permute_148, getitem_613, getitem_614, getitem_619, getitem_620, None, None, None, 8192, 8192, 0.0, True); permute_261 = permute_146 = permute_147 = permute_148 = getitem_613 = getitem_614 = getitem_619 = getitem_620 = None + getitem_838 = _scaled_dot_product_cudnn_attention_backward_2[0] + getitem_839 = _scaled_dot_product_cudnn_attention_backward_2[1] + getitem_840 = _scaled_dot_product_cudnn_attention_backward_2[2]; _scaled_dot_product_cudnn_attention_backward_2 = None + permute_262 = torch.ops.aten.permute.default(getitem_840, [0, 2, 1, 3]); getitem_840 = None + permute_263 = torch.ops.aten.permute.default(getitem_839, [0, 2, 1, 3]); getitem_839 = None + permute_264 = torch.ops.aten.permute.default(getitem_838, [0, 2, 1, 3]); getitem_838 = None + view_1228 = torch.ops.aten.view.default(permute_262, [2, 8192, 1, 4, 128]); permute_262 = None + sum_17 = torch.ops.aten.sum.dim_IntList(view_1228, [3], True); view_1228 = None + squeeze_4 = torch.ops.aten.squeeze.dim(sum_17, 3); sum_17 = None + view_1229 = torch.ops.aten.view.default(permute_263, [2, 8192, 1, 4, 128]); permute_263 = None + sum_18 = torch.ops.aten.sum.dim_IntList(view_1229, [3], True); view_1229 = None + squeeze_5 = torch.ops.aten.squeeze.dim(sum_18, 3); sum_18 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(squeeze_5, torch.float32); squeeze_5 = None + convert_element_type_684 = torch.ops.prims.convert_element_type.default(permute_264, torch.float32); permute_264 = None + view_1230 = torch.ops.aten.view.default(convert_element_type_683, [2, 8192, 1, 64, 2]); convert_element_type_683 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_1230); view_1230 = None + mul_188 = torch.ops.aten.mul.Tensor(view_as_complex_36, _conj); view_as_complex_36 = None + view_1231 = torch.ops.aten.view.default(convert_element_type_684, [2, 8192, 4, 64, 2]); convert_element_type_684 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_1231); view_1231 = None + mul_189 = torch.ops.aten.mul.Tensor(view_as_complex_37, _conj); view_as_complex_37 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_188); mul_188 = None + view_1232 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 1, 128]); view_as_real_36 = None + convert_element_type_685 = torch.ops.prims.convert_element_type.default(view_1232, torch.bfloat16); view_1232 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_189); mul_189 = None + view_1233 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 4, 128]); view_as_real_37 = None + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_1233, torch.bfloat16); view_1233 = None + view_1234 = torch.ops.aten.view.default(squeeze_4, [2, 8192, 128]); squeeze_4 = None + view_1235 = torch.ops.aten.view.default(convert_element_type_685, [2, 8192, 128]); convert_element_type_685 = None + view_1236 = torch.ops.aten.view.default(convert_element_type_686, [2, 8192, 512]); convert_element_type_686 = None + view_1237 = torch.ops.aten.view.default(view_1234, [16384, 128]); view_1234 = None + permute_265 = torch.ops.aten.permute.default(view_1237, [1, 0]) + mm_151 = torch.ops.aten.mm.default(permute_265, view_951); permute_265 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16); primals_124 = None + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 8, '0'); convert_element_type_439 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + permute_267 = torch.ops.aten.permute.default(permute_145, [1, 0]); permute_145 = None + mm_152 = torch.ops.aten.mm.default(view_1237, permute_267); view_1237 = permute_267 = None + view_1238 = torch.ops.aten.view.default(mm_152, [2, 8192, 4096]); mm_152 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(mm_151, torch.float32); mm_151 = None + reduce_scatter_tensor_64 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_691, 'avg', 8, '0'); convert_element_type_691 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64); reduce_scatter_tensor_64 = None + view_1239 = torch.ops.aten.view.default(view_1235, [16384, 128]); view_1235 = None + permute_269 = torch.ops.aten.permute.default(view_1239, [1, 0]) + mm_153 = torch.ops.aten.mm.default(permute_269, view_951); permute_269 = None + permute_271 = torch.ops.aten.permute.default(permute_144, [1, 0]); permute_144 = None + mm_154 = torch.ops.aten.mm.default(view_1239, permute_271); view_1239 = permute_271 = None + view_1240 = torch.ops.aten.view.default(mm_154, [2, 8192, 4096]); mm_154 = None + add_83 = torch.ops.aten.add.Tensor(view_1238, view_1240); view_1238 = view_1240 = None + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mm_153, torch.float32); mm_153 = None + reduce_scatter_tensor_65 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_696, 'avg', 8, '0'); convert_element_type_696 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_65); reduce_scatter_tensor_65 = None + view_1241 = torch.ops.aten.view.default(view_1236, [16384, 512]); view_1236 = None + permute_273 = torch.ops.aten.permute.default(view_1241, [1, 0]) + mm_155 = torch.ops.aten.mm.default(permute_273, view_951); permute_273 = view_951 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16); primals_122 = None + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 8, '0'); convert_element_type_433 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + permute_275 = torch.ops.aten.permute.default(permute_143, [1, 0]); permute_143 = None + mm_156 = torch.ops.aten.mm.default(view_1241, permute_275); view_1241 = permute_275 = None + view_1242 = torch.ops.aten.view.default(mm_156, [2, 8192, 4096]); mm_156 = None + add_84 = torch.ops.aten.add.Tensor(add_83, view_1242); add_83 = view_1242 = None + convert_element_type_701 = torch.ops.prims.convert_element_type.default(mm_155, torch.float32); mm_155 = None + reduce_scatter_tensor_66 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_701, 'avg', 8, '0'); convert_element_type_701 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_66); reduce_scatter_tensor_66 = None + split_86 = torch.ops.aten.split.Tensor(add_84, 1024, 1); add_84 = None + getitem_841 = split_86[0] + getitem_842 = split_86[1] + getitem_843 = split_86[2] + getitem_844 = split_86[3] + getitem_845 = split_86[4] + getitem_846 = split_86[5] + getitem_847 = split_86[6] + getitem_848 = split_86[7]; split_86 = None + cat_78 = torch.ops.aten.cat.default([getitem_841, getitem_842, getitem_843, getitem_844, getitem_845, getitem_846, getitem_847, getitem_848]); getitem_841 = getitem_842 = getitem_843 = getitem_844 = getitem_845 = getitem_846 = getitem_847 = getitem_848 = None + reduce_scatter_tensor_67 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_78, 'sum', 8, '1'); cat_78 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_67); reduce_scatter_tensor_67 = None + convert_element_type_702 = torch.ops.prims.convert_element_type.default(wait_tensor_259, torch.float32); wait_tensor_259 = None + convert_element_type_704 = torch.ops.prims.convert_element_type.default(wait_tensor_171, torch.float32); wait_tensor_171 = None + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_702, convert_element_type_704); convert_element_type_704 = None + mul_192 = torch.ops.aten.mul.Tensor(mul_104, mul_190) + sum_19 = torch.ops.aten.sum.dim_IntList(mul_192, [2], True); mul_192 = None + div_6 = torch.ops.aten.div.Tensor(mul_104, 4096) + mul_193 = torch.ops.aten.mul.Tensor(div_6, sum_19); div_6 = sum_19 = None + sub_10 = torch.ops.aten.sub.Tensor(mul_190, mul_193); mul_190 = mul_193 = None + mul_194 = torch.ops.aten.mul.Tensor(sub_10, rsqrt_26); sub_10 = rsqrt_26 = None + mul_195 = torch.ops.aten.mul.Tensor(convert_element_type_702, mul_104); convert_element_type_702 = mul_104 = None + sum_20 = torch.ops.aten.sum.dim_IntList(mul_195, [0, 1]); mul_195 = None + convert_element_type_705 = torch.ops.prims.convert_element_type.default(mul_194, torch.bfloat16); mul_194 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(sum_20, torch.bfloat16); sum_20 = None + all_reduce_6 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_706, 'sum', '1'); convert_element_type_706 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_6); all_reduce_6 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(wait_tensor_260, torch.float32); wait_tensor_260 = None + reduce_scatter_tensor_68 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_707, 'avg', 8, '0'); convert_element_type_707 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_68); reduce_scatter_tensor_68 = None + add_85 = torch.ops.aten.add.Tensor(add_82, convert_element_type_705); add_82 = convert_element_type_705 = None + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_85, 8, '1') + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + split_87 = torch.ops.aten.split.Tensor(wait_tensor_262, 2); wait_tensor_262 = None + getitem_849 = split_87[0] + getitem_850 = split_87[1] + getitem_851 = split_87[2] + getitem_852 = split_87[3] + getitem_853 = split_87[4] + getitem_854 = split_87[5] + getitem_855 = split_87[6] + getitem_856 = split_87[7]; split_87 = None + cat_79 = torch.ops.aten.cat.default([getitem_849, getitem_850, getitem_851, getitem_852, getitem_853, getitem_854, getitem_855, getitem_856], 1); getitem_849 = getitem_850 = getitem_851 = getitem_852 = getitem_853 = getitem_854 = getitem_855 = getitem_856 = None + view_1243 = torch.ops.aten.view.default(cat_79, [16384, 4096]); cat_79 = None + permute_277 = torch.ops.aten.permute.default(view_1243, [1, 0]) + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25); reduce_scatter_tensor_25 = None + add_49 = torch.ops.aten.add.Tensor(add_47, wait_tensor_164); wait_tensor_164 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16); primals_117 = None + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 8, '0'); convert_element_type_416 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32); add_49 = None + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_165) + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_418, 8, '1'); convert_element_type_418 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + split_59 = torch.ops.aten.split.Tensor(wait_tensor_166, 2); wait_tensor_166 = None + getitem_589 = split_59[0] + getitem_590 = split_59[1] + getitem_591 = split_59[2] + getitem_592 = split_59[3] + getitem_593 = split_59[4] + getitem_594 = split_59[5] + getitem_595 = split_59[6] + getitem_596 = split_59[7]; split_59 = None + cat_51 = torch.ops.aten.cat.default([getitem_589, getitem_590, getitem_591, getitem_592, getitem_593, getitem_594, getitem_595, getitem_596], 1); getitem_589 = getitem_590 = getitem_591 = getitem_592 = getitem_593 = getitem_594 = getitem_595 = getitem_596 = None + view_924 = torch.ops.aten.view.default(cat_51, [16384, 4096]); cat_51 = None + view_925 = torch.ops.aten.view.default(mm_88, [2, 8192, 1792]); mm_88 = None + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_925, torch.float32); view_925 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16); primals_119 = None + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 8, '0'); convert_element_type_424 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_168, [1, 0]); wait_tensor_168 = None + mm_89 = torch.ops.aten.mm.default(view_924, permute_141) + view_932 = torch.ops.aten.view.default(mm_89, [2, 8192, 1792]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_932) + view_939 = torch.ops.aten.view.default(mul_103, [16384, 1792]); mul_103 = None + mm_157 = torch.ops.aten.mm.default(permute_277, view_939); permute_277 = view_939 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16); primals_120 = None + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 8, '0'); convert_element_type_427 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + permute_279 = torch.ops.aten.permute.default(permute_142, [1, 0]); permute_142 = None + mm_158 = torch.ops.aten.mm.default(view_1243, permute_279); view_1243 = permute_279 = None + view_1244 = torch.ops.aten.view.default(mm_158, [2, 8192, 1792]); mm_158 = None + convert_element_type_712 = torch.ops.prims.convert_element_type.default(mm_157, torch.float32); mm_157 = None + reduce_scatter_tensor_69 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_712, 'avg', 8, '0'); convert_element_type_712 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_69); reduce_scatter_tensor_69 = None + mul_196 = torch.ops.aten.mul.Tensor(view_1244, convert_element_type_423); convert_element_type_423 = None + mul_197 = torch.ops.aten.mul.Tensor(view_1244, view_932); view_1244 = view_932 = None + view_1245 = torch.ops.aten.view.default(mul_196, [16384, 1792]); mul_196 = None + permute_281 = torch.ops.aten.permute.default(view_1245, [1, 0]) + mm_159 = torch.ops.aten.mm.default(permute_281, view_924); permute_281 = None + permute_283 = torch.ops.aten.permute.default(permute_141, [1, 0]); permute_141 = None + mm_160 = torch.ops.aten.mm.default(view_1245, permute_283); view_1245 = permute_283 = None + view_1246 = torch.ops.aten.view.default(mm_160, [2, 8192, 4096]); mm_160 = None + convert_element_type_717 = torch.ops.prims.convert_element_type.default(mm_159, torch.float32); mm_159 = None + reduce_scatter_tensor_70 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_717, 'avg', 8, '0'); convert_element_type_717 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_70); reduce_scatter_tensor_70 = None + convert_element_type_718 = torch.ops.prims.convert_element_type.default(mul_197, torch.float32); mul_197 = None + neg_3 = torch.ops.aten.neg.default(convert_element_type_422) + exp_3 = torch.ops.aten.exp.default(neg_3); neg_3 = None + add_86 = torch.ops.aten.add.Tensor(exp_3, 1); exp_3 = None + reciprocal_3 = torch.ops.aten.reciprocal.default(add_86); add_86 = None + mul_198 = torch.ops.aten.mul.Tensor(reciprocal_3, 1); reciprocal_3 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_718, mul_198); convert_element_type_718 = None + sub_11 = torch.ops.aten.sub.Tensor(1, mul_198); mul_198 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_422, sub_11); convert_element_type_422 = sub_11 = None + add_87 = torch.ops.aten.add.Tensor(mul_200, 1); mul_200 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_199, add_87); mul_199 = add_87 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + view_1247 = torch.ops.aten.view.default(convert_element_type_720, [16384, 1792]); convert_element_type_720 = None + permute_285 = torch.ops.aten.permute.default(view_1247, [1, 0]) + mm_161 = torch.ops.aten.mm.default(permute_285, view_924); permute_285 = view_924 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16); primals_118 = None + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 8, '0'); convert_element_type_419 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + permute_287 = torch.ops.aten.permute.default(permute_140, [1, 0]); permute_140 = None + mm_162 = torch.ops.aten.mm.default(view_1247, permute_287); view_1247 = permute_287 = None + view_1248 = torch.ops.aten.view.default(mm_162, [2, 8192, 4096]); mm_162 = None + add_88 = torch.ops.aten.add.Tensor(view_1246, view_1248); view_1246 = view_1248 = None + convert_element_type_725 = torch.ops.prims.convert_element_type.default(mm_161, torch.float32); mm_161 = None + reduce_scatter_tensor_71 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_725, 'avg', 8, '0'); convert_element_type_725 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_71); reduce_scatter_tensor_71 = None + split_88 = torch.ops.aten.split.Tensor(add_88, 1024, 1); add_88 = None + getitem_857 = split_88[0] + getitem_858 = split_88[1] + getitem_859 = split_88[2] + getitem_860 = split_88[3] + getitem_861 = split_88[4] + getitem_862 = split_88[5] + getitem_863 = split_88[6] + getitem_864 = split_88[7]; split_88 = None + cat_80 = torch.ops.aten.cat.default([getitem_857, getitem_858, getitem_859, getitem_860, getitem_861, getitem_862, getitem_863, getitem_864]); getitem_857 = getitem_858 = getitem_859 = getitem_860 = getitem_861 = getitem_862 = getitem_863 = getitem_864 = None + reduce_scatter_tensor_72 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_80, 'sum', 8, '1'); cat_80 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_72); reduce_scatter_tensor_72 = None + convert_element_type_726 = torch.ops.prims.convert_element_type.default(wait_tensor_266, torch.float32); wait_tensor_266 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(wait_tensor_165, torch.float32); wait_tensor_165 = None + mul_202 = torch.ops.aten.mul.Tensor(convert_element_type_726, convert_element_type_728); convert_element_type_728 = None + mul_204 = torch.ops.aten.mul.Tensor(mul_100, mul_202) + sum_21 = torch.ops.aten.sum.dim_IntList(mul_204, [2], True); mul_204 = None + div_7 = torch.ops.aten.div.Tensor(mul_100, 4096) + mul_205 = torch.ops.aten.mul.Tensor(div_7, sum_21); div_7 = sum_21 = None + sub_12 = torch.ops.aten.sub.Tensor(mul_202, mul_205); mul_202 = mul_205 = None + mul_206 = torch.ops.aten.mul.Tensor(sub_12, rsqrt_25); sub_12 = rsqrt_25 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_726, mul_100); convert_element_type_726 = mul_100 = None + sum_22 = torch.ops.aten.sum.dim_IntList(mul_207, [0, 1]); mul_207 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(sum_22, torch.bfloat16); sum_22 = None + all_reduce_7 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_730, 'sum', '1'); convert_element_type_730 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_7); all_reduce_7 = None + convert_element_type_731 = torch.ops.prims.convert_element_type.default(wait_tensor_267, torch.float32); wait_tensor_267 = None + reduce_scatter_tensor_73 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_731, 'avg', 8, '0'); convert_element_type_731 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_73); reduce_scatter_tensor_73 = None + add_89 = torch.ops.aten.add.Tensor(add_85, convert_element_type_729); add_85 = convert_element_type_729 = None + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_89, 8, '1') + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + split_89 = torch.ops.aten.split.Tensor(wait_tensor_269, 2); wait_tensor_269 = None + getitem_865 = split_89[0] + getitem_866 = split_89[1] + getitem_867 = split_89[2] + getitem_868 = split_89[3] + getitem_869 = split_89[4] + getitem_870 = split_89[5] + getitem_871 = split_89[6] + getitem_872 = split_89[7]; split_89 = None + cat_81 = torch.ops.aten.cat.default([getitem_865, getitem_866, getitem_867, getitem_868, getitem_869, getitem_870, getitem_871, getitem_872], 1); getitem_865 = getitem_866 = getitem_867 = getitem_868 = getitem_869 = getitem_870 = getitem_871 = getitem_872 = None + view_1249 = torch.ops.aten.view.default(cat_81, [16384, 4096]); cat_81 = None + permute_289 = torch.ops.aten.permute.default(view_1249, [1, 0]) + permute_138 = torch.ops.aten.permute.default(getitem_572, [0, 2, 1, 3]) + view_906 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + view_912 = torch.ops.aten.view.default(view_906, [16384, 512]); view_906 = None + mm_163 = torch.ops.aten.mm.default(permute_289, view_912); permute_289 = view_912 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16); primals_116 = None + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 8, '0'); convert_element_type_413 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_163, [1, 0]); wait_tensor_163 = None + permute_291 = torch.ops.aten.permute.default(permute_139, [1, 0]); permute_139 = None + mm_164 = torch.ops.aten.mm.default(view_1249, permute_291); view_1249 = permute_291 = None + view_1250 = torch.ops.aten.view.default(mm_164, [2, 8192, 512]); mm_164 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(mm_163, torch.float32); mm_163 = None + reduce_scatter_tensor_74 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_736, 'avg', 8, '0'); convert_element_type_736 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_74); reduce_scatter_tensor_74 = None + view_1251 = torch.ops.aten.view.default(view_1250, [2, 8192, 4, 128]); view_1250 = None + permute_293 = torch.ops.aten.permute.default(view_1251, [0, 2, 1, 3]); view_1251 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16); primals_112 = None + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 8, '0'); convert_element_type_397 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32); add_47 = None + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_158) + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_399, 8, '1'); convert_element_type_399 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + split_57 = torch.ops.aten.split.Tensor(wait_tensor_159, 2); wait_tensor_159 = None + getitem_564 = split_57[0] + getitem_565 = split_57[1] + getitem_566 = split_57[2] + getitem_567 = split_57[3] + getitem_568 = split_57[4] + getitem_569 = split_57[5] + getitem_570 = split_57[6] + getitem_571 = split_57[7]; split_57 = None + cat_49 = torch.ops.aten.cat.default([getitem_564, getitem_565, getitem_566, getitem_567, getitem_568, getitem_569, getitem_570, getitem_571], 1); getitem_564 = getitem_565 = getitem_566 = getitem_567 = getitem_568 = getitem_569 = getitem_570 = getitem_571 = None + view_879 = torch.ops.aten.view.default(cat_49, [16384, 4096]); cat_49 = None + view_880 = torch.ops.aten.view.default(mm_84, [2, 8192, 512]); mm_84 = None + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16); primals_114 = None + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 8, '0'); convert_element_type_403 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_85 = torch.ops.aten.mm.default(view_879, permute_133) + view_887 = torch.ops.aten.view.default(mm_85, [2, 8192, 128]); mm_85 = None + view_894 = torch.ops.aten.view.default(mm_86, [2, 8192, 128]); mm_86 = None + view_896 = torch.ops.aten.view.default(view_880, [2, 8192, -1, 128]); view_880 = None + view_897 = torch.ops.aten.view.default(view_887, [2, 8192, -1, 128]); view_887 = None + view_898 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 4, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_897, torch.float32); view_897 = None + view_900 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 1, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_900); view_900 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_37); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_902 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 4, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_37); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_903 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 1, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_903, torch.bfloat16); view_903 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 1, 4, 128]); unsqueeze_24 = None + view_904 = torch.ops.aten.view.default(expand_24, [2, 8192, 4, 128]); expand_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_898, 3); view_898 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 1, 4, 128]); unsqueeze_25 = None + view_905 = torch.ops.aten.view.default(expand_25, [2, 8192, 4, 128]); expand_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + permute_137 = torch.ops.aten.permute.default(view_905, [0, 2, 1, 3]); view_905 = None + _scaled_dot_product_cudnn_attention_backward_3 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_293, permute_135, permute_136, permute_137, getitem_572, getitem_573, getitem_578, getitem_579, None, None, None, 8192, 8192, 0.0, True); permute_293 = permute_135 = permute_136 = permute_137 = getitem_572 = getitem_573 = getitem_578 = getitem_579 = None + getitem_873 = _scaled_dot_product_cudnn_attention_backward_3[0] + getitem_874 = _scaled_dot_product_cudnn_attention_backward_3[1] + getitem_875 = _scaled_dot_product_cudnn_attention_backward_3[2]; _scaled_dot_product_cudnn_attention_backward_3 = None + permute_294 = torch.ops.aten.permute.default(getitem_875, [0, 2, 1, 3]); getitem_875 = None + permute_295 = torch.ops.aten.permute.default(getitem_874, [0, 2, 1, 3]); getitem_874 = None + permute_296 = torch.ops.aten.permute.default(getitem_873, [0, 2, 1, 3]); getitem_873 = None + view_1252 = torch.ops.aten.view.default(permute_294, [2, 8192, 1, 4, 128]); permute_294 = None + sum_23 = torch.ops.aten.sum.dim_IntList(view_1252, [3], True); view_1252 = None + squeeze_6 = torch.ops.aten.squeeze.dim(sum_23, 3); sum_23 = None + view_1253 = torch.ops.aten.view.default(permute_295, [2, 8192, 1, 4, 128]); permute_295 = None + sum_24 = torch.ops.aten.sum.dim_IntList(view_1253, [3], True); view_1253 = None + squeeze_7 = torch.ops.aten.squeeze.dim(sum_24, 3); sum_24 = None + convert_element_type_737 = torch.ops.prims.convert_element_type.default(squeeze_7, torch.float32); squeeze_7 = None + convert_element_type_738 = torch.ops.prims.convert_element_type.default(permute_296, torch.float32); permute_296 = None + view_1254 = torch.ops.aten.view.default(convert_element_type_737, [2, 8192, 1, 64, 2]); convert_element_type_737 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_1254); view_1254 = None + mul_208 = torch.ops.aten.mul.Tensor(view_as_complex_38, _conj); view_as_complex_38 = None + view_1255 = torch.ops.aten.view.default(convert_element_type_738, [2, 8192, 4, 64, 2]); convert_element_type_738 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_1255); view_1255 = None + mul_209 = torch.ops.aten.mul.Tensor(view_as_complex_39, _conj); view_as_complex_39 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_208); mul_208 = None + view_1256 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 1, 128]); view_as_real_38 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_1256, torch.bfloat16); view_1256 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_209); mul_209 = None + view_1257 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 4, 128]); view_as_real_39 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_1257, torch.bfloat16); view_1257 = None + view_1258 = torch.ops.aten.view.default(squeeze_6, [2, 8192, 128]); squeeze_6 = None + view_1259 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 128]); convert_element_type_739 = None + view_1260 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 512]); convert_element_type_740 = None + view_1261 = torch.ops.aten.view.default(view_1258, [16384, 128]); view_1258 = None + permute_297 = torch.ops.aten.permute.default(view_1261, [1, 0]) + mm_165 = torch.ops.aten.mm.default(permute_297, view_879); permute_297 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16); primals_115 = None + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 8, '0'); convert_element_type_406 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + permute_299 = torch.ops.aten.permute.default(permute_134, [1, 0]); permute_134 = None + mm_166 = torch.ops.aten.mm.default(view_1261, permute_299); view_1261 = permute_299 = None + view_1262 = torch.ops.aten.view.default(mm_166, [2, 8192, 4096]); mm_166 = None + convert_element_type_745 = torch.ops.prims.convert_element_type.default(mm_165, torch.float32); mm_165 = None + reduce_scatter_tensor_75 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_745, 'avg', 8, '0'); convert_element_type_745 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_75); reduce_scatter_tensor_75 = None + view_1263 = torch.ops.aten.view.default(view_1259, [16384, 128]); view_1259 = None + permute_301 = torch.ops.aten.permute.default(view_1263, [1, 0]) + mm_167 = torch.ops.aten.mm.default(permute_301, view_879); permute_301 = None + permute_303 = torch.ops.aten.permute.default(permute_133, [1, 0]); permute_133 = None + mm_168 = torch.ops.aten.mm.default(view_1263, permute_303); view_1263 = permute_303 = None + view_1264 = torch.ops.aten.view.default(mm_168, [2, 8192, 4096]); mm_168 = None + add_90 = torch.ops.aten.add.Tensor(view_1262, view_1264); view_1262 = view_1264 = None + convert_element_type_750 = torch.ops.prims.convert_element_type.default(mm_167, torch.float32); mm_167 = None + reduce_scatter_tensor_76 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_750, 'avg', 8, '0'); convert_element_type_750 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_76); reduce_scatter_tensor_76 = None + view_1265 = torch.ops.aten.view.default(view_1260, [16384, 512]); view_1260 = None + permute_305 = torch.ops.aten.permute.default(view_1265, [1, 0]) + mm_169 = torch.ops.aten.mm.default(permute_305, view_879); permute_305 = view_879 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16); primals_113 = None + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 8, '0'); convert_element_type_400 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + permute_307 = torch.ops.aten.permute.default(permute_132, [1, 0]); permute_132 = None + mm_170 = torch.ops.aten.mm.default(view_1265, permute_307); view_1265 = permute_307 = None + view_1266 = torch.ops.aten.view.default(mm_170, [2, 8192, 4096]); mm_170 = None + add_91 = torch.ops.aten.add.Tensor(add_90, view_1266); add_90 = view_1266 = None + convert_element_type_755 = torch.ops.prims.convert_element_type.default(mm_169, torch.float32); mm_169 = None + reduce_scatter_tensor_77 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_755, 'avg', 8, '0'); convert_element_type_755 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_77); reduce_scatter_tensor_77 = None + split_90 = torch.ops.aten.split.Tensor(add_91, 1024, 1); add_91 = None + getitem_876 = split_90[0] + getitem_877 = split_90[1] + getitem_878 = split_90[2] + getitem_879 = split_90[3] + getitem_880 = split_90[4] + getitem_881 = split_90[5] + getitem_882 = split_90[6] + getitem_883 = split_90[7]; split_90 = None + cat_82 = torch.ops.aten.cat.default([getitem_876, getitem_877, getitem_878, getitem_879, getitem_880, getitem_881, getitem_882, getitem_883]); getitem_876 = getitem_877 = getitem_878 = getitem_879 = getitem_880 = getitem_881 = getitem_882 = getitem_883 = None + reduce_scatter_tensor_78 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_82, 'sum', 8, '1'); cat_82 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_78); reduce_scatter_tensor_78 = None + convert_element_type_756 = torch.ops.prims.convert_element_type.default(wait_tensor_274, torch.float32); wait_tensor_274 = None + convert_element_type_758 = torch.ops.prims.convert_element_type.default(wait_tensor_158, torch.float32); wait_tensor_158 = None + mul_210 = torch.ops.aten.mul.Tensor(convert_element_type_756, convert_element_type_758); convert_element_type_758 = None + mul_212 = torch.ops.aten.mul.Tensor(mul_96, mul_210) + sum_25 = torch.ops.aten.sum.dim_IntList(mul_212, [2], True); mul_212 = None + div_8 = torch.ops.aten.div.Tensor(mul_96, 4096) + mul_213 = torch.ops.aten.mul.Tensor(div_8, sum_25); div_8 = sum_25 = None + sub_13 = torch.ops.aten.sub.Tensor(mul_210, mul_213); mul_210 = mul_213 = None + mul_214 = torch.ops.aten.mul.Tensor(sub_13, rsqrt_24); sub_13 = rsqrt_24 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_756, mul_96); convert_element_type_756 = mul_96 = None + sum_26 = torch.ops.aten.sum.dim_IntList(mul_215, [0, 1]); mul_215 = None + convert_element_type_759 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(sum_26, torch.bfloat16); sum_26 = None + all_reduce_8 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_760, 'sum', '1'); convert_element_type_760 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_8); all_reduce_8 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(wait_tensor_275, torch.float32); wait_tensor_275 = None + reduce_scatter_tensor_79 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_761, 'avg', 8, '0'); convert_element_type_761 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_79); reduce_scatter_tensor_79 = None + add_92 = torch.ops.aten.add.Tensor(add_89, convert_element_type_759); add_89 = convert_element_type_759 = None + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_92, 8, '1') + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + split_91 = torch.ops.aten.split.Tensor(wait_tensor_277, 2); wait_tensor_277 = None + getitem_884 = split_91[0] + getitem_885 = split_91[1] + getitem_886 = split_91[2] + getitem_887 = split_91[3] + getitem_888 = split_91[4] + getitem_889 = split_91[5] + getitem_890 = split_91[6] + getitem_891 = split_91[7]; split_91 = None + cat_83 = torch.ops.aten.cat.default([getitem_884, getitem_885, getitem_886, getitem_887, getitem_888, getitem_889, getitem_890, getitem_891], 1); getitem_884 = getitem_885 = getitem_886 = getitem_887 = getitem_888 = getitem_889 = getitem_890 = getitem_891 = None + view_1267 = torch.ops.aten.view.default(cat_83, [16384, 4096]); cat_83 = None + permute_309 = torch.ops.aten.permute.default(view_1267, [1, 0]) + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23); reduce_scatter_tensor_23 = None + add_45 = torch.ops.aten.add.Tensor(add_43, wait_tensor_151); wait_tensor_151 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16); primals_108 = None + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 8, '0'); convert_element_type_383 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32); add_45 = None + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_152) + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_385, 8, '1'); convert_element_type_385 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + split_55 = torch.ops.aten.split.Tensor(wait_tensor_153, 2); wait_tensor_153 = None + getitem_548 = split_55[0] + getitem_549 = split_55[1] + getitem_550 = split_55[2] + getitem_551 = split_55[3] + getitem_552 = split_55[4] + getitem_553 = split_55[5] + getitem_554 = split_55[6] + getitem_555 = split_55[7]; split_55 = None + cat_47 = torch.ops.aten.cat.default([getitem_548, getitem_549, getitem_550, getitem_551, getitem_552, getitem_553, getitem_554, getitem_555], 1); getitem_548 = getitem_549 = getitem_550 = getitem_551 = getitem_552 = getitem_553 = getitem_554 = getitem_555 = None + view_852 = torch.ops.aten.view.default(cat_47, [16384, 4096]); cat_47 = None + view_853 = torch.ops.aten.view.default(mm_81, [2, 8192, 1792]); mm_81 = None + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_853, torch.float32); view_853 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16); primals_110 = None + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 8, '0'); convert_element_type_391 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + mm_82 = torch.ops.aten.mm.default(view_852, permute_130) + view_860 = torch.ops.aten.view.default(mm_82, [2, 8192, 1792]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_860) + view_867 = torch.ops.aten.view.default(mul_95, [16384, 1792]); mul_95 = None + mm_171 = torch.ops.aten.mm.default(permute_309, view_867); permute_309 = view_867 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16); primals_111 = None + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 8, '0'); convert_element_type_394 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + permute_311 = torch.ops.aten.permute.default(permute_131, [1, 0]); permute_131 = None + mm_172 = torch.ops.aten.mm.default(view_1267, permute_311); view_1267 = permute_311 = None + view_1268 = torch.ops.aten.view.default(mm_172, [2, 8192, 1792]); mm_172 = None + convert_element_type_766 = torch.ops.prims.convert_element_type.default(mm_171, torch.float32); mm_171 = None + reduce_scatter_tensor_80 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_766, 'avg', 8, '0'); convert_element_type_766 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_80); reduce_scatter_tensor_80 = None + mul_216 = torch.ops.aten.mul.Tensor(view_1268, convert_element_type_390); convert_element_type_390 = None + mul_217 = torch.ops.aten.mul.Tensor(view_1268, view_860); view_1268 = view_860 = None + view_1269 = torch.ops.aten.view.default(mul_216, [16384, 1792]); mul_216 = None + permute_313 = torch.ops.aten.permute.default(view_1269, [1, 0]) + mm_173 = torch.ops.aten.mm.default(permute_313, view_852); permute_313 = None + permute_315 = torch.ops.aten.permute.default(permute_130, [1, 0]); permute_130 = None + mm_174 = torch.ops.aten.mm.default(view_1269, permute_315); view_1269 = permute_315 = None + view_1270 = torch.ops.aten.view.default(mm_174, [2, 8192, 4096]); mm_174 = None + convert_element_type_771 = torch.ops.prims.convert_element_type.default(mm_173, torch.float32); mm_173 = None + reduce_scatter_tensor_81 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_771, 'avg', 8, '0'); convert_element_type_771 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_81); reduce_scatter_tensor_81 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(mul_217, torch.float32); mul_217 = None + neg_4 = torch.ops.aten.neg.default(convert_element_type_389) + exp_4 = torch.ops.aten.exp.default(neg_4); neg_4 = None + add_93 = torch.ops.aten.add.Tensor(exp_4, 1); exp_4 = None + reciprocal_4 = torch.ops.aten.reciprocal.default(add_93); add_93 = None + mul_218 = torch.ops.aten.mul.Tensor(reciprocal_4, 1); reciprocal_4 = None + mul_219 = torch.ops.aten.mul.Tensor(convert_element_type_772, mul_218); convert_element_type_772 = None + sub_14 = torch.ops.aten.sub.Tensor(1, mul_218); mul_218 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_389, sub_14); convert_element_type_389 = sub_14 = None + add_94 = torch.ops.aten.add.Tensor(mul_220, 1); mul_220 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_219, add_94); mul_219 = add_94 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + view_1271 = torch.ops.aten.view.default(convert_element_type_774, [16384, 1792]); convert_element_type_774 = None + permute_317 = torch.ops.aten.permute.default(view_1271, [1, 0]) + mm_175 = torch.ops.aten.mm.default(permute_317, view_852); permute_317 = view_852 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16); primals_109 = None + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 8, '0'); convert_element_type_386 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_154, [1, 0]); wait_tensor_154 = None + permute_319 = torch.ops.aten.permute.default(permute_129, [1, 0]); permute_129 = None + mm_176 = torch.ops.aten.mm.default(view_1271, permute_319); view_1271 = permute_319 = None + view_1272 = torch.ops.aten.view.default(mm_176, [2, 8192, 4096]); mm_176 = None + add_95 = torch.ops.aten.add.Tensor(view_1270, view_1272); view_1270 = view_1272 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(mm_175, torch.float32); mm_175 = None + reduce_scatter_tensor_82 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_779, 'avg', 8, '0'); convert_element_type_779 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_82); reduce_scatter_tensor_82 = None + split_92 = torch.ops.aten.split.Tensor(add_95, 1024, 1); add_95 = None + getitem_892 = split_92[0] + getitem_893 = split_92[1] + getitem_894 = split_92[2] + getitem_895 = split_92[3] + getitem_896 = split_92[4] + getitem_897 = split_92[5] + getitem_898 = split_92[6] + getitem_899 = split_92[7]; split_92 = None + cat_84 = torch.ops.aten.cat.default([getitem_892, getitem_893, getitem_894, getitem_895, getitem_896, getitem_897, getitem_898, getitem_899]); getitem_892 = getitem_893 = getitem_894 = getitem_895 = getitem_896 = getitem_897 = getitem_898 = getitem_899 = None + reduce_scatter_tensor_83 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_84, 'sum', 8, '1'); cat_84 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_83); reduce_scatter_tensor_83 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(wait_tensor_281, torch.float32); wait_tensor_281 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(wait_tensor_152, torch.float32); wait_tensor_152 = None + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_780, convert_element_type_782); convert_element_type_782 = None + mul_224 = torch.ops.aten.mul.Tensor(mul_92, mul_222) + sum_27 = torch.ops.aten.sum.dim_IntList(mul_224, [2], True); mul_224 = None + div_9 = torch.ops.aten.div.Tensor(mul_92, 4096) + mul_225 = torch.ops.aten.mul.Tensor(div_9, sum_27); div_9 = sum_27 = None + sub_15 = torch.ops.aten.sub.Tensor(mul_222, mul_225); mul_222 = mul_225 = None + mul_226 = torch.ops.aten.mul.Tensor(sub_15, rsqrt_23); sub_15 = rsqrt_23 = None + mul_227 = torch.ops.aten.mul.Tensor(convert_element_type_780, mul_92); convert_element_type_780 = mul_92 = None + sum_28 = torch.ops.aten.sum.dim_IntList(mul_227, [0, 1]); mul_227 = None + convert_element_type_783 = torch.ops.prims.convert_element_type.default(mul_226, torch.bfloat16); mul_226 = None + convert_element_type_784 = torch.ops.prims.convert_element_type.default(sum_28, torch.bfloat16); sum_28 = None + all_reduce_9 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_784, 'sum', '1'); convert_element_type_784 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_9); all_reduce_9 = None + convert_element_type_785 = torch.ops.prims.convert_element_type.default(wait_tensor_282, torch.float32); wait_tensor_282 = None + reduce_scatter_tensor_84 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_785, 'avg', 8, '0'); convert_element_type_785 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_84); reduce_scatter_tensor_84 = None + add_96 = torch.ops.aten.add.Tensor(add_92, convert_element_type_783); add_92 = convert_element_type_783 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_96, 8, '1') + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + split_93 = torch.ops.aten.split.Tensor(wait_tensor_284, 2); wait_tensor_284 = None + getitem_900 = split_93[0] + getitem_901 = split_93[1] + getitem_902 = split_93[2] + getitem_903 = split_93[3] + getitem_904 = split_93[4] + getitem_905 = split_93[5] + getitem_906 = split_93[6] + getitem_907 = split_93[7]; split_93 = None + cat_85 = torch.ops.aten.cat.default([getitem_900, getitem_901, getitem_902, getitem_903, getitem_904, getitem_905, getitem_906, getitem_907], 1); getitem_900 = getitem_901 = getitem_902 = getitem_903 = getitem_904 = getitem_905 = getitem_906 = getitem_907 = None + view_1273 = torch.ops.aten.view.default(cat_85, [16384, 4096]); cat_85 = None + permute_321 = torch.ops.aten.permute.default(view_1273, [1, 0]) + permute_127 = torch.ops.aten.permute.default(getitem_531, [0, 2, 1, 3]) + view_834 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + view_840 = torch.ops.aten.view.default(view_834, [16384, 512]); view_834 = None + mm_177 = torch.ops.aten.mm.default(permute_321, view_840); permute_321 = view_840 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16); primals_107 = None + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 8, '0'); convert_element_type_380 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_150, [1, 0]); wait_tensor_150 = None + permute_323 = torch.ops.aten.permute.default(permute_128, [1, 0]); permute_128 = None + mm_178 = torch.ops.aten.mm.default(view_1273, permute_323); view_1273 = permute_323 = None + view_1274 = torch.ops.aten.view.default(mm_178, [2, 8192, 512]); mm_178 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(mm_177, torch.float32); mm_177 = None + reduce_scatter_tensor_85 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_790, 'avg', 8, '0'); convert_element_type_790 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_85); reduce_scatter_tensor_85 = None + view_1275 = torch.ops.aten.view.default(view_1274, [2, 8192, 4, 128]); view_1274 = None + permute_325 = torch.ops.aten.permute.default(view_1275, [0, 2, 1, 3]); view_1275 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16); primals_103 = None + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 8, '0'); convert_element_type_364 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32); add_43 = None + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_145) + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_366, 8, '1'); convert_element_type_366 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + split_53 = torch.ops.aten.split.Tensor(wait_tensor_146, 2); wait_tensor_146 = None + getitem_523 = split_53[0] + getitem_524 = split_53[1] + getitem_525 = split_53[2] + getitem_526 = split_53[3] + getitem_527 = split_53[4] + getitem_528 = split_53[5] + getitem_529 = split_53[6] + getitem_530 = split_53[7]; split_53 = None + cat_45 = torch.ops.aten.cat.default([getitem_523, getitem_524, getitem_525, getitem_526, getitem_527, getitem_528, getitem_529, getitem_530], 1); getitem_523 = getitem_524 = getitem_525 = getitem_526 = getitem_527 = getitem_528 = getitem_529 = getitem_530 = None + view_807 = torch.ops.aten.view.default(cat_45, [16384, 4096]); cat_45 = None + view_808 = torch.ops.aten.view.default(mm_77, [2, 8192, 512]); mm_77 = None + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16); primals_105 = None + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 8, '0'); convert_element_type_370 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + mm_78 = torch.ops.aten.mm.default(view_807, permute_122) + view_815 = torch.ops.aten.view.default(mm_78, [2, 8192, 128]); mm_78 = None + view_822 = torch.ops.aten.view.default(mm_79, [2, 8192, 128]); mm_79 = None + view_824 = torch.ops.aten.view.default(view_808, [2, 8192, -1, 128]); view_808 = None + view_825 = torch.ops.aten.view.default(view_815, [2, 8192, -1, 128]); view_815 = None + view_826 = torch.ops.aten.view.default(view_822, [2, 8192, -1, 128]); view_822 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_824, torch.float32); view_824 = None + view_827 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 4, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_827); view_827 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_825, torch.float32); view_825 = None + view_828 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 1, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_828); view_828 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_37); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_830 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 4, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_37); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_831 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 1, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_830, torch.bfloat16); view_830 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_831, torch.bfloat16); view_831 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 1, 4, 128]); unsqueeze_22 = None + view_832 = torch.ops.aten.view.default(expand_22, [2, 8192, 4, 128]); expand_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_826, 3); view_826 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 1, 4, 128]); unsqueeze_23 = None + view_833 = torch.ops.aten.view.default(expand_23, [2, 8192, 4, 128]); expand_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_832, [0, 2, 1, 3]); view_832 = None + permute_126 = torch.ops.aten.permute.default(view_833, [0, 2, 1, 3]); view_833 = None + _scaled_dot_product_cudnn_attention_backward_4 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_325, permute_124, permute_125, permute_126, getitem_531, getitem_532, getitem_537, getitem_538, None, None, None, 8192, 8192, 0.0, True); permute_325 = permute_124 = permute_125 = permute_126 = getitem_531 = getitem_532 = getitem_537 = getitem_538 = None + getitem_908 = _scaled_dot_product_cudnn_attention_backward_4[0] + getitem_909 = _scaled_dot_product_cudnn_attention_backward_4[1] + getitem_910 = _scaled_dot_product_cudnn_attention_backward_4[2]; _scaled_dot_product_cudnn_attention_backward_4 = None + permute_326 = torch.ops.aten.permute.default(getitem_910, [0, 2, 1, 3]); getitem_910 = None + permute_327 = torch.ops.aten.permute.default(getitem_909, [0, 2, 1, 3]); getitem_909 = None + permute_328 = torch.ops.aten.permute.default(getitem_908, [0, 2, 1, 3]); getitem_908 = None + view_1276 = torch.ops.aten.view.default(permute_326, [2, 8192, 1, 4, 128]); permute_326 = None + sum_29 = torch.ops.aten.sum.dim_IntList(view_1276, [3], True); view_1276 = None + squeeze_8 = torch.ops.aten.squeeze.dim(sum_29, 3); sum_29 = None + view_1277 = torch.ops.aten.view.default(permute_327, [2, 8192, 1, 4, 128]); permute_327 = None + sum_30 = torch.ops.aten.sum.dim_IntList(view_1277, [3], True); view_1277 = None + squeeze_9 = torch.ops.aten.squeeze.dim(sum_30, 3); sum_30 = None + convert_element_type_791 = torch.ops.prims.convert_element_type.default(squeeze_9, torch.float32); squeeze_9 = None + convert_element_type_792 = torch.ops.prims.convert_element_type.default(permute_328, torch.float32); permute_328 = None + view_1278 = torch.ops.aten.view.default(convert_element_type_791, [2, 8192, 1, 64, 2]); convert_element_type_791 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_1278); view_1278 = None + mul_228 = torch.ops.aten.mul.Tensor(view_as_complex_40, _conj); view_as_complex_40 = None + view_1279 = torch.ops.aten.view.default(convert_element_type_792, [2, 8192, 4, 64, 2]); convert_element_type_792 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_1279); view_1279 = None + mul_229 = torch.ops.aten.mul.Tensor(view_as_complex_41, _conj); view_as_complex_41 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_228); mul_228 = None + view_1280 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 1, 128]); view_as_real_40 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(view_1280, torch.bfloat16); view_1280 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_229); mul_229 = None + view_1281 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 4, 128]); view_as_real_41 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(view_1281, torch.bfloat16); view_1281 = None + view_1282 = torch.ops.aten.view.default(squeeze_8, [2, 8192, 128]); squeeze_8 = None + view_1283 = torch.ops.aten.view.default(convert_element_type_793, [2, 8192, 128]); convert_element_type_793 = None + view_1284 = torch.ops.aten.view.default(convert_element_type_794, [2, 8192, 512]); convert_element_type_794 = None + view_1285 = torch.ops.aten.view.default(view_1282, [16384, 128]); view_1282 = None + permute_329 = torch.ops.aten.permute.default(view_1285, [1, 0]) + mm_179 = torch.ops.aten.mm.default(permute_329, view_807); permute_329 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16); primals_106 = None + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 8, '0'); convert_element_type_373 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + permute_331 = torch.ops.aten.permute.default(permute_123, [1, 0]); permute_123 = None + mm_180 = torch.ops.aten.mm.default(view_1285, permute_331); view_1285 = permute_331 = None + view_1286 = torch.ops.aten.view.default(mm_180, [2, 8192, 4096]); mm_180 = None + convert_element_type_799 = torch.ops.prims.convert_element_type.default(mm_179, torch.float32); mm_179 = None + reduce_scatter_tensor_86 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_799, 'avg', 8, '0'); convert_element_type_799 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_86); reduce_scatter_tensor_86 = None + view_1287 = torch.ops.aten.view.default(view_1283, [16384, 128]); view_1283 = None + permute_333 = torch.ops.aten.permute.default(view_1287, [1, 0]) + mm_181 = torch.ops.aten.mm.default(permute_333, view_807); permute_333 = None + permute_335 = torch.ops.aten.permute.default(permute_122, [1, 0]); permute_122 = None + mm_182 = torch.ops.aten.mm.default(view_1287, permute_335); view_1287 = permute_335 = None + view_1288 = torch.ops.aten.view.default(mm_182, [2, 8192, 4096]); mm_182 = None + add_97 = torch.ops.aten.add.Tensor(view_1286, view_1288); view_1286 = view_1288 = None + convert_element_type_804 = torch.ops.prims.convert_element_type.default(mm_181, torch.float32); mm_181 = None + reduce_scatter_tensor_87 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_804, 'avg', 8, '0'); convert_element_type_804 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_87); reduce_scatter_tensor_87 = None + view_1289 = torch.ops.aten.view.default(view_1284, [16384, 512]); view_1284 = None + permute_337 = torch.ops.aten.permute.default(view_1289, [1, 0]) + mm_183 = torch.ops.aten.mm.default(permute_337, view_807); permute_337 = view_807 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16); primals_104 = None + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 8, '0'); convert_element_type_367 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + permute_339 = torch.ops.aten.permute.default(permute_121, [1, 0]); permute_121 = None + mm_184 = torch.ops.aten.mm.default(view_1289, permute_339); view_1289 = permute_339 = None + view_1290 = torch.ops.aten.view.default(mm_184, [2, 8192, 4096]); mm_184 = None + add_98 = torch.ops.aten.add.Tensor(add_97, view_1290); add_97 = view_1290 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(mm_183, torch.float32); mm_183 = None + reduce_scatter_tensor_88 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_809, 'avg', 8, '0'); convert_element_type_809 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_88); reduce_scatter_tensor_88 = None + split_94 = torch.ops.aten.split.Tensor(add_98, 1024, 1); add_98 = None + getitem_911 = split_94[0] + getitem_912 = split_94[1] + getitem_913 = split_94[2] + getitem_914 = split_94[3] + getitem_915 = split_94[4] + getitem_916 = split_94[5] + getitem_917 = split_94[6] + getitem_918 = split_94[7]; split_94 = None + cat_86 = torch.ops.aten.cat.default([getitem_911, getitem_912, getitem_913, getitem_914, getitem_915, getitem_916, getitem_917, getitem_918]); getitem_911 = getitem_912 = getitem_913 = getitem_914 = getitem_915 = getitem_916 = getitem_917 = getitem_918 = None + reduce_scatter_tensor_89 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_86, 'sum', 8, '1'); cat_86 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_89); reduce_scatter_tensor_89 = None + convert_element_type_810 = torch.ops.prims.convert_element_type.default(wait_tensor_289, torch.float32); wait_tensor_289 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(wait_tensor_145, torch.float32); wait_tensor_145 = None + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_810, convert_element_type_812); convert_element_type_812 = None + mul_232 = torch.ops.aten.mul.Tensor(mul_88, mul_230) + sum_31 = torch.ops.aten.sum.dim_IntList(mul_232, [2], True); mul_232 = None + div_10 = torch.ops.aten.div.Tensor(mul_88, 4096) + mul_233 = torch.ops.aten.mul.Tensor(div_10, sum_31); div_10 = sum_31 = None + sub_16 = torch.ops.aten.sub.Tensor(mul_230, mul_233); mul_230 = mul_233 = None + mul_234 = torch.ops.aten.mul.Tensor(sub_16, rsqrt_22); sub_16 = rsqrt_22 = None + mul_235 = torch.ops.aten.mul.Tensor(convert_element_type_810, mul_88); convert_element_type_810 = mul_88 = None + sum_32 = torch.ops.aten.sum.dim_IntList(mul_235, [0, 1]); mul_235 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(mul_234, torch.bfloat16); mul_234 = None + convert_element_type_814 = torch.ops.prims.convert_element_type.default(sum_32, torch.bfloat16); sum_32 = None + all_reduce_10 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_814, 'sum', '1'); convert_element_type_814 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_10); all_reduce_10 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(wait_tensor_290, torch.float32); wait_tensor_290 = None + reduce_scatter_tensor_90 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_815, 'avg', 8, '0'); convert_element_type_815 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_90); reduce_scatter_tensor_90 = None + add_99 = torch.ops.aten.add.Tensor(add_96, convert_element_type_813); add_96 = convert_element_type_813 = None + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_99, 8, '1') + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + split_95 = torch.ops.aten.split.Tensor(wait_tensor_292, 2); wait_tensor_292 = None + getitem_919 = split_95[0] + getitem_920 = split_95[1] + getitem_921 = split_95[2] + getitem_922 = split_95[3] + getitem_923 = split_95[4] + getitem_924 = split_95[5] + getitem_925 = split_95[6] + getitem_926 = split_95[7]; split_95 = None + cat_87 = torch.ops.aten.cat.default([getitem_919, getitem_920, getitem_921, getitem_922, getitem_923, getitem_924, getitem_925, getitem_926], 1); getitem_919 = getitem_920 = getitem_921 = getitem_922 = getitem_923 = getitem_924 = getitem_925 = getitem_926 = None + view_1291 = torch.ops.aten.view.default(cat_87, [16384, 4096]); cat_87 = None + permute_341 = torch.ops.aten.permute.default(view_1291, [1, 0]) + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21); reduce_scatter_tensor_21 = None + add_41 = torch.ops.aten.add.Tensor(add_39, wait_tensor_138); wait_tensor_138 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16); primals_99 = None + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 8, '0'); convert_element_type_350 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32); add_41 = None + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_139) + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_352, 8, '1'); convert_element_type_352 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + split_51 = torch.ops.aten.split.Tensor(wait_tensor_140, 2); wait_tensor_140 = None + getitem_507 = split_51[0] + getitem_508 = split_51[1] + getitem_509 = split_51[2] + getitem_510 = split_51[3] + getitem_511 = split_51[4] + getitem_512 = split_51[5] + getitem_513 = split_51[6] + getitem_514 = split_51[7]; split_51 = None + cat_43 = torch.ops.aten.cat.default([getitem_507, getitem_508, getitem_509, getitem_510, getitem_511, getitem_512, getitem_513, getitem_514], 1); getitem_507 = getitem_508 = getitem_509 = getitem_510 = getitem_511 = getitem_512 = getitem_513 = getitem_514 = None + view_780 = torch.ops.aten.view.default(cat_43, [16384, 4096]); cat_43 = None + view_781 = torch.ops.aten.view.default(mm_74, [2, 8192, 1792]); mm_74 = None + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_781, torch.float32); view_781 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16); primals_101 = None + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 8, '0'); convert_element_type_358 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + mm_75 = torch.ops.aten.mm.default(view_780, permute_119) + view_788 = torch.ops.aten.view.default(mm_75, [2, 8192, 1792]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_788) + view_795 = torch.ops.aten.view.default(mul_87, [16384, 1792]); mul_87 = None + mm_185 = torch.ops.aten.mm.default(permute_341, view_795); permute_341 = view_795 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16); primals_102 = None + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 8, '0'); convert_element_type_361 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + permute_343 = torch.ops.aten.permute.default(permute_120, [1, 0]); permute_120 = None + mm_186 = torch.ops.aten.mm.default(view_1291, permute_343); view_1291 = permute_343 = None + view_1292 = torch.ops.aten.view.default(mm_186, [2, 8192, 1792]); mm_186 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(mm_185, torch.float32); mm_185 = None + reduce_scatter_tensor_91 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_820, 'avg', 8, '0'); convert_element_type_820 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_91); reduce_scatter_tensor_91 = None + mul_236 = torch.ops.aten.mul.Tensor(view_1292, convert_element_type_357); convert_element_type_357 = None + mul_237 = torch.ops.aten.mul.Tensor(view_1292, view_788); view_1292 = view_788 = None + view_1293 = torch.ops.aten.view.default(mul_236, [16384, 1792]); mul_236 = None + permute_345 = torch.ops.aten.permute.default(view_1293, [1, 0]) + mm_187 = torch.ops.aten.mm.default(permute_345, view_780); permute_345 = None + permute_347 = torch.ops.aten.permute.default(permute_119, [1, 0]); permute_119 = None + mm_188 = torch.ops.aten.mm.default(view_1293, permute_347); view_1293 = permute_347 = None + view_1294 = torch.ops.aten.view.default(mm_188, [2, 8192, 4096]); mm_188 = None + convert_element_type_825 = torch.ops.prims.convert_element_type.default(mm_187, torch.float32); mm_187 = None + reduce_scatter_tensor_92 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_825, 'avg', 8, '0'); convert_element_type_825 = None + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_92); reduce_scatter_tensor_92 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(mul_237, torch.float32); mul_237 = None + neg_5 = torch.ops.aten.neg.default(convert_element_type_356) + exp_5 = torch.ops.aten.exp.default(neg_5); neg_5 = None + add_100 = torch.ops.aten.add.Tensor(exp_5, 1); exp_5 = None + reciprocal_5 = torch.ops.aten.reciprocal.default(add_100); add_100 = None + mul_238 = torch.ops.aten.mul.Tensor(reciprocal_5, 1); reciprocal_5 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_826, mul_238); convert_element_type_826 = None + sub_17 = torch.ops.aten.sub.Tensor(1, mul_238); mul_238 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_356, sub_17); convert_element_type_356 = sub_17 = None + add_101 = torch.ops.aten.add.Tensor(mul_240, 1); mul_240 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_239, add_101); mul_239 = add_101 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + view_1295 = torch.ops.aten.view.default(convert_element_type_828, [16384, 1792]); convert_element_type_828 = None + permute_349 = torch.ops.aten.permute.default(view_1295, [1, 0]) + mm_189 = torch.ops.aten.mm.default(permute_349, view_780); permute_349 = view_780 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16); primals_100 = None + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 8, '0'); convert_element_type_353 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_141, [1, 0]); wait_tensor_141 = None + permute_351 = torch.ops.aten.permute.default(permute_118, [1, 0]); permute_118 = None + mm_190 = torch.ops.aten.mm.default(view_1295, permute_351); view_1295 = permute_351 = None + view_1296 = torch.ops.aten.view.default(mm_190, [2, 8192, 4096]); mm_190 = None + add_102 = torch.ops.aten.add.Tensor(view_1294, view_1296); view_1294 = view_1296 = None + convert_element_type_833 = torch.ops.prims.convert_element_type.default(mm_189, torch.float32); mm_189 = None + reduce_scatter_tensor_93 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_833, 'avg', 8, '0'); convert_element_type_833 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_93); reduce_scatter_tensor_93 = None + split_96 = torch.ops.aten.split.Tensor(add_102, 1024, 1); add_102 = None + getitem_927 = split_96[0] + getitem_928 = split_96[1] + getitem_929 = split_96[2] + getitem_930 = split_96[3] + getitem_931 = split_96[4] + getitem_932 = split_96[5] + getitem_933 = split_96[6] + getitem_934 = split_96[7]; split_96 = None + cat_88 = torch.ops.aten.cat.default([getitem_927, getitem_928, getitem_929, getitem_930, getitem_931, getitem_932, getitem_933, getitem_934]); getitem_927 = getitem_928 = getitem_929 = getitem_930 = getitem_931 = getitem_932 = getitem_933 = getitem_934 = None + reduce_scatter_tensor_94 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_88, 'sum', 8, '1'); cat_88 = None + wait_tensor_296 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_94); reduce_scatter_tensor_94 = None + convert_element_type_834 = torch.ops.prims.convert_element_type.default(wait_tensor_296, torch.float32); wait_tensor_296 = None + convert_element_type_836 = torch.ops.prims.convert_element_type.default(wait_tensor_139, torch.float32); wait_tensor_139 = None + mul_242 = torch.ops.aten.mul.Tensor(convert_element_type_834, convert_element_type_836); convert_element_type_836 = None + mul_244 = torch.ops.aten.mul.Tensor(mul_84, mul_242) + sum_33 = torch.ops.aten.sum.dim_IntList(mul_244, [2], True); mul_244 = None + div_11 = torch.ops.aten.div.Tensor(mul_84, 4096) + mul_245 = torch.ops.aten.mul.Tensor(div_11, sum_33); div_11 = sum_33 = None + sub_18 = torch.ops.aten.sub.Tensor(mul_242, mul_245); mul_242 = mul_245 = None + mul_246 = torch.ops.aten.mul.Tensor(sub_18, rsqrt_21); sub_18 = rsqrt_21 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_834, mul_84); convert_element_type_834 = mul_84 = None + sum_34 = torch.ops.aten.sum.dim_IntList(mul_247, [0, 1]); mul_247 = None + convert_element_type_837 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(sum_34, torch.bfloat16); sum_34 = None + all_reduce_11 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_838, 'sum', '1'); convert_element_type_838 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_11); all_reduce_11 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(wait_tensor_297, torch.float32); wait_tensor_297 = None + reduce_scatter_tensor_95 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_839, 'avg', 8, '0'); convert_element_type_839 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_95); reduce_scatter_tensor_95 = None + add_103 = torch.ops.aten.add.Tensor(add_99, convert_element_type_837); add_99 = convert_element_type_837 = None + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_103, 8, '1') + wait_tensor_299 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + split_97 = torch.ops.aten.split.Tensor(wait_tensor_299, 2); wait_tensor_299 = None + getitem_935 = split_97[0] + getitem_936 = split_97[1] + getitem_937 = split_97[2] + getitem_938 = split_97[3] + getitem_939 = split_97[4] + getitem_940 = split_97[5] + getitem_941 = split_97[6] + getitem_942 = split_97[7]; split_97 = None + cat_89 = torch.ops.aten.cat.default([getitem_935, getitem_936, getitem_937, getitem_938, getitem_939, getitem_940, getitem_941, getitem_942], 1); getitem_935 = getitem_936 = getitem_937 = getitem_938 = getitem_939 = getitem_940 = getitem_941 = getitem_942 = None + view_1297 = torch.ops.aten.view.default(cat_89, [16384, 4096]); cat_89 = None + permute_353 = torch.ops.aten.permute.default(view_1297, [1, 0]) + permute_116 = torch.ops.aten.permute.default(getitem_490, [0, 2, 1, 3]) + view_762 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + view_768 = torch.ops.aten.view.default(view_762, [16384, 512]); view_762 = None + mm_191 = torch.ops.aten.mm.default(permute_353, view_768); permute_353 = view_768 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16); primals_98 = None + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 8, '0'); convert_element_type_347 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + permute_355 = torch.ops.aten.permute.default(permute_117, [1, 0]); permute_117 = None + mm_192 = torch.ops.aten.mm.default(view_1297, permute_355); view_1297 = permute_355 = None + view_1298 = torch.ops.aten.view.default(mm_192, [2, 8192, 512]); mm_192 = None + convert_element_type_844 = torch.ops.prims.convert_element_type.default(mm_191, torch.float32); mm_191 = None + reduce_scatter_tensor_96 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_844, 'avg', 8, '0'); convert_element_type_844 = None + wait_tensor_300 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_96); reduce_scatter_tensor_96 = None + view_1299 = torch.ops.aten.view.default(view_1298, [2, 8192, 4, 128]); view_1298 = None + permute_357 = torch.ops.aten.permute.default(view_1299, [0, 2, 1, 3]); view_1299 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16); primals_94 = None + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 8, '0'); convert_element_type_331 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32); add_39 = None + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_132) + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_333, 8, '1'); convert_element_type_333 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + split_49 = torch.ops.aten.split.Tensor(wait_tensor_133, 2); wait_tensor_133 = None + getitem_482 = split_49[0] + getitem_483 = split_49[1] + getitem_484 = split_49[2] + getitem_485 = split_49[3] + getitem_486 = split_49[4] + getitem_487 = split_49[5] + getitem_488 = split_49[6] + getitem_489 = split_49[7]; split_49 = None + cat_41 = torch.ops.aten.cat.default([getitem_482, getitem_483, getitem_484, getitem_485, getitem_486, getitem_487, getitem_488, getitem_489], 1); getitem_482 = getitem_483 = getitem_484 = getitem_485 = getitem_486 = getitem_487 = getitem_488 = getitem_489 = None + view_735 = torch.ops.aten.view.default(cat_41, [16384, 4096]); cat_41 = None + view_736 = torch.ops.aten.view.default(mm_70, [2, 8192, 512]); mm_70 = None + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16); primals_96 = None + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 8, '0'); convert_element_type_337 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + mm_71 = torch.ops.aten.mm.default(view_735, permute_111) + view_743 = torch.ops.aten.view.default(mm_71, [2, 8192, 128]); mm_71 = None + view_750 = torch.ops.aten.view.default(mm_72, [2, 8192, 128]); mm_72 = None + view_752 = torch.ops.aten.view.default(view_736, [2, 8192, -1, 128]); view_736 = None + view_753 = torch.ops.aten.view.default(view_743, [2, 8192, -1, 128]); view_743 = None + view_754 = torch.ops.aten.view.default(view_750, [2, 8192, -1, 128]); view_750 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_752, torch.float32); view_752 = None + view_755 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 4, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_755); view_755 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_753, torch.float32); view_753 = None + view_756 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 1, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_756); view_756 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_37); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_758 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 4, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_37); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_759 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 1, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_758, torch.bfloat16); view_758 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_759, torch.bfloat16); view_759 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 1, 4, 128]); unsqueeze_20 = None + view_760 = torch.ops.aten.view.default(expand_20, [2, 8192, 4, 128]); expand_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_754, 3); view_754 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 1, 4, 128]); unsqueeze_21 = None + view_761 = torch.ops.aten.view.default(expand_21, [2, 8192, 4, 128]); expand_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_760, [0, 2, 1, 3]); view_760 = None + permute_115 = torch.ops.aten.permute.default(view_761, [0, 2, 1, 3]); view_761 = None + _scaled_dot_product_cudnn_attention_backward_5 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_357, permute_113, permute_114, permute_115, getitem_490, getitem_491, getitem_496, getitem_497, None, None, None, 8192, 8192, 0.0, True); permute_357 = permute_113 = permute_114 = permute_115 = getitem_490 = getitem_491 = getitem_496 = getitem_497 = None + getitem_943 = _scaled_dot_product_cudnn_attention_backward_5[0] + getitem_944 = _scaled_dot_product_cudnn_attention_backward_5[1] + getitem_945 = _scaled_dot_product_cudnn_attention_backward_5[2]; _scaled_dot_product_cudnn_attention_backward_5 = None + permute_358 = torch.ops.aten.permute.default(getitem_945, [0, 2, 1, 3]); getitem_945 = None + permute_359 = torch.ops.aten.permute.default(getitem_944, [0, 2, 1, 3]); getitem_944 = None + permute_360 = torch.ops.aten.permute.default(getitem_943, [0, 2, 1, 3]); getitem_943 = None + view_1300 = torch.ops.aten.view.default(permute_358, [2, 8192, 1, 4, 128]); permute_358 = None + sum_35 = torch.ops.aten.sum.dim_IntList(view_1300, [3], True); view_1300 = None + squeeze_10 = torch.ops.aten.squeeze.dim(sum_35, 3); sum_35 = None + view_1301 = torch.ops.aten.view.default(permute_359, [2, 8192, 1, 4, 128]); permute_359 = None + sum_36 = torch.ops.aten.sum.dim_IntList(view_1301, [3], True); view_1301 = None + squeeze_11 = torch.ops.aten.squeeze.dim(sum_36, 3); sum_36 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(squeeze_11, torch.float32); squeeze_11 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(permute_360, torch.float32); permute_360 = None + view_1302 = torch.ops.aten.view.default(convert_element_type_845, [2, 8192, 1, 64, 2]); convert_element_type_845 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_1302); view_1302 = None + mul_248 = torch.ops.aten.mul.Tensor(view_as_complex_42, _conj); view_as_complex_42 = None + view_1303 = torch.ops.aten.view.default(convert_element_type_846, [2, 8192, 4, 64, 2]); convert_element_type_846 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_1303); view_1303 = None + mul_249 = torch.ops.aten.mul.Tensor(view_as_complex_43, _conj); view_as_complex_43 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_248); mul_248 = None + view_1304 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 1, 128]); view_as_real_42 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(view_1304, torch.bfloat16); view_1304 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_249); mul_249 = None + view_1305 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 4, 128]); view_as_real_43 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(view_1305, torch.bfloat16); view_1305 = None + view_1306 = torch.ops.aten.view.default(squeeze_10, [2, 8192, 128]); squeeze_10 = None + view_1307 = torch.ops.aten.view.default(convert_element_type_847, [2, 8192, 128]); convert_element_type_847 = None + view_1308 = torch.ops.aten.view.default(convert_element_type_848, [2, 8192, 512]); convert_element_type_848 = None + view_1309 = torch.ops.aten.view.default(view_1306, [16384, 128]); view_1306 = None + permute_361 = torch.ops.aten.permute.default(view_1309, [1, 0]) + mm_193 = torch.ops.aten.mm.default(permute_361, view_735); permute_361 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16); primals_97 = None + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 8, '0'); convert_element_type_340 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_136, [1, 0]); wait_tensor_136 = None + permute_363 = torch.ops.aten.permute.default(permute_112, [1, 0]); permute_112 = None + mm_194 = torch.ops.aten.mm.default(view_1309, permute_363); view_1309 = permute_363 = None + view_1310 = torch.ops.aten.view.default(mm_194, [2, 8192, 4096]); mm_194 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(mm_193, torch.float32); mm_193 = None + reduce_scatter_tensor_97 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_853, 'avg', 8, '0'); convert_element_type_853 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_97); reduce_scatter_tensor_97 = None + view_1311 = torch.ops.aten.view.default(view_1307, [16384, 128]); view_1307 = None + permute_365 = torch.ops.aten.permute.default(view_1311, [1, 0]) + mm_195 = torch.ops.aten.mm.default(permute_365, view_735); permute_365 = None + permute_367 = torch.ops.aten.permute.default(permute_111, [1, 0]); permute_111 = None + mm_196 = torch.ops.aten.mm.default(view_1311, permute_367); view_1311 = permute_367 = None + view_1312 = torch.ops.aten.view.default(mm_196, [2, 8192, 4096]); mm_196 = None + add_104 = torch.ops.aten.add.Tensor(view_1310, view_1312); view_1310 = view_1312 = None + convert_element_type_858 = torch.ops.prims.convert_element_type.default(mm_195, torch.float32); mm_195 = None + reduce_scatter_tensor_98 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_858, 'avg', 8, '0'); convert_element_type_858 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_98); reduce_scatter_tensor_98 = None + view_1313 = torch.ops.aten.view.default(view_1308, [16384, 512]); view_1308 = None + permute_369 = torch.ops.aten.permute.default(view_1313, [1, 0]) + mm_197 = torch.ops.aten.mm.default(permute_369, view_735); permute_369 = view_735 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16); primals_95 = None + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 8, '0'); convert_element_type_334 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + permute_371 = torch.ops.aten.permute.default(permute_110, [1, 0]); permute_110 = None + mm_198 = torch.ops.aten.mm.default(view_1313, permute_371); view_1313 = permute_371 = None + view_1314 = torch.ops.aten.view.default(mm_198, [2, 8192, 4096]); mm_198 = None + add_105 = torch.ops.aten.add.Tensor(add_104, view_1314); add_104 = view_1314 = None + convert_element_type_863 = torch.ops.prims.convert_element_type.default(mm_197, torch.float32); mm_197 = None + reduce_scatter_tensor_99 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_863, 'avg', 8, '0'); convert_element_type_863 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_99); reduce_scatter_tensor_99 = None + split_98 = torch.ops.aten.split.Tensor(add_105, 1024, 1); add_105 = None + getitem_946 = split_98[0] + getitem_947 = split_98[1] + getitem_948 = split_98[2] + getitem_949 = split_98[3] + getitem_950 = split_98[4] + getitem_951 = split_98[5] + getitem_952 = split_98[6] + getitem_953 = split_98[7]; split_98 = None + cat_90 = torch.ops.aten.cat.default([getitem_946, getitem_947, getitem_948, getitem_949, getitem_950, getitem_951, getitem_952, getitem_953]); getitem_946 = getitem_947 = getitem_948 = getitem_949 = getitem_950 = getitem_951 = getitem_952 = getitem_953 = None + reduce_scatter_tensor_100 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_90, 'sum', 8, '1'); cat_90 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_100); reduce_scatter_tensor_100 = None + convert_element_type_864 = torch.ops.prims.convert_element_type.default(wait_tensor_304, torch.float32); wait_tensor_304 = None + convert_element_type_866 = torch.ops.prims.convert_element_type.default(wait_tensor_132, torch.float32); wait_tensor_132 = None + mul_250 = torch.ops.aten.mul.Tensor(convert_element_type_864, convert_element_type_866); convert_element_type_866 = None + mul_252 = torch.ops.aten.mul.Tensor(mul_80, mul_250) + sum_37 = torch.ops.aten.sum.dim_IntList(mul_252, [2], True); mul_252 = None + div_12 = torch.ops.aten.div.Tensor(mul_80, 4096) + mul_253 = torch.ops.aten.mul.Tensor(div_12, sum_37); div_12 = sum_37 = None + sub_19 = torch.ops.aten.sub.Tensor(mul_250, mul_253); mul_250 = mul_253 = None + mul_254 = torch.ops.aten.mul.Tensor(sub_19, rsqrt_20); sub_19 = rsqrt_20 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_864, mul_80); convert_element_type_864 = mul_80 = None + sum_38 = torch.ops.aten.sum.dim_IntList(mul_255, [0, 1]); mul_255 = None + convert_element_type_867 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(sum_38, torch.bfloat16); sum_38 = None + all_reduce_12 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_868, 'sum', '1'); convert_element_type_868 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_12); all_reduce_12 = None + convert_element_type_869 = torch.ops.prims.convert_element_type.default(wait_tensor_305, torch.float32); wait_tensor_305 = None + reduce_scatter_tensor_101 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_869, 'avg', 8, '0'); convert_element_type_869 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_101); reduce_scatter_tensor_101 = None + add_106 = torch.ops.aten.add.Tensor(add_103, convert_element_type_867); add_103 = convert_element_type_867 = None + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_106, 8, '1') + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + split_99 = torch.ops.aten.split.Tensor(wait_tensor_307, 2); wait_tensor_307 = None + getitem_954 = split_99[0] + getitem_955 = split_99[1] + getitem_956 = split_99[2] + getitem_957 = split_99[3] + getitem_958 = split_99[4] + getitem_959 = split_99[5] + getitem_960 = split_99[6] + getitem_961 = split_99[7]; split_99 = None + cat_91 = torch.ops.aten.cat.default([getitem_954, getitem_955, getitem_956, getitem_957, getitem_958, getitem_959, getitem_960, getitem_961], 1); getitem_954 = getitem_955 = getitem_956 = getitem_957 = getitem_958 = getitem_959 = getitem_960 = getitem_961 = None + view_1315 = torch.ops.aten.view.default(cat_91, [16384, 4096]); cat_91 = None + permute_373 = torch.ops.aten.permute.default(view_1315, [1, 0]) + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19); reduce_scatter_tensor_19 = None + add_37 = torch.ops.aten.add.Tensor(add_35, wait_tensor_125); wait_tensor_125 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16); primals_90 = None + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 8, '0'); convert_element_type_317 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32); add_37 = None + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_126) + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_319, 8, '1'); convert_element_type_319 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + split_47 = torch.ops.aten.split.Tensor(wait_tensor_127, 2); wait_tensor_127 = None + getitem_466 = split_47[0] + getitem_467 = split_47[1] + getitem_468 = split_47[2] + getitem_469 = split_47[3] + getitem_470 = split_47[4] + getitem_471 = split_47[5] + getitem_472 = split_47[6] + getitem_473 = split_47[7]; split_47 = None + cat_39 = torch.ops.aten.cat.default([getitem_466, getitem_467, getitem_468, getitem_469, getitem_470, getitem_471, getitem_472, getitem_473], 1); getitem_466 = getitem_467 = getitem_468 = getitem_469 = getitem_470 = getitem_471 = getitem_472 = getitem_473 = None + view_708 = torch.ops.aten.view.default(cat_39, [16384, 4096]); cat_39 = None + view_709 = torch.ops.aten.view.default(mm_67, [2, 8192, 1792]); mm_67 = None + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_709, torch.float32); view_709 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16); primals_92 = None + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 8, '0'); convert_element_type_325 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_68 = torch.ops.aten.mm.default(view_708, permute_108) + view_716 = torch.ops.aten.view.default(mm_68, [2, 8192, 1792]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_716) + view_723 = torch.ops.aten.view.default(mul_79, [16384, 1792]); mul_79 = None + mm_199 = torch.ops.aten.mm.default(permute_373, view_723); permute_373 = view_723 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16); primals_93 = None + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 8, '0'); convert_element_type_328 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + permute_375 = torch.ops.aten.permute.default(permute_109, [1, 0]); permute_109 = None + mm_200 = torch.ops.aten.mm.default(view_1315, permute_375); view_1315 = permute_375 = None + view_1316 = torch.ops.aten.view.default(mm_200, [2, 8192, 1792]); mm_200 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(mm_199, torch.float32); mm_199 = None + reduce_scatter_tensor_102 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_874, 'avg', 8, '0'); convert_element_type_874 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_102); reduce_scatter_tensor_102 = None + mul_256 = torch.ops.aten.mul.Tensor(view_1316, convert_element_type_324); convert_element_type_324 = None + mul_257 = torch.ops.aten.mul.Tensor(view_1316, view_716); view_1316 = view_716 = None + view_1317 = torch.ops.aten.view.default(mul_256, [16384, 1792]); mul_256 = None + permute_377 = torch.ops.aten.permute.default(view_1317, [1, 0]) + mm_201 = torch.ops.aten.mm.default(permute_377, view_708); permute_377 = None + permute_379 = torch.ops.aten.permute.default(permute_108, [1, 0]); permute_108 = None + mm_202 = torch.ops.aten.mm.default(view_1317, permute_379); view_1317 = permute_379 = None + view_1318 = torch.ops.aten.view.default(mm_202, [2, 8192, 4096]); mm_202 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(mm_201, torch.float32); mm_201 = None + reduce_scatter_tensor_103 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_879, 'avg', 8, '0'); convert_element_type_879 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_103); reduce_scatter_tensor_103 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_257, torch.float32); mul_257 = None + neg_6 = torch.ops.aten.neg.default(convert_element_type_323) + exp_6 = torch.ops.aten.exp.default(neg_6); neg_6 = None + add_107 = torch.ops.aten.add.Tensor(exp_6, 1); exp_6 = None + reciprocal_6 = torch.ops.aten.reciprocal.default(add_107); add_107 = None + mul_258 = torch.ops.aten.mul.Tensor(reciprocal_6, 1); reciprocal_6 = None + mul_259 = torch.ops.aten.mul.Tensor(convert_element_type_880, mul_258); convert_element_type_880 = None + sub_20 = torch.ops.aten.sub.Tensor(1, mul_258); mul_258 = None + mul_260 = torch.ops.aten.mul.Tensor(convert_element_type_323, sub_20); convert_element_type_323 = sub_20 = None + add_108 = torch.ops.aten.add.Tensor(mul_260, 1); mul_260 = None + mul_261 = torch.ops.aten.mul.Tensor(mul_259, add_108); mul_259 = add_108 = None + convert_element_type_882 = torch.ops.prims.convert_element_type.default(mul_261, torch.bfloat16); mul_261 = None + view_1319 = torch.ops.aten.view.default(convert_element_type_882, [16384, 1792]); convert_element_type_882 = None + permute_381 = torch.ops.aten.permute.default(view_1319, [1, 0]) + mm_203 = torch.ops.aten.mm.default(permute_381, view_708); permute_381 = view_708 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16); primals_91 = None + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 8, '0'); convert_element_type_320 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + permute_383 = torch.ops.aten.permute.default(permute_107, [1, 0]); permute_107 = None + mm_204 = torch.ops.aten.mm.default(view_1319, permute_383); view_1319 = permute_383 = None + view_1320 = torch.ops.aten.view.default(mm_204, [2, 8192, 4096]); mm_204 = None + add_109 = torch.ops.aten.add.Tensor(view_1318, view_1320); view_1318 = view_1320 = None + convert_element_type_887 = torch.ops.prims.convert_element_type.default(mm_203, torch.float32); mm_203 = None + reduce_scatter_tensor_104 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_887, 'avg', 8, '0'); convert_element_type_887 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_104); reduce_scatter_tensor_104 = None + split_100 = torch.ops.aten.split.Tensor(add_109, 1024, 1); add_109 = None + getitem_962 = split_100[0] + getitem_963 = split_100[1] + getitem_964 = split_100[2] + getitem_965 = split_100[3] + getitem_966 = split_100[4] + getitem_967 = split_100[5] + getitem_968 = split_100[6] + getitem_969 = split_100[7]; split_100 = None + cat_92 = torch.ops.aten.cat.default([getitem_962, getitem_963, getitem_964, getitem_965, getitem_966, getitem_967, getitem_968, getitem_969]); getitem_962 = getitem_963 = getitem_964 = getitem_965 = getitem_966 = getitem_967 = getitem_968 = getitem_969 = None + reduce_scatter_tensor_105 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_92, 'sum', 8, '1'); cat_92 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_105); reduce_scatter_tensor_105 = None + convert_element_type_888 = torch.ops.prims.convert_element_type.default(wait_tensor_311, torch.float32); wait_tensor_311 = None + convert_element_type_890 = torch.ops.prims.convert_element_type.default(wait_tensor_126, torch.float32); wait_tensor_126 = None + mul_262 = torch.ops.aten.mul.Tensor(convert_element_type_888, convert_element_type_890); convert_element_type_890 = None + mul_264 = torch.ops.aten.mul.Tensor(mul_76, mul_262) + sum_39 = torch.ops.aten.sum.dim_IntList(mul_264, [2], True); mul_264 = None + div_13 = torch.ops.aten.div.Tensor(mul_76, 4096) + mul_265 = torch.ops.aten.mul.Tensor(div_13, sum_39); div_13 = sum_39 = None + sub_21 = torch.ops.aten.sub.Tensor(mul_262, mul_265); mul_262 = mul_265 = None + mul_266 = torch.ops.aten.mul.Tensor(sub_21, rsqrt_19); sub_21 = rsqrt_19 = None + mul_267 = torch.ops.aten.mul.Tensor(convert_element_type_888, mul_76); convert_element_type_888 = mul_76 = None + sum_40 = torch.ops.aten.sum.dim_IntList(mul_267, [0, 1]); mul_267 = None + convert_element_type_891 = torch.ops.prims.convert_element_type.default(mul_266, torch.bfloat16); mul_266 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(sum_40, torch.bfloat16); sum_40 = None + all_reduce_13 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_892, 'sum', '1'); convert_element_type_892 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_13); all_reduce_13 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(wait_tensor_312, torch.float32); wait_tensor_312 = None + reduce_scatter_tensor_106 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_893, 'avg', 8, '0'); convert_element_type_893 = None + wait_tensor_313 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_106); reduce_scatter_tensor_106 = None + add_110 = torch.ops.aten.add.Tensor(add_106, convert_element_type_891); add_106 = convert_element_type_891 = None + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_110, 8, '1') + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + split_101 = torch.ops.aten.split.Tensor(wait_tensor_314, 2); wait_tensor_314 = None + getitem_970 = split_101[0] + getitem_971 = split_101[1] + getitem_972 = split_101[2] + getitem_973 = split_101[3] + getitem_974 = split_101[4] + getitem_975 = split_101[5] + getitem_976 = split_101[6] + getitem_977 = split_101[7]; split_101 = None + cat_93 = torch.ops.aten.cat.default([getitem_970, getitem_971, getitem_972, getitem_973, getitem_974, getitem_975, getitem_976, getitem_977], 1); getitem_970 = getitem_971 = getitem_972 = getitem_973 = getitem_974 = getitem_975 = getitem_976 = getitem_977 = None + view_1321 = torch.ops.aten.view.default(cat_93, [16384, 4096]); cat_93 = None + permute_385 = torch.ops.aten.permute.default(view_1321, [1, 0]) + permute_105 = torch.ops.aten.permute.default(getitem_449, [0, 2, 1, 3]) + view_690 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + view_696 = torch.ops.aten.view.default(view_690, [16384, 512]); view_690 = None + mm_205 = torch.ops.aten.mm.default(permute_385, view_696); permute_385 = view_696 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16); primals_89 = None + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 8, '0'); convert_element_type_314 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + permute_387 = torch.ops.aten.permute.default(permute_106, [1, 0]); permute_106 = None + mm_206 = torch.ops.aten.mm.default(view_1321, permute_387); view_1321 = permute_387 = None + view_1322 = torch.ops.aten.view.default(mm_206, [2, 8192, 512]); mm_206 = None + convert_element_type_898 = torch.ops.prims.convert_element_type.default(mm_205, torch.float32); mm_205 = None + reduce_scatter_tensor_107 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_898, 'avg', 8, '0'); convert_element_type_898 = None + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_107); reduce_scatter_tensor_107 = None + view_1323 = torch.ops.aten.view.default(view_1322, [2, 8192, 4, 128]); view_1322 = None + permute_389 = torch.ops.aten.permute.default(view_1323, [0, 2, 1, 3]); view_1323 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16); primals_85 = None + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 8, '0'); convert_element_type_298 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32); add_35 = None + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_119) + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_300, 8, '1'); convert_element_type_300 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + split_45 = torch.ops.aten.split.Tensor(wait_tensor_120, 2); wait_tensor_120 = None + getitem_441 = split_45[0] + getitem_442 = split_45[1] + getitem_443 = split_45[2] + getitem_444 = split_45[3] + getitem_445 = split_45[4] + getitem_446 = split_45[5] + getitem_447 = split_45[6] + getitem_448 = split_45[7]; split_45 = None + cat_37 = torch.ops.aten.cat.default([getitem_441, getitem_442, getitem_443, getitem_444, getitem_445, getitem_446, getitem_447, getitem_448], 1); getitem_441 = getitem_442 = getitem_443 = getitem_444 = getitem_445 = getitem_446 = getitem_447 = getitem_448 = None + view_663 = torch.ops.aten.view.default(cat_37, [16384, 4096]); cat_37 = None + view_664 = torch.ops.aten.view.default(mm_63, [2, 8192, 512]); mm_63 = None + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16); primals_87 = None + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 8, '0'); convert_element_type_304 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + mm_64 = torch.ops.aten.mm.default(view_663, permute_100) + view_671 = torch.ops.aten.view.default(mm_64, [2, 8192, 128]); mm_64 = None + view_678 = torch.ops.aten.view.default(mm_65, [2, 8192, 128]); mm_65 = None + view_680 = torch.ops.aten.view.default(view_664, [2, 8192, -1, 128]); view_664 = None + view_681 = torch.ops.aten.view.default(view_671, [2, 8192, -1, 128]); view_671 = None + view_682 = torch.ops.aten.view.default(view_678, [2, 8192, -1, 128]); view_678 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_680, torch.float32); view_680 = None + view_683 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 4, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_683); view_683 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_681, torch.float32); view_681 = None + view_684 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 1, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_684); view_684 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_37); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_686 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 4, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_37); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_687 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 1, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_686, torch.bfloat16); view_686 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_687, torch.bfloat16); view_687 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 1, 4, 128]); unsqueeze_18 = None + view_688 = torch.ops.aten.view.default(expand_18, [2, 8192, 4, 128]); expand_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_682, 3); view_682 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 1, 4, 128]); unsqueeze_19 = None + view_689 = torch.ops.aten.view.default(expand_19, [2, 8192, 4, 128]); expand_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_688, [0, 2, 1, 3]); view_688 = None + permute_104 = torch.ops.aten.permute.default(view_689, [0, 2, 1, 3]); view_689 = None + _scaled_dot_product_cudnn_attention_backward_6 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_389, permute_102, permute_103, permute_104, getitem_449, getitem_450, getitem_455, getitem_456, None, None, None, 8192, 8192, 0.0, True); permute_389 = permute_102 = permute_103 = permute_104 = getitem_449 = getitem_450 = getitem_455 = getitem_456 = None + getitem_978 = _scaled_dot_product_cudnn_attention_backward_6[0] + getitem_979 = _scaled_dot_product_cudnn_attention_backward_6[1] + getitem_980 = _scaled_dot_product_cudnn_attention_backward_6[2]; _scaled_dot_product_cudnn_attention_backward_6 = None + permute_390 = torch.ops.aten.permute.default(getitem_980, [0, 2, 1, 3]); getitem_980 = None + permute_391 = torch.ops.aten.permute.default(getitem_979, [0, 2, 1, 3]); getitem_979 = None + permute_392 = torch.ops.aten.permute.default(getitem_978, [0, 2, 1, 3]); getitem_978 = None + view_1324 = torch.ops.aten.view.default(permute_390, [2, 8192, 1, 4, 128]); permute_390 = None + sum_41 = torch.ops.aten.sum.dim_IntList(view_1324, [3], True); view_1324 = None + squeeze_12 = torch.ops.aten.squeeze.dim(sum_41, 3); sum_41 = None + view_1325 = torch.ops.aten.view.default(permute_391, [2, 8192, 1, 4, 128]); permute_391 = None + sum_42 = torch.ops.aten.sum.dim_IntList(view_1325, [3], True); view_1325 = None + squeeze_13 = torch.ops.aten.squeeze.dim(sum_42, 3); sum_42 = None + convert_element_type_899 = torch.ops.prims.convert_element_type.default(squeeze_13, torch.float32); squeeze_13 = None + convert_element_type_900 = torch.ops.prims.convert_element_type.default(permute_392, torch.float32); permute_392 = None + view_1326 = torch.ops.aten.view.default(convert_element_type_899, [2, 8192, 1, 64, 2]); convert_element_type_899 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_1326); view_1326 = None + mul_268 = torch.ops.aten.mul.Tensor(view_as_complex_44, _conj); view_as_complex_44 = None + view_1327 = torch.ops.aten.view.default(convert_element_type_900, [2, 8192, 4, 64, 2]); convert_element_type_900 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_1327); view_1327 = None + mul_269 = torch.ops.aten.mul.Tensor(view_as_complex_45, _conj); view_as_complex_45 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_268); mul_268 = None + view_1328 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 1, 128]); view_as_real_44 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(view_1328, torch.bfloat16); view_1328 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_269); mul_269 = None + view_1329 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 4, 128]); view_as_real_45 = None + convert_element_type_902 = torch.ops.prims.convert_element_type.default(view_1329, torch.bfloat16); view_1329 = None + view_1330 = torch.ops.aten.view.default(squeeze_12, [2, 8192, 128]); squeeze_12 = None + view_1331 = torch.ops.aten.view.default(convert_element_type_901, [2, 8192, 128]); convert_element_type_901 = None + view_1332 = torch.ops.aten.view.default(convert_element_type_902, [2, 8192, 512]); convert_element_type_902 = None + view_1333 = torch.ops.aten.view.default(view_1330, [16384, 128]); view_1330 = None + permute_393 = torch.ops.aten.permute.default(view_1333, [1, 0]) + mm_207 = torch.ops.aten.mm.default(permute_393, view_663); permute_393 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16); primals_88 = None + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 8, '0'); convert_element_type_307 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_123, [1, 0]); wait_tensor_123 = None + permute_395 = torch.ops.aten.permute.default(permute_101, [1, 0]); permute_101 = None + mm_208 = torch.ops.aten.mm.default(view_1333, permute_395); view_1333 = permute_395 = None + view_1334 = torch.ops.aten.view.default(mm_208, [2, 8192, 4096]); mm_208 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(mm_207, torch.float32); mm_207 = None + reduce_scatter_tensor_108 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_907, 'avg', 8, '0'); convert_element_type_907 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_108); reduce_scatter_tensor_108 = None + view_1335 = torch.ops.aten.view.default(view_1331, [16384, 128]); view_1331 = None + permute_397 = torch.ops.aten.permute.default(view_1335, [1, 0]) + mm_209 = torch.ops.aten.mm.default(permute_397, view_663); permute_397 = None + permute_399 = torch.ops.aten.permute.default(permute_100, [1, 0]); permute_100 = None + mm_210 = torch.ops.aten.mm.default(view_1335, permute_399); view_1335 = permute_399 = None + view_1336 = torch.ops.aten.view.default(mm_210, [2, 8192, 4096]); mm_210 = None + add_111 = torch.ops.aten.add.Tensor(view_1334, view_1336); view_1334 = view_1336 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(mm_209, torch.float32); mm_209 = None + reduce_scatter_tensor_109 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_912, 'avg', 8, '0'); convert_element_type_912 = None + wait_tensor_317 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_109); reduce_scatter_tensor_109 = None + view_1337 = torch.ops.aten.view.default(view_1332, [16384, 512]); view_1332 = None + permute_401 = torch.ops.aten.permute.default(view_1337, [1, 0]) + mm_211 = torch.ops.aten.mm.default(permute_401, view_663); permute_401 = view_663 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16); primals_86 = None + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 8, '0'); convert_element_type_301 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + permute_403 = torch.ops.aten.permute.default(permute_99, [1, 0]); permute_99 = None + mm_212 = torch.ops.aten.mm.default(view_1337, permute_403); view_1337 = permute_403 = None + view_1338 = torch.ops.aten.view.default(mm_212, [2, 8192, 4096]); mm_212 = None + add_112 = torch.ops.aten.add.Tensor(add_111, view_1338); add_111 = view_1338 = None + convert_element_type_917 = torch.ops.prims.convert_element_type.default(mm_211, torch.float32); mm_211 = None + reduce_scatter_tensor_110 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_917, 'avg', 8, '0'); convert_element_type_917 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_110); reduce_scatter_tensor_110 = None + split_102 = torch.ops.aten.split.Tensor(add_112, 1024, 1); add_112 = None + getitem_981 = split_102[0] + getitem_982 = split_102[1] + getitem_983 = split_102[2] + getitem_984 = split_102[3] + getitem_985 = split_102[4] + getitem_986 = split_102[5] + getitem_987 = split_102[6] + getitem_988 = split_102[7]; split_102 = None + cat_94 = torch.ops.aten.cat.default([getitem_981, getitem_982, getitem_983, getitem_984, getitem_985, getitem_986, getitem_987, getitem_988]); getitem_981 = getitem_982 = getitem_983 = getitem_984 = getitem_985 = getitem_986 = getitem_987 = getitem_988 = None + reduce_scatter_tensor_111 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_94, 'sum', 8, '1'); cat_94 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_111); reduce_scatter_tensor_111 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(wait_tensor_319, torch.float32); wait_tensor_319 = None + convert_element_type_920 = torch.ops.prims.convert_element_type.default(wait_tensor_119, torch.float32); wait_tensor_119 = None + mul_270 = torch.ops.aten.mul.Tensor(convert_element_type_918, convert_element_type_920); convert_element_type_920 = None + mul_272 = torch.ops.aten.mul.Tensor(mul_72, mul_270) + sum_43 = torch.ops.aten.sum.dim_IntList(mul_272, [2], True); mul_272 = None + div_14 = torch.ops.aten.div.Tensor(mul_72, 4096) + mul_273 = torch.ops.aten.mul.Tensor(div_14, sum_43); div_14 = sum_43 = None + sub_22 = torch.ops.aten.sub.Tensor(mul_270, mul_273); mul_270 = mul_273 = None + mul_274 = torch.ops.aten.mul.Tensor(sub_22, rsqrt_18); sub_22 = rsqrt_18 = None + mul_275 = torch.ops.aten.mul.Tensor(convert_element_type_918, mul_72); convert_element_type_918 = mul_72 = None + sum_44 = torch.ops.aten.sum.dim_IntList(mul_275, [0, 1]); mul_275 = None + convert_element_type_921 = torch.ops.prims.convert_element_type.default(mul_274, torch.bfloat16); mul_274 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(sum_44, torch.bfloat16); sum_44 = None + all_reduce_14 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_922, 'sum', '1'); convert_element_type_922 = None + wait_tensor_320 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_14); all_reduce_14 = None + convert_element_type_923 = torch.ops.prims.convert_element_type.default(wait_tensor_320, torch.float32); wait_tensor_320 = None + reduce_scatter_tensor_112 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_923, 'avg', 8, '0'); convert_element_type_923 = None + wait_tensor_321 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_112); reduce_scatter_tensor_112 = None + add_113 = torch.ops.aten.add.Tensor(add_110, convert_element_type_921); add_110 = convert_element_type_921 = None + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_113, 8, '1') + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + split_103 = torch.ops.aten.split.Tensor(wait_tensor_322, 2); wait_tensor_322 = None + getitem_989 = split_103[0] + getitem_990 = split_103[1] + getitem_991 = split_103[2] + getitem_992 = split_103[3] + getitem_993 = split_103[4] + getitem_994 = split_103[5] + getitem_995 = split_103[6] + getitem_996 = split_103[7]; split_103 = None + cat_95 = torch.ops.aten.cat.default([getitem_989, getitem_990, getitem_991, getitem_992, getitem_993, getitem_994, getitem_995, getitem_996], 1); getitem_989 = getitem_990 = getitem_991 = getitem_992 = getitem_993 = getitem_994 = getitem_995 = getitem_996 = None + view_1339 = torch.ops.aten.view.default(cat_95, [16384, 4096]); cat_95 = None + permute_405 = torch.ops.aten.permute.default(view_1339, [1, 0]) + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17); reduce_scatter_tensor_17 = None + add_33 = torch.ops.aten.add.Tensor(add_31, wait_tensor_112); wait_tensor_112 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16); primals_81 = None + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 8, '0'); convert_element_type_284 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32); add_33 = None + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_113) + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_286, 8, '1'); convert_element_type_286 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + split_43 = torch.ops.aten.split.Tensor(wait_tensor_114, 2); wait_tensor_114 = None + getitem_425 = split_43[0] + getitem_426 = split_43[1] + getitem_427 = split_43[2] + getitem_428 = split_43[3] + getitem_429 = split_43[4] + getitem_430 = split_43[5] + getitem_431 = split_43[6] + getitem_432 = split_43[7]; split_43 = None + cat_35 = torch.ops.aten.cat.default([getitem_425, getitem_426, getitem_427, getitem_428, getitem_429, getitem_430, getitem_431, getitem_432], 1); getitem_425 = getitem_426 = getitem_427 = getitem_428 = getitem_429 = getitem_430 = getitem_431 = getitem_432 = None + view_636 = torch.ops.aten.view.default(cat_35, [16384, 4096]); cat_35 = None + view_637 = torch.ops.aten.view.default(mm_60, [2, 8192, 1792]); mm_60 = None + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_637, torch.float32); view_637 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16); primals_83 = None + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 8, '0'); convert_element_type_292 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_61 = torch.ops.aten.mm.default(view_636, permute_97) + view_644 = torch.ops.aten.view.default(mm_61, [2, 8192, 1792]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_644) + view_651 = torch.ops.aten.view.default(mul_71, [16384, 1792]); mul_71 = None + mm_213 = torch.ops.aten.mm.default(permute_405, view_651); permute_405 = view_651 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16); primals_84 = None + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 8, '0'); convert_element_type_295 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + permute_407 = torch.ops.aten.permute.default(permute_98, [1, 0]); permute_98 = None + mm_214 = torch.ops.aten.mm.default(view_1339, permute_407); view_1339 = permute_407 = None + view_1340 = torch.ops.aten.view.default(mm_214, [2, 8192, 1792]); mm_214 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(mm_213, torch.float32); mm_213 = None + reduce_scatter_tensor_113 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_928, 'avg', 8, '0'); convert_element_type_928 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_113); reduce_scatter_tensor_113 = None + mul_276 = torch.ops.aten.mul.Tensor(view_1340, convert_element_type_291); convert_element_type_291 = None + mul_277 = torch.ops.aten.mul.Tensor(view_1340, view_644); view_1340 = view_644 = None + view_1341 = torch.ops.aten.view.default(mul_276, [16384, 1792]); mul_276 = None + permute_409 = torch.ops.aten.permute.default(view_1341, [1, 0]) + mm_215 = torch.ops.aten.mm.default(permute_409, view_636); permute_409 = None + permute_411 = torch.ops.aten.permute.default(permute_97, [1, 0]); permute_97 = None + mm_216 = torch.ops.aten.mm.default(view_1341, permute_411); view_1341 = permute_411 = None + view_1342 = torch.ops.aten.view.default(mm_216, [2, 8192, 4096]); mm_216 = None + convert_element_type_933 = torch.ops.prims.convert_element_type.default(mm_215, torch.float32); mm_215 = None + reduce_scatter_tensor_114 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_933, 'avg', 8, '0'); convert_element_type_933 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_114); reduce_scatter_tensor_114 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(mul_277, torch.float32); mul_277 = None + neg_7 = torch.ops.aten.neg.default(convert_element_type_290) + exp_7 = torch.ops.aten.exp.default(neg_7); neg_7 = None + add_114 = torch.ops.aten.add.Tensor(exp_7, 1); exp_7 = None + reciprocal_7 = torch.ops.aten.reciprocal.default(add_114); add_114 = None + mul_278 = torch.ops.aten.mul.Tensor(reciprocal_7, 1); reciprocal_7 = None + mul_279 = torch.ops.aten.mul.Tensor(convert_element_type_934, mul_278); convert_element_type_934 = None + sub_23 = torch.ops.aten.sub.Tensor(1, mul_278); mul_278 = None + mul_280 = torch.ops.aten.mul.Tensor(convert_element_type_290, sub_23); convert_element_type_290 = sub_23 = None + add_115 = torch.ops.aten.add.Tensor(mul_280, 1); mul_280 = None + mul_281 = torch.ops.aten.mul.Tensor(mul_279, add_115); mul_279 = add_115 = None + convert_element_type_936 = torch.ops.prims.convert_element_type.default(mul_281, torch.bfloat16); mul_281 = None + view_1343 = torch.ops.aten.view.default(convert_element_type_936, [16384, 1792]); convert_element_type_936 = None + permute_413 = torch.ops.aten.permute.default(view_1343, [1, 0]) + mm_217 = torch.ops.aten.mm.default(permute_413, view_636); permute_413 = view_636 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16); primals_82 = None + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 8, '0'); convert_element_type_287 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + permute_415 = torch.ops.aten.permute.default(permute_96, [1, 0]); permute_96 = None + mm_218 = torch.ops.aten.mm.default(view_1343, permute_415); view_1343 = permute_415 = None + view_1344 = torch.ops.aten.view.default(mm_218, [2, 8192, 4096]); mm_218 = None + add_116 = torch.ops.aten.add.Tensor(view_1342, view_1344); view_1342 = view_1344 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(mm_217, torch.float32); mm_217 = None + reduce_scatter_tensor_115 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_941, 'avg', 8, '0'); convert_element_type_941 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_115); reduce_scatter_tensor_115 = None + split_104 = torch.ops.aten.split.Tensor(add_116, 1024, 1); add_116 = None + getitem_997 = split_104[0] + getitem_998 = split_104[1] + getitem_999 = split_104[2] + getitem_1000 = split_104[3] + getitem_1001 = split_104[4] + getitem_1002 = split_104[5] + getitem_1003 = split_104[6] + getitem_1004 = split_104[7]; split_104 = None + cat_96 = torch.ops.aten.cat.default([getitem_997, getitem_998, getitem_999, getitem_1000, getitem_1001, getitem_1002, getitem_1003, getitem_1004]); getitem_997 = getitem_998 = getitem_999 = getitem_1000 = getitem_1001 = getitem_1002 = getitem_1003 = getitem_1004 = None + reduce_scatter_tensor_116 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_96, 'sum', 8, '1'); cat_96 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_116); reduce_scatter_tensor_116 = None + convert_element_type_942 = torch.ops.prims.convert_element_type.default(wait_tensor_326, torch.float32); wait_tensor_326 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(wait_tensor_113, torch.float32); wait_tensor_113 = None + mul_282 = torch.ops.aten.mul.Tensor(convert_element_type_942, convert_element_type_944); convert_element_type_944 = None + mul_284 = torch.ops.aten.mul.Tensor(mul_68, mul_282) + sum_45 = torch.ops.aten.sum.dim_IntList(mul_284, [2], True); mul_284 = None + div_15 = torch.ops.aten.div.Tensor(mul_68, 4096) + mul_285 = torch.ops.aten.mul.Tensor(div_15, sum_45); div_15 = sum_45 = None + sub_24 = torch.ops.aten.sub.Tensor(mul_282, mul_285); mul_282 = mul_285 = None + mul_286 = torch.ops.aten.mul.Tensor(sub_24, rsqrt_17); sub_24 = rsqrt_17 = None + mul_287 = torch.ops.aten.mul.Tensor(convert_element_type_942, mul_68); convert_element_type_942 = mul_68 = None + sum_46 = torch.ops.aten.sum.dim_IntList(mul_287, [0, 1]); mul_287 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(mul_286, torch.bfloat16); mul_286 = None + convert_element_type_946 = torch.ops.prims.convert_element_type.default(sum_46, torch.bfloat16); sum_46 = None + all_reduce_15 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_946, 'sum', '1'); convert_element_type_946 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_15); all_reduce_15 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(wait_tensor_327, torch.float32); wait_tensor_327 = None + reduce_scatter_tensor_117 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_947, 'avg', 8, '0'); convert_element_type_947 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_117); reduce_scatter_tensor_117 = None + add_117 = torch.ops.aten.add.Tensor(add_113, convert_element_type_945); add_113 = convert_element_type_945 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_117, 8, '1') + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + split_105 = torch.ops.aten.split.Tensor(wait_tensor_329, 2); wait_tensor_329 = None + getitem_1005 = split_105[0] + getitem_1006 = split_105[1] + getitem_1007 = split_105[2] + getitem_1008 = split_105[3] + getitem_1009 = split_105[4] + getitem_1010 = split_105[5] + getitem_1011 = split_105[6] + getitem_1012 = split_105[7]; split_105 = None + cat_97 = torch.ops.aten.cat.default([getitem_1005, getitem_1006, getitem_1007, getitem_1008, getitem_1009, getitem_1010, getitem_1011, getitem_1012], 1); getitem_1005 = getitem_1006 = getitem_1007 = getitem_1008 = getitem_1009 = getitem_1010 = getitem_1011 = getitem_1012 = None + view_1345 = torch.ops.aten.view.default(cat_97, [16384, 4096]); cat_97 = None + permute_417 = torch.ops.aten.permute.default(view_1345, [1, 0]) + permute_94 = torch.ops.aten.permute.default(getitem_408, [0, 2, 1, 3]) + view_618 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + view_624 = torch.ops.aten.view.default(view_618, [16384, 512]); view_618 = None + mm_219 = torch.ops.aten.mm.default(permute_417, view_624); permute_417 = view_624 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16); primals_80 = None + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 8, '0'); convert_element_type_281 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + permute_419 = torch.ops.aten.permute.default(permute_95, [1, 0]); permute_95 = None + mm_220 = torch.ops.aten.mm.default(view_1345, permute_419); view_1345 = permute_419 = None + view_1346 = torch.ops.aten.view.default(mm_220, [2, 8192, 512]); mm_220 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(mm_219, torch.float32); mm_219 = None + reduce_scatter_tensor_118 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_952, 'avg', 8, '0'); convert_element_type_952 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_118); reduce_scatter_tensor_118 = None + view_1347 = torch.ops.aten.view.default(view_1346, [2, 8192, 4, 128]); view_1346 = None + permute_421 = torch.ops.aten.permute.default(view_1347, [0, 2, 1, 3]); view_1347 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16); primals_76 = None + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 8, '0'); convert_element_type_265 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32); add_31 = None + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_106) + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_267, 8, '1'); convert_element_type_267 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + split_41 = torch.ops.aten.split.Tensor(wait_tensor_107, 2); wait_tensor_107 = None + getitem_400 = split_41[0] + getitem_401 = split_41[1] + getitem_402 = split_41[2] + getitem_403 = split_41[3] + getitem_404 = split_41[4] + getitem_405 = split_41[5] + getitem_406 = split_41[6] + getitem_407 = split_41[7]; split_41 = None + cat_33 = torch.ops.aten.cat.default([getitem_400, getitem_401, getitem_402, getitem_403, getitem_404, getitem_405, getitem_406, getitem_407], 1); getitem_400 = getitem_401 = getitem_402 = getitem_403 = getitem_404 = getitem_405 = getitem_406 = getitem_407 = None + view_591 = torch.ops.aten.view.default(cat_33, [16384, 4096]); cat_33 = None + view_592 = torch.ops.aten.view.default(mm_56, [2, 8192, 512]); mm_56 = None + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16); primals_78 = None + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 8, '0'); convert_element_type_271 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_109, [1, 0]); wait_tensor_109 = None + mm_57 = torch.ops.aten.mm.default(view_591, permute_89) + view_599 = torch.ops.aten.view.default(mm_57, [2, 8192, 128]); mm_57 = None + view_606 = torch.ops.aten.view.default(mm_58, [2, 8192, 128]); mm_58 = None + view_608 = torch.ops.aten.view.default(view_592, [2, 8192, -1, 128]); view_592 = None + view_609 = torch.ops.aten.view.default(view_599, [2, 8192, -1, 128]); view_599 = None + view_610 = torch.ops.aten.view.default(view_606, [2, 8192, -1, 128]); view_606 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_608, torch.float32); view_608 = None + view_611 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 4, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_611); view_611 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_609, torch.float32); view_609 = None + view_612 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 1, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_612); view_612 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_37); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_614 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 4, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_37); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_615 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 1, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_614, torch.bfloat16); view_614 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_615, torch.bfloat16); view_615 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 1, 4, 128]); unsqueeze_16 = None + view_616 = torch.ops.aten.view.default(expand_16, [2, 8192, 4, 128]); expand_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_610, 3); view_610 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 1, 4, 128]); unsqueeze_17 = None + view_617 = torch.ops.aten.view.default(expand_17, [2, 8192, 4, 128]); expand_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_616, [0, 2, 1, 3]); view_616 = None + permute_93 = torch.ops.aten.permute.default(view_617, [0, 2, 1, 3]); view_617 = None + _scaled_dot_product_cudnn_attention_backward_7 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_421, permute_91, permute_92, permute_93, getitem_408, getitem_409, getitem_414, getitem_415, None, None, None, 8192, 8192, 0.0, True); permute_421 = permute_91 = permute_92 = permute_93 = getitem_408 = getitem_409 = getitem_414 = getitem_415 = None + getitem_1013 = _scaled_dot_product_cudnn_attention_backward_7[0] + getitem_1014 = _scaled_dot_product_cudnn_attention_backward_7[1] + getitem_1015 = _scaled_dot_product_cudnn_attention_backward_7[2]; _scaled_dot_product_cudnn_attention_backward_7 = None + permute_422 = torch.ops.aten.permute.default(getitem_1015, [0, 2, 1, 3]); getitem_1015 = None + permute_423 = torch.ops.aten.permute.default(getitem_1014, [0, 2, 1, 3]); getitem_1014 = None + permute_424 = torch.ops.aten.permute.default(getitem_1013, [0, 2, 1, 3]); getitem_1013 = None + view_1348 = torch.ops.aten.view.default(permute_422, [2, 8192, 1, 4, 128]); permute_422 = None + sum_47 = torch.ops.aten.sum.dim_IntList(view_1348, [3], True); view_1348 = None + squeeze_14 = torch.ops.aten.squeeze.dim(sum_47, 3); sum_47 = None + view_1349 = torch.ops.aten.view.default(permute_423, [2, 8192, 1, 4, 128]); permute_423 = None + sum_48 = torch.ops.aten.sum.dim_IntList(view_1349, [3], True); view_1349 = None + squeeze_15 = torch.ops.aten.squeeze.dim(sum_48, 3); sum_48 = None + convert_element_type_953 = torch.ops.prims.convert_element_type.default(squeeze_15, torch.float32); squeeze_15 = None + convert_element_type_954 = torch.ops.prims.convert_element_type.default(permute_424, torch.float32); permute_424 = None + view_1350 = torch.ops.aten.view.default(convert_element_type_953, [2, 8192, 1, 64, 2]); convert_element_type_953 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_1350); view_1350 = None + mul_288 = torch.ops.aten.mul.Tensor(view_as_complex_46, _conj); view_as_complex_46 = None + view_1351 = torch.ops.aten.view.default(convert_element_type_954, [2, 8192, 4, 64, 2]); convert_element_type_954 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_1351); view_1351 = None + mul_289 = torch.ops.aten.mul.Tensor(view_as_complex_47, _conj); view_as_complex_47 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_288); mul_288 = None + view_1352 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 1, 128]); view_as_real_46 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(view_1352, torch.bfloat16); view_1352 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_289); mul_289 = None + view_1353 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 4, 128]); view_as_real_47 = None + convert_element_type_956 = torch.ops.prims.convert_element_type.default(view_1353, torch.bfloat16); view_1353 = None + view_1354 = torch.ops.aten.view.default(squeeze_14, [2, 8192, 128]); squeeze_14 = None + view_1355 = torch.ops.aten.view.default(convert_element_type_955, [2, 8192, 128]); convert_element_type_955 = None + view_1356 = torch.ops.aten.view.default(convert_element_type_956, [2, 8192, 512]); convert_element_type_956 = None + view_1357 = torch.ops.aten.view.default(view_1354, [16384, 128]); view_1354 = None + permute_425 = torch.ops.aten.permute.default(view_1357, [1, 0]) + mm_221 = torch.ops.aten.mm.default(permute_425, view_591); permute_425 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16); primals_79 = None + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 8, '0'); convert_element_type_274 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + permute_427 = torch.ops.aten.permute.default(permute_90, [1, 0]); permute_90 = None + mm_222 = torch.ops.aten.mm.default(view_1357, permute_427); view_1357 = permute_427 = None + view_1358 = torch.ops.aten.view.default(mm_222, [2, 8192, 4096]); mm_222 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(mm_221, torch.float32); mm_221 = None + reduce_scatter_tensor_119 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_961, 'avg', 8, '0'); convert_element_type_961 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_119); reduce_scatter_tensor_119 = None + view_1359 = torch.ops.aten.view.default(view_1355, [16384, 128]); view_1355 = None + permute_429 = torch.ops.aten.permute.default(view_1359, [1, 0]) + mm_223 = torch.ops.aten.mm.default(permute_429, view_591); permute_429 = None + permute_431 = torch.ops.aten.permute.default(permute_89, [1, 0]); permute_89 = None + mm_224 = torch.ops.aten.mm.default(view_1359, permute_431); view_1359 = permute_431 = None + view_1360 = torch.ops.aten.view.default(mm_224, [2, 8192, 4096]); mm_224 = None + add_118 = torch.ops.aten.add.Tensor(view_1358, view_1360); view_1358 = view_1360 = None + convert_element_type_966 = torch.ops.prims.convert_element_type.default(mm_223, torch.float32); mm_223 = None + reduce_scatter_tensor_120 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_966, 'avg', 8, '0'); convert_element_type_966 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_120); reduce_scatter_tensor_120 = None + view_1361 = torch.ops.aten.view.default(view_1356, [16384, 512]); view_1356 = None + permute_433 = torch.ops.aten.permute.default(view_1361, [1, 0]) + mm_225 = torch.ops.aten.mm.default(permute_433, view_591); permute_433 = view_591 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16); primals_77 = None + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 8, '0'); convert_element_type_268 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + permute_435 = torch.ops.aten.permute.default(permute_88, [1, 0]); permute_88 = None + mm_226 = torch.ops.aten.mm.default(view_1361, permute_435); view_1361 = permute_435 = None + view_1362 = torch.ops.aten.view.default(mm_226, [2, 8192, 4096]); mm_226 = None + add_119 = torch.ops.aten.add.Tensor(add_118, view_1362); add_118 = view_1362 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(mm_225, torch.float32); mm_225 = None + reduce_scatter_tensor_121 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_971, 'avg', 8, '0'); convert_element_type_971 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_121); reduce_scatter_tensor_121 = None + split_106 = torch.ops.aten.split.Tensor(add_119, 1024, 1); add_119 = None + getitem_1016 = split_106[0] + getitem_1017 = split_106[1] + getitem_1018 = split_106[2] + getitem_1019 = split_106[3] + getitem_1020 = split_106[4] + getitem_1021 = split_106[5] + getitem_1022 = split_106[6] + getitem_1023 = split_106[7]; split_106 = None + cat_98 = torch.ops.aten.cat.default([getitem_1016, getitem_1017, getitem_1018, getitem_1019, getitem_1020, getitem_1021, getitem_1022, getitem_1023]); getitem_1016 = getitem_1017 = getitem_1018 = getitem_1019 = getitem_1020 = getitem_1021 = getitem_1022 = getitem_1023 = None + reduce_scatter_tensor_122 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_98, 'sum', 8, '1'); cat_98 = None + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_122); reduce_scatter_tensor_122 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(wait_tensor_334, torch.float32); wait_tensor_334 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(wait_tensor_106, torch.float32); wait_tensor_106 = None + mul_290 = torch.ops.aten.mul.Tensor(convert_element_type_972, convert_element_type_974); convert_element_type_974 = None + mul_292 = torch.ops.aten.mul.Tensor(mul_64, mul_290) + sum_49 = torch.ops.aten.sum.dim_IntList(mul_292, [2], True); mul_292 = None + div_16 = torch.ops.aten.div.Tensor(mul_64, 4096) + mul_293 = torch.ops.aten.mul.Tensor(div_16, sum_49); div_16 = sum_49 = None + sub_25 = torch.ops.aten.sub.Tensor(mul_290, mul_293); mul_290 = mul_293 = None + mul_294 = torch.ops.aten.mul.Tensor(sub_25, rsqrt_16); sub_25 = rsqrt_16 = None + mul_295 = torch.ops.aten.mul.Tensor(convert_element_type_972, mul_64); convert_element_type_972 = mul_64 = None + sum_50 = torch.ops.aten.sum.dim_IntList(mul_295, [0, 1]); mul_295 = None + convert_element_type_975 = torch.ops.prims.convert_element_type.default(mul_294, torch.bfloat16); mul_294 = None + convert_element_type_976 = torch.ops.prims.convert_element_type.default(sum_50, torch.bfloat16); sum_50 = None + all_reduce_16 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_976, 'sum', '1'); convert_element_type_976 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_16); all_reduce_16 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(wait_tensor_335, torch.float32); wait_tensor_335 = None + reduce_scatter_tensor_123 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_977, 'avg', 8, '0'); convert_element_type_977 = None + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_123); reduce_scatter_tensor_123 = None + add_120 = torch.ops.aten.add.Tensor(add_117, convert_element_type_975); add_117 = convert_element_type_975 = None + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_120, 8, '1') + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + split_107 = torch.ops.aten.split.Tensor(wait_tensor_337, 2); wait_tensor_337 = None + getitem_1024 = split_107[0] + getitem_1025 = split_107[1] + getitem_1026 = split_107[2] + getitem_1027 = split_107[3] + getitem_1028 = split_107[4] + getitem_1029 = split_107[5] + getitem_1030 = split_107[6] + getitem_1031 = split_107[7]; split_107 = None + cat_99 = torch.ops.aten.cat.default([getitem_1024, getitem_1025, getitem_1026, getitem_1027, getitem_1028, getitem_1029, getitem_1030, getitem_1031], 1); getitem_1024 = getitem_1025 = getitem_1026 = getitem_1027 = getitem_1028 = getitem_1029 = getitem_1030 = getitem_1031 = None + view_1363 = torch.ops.aten.view.default(cat_99, [16384, 4096]); cat_99 = None + permute_437 = torch.ops.aten.permute.default(view_1363, [1, 0]) + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15); reduce_scatter_tensor_15 = None + add_29 = torch.ops.aten.add.Tensor(add_27, wait_tensor_99); wait_tensor_99 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16); primals_72 = None + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 8, '0'); convert_element_type_251 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32); add_29 = None + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_100) + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_253, 8, '1'); convert_element_type_253 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + split_39 = torch.ops.aten.split.Tensor(wait_tensor_101, 2); wait_tensor_101 = None + getitem_384 = split_39[0] + getitem_385 = split_39[1] + getitem_386 = split_39[2] + getitem_387 = split_39[3] + getitem_388 = split_39[4] + getitem_389 = split_39[5] + getitem_390 = split_39[6] + getitem_391 = split_39[7]; split_39 = None + cat_31 = torch.ops.aten.cat.default([getitem_384, getitem_385, getitem_386, getitem_387, getitem_388, getitem_389, getitem_390, getitem_391], 1); getitem_384 = getitem_385 = getitem_386 = getitem_387 = getitem_388 = getitem_389 = getitem_390 = getitem_391 = None + view_564 = torch.ops.aten.view.default(cat_31, [16384, 4096]); cat_31 = None + view_565 = torch.ops.aten.view.default(mm_53, [2, 8192, 1792]); mm_53 = None + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_565, torch.float32); view_565 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16); primals_74 = None + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 8, '0'); convert_element_type_259 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + mm_54 = torch.ops.aten.mm.default(view_564, permute_86) + view_572 = torch.ops.aten.view.default(mm_54, [2, 8192, 1792]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_572) + view_579 = torch.ops.aten.view.default(mul_63, [16384, 1792]); mul_63 = None + mm_227 = torch.ops.aten.mm.default(permute_437, view_579); permute_437 = view_579 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16); primals_75 = None + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 8, '0'); convert_element_type_262 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + permute_439 = torch.ops.aten.permute.default(permute_87, [1, 0]); permute_87 = None + mm_228 = torch.ops.aten.mm.default(view_1363, permute_439); view_1363 = permute_439 = None + view_1364 = torch.ops.aten.view.default(mm_228, [2, 8192, 1792]); mm_228 = None + convert_element_type_982 = torch.ops.prims.convert_element_type.default(mm_227, torch.float32); mm_227 = None + reduce_scatter_tensor_124 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_982, 'avg', 8, '0'); convert_element_type_982 = None + wait_tensor_338 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_124); reduce_scatter_tensor_124 = None + mul_296 = torch.ops.aten.mul.Tensor(view_1364, convert_element_type_258); convert_element_type_258 = None + mul_297 = torch.ops.aten.mul.Tensor(view_1364, view_572); view_1364 = view_572 = None + view_1365 = torch.ops.aten.view.default(mul_296, [16384, 1792]); mul_296 = None + permute_441 = torch.ops.aten.permute.default(view_1365, [1, 0]) + mm_229 = torch.ops.aten.mm.default(permute_441, view_564); permute_441 = None + permute_443 = torch.ops.aten.permute.default(permute_86, [1, 0]); permute_86 = None + mm_230 = torch.ops.aten.mm.default(view_1365, permute_443); view_1365 = permute_443 = None + view_1366 = torch.ops.aten.view.default(mm_230, [2, 8192, 4096]); mm_230 = None + convert_element_type_987 = torch.ops.prims.convert_element_type.default(mm_229, torch.float32); mm_229 = None + reduce_scatter_tensor_125 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_987, 'avg', 8, '0'); convert_element_type_987 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_125); reduce_scatter_tensor_125 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(mul_297, torch.float32); mul_297 = None + neg_8 = torch.ops.aten.neg.default(convert_element_type_257) + exp_8 = torch.ops.aten.exp.default(neg_8); neg_8 = None + add_121 = torch.ops.aten.add.Tensor(exp_8, 1); exp_8 = None + reciprocal_8 = torch.ops.aten.reciprocal.default(add_121); add_121 = None + mul_298 = torch.ops.aten.mul.Tensor(reciprocal_8, 1); reciprocal_8 = None + mul_299 = torch.ops.aten.mul.Tensor(convert_element_type_988, mul_298); convert_element_type_988 = None + sub_26 = torch.ops.aten.sub.Tensor(1, mul_298); mul_298 = None + mul_300 = torch.ops.aten.mul.Tensor(convert_element_type_257, sub_26); convert_element_type_257 = sub_26 = None + add_122 = torch.ops.aten.add.Tensor(mul_300, 1); mul_300 = None + mul_301 = torch.ops.aten.mul.Tensor(mul_299, add_122); mul_299 = add_122 = None + convert_element_type_990 = torch.ops.prims.convert_element_type.default(mul_301, torch.bfloat16); mul_301 = None + view_1367 = torch.ops.aten.view.default(convert_element_type_990, [16384, 1792]); convert_element_type_990 = None + permute_445 = torch.ops.aten.permute.default(view_1367, [1, 0]) + mm_231 = torch.ops.aten.mm.default(permute_445, view_564); permute_445 = view_564 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16); primals_73 = None + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 8, '0'); convert_element_type_254 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + permute_447 = torch.ops.aten.permute.default(permute_85, [1, 0]); permute_85 = None + mm_232 = torch.ops.aten.mm.default(view_1367, permute_447); view_1367 = permute_447 = None + view_1368 = torch.ops.aten.view.default(mm_232, [2, 8192, 4096]); mm_232 = None + add_123 = torch.ops.aten.add.Tensor(view_1366, view_1368); view_1366 = view_1368 = None + convert_element_type_995 = torch.ops.prims.convert_element_type.default(mm_231, torch.float32); mm_231 = None + reduce_scatter_tensor_126 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_995, 'avg', 8, '0'); convert_element_type_995 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_126); reduce_scatter_tensor_126 = None + split_108 = torch.ops.aten.split.Tensor(add_123, 1024, 1); add_123 = None + getitem_1032 = split_108[0] + getitem_1033 = split_108[1] + getitem_1034 = split_108[2] + getitem_1035 = split_108[3] + getitem_1036 = split_108[4] + getitem_1037 = split_108[5] + getitem_1038 = split_108[6] + getitem_1039 = split_108[7]; split_108 = None + cat_100 = torch.ops.aten.cat.default([getitem_1032, getitem_1033, getitem_1034, getitem_1035, getitem_1036, getitem_1037, getitem_1038, getitem_1039]); getitem_1032 = getitem_1033 = getitem_1034 = getitem_1035 = getitem_1036 = getitem_1037 = getitem_1038 = getitem_1039 = None + reduce_scatter_tensor_127 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_100, 'sum', 8, '1'); cat_100 = None + wait_tensor_341 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_127); reduce_scatter_tensor_127 = None + convert_element_type_996 = torch.ops.prims.convert_element_type.default(wait_tensor_341, torch.float32); wait_tensor_341 = None + convert_element_type_998 = torch.ops.prims.convert_element_type.default(wait_tensor_100, torch.float32); wait_tensor_100 = None + mul_302 = torch.ops.aten.mul.Tensor(convert_element_type_996, convert_element_type_998); convert_element_type_998 = None + mul_304 = torch.ops.aten.mul.Tensor(mul_60, mul_302) + sum_51 = torch.ops.aten.sum.dim_IntList(mul_304, [2], True); mul_304 = None + div_17 = torch.ops.aten.div.Tensor(mul_60, 4096) + mul_305 = torch.ops.aten.mul.Tensor(div_17, sum_51); div_17 = sum_51 = None + sub_27 = torch.ops.aten.sub.Tensor(mul_302, mul_305); mul_302 = mul_305 = None + mul_306 = torch.ops.aten.mul.Tensor(sub_27, rsqrt_15); sub_27 = rsqrt_15 = None + mul_307 = torch.ops.aten.mul.Tensor(convert_element_type_996, mul_60); convert_element_type_996 = mul_60 = None + sum_52 = torch.ops.aten.sum.dim_IntList(mul_307, [0, 1]); mul_307 = None + convert_element_type_999 = torch.ops.prims.convert_element_type.default(mul_306, torch.bfloat16); mul_306 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(sum_52, torch.bfloat16); sum_52 = None + all_reduce_17 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1000, 'sum', '1'); convert_element_type_1000 = None + wait_tensor_342 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_17); all_reduce_17 = None + convert_element_type_1001 = torch.ops.prims.convert_element_type.default(wait_tensor_342, torch.float32); wait_tensor_342 = None + reduce_scatter_tensor_128 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1001, 'avg', 8, '0'); convert_element_type_1001 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_128); reduce_scatter_tensor_128 = None + add_124 = torch.ops.aten.add.Tensor(add_120, convert_element_type_999); add_120 = convert_element_type_999 = None + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_124, 8, '1') + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + split_109 = torch.ops.aten.split.Tensor(wait_tensor_344, 2); wait_tensor_344 = None + getitem_1040 = split_109[0] + getitem_1041 = split_109[1] + getitem_1042 = split_109[2] + getitem_1043 = split_109[3] + getitem_1044 = split_109[4] + getitem_1045 = split_109[5] + getitem_1046 = split_109[6] + getitem_1047 = split_109[7]; split_109 = None + cat_101 = torch.ops.aten.cat.default([getitem_1040, getitem_1041, getitem_1042, getitem_1043, getitem_1044, getitem_1045, getitem_1046, getitem_1047], 1); getitem_1040 = getitem_1041 = getitem_1042 = getitem_1043 = getitem_1044 = getitem_1045 = getitem_1046 = getitem_1047 = None + view_1369 = torch.ops.aten.view.default(cat_101, [16384, 4096]); cat_101 = None + permute_449 = torch.ops.aten.permute.default(view_1369, [1, 0]) + permute_83 = torch.ops.aten.permute.default(getitem_367, [0, 2, 1, 3]) + view_546 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + view_552 = torch.ops.aten.view.default(view_546, [16384, 512]); view_546 = None + mm_233 = torch.ops.aten.mm.default(permute_449, view_552); permute_449 = view_552 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16); primals_71 = None + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 8, '0'); convert_element_type_248 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + permute_451 = torch.ops.aten.permute.default(permute_84, [1, 0]); permute_84 = None + mm_234 = torch.ops.aten.mm.default(view_1369, permute_451); view_1369 = permute_451 = None + view_1370 = torch.ops.aten.view.default(mm_234, [2, 8192, 512]); mm_234 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(mm_233, torch.float32); mm_233 = None + reduce_scatter_tensor_129 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1006, 'avg', 8, '0'); convert_element_type_1006 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_129); reduce_scatter_tensor_129 = None + view_1371 = torch.ops.aten.view.default(view_1370, [2, 8192, 4, 128]); view_1370 = None + permute_453 = torch.ops.aten.permute.default(view_1371, [0, 2, 1, 3]); view_1371 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16); primals_67 = None + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 8, '0'); convert_element_type_232 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32); add_27 = None + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_93) + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_234, 8, '1'); convert_element_type_234 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + split_37 = torch.ops.aten.split.Tensor(wait_tensor_94, 2); wait_tensor_94 = None + getitem_359 = split_37[0] + getitem_360 = split_37[1] + getitem_361 = split_37[2] + getitem_362 = split_37[3] + getitem_363 = split_37[4] + getitem_364 = split_37[5] + getitem_365 = split_37[6] + getitem_366 = split_37[7]; split_37 = None + cat_29 = torch.ops.aten.cat.default([getitem_359, getitem_360, getitem_361, getitem_362, getitem_363, getitem_364, getitem_365, getitem_366], 1); getitem_359 = getitem_360 = getitem_361 = getitem_362 = getitem_363 = getitem_364 = getitem_365 = getitem_366 = None + view_519 = torch.ops.aten.view.default(cat_29, [16384, 4096]); cat_29 = None + view_520 = torch.ops.aten.view.default(mm_49, [2, 8192, 512]); mm_49 = None + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16); primals_69 = None + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 8, '0'); convert_element_type_238 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_96, [1, 0]); wait_tensor_96 = None + mm_50 = torch.ops.aten.mm.default(view_519, permute_78) + view_527 = torch.ops.aten.view.default(mm_50, [2, 8192, 128]); mm_50 = None + view_534 = torch.ops.aten.view.default(mm_51, [2, 8192, 128]); mm_51 = None + view_536 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + view_537 = torch.ops.aten.view.default(view_527, [2, 8192, -1, 128]); view_527 = None + view_538 = torch.ops.aten.view.default(view_534, [2, 8192, -1, 128]); view_534 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_536, torch.float32); view_536 = None + view_539 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 4, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_539); view_539 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_537, torch.float32); view_537 = None + view_540 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 1, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_540); view_540 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_37); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_542 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 4, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_37); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_543 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 1, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_542, torch.bfloat16); view_542 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_543, torch.bfloat16); view_543 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 1, 4, 128]); unsqueeze_14 = None + view_544 = torch.ops.aten.view.default(expand_14, [2, 8192, 4, 128]); expand_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_538, 3); view_538 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 1, 4, 128]); unsqueeze_15 = None + view_545 = torch.ops.aten.view.default(expand_15, [2, 8192, 4, 128]); expand_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_544, [0, 2, 1, 3]); view_544 = None + permute_82 = torch.ops.aten.permute.default(view_545, [0, 2, 1, 3]); view_545 = None + _scaled_dot_product_cudnn_attention_backward_8 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_453, permute_80, permute_81, permute_82, getitem_367, getitem_368, getitem_373, getitem_374, None, None, None, 8192, 8192, 0.0, True); permute_453 = permute_80 = permute_81 = permute_82 = getitem_367 = getitem_368 = getitem_373 = getitem_374 = None + getitem_1048 = _scaled_dot_product_cudnn_attention_backward_8[0] + getitem_1049 = _scaled_dot_product_cudnn_attention_backward_8[1] + getitem_1050 = _scaled_dot_product_cudnn_attention_backward_8[2]; _scaled_dot_product_cudnn_attention_backward_8 = None + permute_454 = torch.ops.aten.permute.default(getitem_1050, [0, 2, 1, 3]); getitem_1050 = None + permute_455 = torch.ops.aten.permute.default(getitem_1049, [0, 2, 1, 3]); getitem_1049 = None + permute_456 = torch.ops.aten.permute.default(getitem_1048, [0, 2, 1, 3]); getitem_1048 = None + view_1372 = torch.ops.aten.view.default(permute_454, [2, 8192, 1, 4, 128]); permute_454 = None + sum_53 = torch.ops.aten.sum.dim_IntList(view_1372, [3], True); view_1372 = None + squeeze_16 = torch.ops.aten.squeeze.dim(sum_53, 3); sum_53 = None + view_1373 = torch.ops.aten.view.default(permute_455, [2, 8192, 1, 4, 128]); permute_455 = None + sum_54 = torch.ops.aten.sum.dim_IntList(view_1373, [3], True); view_1373 = None + squeeze_17 = torch.ops.aten.squeeze.dim(sum_54, 3); sum_54 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(squeeze_17, torch.float32); squeeze_17 = None + convert_element_type_1008 = torch.ops.prims.convert_element_type.default(permute_456, torch.float32); permute_456 = None + view_1374 = torch.ops.aten.view.default(convert_element_type_1007, [2, 8192, 1, 64, 2]); convert_element_type_1007 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_1374); view_1374 = None + mul_308 = torch.ops.aten.mul.Tensor(view_as_complex_48, _conj); view_as_complex_48 = None + view_1375 = torch.ops.aten.view.default(convert_element_type_1008, [2, 8192, 4, 64, 2]); convert_element_type_1008 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_1375); view_1375 = None + mul_309 = torch.ops.aten.mul.Tensor(view_as_complex_49, _conj); view_as_complex_49 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_308); mul_308 = None + view_1376 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 1, 128]); view_as_real_48 = None + convert_element_type_1009 = torch.ops.prims.convert_element_type.default(view_1376, torch.bfloat16); view_1376 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_309); mul_309 = None + view_1377 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 4, 128]); view_as_real_49 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(view_1377, torch.bfloat16); view_1377 = None + view_1378 = torch.ops.aten.view.default(squeeze_16, [2, 8192, 128]); squeeze_16 = None + view_1379 = torch.ops.aten.view.default(convert_element_type_1009, [2, 8192, 128]); convert_element_type_1009 = None + view_1380 = torch.ops.aten.view.default(convert_element_type_1010, [2, 8192, 512]); convert_element_type_1010 = None + view_1381 = torch.ops.aten.view.default(view_1378, [16384, 128]); view_1378 = None + permute_457 = torch.ops.aten.permute.default(view_1381, [1, 0]) + mm_235 = torch.ops.aten.mm.default(permute_457, view_519); permute_457 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16); primals_70 = None + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 8, '0'); convert_element_type_241 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + permute_459 = torch.ops.aten.permute.default(permute_79, [1, 0]); permute_79 = None + mm_236 = torch.ops.aten.mm.default(view_1381, permute_459); view_1381 = permute_459 = None + view_1382 = torch.ops.aten.view.default(mm_236, [2, 8192, 4096]); mm_236 = None + convert_element_type_1015 = torch.ops.prims.convert_element_type.default(mm_235, torch.float32); mm_235 = None + reduce_scatter_tensor_130 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1015, 'avg', 8, '0'); convert_element_type_1015 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_130); reduce_scatter_tensor_130 = None + view_1383 = torch.ops.aten.view.default(view_1379, [16384, 128]); view_1379 = None + permute_461 = torch.ops.aten.permute.default(view_1383, [1, 0]) + mm_237 = torch.ops.aten.mm.default(permute_461, view_519); permute_461 = None + permute_463 = torch.ops.aten.permute.default(permute_78, [1, 0]); permute_78 = None + mm_238 = torch.ops.aten.mm.default(view_1383, permute_463); view_1383 = permute_463 = None + view_1384 = torch.ops.aten.view.default(mm_238, [2, 8192, 4096]); mm_238 = None + add_125 = torch.ops.aten.add.Tensor(view_1382, view_1384); view_1382 = view_1384 = None + convert_element_type_1020 = torch.ops.prims.convert_element_type.default(mm_237, torch.float32); mm_237 = None + reduce_scatter_tensor_131 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1020, 'avg', 8, '0'); convert_element_type_1020 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_131); reduce_scatter_tensor_131 = None + view_1385 = torch.ops.aten.view.default(view_1380, [16384, 512]); view_1380 = None + permute_465 = torch.ops.aten.permute.default(view_1385, [1, 0]) + mm_239 = torch.ops.aten.mm.default(permute_465, view_519); permute_465 = view_519 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16); primals_68 = None + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 8, '0'); convert_element_type_235 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + permute_467 = torch.ops.aten.permute.default(permute_77, [1, 0]); permute_77 = None + mm_240 = torch.ops.aten.mm.default(view_1385, permute_467); view_1385 = permute_467 = None + view_1386 = torch.ops.aten.view.default(mm_240, [2, 8192, 4096]); mm_240 = None + add_126 = torch.ops.aten.add.Tensor(add_125, view_1386); add_125 = view_1386 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(mm_239, torch.float32); mm_239 = None + reduce_scatter_tensor_132 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1025, 'avg', 8, '0'); convert_element_type_1025 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_132); reduce_scatter_tensor_132 = None + split_110 = torch.ops.aten.split.Tensor(add_126, 1024, 1); add_126 = None + getitem_1051 = split_110[0] + getitem_1052 = split_110[1] + getitem_1053 = split_110[2] + getitem_1054 = split_110[3] + getitem_1055 = split_110[4] + getitem_1056 = split_110[5] + getitem_1057 = split_110[6] + getitem_1058 = split_110[7]; split_110 = None + cat_102 = torch.ops.aten.cat.default([getitem_1051, getitem_1052, getitem_1053, getitem_1054, getitem_1055, getitem_1056, getitem_1057, getitem_1058]); getitem_1051 = getitem_1052 = getitem_1053 = getitem_1054 = getitem_1055 = getitem_1056 = getitem_1057 = getitem_1058 = None + reduce_scatter_tensor_133 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_102, 'sum', 8, '1'); cat_102 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_133); reduce_scatter_tensor_133 = None + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(wait_tensor_349, torch.float32); wait_tensor_349 = None + convert_element_type_1028 = torch.ops.prims.convert_element_type.default(wait_tensor_93, torch.float32); wait_tensor_93 = None + mul_310 = torch.ops.aten.mul.Tensor(convert_element_type_1026, convert_element_type_1028); convert_element_type_1028 = None + mul_312 = torch.ops.aten.mul.Tensor(mul_56, mul_310) + sum_55 = torch.ops.aten.sum.dim_IntList(mul_312, [2], True); mul_312 = None + div_18 = torch.ops.aten.div.Tensor(mul_56, 4096) + mul_313 = torch.ops.aten.mul.Tensor(div_18, sum_55); div_18 = sum_55 = None + sub_28 = torch.ops.aten.sub.Tensor(mul_310, mul_313); mul_310 = mul_313 = None + mul_314 = torch.ops.aten.mul.Tensor(sub_28, rsqrt_14); sub_28 = rsqrt_14 = None + mul_315 = torch.ops.aten.mul.Tensor(convert_element_type_1026, mul_56); convert_element_type_1026 = mul_56 = None + sum_56 = torch.ops.aten.sum.dim_IntList(mul_315, [0, 1]); mul_315 = None + convert_element_type_1029 = torch.ops.prims.convert_element_type.default(mul_314, torch.bfloat16); mul_314 = None + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(sum_56, torch.bfloat16); sum_56 = None + all_reduce_18 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1030, 'sum', '1'); convert_element_type_1030 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_18); all_reduce_18 = None + convert_element_type_1031 = torch.ops.prims.convert_element_type.default(wait_tensor_350, torch.float32); wait_tensor_350 = None + reduce_scatter_tensor_134 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1031, 'avg', 8, '0'); convert_element_type_1031 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_134); reduce_scatter_tensor_134 = None + add_127 = torch.ops.aten.add.Tensor(add_124, convert_element_type_1029); add_124 = convert_element_type_1029 = None + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_127, 8, '1') + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + split_111 = torch.ops.aten.split.Tensor(wait_tensor_352, 2); wait_tensor_352 = None + getitem_1059 = split_111[0] + getitem_1060 = split_111[1] + getitem_1061 = split_111[2] + getitem_1062 = split_111[3] + getitem_1063 = split_111[4] + getitem_1064 = split_111[5] + getitem_1065 = split_111[6] + getitem_1066 = split_111[7]; split_111 = None + cat_103 = torch.ops.aten.cat.default([getitem_1059, getitem_1060, getitem_1061, getitem_1062, getitem_1063, getitem_1064, getitem_1065, getitem_1066], 1); getitem_1059 = getitem_1060 = getitem_1061 = getitem_1062 = getitem_1063 = getitem_1064 = getitem_1065 = getitem_1066 = None + view_1387 = torch.ops.aten.view.default(cat_103, [16384, 4096]); cat_103 = None + permute_469 = torch.ops.aten.permute.default(view_1387, [1, 0]) + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13); reduce_scatter_tensor_13 = None + add_25 = torch.ops.aten.add.Tensor(add_23, wait_tensor_86); wait_tensor_86 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16); primals_63 = None + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 8, '0'); convert_element_type_218 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32); add_25 = None + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_87) + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_220, 8, '1'); convert_element_type_220 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + split_35 = torch.ops.aten.split.Tensor(wait_tensor_88, 2); wait_tensor_88 = None + getitem_343 = split_35[0] + getitem_344 = split_35[1] + getitem_345 = split_35[2] + getitem_346 = split_35[3] + getitem_347 = split_35[4] + getitem_348 = split_35[5] + getitem_349 = split_35[6] + getitem_350 = split_35[7]; split_35 = None + cat_27 = torch.ops.aten.cat.default([getitem_343, getitem_344, getitem_345, getitem_346, getitem_347, getitem_348, getitem_349, getitem_350], 1); getitem_343 = getitem_344 = getitem_345 = getitem_346 = getitem_347 = getitem_348 = getitem_349 = getitem_350 = None + view_492 = torch.ops.aten.view.default(cat_27, [16384, 4096]); cat_27 = None + view_493 = torch.ops.aten.view.default(mm_46, [2, 8192, 1792]); mm_46 = None + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_493, torch.float32); view_493 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16); primals_65 = None + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 8, '0'); convert_element_type_226 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + mm_47 = torch.ops.aten.mm.default(view_492, permute_75) + view_500 = torch.ops.aten.view.default(mm_47, [2, 8192, 1792]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_500) + view_507 = torch.ops.aten.view.default(mul_55, [16384, 1792]); mul_55 = None + mm_241 = torch.ops.aten.mm.default(permute_469, view_507); permute_469 = view_507 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16); primals_66 = None + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 8, '0'); convert_element_type_229 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_91, [1, 0]); wait_tensor_91 = None + permute_471 = torch.ops.aten.permute.default(permute_76, [1, 0]); permute_76 = None + mm_242 = torch.ops.aten.mm.default(view_1387, permute_471); view_1387 = permute_471 = None + view_1388 = torch.ops.aten.view.default(mm_242, [2, 8192, 1792]); mm_242 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(mm_241, torch.float32); mm_241 = None + reduce_scatter_tensor_135 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1036, 'avg', 8, '0'); convert_element_type_1036 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_135); reduce_scatter_tensor_135 = None + mul_316 = torch.ops.aten.mul.Tensor(view_1388, convert_element_type_225); convert_element_type_225 = None + mul_317 = torch.ops.aten.mul.Tensor(view_1388, view_500); view_1388 = view_500 = None + view_1389 = torch.ops.aten.view.default(mul_316, [16384, 1792]); mul_316 = None + permute_473 = torch.ops.aten.permute.default(view_1389, [1, 0]) + mm_243 = torch.ops.aten.mm.default(permute_473, view_492); permute_473 = None + permute_475 = torch.ops.aten.permute.default(permute_75, [1, 0]); permute_75 = None + mm_244 = torch.ops.aten.mm.default(view_1389, permute_475); view_1389 = permute_475 = None + view_1390 = torch.ops.aten.view.default(mm_244, [2, 8192, 4096]); mm_244 = None + convert_element_type_1041 = torch.ops.prims.convert_element_type.default(mm_243, torch.float32); mm_243 = None + reduce_scatter_tensor_136 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1041, 'avg', 8, '0'); convert_element_type_1041 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_136); reduce_scatter_tensor_136 = None + convert_element_type_1042 = torch.ops.prims.convert_element_type.default(mul_317, torch.float32); mul_317 = None + neg_9 = torch.ops.aten.neg.default(convert_element_type_224) + exp_9 = torch.ops.aten.exp.default(neg_9); neg_9 = None + add_128 = torch.ops.aten.add.Tensor(exp_9, 1); exp_9 = None + reciprocal_9 = torch.ops.aten.reciprocal.default(add_128); add_128 = None + mul_318 = torch.ops.aten.mul.Tensor(reciprocal_9, 1); reciprocal_9 = None + mul_319 = torch.ops.aten.mul.Tensor(convert_element_type_1042, mul_318); convert_element_type_1042 = None + sub_29 = torch.ops.aten.sub.Tensor(1, mul_318); mul_318 = None + mul_320 = torch.ops.aten.mul.Tensor(convert_element_type_224, sub_29); convert_element_type_224 = sub_29 = None + add_129 = torch.ops.aten.add.Tensor(mul_320, 1); mul_320 = None + mul_321 = torch.ops.aten.mul.Tensor(mul_319, add_129); mul_319 = add_129 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(mul_321, torch.bfloat16); mul_321 = None + view_1391 = torch.ops.aten.view.default(convert_element_type_1044, [16384, 1792]); convert_element_type_1044 = None + permute_477 = torch.ops.aten.permute.default(view_1391, [1, 0]) + mm_245 = torch.ops.aten.mm.default(permute_477, view_492); permute_477 = view_492 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16); primals_64 = None + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 8, '0'); convert_element_type_221 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + permute_479 = torch.ops.aten.permute.default(permute_74, [1, 0]); permute_74 = None + mm_246 = torch.ops.aten.mm.default(view_1391, permute_479); view_1391 = permute_479 = None + view_1392 = torch.ops.aten.view.default(mm_246, [2, 8192, 4096]); mm_246 = None + add_130 = torch.ops.aten.add.Tensor(view_1390, view_1392); view_1390 = view_1392 = None + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(mm_245, torch.float32); mm_245 = None + reduce_scatter_tensor_137 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1049, 'avg', 8, '0'); convert_element_type_1049 = None + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_137); reduce_scatter_tensor_137 = None + split_112 = torch.ops.aten.split.Tensor(add_130, 1024, 1); add_130 = None + getitem_1067 = split_112[0] + getitem_1068 = split_112[1] + getitem_1069 = split_112[2] + getitem_1070 = split_112[3] + getitem_1071 = split_112[4] + getitem_1072 = split_112[5] + getitem_1073 = split_112[6] + getitem_1074 = split_112[7]; split_112 = None + cat_104 = torch.ops.aten.cat.default([getitem_1067, getitem_1068, getitem_1069, getitem_1070, getitem_1071, getitem_1072, getitem_1073, getitem_1074]); getitem_1067 = getitem_1068 = getitem_1069 = getitem_1070 = getitem_1071 = getitem_1072 = getitem_1073 = getitem_1074 = None + reduce_scatter_tensor_138 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_104, 'sum', 8, '1'); cat_104 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_138); reduce_scatter_tensor_138 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(wait_tensor_356, torch.float32); wait_tensor_356 = None + convert_element_type_1052 = torch.ops.prims.convert_element_type.default(wait_tensor_87, torch.float32); wait_tensor_87 = None + mul_322 = torch.ops.aten.mul.Tensor(convert_element_type_1050, convert_element_type_1052); convert_element_type_1052 = None + mul_324 = torch.ops.aten.mul.Tensor(mul_52, mul_322) + sum_57 = torch.ops.aten.sum.dim_IntList(mul_324, [2], True); mul_324 = None + div_19 = torch.ops.aten.div.Tensor(mul_52, 4096) + mul_325 = torch.ops.aten.mul.Tensor(div_19, sum_57); div_19 = sum_57 = None + sub_30 = torch.ops.aten.sub.Tensor(mul_322, mul_325); mul_322 = mul_325 = None + mul_326 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_13); sub_30 = rsqrt_13 = None + mul_327 = torch.ops.aten.mul.Tensor(convert_element_type_1050, mul_52); convert_element_type_1050 = mul_52 = None + sum_58 = torch.ops.aten.sum.dim_IntList(mul_327, [0, 1]); mul_327 = None + convert_element_type_1053 = torch.ops.prims.convert_element_type.default(mul_326, torch.bfloat16); mul_326 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(sum_58, torch.bfloat16); sum_58 = None + all_reduce_19 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1054, 'sum', '1'); convert_element_type_1054 = None + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_19); all_reduce_19 = None + convert_element_type_1055 = torch.ops.prims.convert_element_type.default(wait_tensor_357, torch.float32); wait_tensor_357 = None + reduce_scatter_tensor_139 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1055, 'avg', 8, '0'); convert_element_type_1055 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_139); reduce_scatter_tensor_139 = None + add_131 = torch.ops.aten.add.Tensor(add_127, convert_element_type_1053); add_127 = convert_element_type_1053 = None + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_131, 8, '1') + wait_tensor_359 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + split_113 = torch.ops.aten.split.Tensor(wait_tensor_359, 2); wait_tensor_359 = None + getitem_1075 = split_113[0] + getitem_1076 = split_113[1] + getitem_1077 = split_113[2] + getitem_1078 = split_113[3] + getitem_1079 = split_113[4] + getitem_1080 = split_113[5] + getitem_1081 = split_113[6] + getitem_1082 = split_113[7]; split_113 = None + cat_105 = torch.ops.aten.cat.default([getitem_1075, getitem_1076, getitem_1077, getitem_1078, getitem_1079, getitem_1080, getitem_1081, getitem_1082], 1); getitem_1075 = getitem_1076 = getitem_1077 = getitem_1078 = getitem_1079 = getitem_1080 = getitem_1081 = getitem_1082 = None + view_1393 = torch.ops.aten.view.default(cat_105, [16384, 4096]); cat_105 = None + permute_481 = torch.ops.aten.permute.default(view_1393, [1, 0]) + permute_72 = torch.ops.aten.permute.default(getitem_326, [0, 2, 1, 3]) + view_474 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + view_480 = torch.ops.aten.view.default(view_474, [16384, 512]); view_474 = None + mm_247 = torch.ops.aten.mm.default(permute_481, view_480); permute_481 = view_480 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16); primals_62 = None + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 8, '0'); convert_element_type_215 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + permute_483 = torch.ops.aten.permute.default(permute_73, [1, 0]); permute_73 = None + mm_248 = torch.ops.aten.mm.default(view_1393, permute_483); view_1393 = permute_483 = None + view_1394 = torch.ops.aten.view.default(mm_248, [2, 8192, 512]); mm_248 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(mm_247, torch.float32); mm_247 = None + reduce_scatter_tensor_140 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1060, 'avg', 8, '0'); convert_element_type_1060 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_140); reduce_scatter_tensor_140 = None + view_1395 = torch.ops.aten.view.default(view_1394, [2, 8192, 4, 128]); view_1394 = None + permute_485 = torch.ops.aten.permute.default(view_1395, [0, 2, 1, 3]); view_1395 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16); primals_58 = None + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 8, '0'); convert_element_type_199 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32); add_23 = None + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_80) + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_201, 8, '1'); convert_element_type_201 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + split_33 = torch.ops.aten.split.Tensor(wait_tensor_81, 2); wait_tensor_81 = None + getitem_318 = split_33[0] + getitem_319 = split_33[1] + getitem_320 = split_33[2] + getitem_321 = split_33[3] + getitem_322 = split_33[4] + getitem_323 = split_33[5] + getitem_324 = split_33[6] + getitem_325 = split_33[7]; split_33 = None + cat_25 = torch.ops.aten.cat.default([getitem_318, getitem_319, getitem_320, getitem_321, getitem_322, getitem_323, getitem_324, getitem_325], 1); getitem_318 = getitem_319 = getitem_320 = getitem_321 = getitem_322 = getitem_323 = getitem_324 = getitem_325 = None + view_447 = torch.ops.aten.view.default(cat_25, [16384, 4096]); cat_25 = None + view_448 = torch.ops.aten.view.default(mm_42, [2, 8192, 512]); mm_42 = None + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16); primals_60 = None + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 8, '0'); convert_element_type_205 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + mm_43 = torch.ops.aten.mm.default(view_447, permute_67) + view_455 = torch.ops.aten.view.default(mm_43, [2, 8192, 128]); mm_43 = None + view_462 = torch.ops.aten.view.default(mm_44, [2, 8192, 128]); mm_44 = None + view_464 = torch.ops.aten.view.default(view_448, [2, 8192, -1, 128]); view_448 = None + view_465 = torch.ops.aten.view.default(view_455, [2, 8192, -1, 128]); view_455 = None + view_466 = torch.ops.aten.view.default(view_462, [2, 8192, -1, 128]); view_462 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_464, torch.float32); view_464 = None + view_467 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 4, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_467); view_467 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_465, torch.float32); view_465 = None + view_468 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 1, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_468); view_468 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_37); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_470 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 4, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_37); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_471 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 1, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_470, torch.bfloat16); view_470 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_471, torch.bfloat16); view_471 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 1, 4, 128]); unsqueeze_12 = None + view_472 = torch.ops.aten.view.default(expand_12, [2, 8192, 4, 128]); expand_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_466, 3); view_466 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 1, 4, 128]); unsqueeze_13 = None + view_473 = torch.ops.aten.view.default(expand_13, [2, 8192, 4, 128]); expand_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_472, [0, 2, 1, 3]); view_472 = None + permute_71 = torch.ops.aten.permute.default(view_473, [0, 2, 1, 3]); view_473 = None + _scaled_dot_product_cudnn_attention_backward_9 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_485, permute_69, permute_70, permute_71, getitem_326, getitem_327, getitem_332, getitem_333, None, None, None, 8192, 8192, 0.0, True); permute_485 = permute_69 = permute_70 = permute_71 = getitem_326 = getitem_327 = getitem_332 = getitem_333 = None + getitem_1083 = _scaled_dot_product_cudnn_attention_backward_9[0] + getitem_1084 = _scaled_dot_product_cudnn_attention_backward_9[1] + getitem_1085 = _scaled_dot_product_cudnn_attention_backward_9[2]; _scaled_dot_product_cudnn_attention_backward_9 = None + permute_486 = torch.ops.aten.permute.default(getitem_1085, [0, 2, 1, 3]); getitem_1085 = None + permute_487 = torch.ops.aten.permute.default(getitem_1084, [0, 2, 1, 3]); getitem_1084 = None + permute_488 = torch.ops.aten.permute.default(getitem_1083, [0, 2, 1, 3]); getitem_1083 = None + view_1396 = torch.ops.aten.view.default(permute_486, [2, 8192, 1, 4, 128]); permute_486 = None + sum_59 = torch.ops.aten.sum.dim_IntList(view_1396, [3], True); view_1396 = None + squeeze_18 = torch.ops.aten.squeeze.dim(sum_59, 3); sum_59 = None + view_1397 = torch.ops.aten.view.default(permute_487, [2, 8192, 1, 4, 128]); permute_487 = None + sum_60 = torch.ops.aten.sum.dim_IntList(view_1397, [3], True); view_1397 = None + squeeze_19 = torch.ops.aten.squeeze.dim(sum_60, 3); sum_60 = None + convert_element_type_1061 = torch.ops.prims.convert_element_type.default(squeeze_19, torch.float32); squeeze_19 = None + convert_element_type_1062 = torch.ops.prims.convert_element_type.default(permute_488, torch.float32); permute_488 = None + view_1398 = torch.ops.aten.view.default(convert_element_type_1061, [2, 8192, 1, 64, 2]); convert_element_type_1061 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_1398); view_1398 = None + mul_328 = torch.ops.aten.mul.Tensor(view_as_complex_50, _conj); view_as_complex_50 = None + view_1399 = torch.ops.aten.view.default(convert_element_type_1062, [2, 8192, 4, 64, 2]); convert_element_type_1062 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_1399); view_1399 = None + mul_329 = torch.ops.aten.mul.Tensor(view_as_complex_51, _conj); view_as_complex_51 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_328); mul_328 = None + view_1400 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 1, 128]); view_as_real_50 = None + convert_element_type_1063 = torch.ops.prims.convert_element_type.default(view_1400, torch.bfloat16); view_1400 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_329); mul_329 = None + view_1401 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 4, 128]); view_as_real_51 = None + convert_element_type_1064 = torch.ops.prims.convert_element_type.default(view_1401, torch.bfloat16); view_1401 = None + view_1402 = torch.ops.aten.view.default(squeeze_18, [2, 8192, 128]); squeeze_18 = None + view_1403 = torch.ops.aten.view.default(convert_element_type_1063, [2, 8192, 128]); convert_element_type_1063 = None + view_1404 = torch.ops.aten.view.default(convert_element_type_1064, [2, 8192, 512]); convert_element_type_1064 = None + view_1405 = torch.ops.aten.view.default(view_1402, [16384, 128]); view_1402 = None + permute_489 = torch.ops.aten.permute.default(view_1405, [1, 0]) + mm_249 = torch.ops.aten.mm.default(permute_489, view_447); permute_489 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16); primals_61 = None + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 8, '0'); convert_element_type_208 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + permute_491 = torch.ops.aten.permute.default(permute_68, [1, 0]); permute_68 = None + mm_250 = torch.ops.aten.mm.default(view_1405, permute_491); view_1405 = permute_491 = None + view_1406 = torch.ops.aten.view.default(mm_250, [2, 8192, 4096]); mm_250 = None + convert_element_type_1069 = torch.ops.prims.convert_element_type.default(mm_249, torch.float32); mm_249 = None + reduce_scatter_tensor_141 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1069, 'avg', 8, '0'); convert_element_type_1069 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_141); reduce_scatter_tensor_141 = None + view_1407 = torch.ops.aten.view.default(view_1403, [16384, 128]); view_1403 = None + permute_493 = torch.ops.aten.permute.default(view_1407, [1, 0]) + mm_251 = torch.ops.aten.mm.default(permute_493, view_447); permute_493 = None + permute_495 = torch.ops.aten.permute.default(permute_67, [1, 0]); permute_67 = None + mm_252 = torch.ops.aten.mm.default(view_1407, permute_495); view_1407 = permute_495 = None + view_1408 = torch.ops.aten.view.default(mm_252, [2, 8192, 4096]); mm_252 = None + add_132 = torch.ops.aten.add.Tensor(view_1406, view_1408); view_1406 = view_1408 = None + convert_element_type_1074 = torch.ops.prims.convert_element_type.default(mm_251, torch.float32); mm_251 = None + reduce_scatter_tensor_142 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1074, 'avg', 8, '0'); convert_element_type_1074 = None + wait_tensor_362 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_142); reduce_scatter_tensor_142 = None + view_1409 = torch.ops.aten.view.default(view_1404, [16384, 512]); view_1404 = None + permute_497 = torch.ops.aten.permute.default(view_1409, [1, 0]) + mm_253 = torch.ops.aten.mm.default(permute_497, view_447); permute_497 = view_447 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16); primals_59 = None + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 8, '0'); convert_element_type_202 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_82, [1, 0]); wait_tensor_82 = None + permute_499 = torch.ops.aten.permute.default(permute_66, [1, 0]); permute_66 = None + mm_254 = torch.ops.aten.mm.default(view_1409, permute_499); view_1409 = permute_499 = None + view_1410 = torch.ops.aten.view.default(mm_254, [2, 8192, 4096]); mm_254 = None + add_133 = torch.ops.aten.add.Tensor(add_132, view_1410); add_132 = view_1410 = None + convert_element_type_1079 = torch.ops.prims.convert_element_type.default(mm_253, torch.float32); mm_253 = None + reduce_scatter_tensor_143 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1079, 'avg', 8, '0'); convert_element_type_1079 = None + wait_tensor_363 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_143); reduce_scatter_tensor_143 = None + split_114 = torch.ops.aten.split.Tensor(add_133, 1024, 1); add_133 = None + getitem_1086 = split_114[0] + getitem_1087 = split_114[1] + getitem_1088 = split_114[2] + getitem_1089 = split_114[3] + getitem_1090 = split_114[4] + getitem_1091 = split_114[5] + getitem_1092 = split_114[6] + getitem_1093 = split_114[7]; split_114 = None + cat_106 = torch.ops.aten.cat.default([getitem_1086, getitem_1087, getitem_1088, getitem_1089, getitem_1090, getitem_1091, getitem_1092, getitem_1093]); getitem_1086 = getitem_1087 = getitem_1088 = getitem_1089 = getitem_1090 = getitem_1091 = getitem_1092 = getitem_1093 = None + reduce_scatter_tensor_144 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_106, 'sum', 8, '1'); cat_106 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_144); reduce_scatter_tensor_144 = None + convert_element_type_1080 = torch.ops.prims.convert_element_type.default(wait_tensor_364, torch.float32); wait_tensor_364 = None + convert_element_type_1082 = torch.ops.prims.convert_element_type.default(wait_tensor_80, torch.float32); wait_tensor_80 = None + mul_330 = torch.ops.aten.mul.Tensor(convert_element_type_1080, convert_element_type_1082); convert_element_type_1082 = None + mul_332 = torch.ops.aten.mul.Tensor(mul_48, mul_330) + sum_61 = torch.ops.aten.sum.dim_IntList(mul_332, [2], True); mul_332 = None + div_20 = torch.ops.aten.div.Tensor(mul_48, 4096) + mul_333 = torch.ops.aten.mul.Tensor(div_20, sum_61); div_20 = sum_61 = None + sub_31 = torch.ops.aten.sub.Tensor(mul_330, mul_333); mul_330 = mul_333 = None + mul_334 = torch.ops.aten.mul.Tensor(sub_31, rsqrt_12); sub_31 = rsqrt_12 = None + mul_335 = torch.ops.aten.mul.Tensor(convert_element_type_1080, mul_48); convert_element_type_1080 = mul_48 = None + sum_62 = torch.ops.aten.sum.dim_IntList(mul_335, [0, 1]); mul_335 = None + convert_element_type_1083 = torch.ops.prims.convert_element_type.default(mul_334, torch.bfloat16); mul_334 = None + convert_element_type_1084 = torch.ops.prims.convert_element_type.default(sum_62, torch.bfloat16); sum_62 = None + all_reduce_20 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1084, 'sum', '1'); convert_element_type_1084 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_20); all_reduce_20 = None + convert_element_type_1085 = torch.ops.prims.convert_element_type.default(wait_tensor_365, torch.float32); wait_tensor_365 = None + reduce_scatter_tensor_145 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1085, 'avg', 8, '0'); convert_element_type_1085 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_145); reduce_scatter_tensor_145 = None + add_134 = torch.ops.aten.add.Tensor(add_131, convert_element_type_1083); add_131 = convert_element_type_1083 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_134, 8, '1') + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + split_115 = torch.ops.aten.split.Tensor(wait_tensor_367, 2); wait_tensor_367 = None + getitem_1094 = split_115[0] + getitem_1095 = split_115[1] + getitem_1096 = split_115[2] + getitem_1097 = split_115[3] + getitem_1098 = split_115[4] + getitem_1099 = split_115[5] + getitem_1100 = split_115[6] + getitem_1101 = split_115[7]; split_115 = None + cat_107 = torch.ops.aten.cat.default([getitem_1094, getitem_1095, getitem_1096, getitem_1097, getitem_1098, getitem_1099, getitem_1100, getitem_1101], 1); getitem_1094 = getitem_1095 = getitem_1096 = getitem_1097 = getitem_1098 = getitem_1099 = getitem_1100 = getitem_1101 = None + view_1411 = torch.ops.aten.view.default(cat_107, [16384, 4096]); cat_107 = None + permute_501 = torch.ops.aten.permute.default(view_1411, [1, 0]) + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11); reduce_scatter_tensor_11 = None + add_21 = torch.ops.aten.add.Tensor(add_19, wait_tensor_73); wait_tensor_73 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16); primals_54 = None + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 8, '0'); convert_element_type_185 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32); add_21 = None + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_74) + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_187, 8, '1'); convert_element_type_187 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + split_31 = torch.ops.aten.split.Tensor(wait_tensor_75, 2); wait_tensor_75 = None + getitem_302 = split_31[0] + getitem_303 = split_31[1] + getitem_304 = split_31[2] + getitem_305 = split_31[3] + getitem_306 = split_31[4] + getitem_307 = split_31[5] + getitem_308 = split_31[6] + getitem_309 = split_31[7]; split_31 = None + cat_23 = torch.ops.aten.cat.default([getitem_302, getitem_303, getitem_304, getitem_305, getitem_306, getitem_307, getitem_308, getitem_309], 1); getitem_302 = getitem_303 = getitem_304 = getitem_305 = getitem_306 = getitem_307 = getitem_308 = getitem_309 = None + view_420 = torch.ops.aten.view.default(cat_23, [16384, 4096]); cat_23 = None + view_421 = torch.ops.aten.view.default(mm_39, [2, 8192, 1792]); mm_39 = None + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_421, torch.float32); view_421 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16); primals_56 = None + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 8, '0'); convert_element_type_193 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + mm_40 = torch.ops.aten.mm.default(view_420, permute_64) + view_428 = torch.ops.aten.view.default(mm_40, [2, 8192, 1792]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_428) + view_435 = torch.ops.aten.view.default(mul_47, [16384, 1792]); mul_47 = None + mm_255 = torch.ops.aten.mm.default(permute_501, view_435); permute_501 = view_435 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16); primals_57 = None + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 8, '0'); convert_element_type_196 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_78, [1, 0]); wait_tensor_78 = None + permute_503 = torch.ops.aten.permute.default(permute_65, [1, 0]); permute_65 = None + mm_256 = torch.ops.aten.mm.default(view_1411, permute_503); view_1411 = permute_503 = None + view_1412 = torch.ops.aten.view.default(mm_256, [2, 8192, 1792]); mm_256 = None + convert_element_type_1090 = torch.ops.prims.convert_element_type.default(mm_255, torch.float32); mm_255 = None + reduce_scatter_tensor_146 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1090, 'avg', 8, '0'); convert_element_type_1090 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_146); reduce_scatter_tensor_146 = None + mul_336 = torch.ops.aten.mul.Tensor(view_1412, convert_element_type_192); convert_element_type_192 = None + mul_337 = torch.ops.aten.mul.Tensor(view_1412, view_428); view_1412 = view_428 = None + view_1413 = torch.ops.aten.view.default(mul_336, [16384, 1792]); mul_336 = None + permute_505 = torch.ops.aten.permute.default(view_1413, [1, 0]) + mm_257 = torch.ops.aten.mm.default(permute_505, view_420); permute_505 = None + permute_507 = torch.ops.aten.permute.default(permute_64, [1, 0]); permute_64 = None + mm_258 = torch.ops.aten.mm.default(view_1413, permute_507); view_1413 = permute_507 = None + view_1414 = torch.ops.aten.view.default(mm_258, [2, 8192, 4096]); mm_258 = None + convert_element_type_1095 = torch.ops.prims.convert_element_type.default(mm_257, torch.float32); mm_257 = None + reduce_scatter_tensor_147 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1095, 'avg', 8, '0'); convert_element_type_1095 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_147); reduce_scatter_tensor_147 = None + convert_element_type_1096 = torch.ops.prims.convert_element_type.default(mul_337, torch.float32); mul_337 = None + neg_10 = torch.ops.aten.neg.default(convert_element_type_191) + exp_10 = torch.ops.aten.exp.default(neg_10); neg_10 = None + add_135 = torch.ops.aten.add.Tensor(exp_10, 1); exp_10 = None + reciprocal_10 = torch.ops.aten.reciprocal.default(add_135); add_135 = None + mul_338 = torch.ops.aten.mul.Tensor(reciprocal_10, 1); reciprocal_10 = None + mul_339 = torch.ops.aten.mul.Tensor(convert_element_type_1096, mul_338); convert_element_type_1096 = None + sub_32 = torch.ops.aten.sub.Tensor(1, mul_338); mul_338 = None + mul_340 = torch.ops.aten.mul.Tensor(convert_element_type_191, sub_32); convert_element_type_191 = sub_32 = None + add_136 = torch.ops.aten.add.Tensor(mul_340, 1); mul_340 = None + mul_341 = torch.ops.aten.mul.Tensor(mul_339, add_136); mul_339 = add_136 = None + convert_element_type_1098 = torch.ops.prims.convert_element_type.default(mul_341, torch.bfloat16); mul_341 = None + view_1415 = torch.ops.aten.view.default(convert_element_type_1098, [16384, 1792]); convert_element_type_1098 = None + permute_509 = torch.ops.aten.permute.default(view_1415, [1, 0]) + mm_259 = torch.ops.aten.mm.default(permute_509, view_420); permute_509 = view_420 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16); primals_55 = None + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 8, '0'); convert_element_type_188 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + permute_511 = torch.ops.aten.permute.default(permute_63, [1, 0]); permute_63 = None + mm_260 = torch.ops.aten.mm.default(view_1415, permute_511); view_1415 = permute_511 = None + view_1416 = torch.ops.aten.view.default(mm_260, [2, 8192, 4096]); mm_260 = None + add_137 = torch.ops.aten.add.Tensor(view_1414, view_1416); view_1414 = view_1416 = None + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(mm_259, torch.float32); mm_259 = None + reduce_scatter_tensor_148 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1103, 'avg', 8, '0'); convert_element_type_1103 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_148); reduce_scatter_tensor_148 = None + split_116 = torch.ops.aten.split.Tensor(add_137, 1024, 1); add_137 = None + getitem_1102 = split_116[0] + getitem_1103 = split_116[1] + getitem_1104 = split_116[2] + getitem_1105 = split_116[3] + getitem_1106 = split_116[4] + getitem_1107 = split_116[5] + getitem_1108 = split_116[6] + getitem_1109 = split_116[7]; split_116 = None + cat_108 = torch.ops.aten.cat.default([getitem_1102, getitem_1103, getitem_1104, getitem_1105, getitem_1106, getitem_1107, getitem_1108, getitem_1109]); getitem_1102 = getitem_1103 = getitem_1104 = getitem_1105 = getitem_1106 = getitem_1107 = getitem_1108 = getitem_1109 = None + reduce_scatter_tensor_149 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_108, 'sum', 8, '1'); cat_108 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_149); reduce_scatter_tensor_149 = None + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(wait_tensor_371, torch.float32); wait_tensor_371 = None + convert_element_type_1106 = torch.ops.prims.convert_element_type.default(wait_tensor_74, torch.float32); wait_tensor_74 = None + mul_342 = torch.ops.aten.mul.Tensor(convert_element_type_1104, convert_element_type_1106); convert_element_type_1106 = None + mul_344 = torch.ops.aten.mul.Tensor(mul_44, mul_342) + sum_63 = torch.ops.aten.sum.dim_IntList(mul_344, [2], True); mul_344 = None + div_21 = torch.ops.aten.div.Tensor(mul_44, 4096) + mul_345 = torch.ops.aten.mul.Tensor(div_21, sum_63); div_21 = sum_63 = None + sub_33 = torch.ops.aten.sub.Tensor(mul_342, mul_345); mul_342 = mul_345 = None + mul_346 = torch.ops.aten.mul.Tensor(sub_33, rsqrt_11); sub_33 = rsqrt_11 = None + mul_347 = torch.ops.aten.mul.Tensor(convert_element_type_1104, mul_44); convert_element_type_1104 = mul_44 = None + sum_64 = torch.ops.aten.sum.dim_IntList(mul_347, [0, 1]); mul_347 = None + convert_element_type_1107 = torch.ops.prims.convert_element_type.default(mul_346, torch.bfloat16); mul_346 = None + convert_element_type_1108 = torch.ops.prims.convert_element_type.default(sum_64, torch.bfloat16); sum_64 = None + all_reduce_21 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1108, 'sum', '1'); convert_element_type_1108 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_21); all_reduce_21 = None + convert_element_type_1109 = torch.ops.prims.convert_element_type.default(wait_tensor_372, torch.float32); wait_tensor_372 = None + reduce_scatter_tensor_150 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1109, 'avg', 8, '0'); convert_element_type_1109 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_150); reduce_scatter_tensor_150 = None + add_138 = torch.ops.aten.add.Tensor(add_134, convert_element_type_1107); add_134 = convert_element_type_1107 = None + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_138, 8, '1') + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + split_117 = torch.ops.aten.split.Tensor(wait_tensor_374, 2); wait_tensor_374 = None + getitem_1110 = split_117[0] + getitem_1111 = split_117[1] + getitem_1112 = split_117[2] + getitem_1113 = split_117[3] + getitem_1114 = split_117[4] + getitem_1115 = split_117[5] + getitem_1116 = split_117[6] + getitem_1117 = split_117[7]; split_117 = None + cat_109 = torch.ops.aten.cat.default([getitem_1110, getitem_1111, getitem_1112, getitem_1113, getitem_1114, getitem_1115, getitem_1116, getitem_1117], 1); getitem_1110 = getitem_1111 = getitem_1112 = getitem_1113 = getitem_1114 = getitem_1115 = getitem_1116 = getitem_1117 = None + view_1417 = torch.ops.aten.view.default(cat_109, [16384, 4096]); cat_109 = None + permute_513 = torch.ops.aten.permute.default(view_1417, [1, 0]) + permute_61 = torch.ops.aten.permute.default(getitem_285, [0, 2, 1, 3]) + view_402 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + view_408 = torch.ops.aten.view.default(view_402, [16384, 512]); view_402 = None + mm_261 = torch.ops.aten.mm.default(permute_513, view_408); permute_513 = view_408 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16); primals_53 = None + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 8, '0'); convert_element_type_182 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + permute_515 = torch.ops.aten.permute.default(permute_62, [1, 0]); permute_62 = None + mm_262 = torch.ops.aten.mm.default(view_1417, permute_515); view_1417 = permute_515 = None + view_1418 = torch.ops.aten.view.default(mm_262, [2, 8192, 512]); mm_262 = None + convert_element_type_1114 = torch.ops.prims.convert_element_type.default(mm_261, torch.float32); mm_261 = None + reduce_scatter_tensor_151 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1114, 'avg', 8, '0'); convert_element_type_1114 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_151); reduce_scatter_tensor_151 = None + view_1419 = torch.ops.aten.view.default(view_1418, [2, 8192, 4, 128]); view_1418 = None + permute_517 = torch.ops.aten.permute.default(view_1419, [0, 2, 1, 3]); view_1419 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16); primals_49 = None + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 8, '0'); convert_element_type_166 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32); add_19 = None + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_67) + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_168, 8, '1'); convert_element_type_168 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + split_29 = torch.ops.aten.split.Tensor(wait_tensor_68, 2); wait_tensor_68 = None + getitem_277 = split_29[0] + getitem_278 = split_29[1] + getitem_279 = split_29[2] + getitem_280 = split_29[3] + getitem_281 = split_29[4] + getitem_282 = split_29[5] + getitem_283 = split_29[6] + getitem_284 = split_29[7]; split_29 = None + cat_21 = torch.ops.aten.cat.default([getitem_277, getitem_278, getitem_279, getitem_280, getitem_281, getitem_282, getitem_283, getitem_284], 1); getitem_277 = getitem_278 = getitem_279 = getitem_280 = getitem_281 = getitem_282 = getitem_283 = getitem_284 = None + view_375 = torch.ops.aten.view.default(cat_21, [16384, 4096]); cat_21 = None + view_376 = torch.ops.aten.view.default(mm_35, [2, 8192, 512]); mm_35 = None + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16); primals_51 = None + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 8, '0'); convert_element_type_172 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + mm_36 = torch.ops.aten.mm.default(view_375, permute_56) + view_383 = torch.ops.aten.view.default(mm_36, [2, 8192, 128]); mm_36 = None + view_390 = torch.ops.aten.view.default(mm_37, [2, 8192, 128]); mm_37 = None + view_392 = torch.ops.aten.view.default(view_376, [2, 8192, -1, 128]); view_376 = None + view_393 = torch.ops.aten.view.default(view_383, [2, 8192, -1, 128]); view_383 = None + view_394 = torch.ops.aten.view.default(view_390, [2, 8192, -1, 128]); view_390 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_392, torch.float32); view_392 = None + view_395 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 4, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_395); view_395 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_393, torch.float32); view_393 = None + view_396 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 1, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_396); view_396 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_37); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_398 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 4, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_37); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_399 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 1, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_398, torch.bfloat16); view_398 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_399, torch.bfloat16); view_399 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 1, 4, 128]); unsqueeze_10 = None + view_400 = torch.ops.aten.view.default(expand_10, [2, 8192, 4, 128]); expand_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_394, 3); view_394 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 1, 4, 128]); unsqueeze_11 = None + view_401 = torch.ops.aten.view.default(expand_11, [2, 8192, 4, 128]); expand_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_400, [0, 2, 1, 3]); view_400 = None + permute_60 = torch.ops.aten.permute.default(view_401, [0, 2, 1, 3]); view_401 = None + _scaled_dot_product_cudnn_attention_backward_10 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_517, permute_58, permute_59, permute_60, getitem_285, getitem_286, getitem_291, getitem_292, None, None, None, 8192, 8192, 0.0, True); permute_517 = permute_58 = permute_59 = permute_60 = getitem_285 = getitem_286 = getitem_291 = getitem_292 = None + getitem_1118 = _scaled_dot_product_cudnn_attention_backward_10[0] + getitem_1119 = _scaled_dot_product_cudnn_attention_backward_10[1] + getitem_1120 = _scaled_dot_product_cudnn_attention_backward_10[2]; _scaled_dot_product_cudnn_attention_backward_10 = None + permute_518 = torch.ops.aten.permute.default(getitem_1120, [0, 2, 1, 3]); getitem_1120 = None + permute_519 = torch.ops.aten.permute.default(getitem_1119, [0, 2, 1, 3]); getitem_1119 = None + permute_520 = torch.ops.aten.permute.default(getitem_1118, [0, 2, 1, 3]); getitem_1118 = None + view_1420 = torch.ops.aten.view.default(permute_518, [2, 8192, 1, 4, 128]); permute_518 = None + sum_65 = torch.ops.aten.sum.dim_IntList(view_1420, [3], True); view_1420 = None + squeeze_20 = torch.ops.aten.squeeze.dim(sum_65, 3); sum_65 = None + view_1421 = torch.ops.aten.view.default(permute_519, [2, 8192, 1, 4, 128]); permute_519 = None + sum_66 = torch.ops.aten.sum.dim_IntList(view_1421, [3], True); view_1421 = None + squeeze_21 = torch.ops.aten.squeeze.dim(sum_66, 3); sum_66 = None + convert_element_type_1115 = torch.ops.prims.convert_element_type.default(squeeze_21, torch.float32); squeeze_21 = None + convert_element_type_1116 = torch.ops.prims.convert_element_type.default(permute_520, torch.float32); permute_520 = None + view_1422 = torch.ops.aten.view.default(convert_element_type_1115, [2, 8192, 1, 64, 2]); convert_element_type_1115 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_1422); view_1422 = None + mul_348 = torch.ops.aten.mul.Tensor(view_as_complex_52, _conj); view_as_complex_52 = None + view_1423 = torch.ops.aten.view.default(convert_element_type_1116, [2, 8192, 4, 64, 2]); convert_element_type_1116 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_1423); view_1423 = None + mul_349 = torch.ops.aten.mul.Tensor(view_as_complex_53, _conj); view_as_complex_53 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_348); mul_348 = None + view_1424 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 1, 128]); view_as_real_52 = None + convert_element_type_1117 = torch.ops.prims.convert_element_type.default(view_1424, torch.bfloat16); view_1424 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_349); mul_349 = None + view_1425 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 4, 128]); view_as_real_53 = None + convert_element_type_1118 = torch.ops.prims.convert_element_type.default(view_1425, torch.bfloat16); view_1425 = None + view_1426 = torch.ops.aten.view.default(squeeze_20, [2, 8192, 128]); squeeze_20 = None + view_1427 = torch.ops.aten.view.default(convert_element_type_1117, [2, 8192, 128]); convert_element_type_1117 = None + view_1428 = torch.ops.aten.view.default(convert_element_type_1118, [2, 8192, 512]); convert_element_type_1118 = None + view_1429 = torch.ops.aten.view.default(view_1426, [16384, 128]); view_1426 = None + permute_521 = torch.ops.aten.permute.default(view_1429, [1, 0]) + mm_263 = torch.ops.aten.mm.default(permute_521, view_375); permute_521 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16); primals_52 = None + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 8, '0'); convert_element_type_175 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + permute_523 = torch.ops.aten.permute.default(permute_57, [1, 0]); permute_57 = None + mm_264 = torch.ops.aten.mm.default(view_1429, permute_523); view_1429 = permute_523 = None + view_1430 = torch.ops.aten.view.default(mm_264, [2, 8192, 4096]); mm_264 = None + convert_element_type_1123 = torch.ops.prims.convert_element_type.default(mm_263, torch.float32); mm_263 = None + reduce_scatter_tensor_152 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1123, 'avg', 8, '0'); convert_element_type_1123 = None + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_152); reduce_scatter_tensor_152 = None + view_1431 = torch.ops.aten.view.default(view_1427, [16384, 128]); view_1427 = None + permute_525 = torch.ops.aten.permute.default(view_1431, [1, 0]) + mm_265 = torch.ops.aten.mm.default(permute_525, view_375); permute_525 = None + permute_527 = torch.ops.aten.permute.default(permute_56, [1, 0]); permute_56 = None + mm_266 = torch.ops.aten.mm.default(view_1431, permute_527); view_1431 = permute_527 = None + view_1432 = torch.ops.aten.view.default(mm_266, [2, 8192, 4096]); mm_266 = None + add_139 = torch.ops.aten.add.Tensor(view_1430, view_1432); view_1430 = view_1432 = None + convert_element_type_1128 = torch.ops.prims.convert_element_type.default(mm_265, torch.float32); mm_265 = None + reduce_scatter_tensor_153 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1128, 'avg', 8, '0'); convert_element_type_1128 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_153); reduce_scatter_tensor_153 = None + view_1433 = torch.ops.aten.view.default(view_1428, [16384, 512]); view_1428 = None + permute_529 = torch.ops.aten.permute.default(view_1433, [1, 0]) + mm_267 = torch.ops.aten.mm.default(permute_529, view_375); permute_529 = view_375 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16); primals_50 = None + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 8, '0'); convert_element_type_169 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_69, [1, 0]); wait_tensor_69 = None + permute_531 = torch.ops.aten.permute.default(permute_55, [1, 0]); permute_55 = None + mm_268 = torch.ops.aten.mm.default(view_1433, permute_531); view_1433 = permute_531 = None + view_1434 = torch.ops.aten.view.default(mm_268, [2, 8192, 4096]); mm_268 = None + add_140 = torch.ops.aten.add.Tensor(add_139, view_1434); add_139 = view_1434 = None + convert_element_type_1133 = torch.ops.prims.convert_element_type.default(mm_267, torch.float32); mm_267 = None + reduce_scatter_tensor_154 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1133, 'avg', 8, '0'); convert_element_type_1133 = None + wait_tensor_378 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_154); reduce_scatter_tensor_154 = None + split_118 = torch.ops.aten.split.Tensor(add_140, 1024, 1); add_140 = None + getitem_1121 = split_118[0] + getitem_1122 = split_118[1] + getitem_1123 = split_118[2] + getitem_1124 = split_118[3] + getitem_1125 = split_118[4] + getitem_1126 = split_118[5] + getitem_1127 = split_118[6] + getitem_1128 = split_118[7]; split_118 = None + cat_110 = torch.ops.aten.cat.default([getitem_1121, getitem_1122, getitem_1123, getitem_1124, getitem_1125, getitem_1126, getitem_1127, getitem_1128]); getitem_1121 = getitem_1122 = getitem_1123 = getitem_1124 = getitem_1125 = getitem_1126 = getitem_1127 = getitem_1128 = None + reduce_scatter_tensor_155 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_110, 'sum', 8, '1'); cat_110 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_155); reduce_scatter_tensor_155 = None + convert_element_type_1134 = torch.ops.prims.convert_element_type.default(wait_tensor_379, torch.float32); wait_tensor_379 = None + convert_element_type_1136 = torch.ops.prims.convert_element_type.default(wait_tensor_67, torch.float32); wait_tensor_67 = None + mul_350 = torch.ops.aten.mul.Tensor(convert_element_type_1134, convert_element_type_1136); convert_element_type_1136 = None + mul_352 = torch.ops.aten.mul.Tensor(mul_40, mul_350) + sum_67 = torch.ops.aten.sum.dim_IntList(mul_352, [2], True); mul_352 = None + div_22 = torch.ops.aten.div.Tensor(mul_40, 4096) + mul_353 = torch.ops.aten.mul.Tensor(div_22, sum_67); div_22 = sum_67 = None + sub_34 = torch.ops.aten.sub.Tensor(mul_350, mul_353); mul_350 = mul_353 = None + mul_354 = torch.ops.aten.mul.Tensor(sub_34, rsqrt_10); sub_34 = rsqrt_10 = None + mul_355 = torch.ops.aten.mul.Tensor(convert_element_type_1134, mul_40); convert_element_type_1134 = mul_40 = None + sum_68 = torch.ops.aten.sum.dim_IntList(mul_355, [0, 1]); mul_355 = None + convert_element_type_1137 = torch.ops.prims.convert_element_type.default(mul_354, torch.bfloat16); mul_354 = None + convert_element_type_1138 = torch.ops.prims.convert_element_type.default(sum_68, torch.bfloat16); sum_68 = None + all_reduce_22 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1138, 'sum', '1'); convert_element_type_1138 = None + wait_tensor_380 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_22); all_reduce_22 = None + convert_element_type_1139 = torch.ops.prims.convert_element_type.default(wait_tensor_380, torch.float32); wait_tensor_380 = None + reduce_scatter_tensor_156 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1139, 'avg', 8, '0'); convert_element_type_1139 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_156); reduce_scatter_tensor_156 = None + add_141 = torch.ops.aten.add.Tensor(add_138, convert_element_type_1137); add_138 = convert_element_type_1137 = None + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_141, 8, '1') + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + split_119 = torch.ops.aten.split.Tensor(wait_tensor_382, 2); wait_tensor_382 = None + getitem_1129 = split_119[0] + getitem_1130 = split_119[1] + getitem_1131 = split_119[2] + getitem_1132 = split_119[3] + getitem_1133 = split_119[4] + getitem_1134 = split_119[5] + getitem_1135 = split_119[6] + getitem_1136 = split_119[7]; split_119 = None + cat_111 = torch.ops.aten.cat.default([getitem_1129, getitem_1130, getitem_1131, getitem_1132, getitem_1133, getitem_1134, getitem_1135, getitem_1136], 1); getitem_1129 = getitem_1130 = getitem_1131 = getitem_1132 = getitem_1133 = getitem_1134 = getitem_1135 = getitem_1136 = None + view_1435 = torch.ops.aten.view.default(cat_111, [16384, 4096]); cat_111 = None + permute_533 = torch.ops.aten.permute.default(view_1435, [1, 0]) + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9); reduce_scatter_tensor_9 = None + add_17 = torch.ops.aten.add.Tensor(add_15, wait_tensor_60); wait_tensor_60 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16); primals_45 = None + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 8, '0'); convert_element_type_152 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32); add_17 = None + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_61) + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_154, 8, '1'); convert_element_type_154 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + split_27 = torch.ops.aten.split.Tensor(wait_tensor_62, 2); wait_tensor_62 = None + getitem_261 = split_27[0] + getitem_262 = split_27[1] + getitem_263 = split_27[2] + getitem_264 = split_27[3] + getitem_265 = split_27[4] + getitem_266 = split_27[5] + getitem_267 = split_27[6] + getitem_268 = split_27[7]; split_27 = None + cat_19 = torch.ops.aten.cat.default([getitem_261, getitem_262, getitem_263, getitem_264, getitem_265, getitem_266, getitem_267, getitem_268], 1); getitem_261 = getitem_262 = getitem_263 = getitem_264 = getitem_265 = getitem_266 = getitem_267 = getitem_268 = None + view_348 = torch.ops.aten.view.default(cat_19, [16384, 4096]); cat_19 = None + view_349 = torch.ops.aten.view.default(mm_32, [2, 8192, 1792]); mm_32 = None + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_349, torch.float32); view_349 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16); primals_47 = None + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 8, '0'); convert_element_type_160 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_64, [1, 0]); wait_tensor_64 = None + mm_33 = torch.ops.aten.mm.default(view_348, permute_53) + view_356 = torch.ops.aten.view.default(mm_33, [2, 8192, 1792]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_356) + view_363 = torch.ops.aten.view.default(mul_39, [16384, 1792]); mul_39 = None + mm_269 = torch.ops.aten.mm.default(permute_533, view_363); permute_533 = view_363 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16); primals_48 = None + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 8, '0'); convert_element_type_163 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + permute_535 = torch.ops.aten.permute.default(permute_54, [1, 0]); permute_54 = None + mm_270 = torch.ops.aten.mm.default(view_1435, permute_535); view_1435 = permute_535 = None + view_1436 = torch.ops.aten.view.default(mm_270, [2, 8192, 1792]); mm_270 = None + convert_element_type_1144 = torch.ops.prims.convert_element_type.default(mm_269, torch.float32); mm_269 = None + reduce_scatter_tensor_157 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1144, 'avg', 8, '0'); convert_element_type_1144 = None + wait_tensor_383 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_157); reduce_scatter_tensor_157 = None + mul_356 = torch.ops.aten.mul.Tensor(view_1436, convert_element_type_159); convert_element_type_159 = None + mul_357 = torch.ops.aten.mul.Tensor(view_1436, view_356); view_1436 = view_356 = None + view_1437 = torch.ops.aten.view.default(mul_356, [16384, 1792]); mul_356 = None + permute_537 = torch.ops.aten.permute.default(view_1437, [1, 0]) + mm_271 = torch.ops.aten.mm.default(permute_537, view_348); permute_537 = None + permute_539 = torch.ops.aten.permute.default(permute_53, [1, 0]); permute_53 = None + mm_272 = torch.ops.aten.mm.default(view_1437, permute_539); view_1437 = permute_539 = None + view_1438 = torch.ops.aten.view.default(mm_272, [2, 8192, 4096]); mm_272 = None + convert_element_type_1149 = torch.ops.prims.convert_element_type.default(mm_271, torch.float32); mm_271 = None + reduce_scatter_tensor_158 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1149, 'avg', 8, '0'); convert_element_type_1149 = None + wait_tensor_384 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_158); reduce_scatter_tensor_158 = None + convert_element_type_1150 = torch.ops.prims.convert_element_type.default(mul_357, torch.float32); mul_357 = None + neg_11 = torch.ops.aten.neg.default(convert_element_type_158) + exp_11 = torch.ops.aten.exp.default(neg_11); neg_11 = None + add_142 = torch.ops.aten.add.Tensor(exp_11, 1); exp_11 = None + reciprocal_11 = torch.ops.aten.reciprocal.default(add_142); add_142 = None + mul_358 = torch.ops.aten.mul.Tensor(reciprocal_11, 1); reciprocal_11 = None + mul_359 = torch.ops.aten.mul.Tensor(convert_element_type_1150, mul_358); convert_element_type_1150 = None + sub_35 = torch.ops.aten.sub.Tensor(1, mul_358); mul_358 = None + mul_360 = torch.ops.aten.mul.Tensor(convert_element_type_158, sub_35); convert_element_type_158 = sub_35 = None + add_143 = torch.ops.aten.add.Tensor(mul_360, 1); mul_360 = None + mul_361 = torch.ops.aten.mul.Tensor(mul_359, add_143); mul_359 = add_143 = None + convert_element_type_1152 = torch.ops.prims.convert_element_type.default(mul_361, torch.bfloat16); mul_361 = None + view_1439 = torch.ops.aten.view.default(convert_element_type_1152, [16384, 1792]); convert_element_type_1152 = None + permute_541 = torch.ops.aten.permute.default(view_1439, [1, 0]) + mm_273 = torch.ops.aten.mm.default(permute_541, view_348); permute_541 = view_348 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16); primals_46 = None + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 8, '0'); convert_element_type_155 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + permute_543 = torch.ops.aten.permute.default(permute_52, [1, 0]); permute_52 = None + mm_274 = torch.ops.aten.mm.default(view_1439, permute_543); view_1439 = permute_543 = None + view_1440 = torch.ops.aten.view.default(mm_274, [2, 8192, 4096]); mm_274 = None + add_144 = torch.ops.aten.add.Tensor(view_1438, view_1440); view_1438 = view_1440 = None + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(mm_273, torch.float32); mm_273 = None + reduce_scatter_tensor_159 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1157, 'avg', 8, '0'); convert_element_type_1157 = None + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_159); reduce_scatter_tensor_159 = None + split_120 = torch.ops.aten.split.Tensor(add_144, 1024, 1); add_144 = None + getitem_1137 = split_120[0] + getitem_1138 = split_120[1] + getitem_1139 = split_120[2] + getitem_1140 = split_120[3] + getitem_1141 = split_120[4] + getitem_1142 = split_120[5] + getitem_1143 = split_120[6] + getitem_1144 = split_120[7]; split_120 = None + cat_112 = torch.ops.aten.cat.default([getitem_1137, getitem_1138, getitem_1139, getitem_1140, getitem_1141, getitem_1142, getitem_1143, getitem_1144]); getitem_1137 = getitem_1138 = getitem_1139 = getitem_1140 = getitem_1141 = getitem_1142 = getitem_1143 = getitem_1144 = None + reduce_scatter_tensor_160 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_112, 'sum', 8, '1'); cat_112 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_160); reduce_scatter_tensor_160 = None + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(wait_tensor_386, torch.float32); wait_tensor_386 = None + convert_element_type_1160 = torch.ops.prims.convert_element_type.default(wait_tensor_61, torch.float32); wait_tensor_61 = None + mul_362 = torch.ops.aten.mul.Tensor(convert_element_type_1158, convert_element_type_1160); convert_element_type_1160 = None + mul_364 = torch.ops.aten.mul.Tensor(mul_36, mul_362) + sum_69 = torch.ops.aten.sum.dim_IntList(mul_364, [2], True); mul_364 = None + div_23 = torch.ops.aten.div.Tensor(mul_36, 4096) + mul_365 = torch.ops.aten.mul.Tensor(div_23, sum_69); div_23 = sum_69 = None + sub_36 = torch.ops.aten.sub.Tensor(mul_362, mul_365); mul_362 = mul_365 = None + mul_366 = torch.ops.aten.mul.Tensor(sub_36, rsqrt_9); sub_36 = rsqrt_9 = None + mul_367 = torch.ops.aten.mul.Tensor(convert_element_type_1158, mul_36); convert_element_type_1158 = mul_36 = None + sum_70 = torch.ops.aten.sum.dim_IntList(mul_367, [0, 1]); mul_367 = None + convert_element_type_1161 = torch.ops.prims.convert_element_type.default(mul_366, torch.bfloat16); mul_366 = None + convert_element_type_1162 = torch.ops.prims.convert_element_type.default(sum_70, torch.bfloat16); sum_70 = None + all_reduce_23 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1162, 'sum', '1'); convert_element_type_1162 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_23); all_reduce_23 = None + convert_element_type_1163 = torch.ops.prims.convert_element_type.default(wait_tensor_387, torch.float32); wait_tensor_387 = None + reduce_scatter_tensor_161 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1163, 'avg', 8, '0'); convert_element_type_1163 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_161); reduce_scatter_tensor_161 = None + add_145 = torch.ops.aten.add.Tensor(add_141, convert_element_type_1161); add_141 = convert_element_type_1161 = None + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_145, 8, '1') + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + split_121 = torch.ops.aten.split.Tensor(wait_tensor_389, 2); wait_tensor_389 = None + getitem_1145 = split_121[0] + getitem_1146 = split_121[1] + getitem_1147 = split_121[2] + getitem_1148 = split_121[3] + getitem_1149 = split_121[4] + getitem_1150 = split_121[5] + getitem_1151 = split_121[6] + getitem_1152 = split_121[7]; split_121 = None + cat_113 = torch.ops.aten.cat.default([getitem_1145, getitem_1146, getitem_1147, getitem_1148, getitem_1149, getitem_1150, getitem_1151, getitem_1152], 1); getitem_1145 = getitem_1146 = getitem_1147 = getitem_1148 = getitem_1149 = getitem_1150 = getitem_1151 = getitem_1152 = None + view_1441 = torch.ops.aten.view.default(cat_113, [16384, 4096]); cat_113 = None + permute_545 = torch.ops.aten.permute.default(view_1441, [1, 0]) + permute_50 = torch.ops.aten.permute.default(getitem_244, [0, 2, 1, 3]) + view_330 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + view_336 = torch.ops.aten.view.default(view_330, [16384, 512]); view_330 = None + mm_275 = torch.ops.aten.mm.default(permute_545, view_336); permute_545 = view_336 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16); primals_44 = None + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 8, '0'); convert_element_type_149 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + permute_547 = torch.ops.aten.permute.default(permute_51, [1, 0]); permute_51 = None + mm_276 = torch.ops.aten.mm.default(view_1441, permute_547); view_1441 = permute_547 = None + view_1442 = torch.ops.aten.view.default(mm_276, [2, 8192, 512]); mm_276 = None + convert_element_type_1168 = torch.ops.prims.convert_element_type.default(mm_275, torch.float32); mm_275 = None + reduce_scatter_tensor_162 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1168, 'avg', 8, '0'); convert_element_type_1168 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_162); reduce_scatter_tensor_162 = None + view_1443 = torch.ops.aten.view.default(view_1442, [2, 8192, 4, 128]); view_1442 = None + permute_549 = torch.ops.aten.permute.default(view_1443, [0, 2, 1, 3]); view_1443 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16); primals_40 = None + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 8, '0'); convert_element_type_133 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32); add_15 = None + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_54) + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_135, 8, '1'); convert_element_type_135 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + split_25 = torch.ops.aten.split.Tensor(wait_tensor_55, 2); wait_tensor_55 = None + getitem_236 = split_25[0] + getitem_237 = split_25[1] + getitem_238 = split_25[2] + getitem_239 = split_25[3] + getitem_240 = split_25[4] + getitem_241 = split_25[5] + getitem_242 = split_25[6] + getitem_243 = split_25[7]; split_25 = None + cat_17 = torch.ops.aten.cat.default([getitem_236, getitem_237, getitem_238, getitem_239, getitem_240, getitem_241, getitem_242, getitem_243], 1); getitem_236 = getitem_237 = getitem_238 = getitem_239 = getitem_240 = getitem_241 = getitem_242 = getitem_243 = None + view_303 = torch.ops.aten.view.default(cat_17, [16384, 4096]); cat_17 = None + view_304 = torch.ops.aten.view.default(mm_28, [2, 8192, 512]); mm_28 = None + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16); primals_42 = None + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 8, '0'); convert_element_type_139 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_29 = torch.ops.aten.mm.default(view_303, permute_45) + view_311 = torch.ops.aten.view.default(mm_29, [2, 8192, 128]); mm_29 = None + view_318 = torch.ops.aten.view.default(mm_30, [2, 8192, 128]); mm_30 = None + view_320 = torch.ops.aten.view.default(view_304, [2, 8192, -1, 128]); view_304 = None + view_321 = torch.ops.aten.view.default(view_311, [2, 8192, -1, 128]); view_311 = None + view_322 = torch.ops.aten.view.default(view_318, [2, 8192, -1, 128]); view_318 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_320, torch.float32); view_320 = None + view_323 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 4, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_323); view_323 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_321, torch.float32); view_321 = None + view_324 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 1, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_324); view_324 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_37); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_326 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 4, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_37); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_327 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 1, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_326, torch.bfloat16); view_326 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_327, torch.bfloat16); view_327 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 1, 4, 128]); unsqueeze_8 = None + view_328 = torch.ops.aten.view.default(expand_8, [2, 8192, 4, 128]); expand_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_322, 3); view_322 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 1, 4, 128]); unsqueeze_9 = None + view_329 = torch.ops.aten.view.default(expand_9, [2, 8192, 4, 128]); expand_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_328, [0, 2, 1, 3]); view_328 = None + permute_49 = torch.ops.aten.permute.default(view_329, [0, 2, 1, 3]); view_329 = None + _scaled_dot_product_cudnn_attention_backward_11 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_549, permute_47, permute_48, permute_49, getitem_244, getitem_245, getitem_250, getitem_251, None, None, None, 8192, 8192, 0.0, True); permute_549 = permute_47 = permute_48 = permute_49 = getitem_244 = getitem_245 = getitem_250 = getitem_251 = None + getitem_1153 = _scaled_dot_product_cudnn_attention_backward_11[0] + getitem_1154 = _scaled_dot_product_cudnn_attention_backward_11[1] + getitem_1155 = _scaled_dot_product_cudnn_attention_backward_11[2]; _scaled_dot_product_cudnn_attention_backward_11 = None + permute_550 = torch.ops.aten.permute.default(getitem_1155, [0, 2, 1, 3]); getitem_1155 = None + permute_551 = torch.ops.aten.permute.default(getitem_1154, [0, 2, 1, 3]); getitem_1154 = None + permute_552 = torch.ops.aten.permute.default(getitem_1153, [0, 2, 1, 3]); getitem_1153 = None + view_1444 = torch.ops.aten.view.default(permute_550, [2, 8192, 1, 4, 128]); permute_550 = None + sum_71 = torch.ops.aten.sum.dim_IntList(view_1444, [3], True); view_1444 = None + squeeze_22 = torch.ops.aten.squeeze.dim(sum_71, 3); sum_71 = None + view_1445 = torch.ops.aten.view.default(permute_551, [2, 8192, 1, 4, 128]); permute_551 = None + sum_72 = torch.ops.aten.sum.dim_IntList(view_1445, [3], True); view_1445 = None + squeeze_23 = torch.ops.aten.squeeze.dim(sum_72, 3); sum_72 = None + convert_element_type_1169 = torch.ops.prims.convert_element_type.default(squeeze_23, torch.float32); squeeze_23 = None + convert_element_type_1170 = torch.ops.prims.convert_element_type.default(permute_552, torch.float32); permute_552 = None + view_1446 = torch.ops.aten.view.default(convert_element_type_1169, [2, 8192, 1, 64, 2]); convert_element_type_1169 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_1446); view_1446 = None + mul_368 = torch.ops.aten.mul.Tensor(view_as_complex_54, _conj); view_as_complex_54 = None + view_1447 = torch.ops.aten.view.default(convert_element_type_1170, [2, 8192, 4, 64, 2]); convert_element_type_1170 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_1447); view_1447 = None + mul_369 = torch.ops.aten.mul.Tensor(view_as_complex_55, _conj); view_as_complex_55 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_368); mul_368 = None + view_1448 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 1, 128]); view_as_real_54 = None + convert_element_type_1171 = torch.ops.prims.convert_element_type.default(view_1448, torch.bfloat16); view_1448 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_369); mul_369 = None + view_1449 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 4, 128]); view_as_real_55 = None + convert_element_type_1172 = torch.ops.prims.convert_element_type.default(view_1449, torch.bfloat16); view_1449 = None + view_1450 = torch.ops.aten.view.default(squeeze_22, [2, 8192, 128]); squeeze_22 = None + view_1451 = torch.ops.aten.view.default(convert_element_type_1171, [2, 8192, 128]); convert_element_type_1171 = None + view_1452 = torch.ops.aten.view.default(convert_element_type_1172, [2, 8192, 512]); convert_element_type_1172 = None + view_1453 = torch.ops.aten.view.default(view_1450, [16384, 128]); view_1450 = None + permute_553 = torch.ops.aten.permute.default(view_1453, [1, 0]) + mm_277 = torch.ops.aten.mm.default(permute_553, view_303); permute_553 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16); primals_43 = None + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 8, '0'); convert_element_type_142 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + permute_555 = torch.ops.aten.permute.default(permute_46, [1, 0]); permute_46 = None + mm_278 = torch.ops.aten.mm.default(view_1453, permute_555); view_1453 = permute_555 = None + view_1454 = torch.ops.aten.view.default(mm_278, [2, 8192, 4096]); mm_278 = None + convert_element_type_1177 = torch.ops.prims.convert_element_type.default(mm_277, torch.float32); mm_277 = None + reduce_scatter_tensor_163 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1177, 'avg', 8, '0'); convert_element_type_1177 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_163); reduce_scatter_tensor_163 = None + view_1455 = torch.ops.aten.view.default(view_1451, [16384, 128]); view_1451 = None + permute_557 = torch.ops.aten.permute.default(view_1455, [1, 0]) + mm_279 = torch.ops.aten.mm.default(permute_557, view_303); permute_557 = None + permute_559 = torch.ops.aten.permute.default(permute_45, [1, 0]); permute_45 = None + mm_280 = torch.ops.aten.mm.default(view_1455, permute_559); view_1455 = permute_559 = None + view_1456 = torch.ops.aten.view.default(mm_280, [2, 8192, 4096]); mm_280 = None + add_146 = torch.ops.aten.add.Tensor(view_1454, view_1456); view_1454 = view_1456 = None + convert_element_type_1182 = torch.ops.prims.convert_element_type.default(mm_279, torch.float32); mm_279 = None + reduce_scatter_tensor_164 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1182, 'avg', 8, '0'); convert_element_type_1182 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_164); reduce_scatter_tensor_164 = None + view_1457 = torch.ops.aten.view.default(view_1452, [16384, 512]); view_1452 = None + permute_561 = torch.ops.aten.permute.default(view_1457, [1, 0]) + mm_281 = torch.ops.aten.mm.default(permute_561, view_303); permute_561 = view_303 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16); primals_41 = None + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 8, '0'); convert_element_type_136 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + permute_563 = torch.ops.aten.permute.default(permute_44, [1, 0]); permute_44 = None + mm_282 = torch.ops.aten.mm.default(view_1457, permute_563); view_1457 = permute_563 = None + view_1458 = torch.ops.aten.view.default(mm_282, [2, 8192, 4096]); mm_282 = None + add_147 = torch.ops.aten.add.Tensor(add_146, view_1458); add_146 = view_1458 = None + convert_element_type_1187 = torch.ops.prims.convert_element_type.default(mm_281, torch.float32); mm_281 = None + reduce_scatter_tensor_165 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1187, 'avg', 8, '0'); convert_element_type_1187 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_165); reduce_scatter_tensor_165 = None + split_122 = torch.ops.aten.split.Tensor(add_147, 1024, 1); add_147 = None + getitem_1156 = split_122[0] + getitem_1157 = split_122[1] + getitem_1158 = split_122[2] + getitem_1159 = split_122[3] + getitem_1160 = split_122[4] + getitem_1161 = split_122[5] + getitem_1162 = split_122[6] + getitem_1163 = split_122[7]; split_122 = None + cat_114 = torch.ops.aten.cat.default([getitem_1156, getitem_1157, getitem_1158, getitem_1159, getitem_1160, getitem_1161, getitem_1162, getitem_1163]); getitem_1156 = getitem_1157 = getitem_1158 = getitem_1159 = getitem_1160 = getitem_1161 = getitem_1162 = getitem_1163 = None + reduce_scatter_tensor_166 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_114, 'sum', 8, '1'); cat_114 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_166); reduce_scatter_tensor_166 = None + convert_element_type_1188 = torch.ops.prims.convert_element_type.default(wait_tensor_394, torch.float32); wait_tensor_394 = None + convert_element_type_1190 = torch.ops.prims.convert_element_type.default(wait_tensor_54, torch.float32); wait_tensor_54 = None + mul_370 = torch.ops.aten.mul.Tensor(convert_element_type_1188, convert_element_type_1190); convert_element_type_1190 = None + mul_372 = torch.ops.aten.mul.Tensor(mul_32, mul_370) + sum_73 = torch.ops.aten.sum.dim_IntList(mul_372, [2], True); mul_372 = None + div_24 = torch.ops.aten.div.Tensor(mul_32, 4096) + mul_373 = torch.ops.aten.mul.Tensor(div_24, sum_73); div_24 = sum_73 = None + sub_37 = torch.ops.aten.sub.Tensor(mul_370, mul_373); mul_370 = mul_373 = None + mul_374 = torch.ops.aten.mul.Tensor(sub_37, rsqrt_8); sub_37 = rsqrt_8 = None + mul_375 = torch.ops.aten.mul.Tensor(convert_element_type_1188, mul_32); convert_element_type_1188 = mul_32 = None + sum_74 = torch.ops.aten.sum.dim_IntList(mul_375, [0, 1]); mul_375 = None + convert_element_type_1191 = torch.ops.prims.convert_element_type.default(mul_374, torch.bfloat16); mul_374 = None + convert_element_type_1192 = torch.ops.prims.convert_element_type.default(sum_74, torch.bfloat16); sum_74 = None + all_reduce_24 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1192, 'sum', '1'); convert_element_type_1192 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_24); all_reduce_24 = None + convert_element_type_1193 = torch.ops.prims.convert_element_type.default(wait_tensor_395, torch.float32); wait_tensor_395 = None + reduce_scatter_tensor_167 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1193, 'avg', 8, '0'); convert_element_type_1193 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_167); reduce_scatter_tensor_167 = None + add_148 = torch.ops.aten.add.Tensor(add_145, convert_element_type_1191); add_145 = convert_element_type_1191 = None + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_148, 8, '1') + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + split_123 = torch.ops.aten.split.Tensor(wait_tensor_397, 2); wait_tensor_397 = None + getitem_1164 = split_123[0] + getitem_1165 = split_123[1] + getitem_1166 = split_123[2] + getitem_1167 = split_123[3] + getitem_1168 = split_123[4] + getitem_1169 = split_123[5] + getitem_1170 = split_123[6] + getitem_1171 = split_123[7]; split_123 = None + cat_115 = torch.ops.aten.cat.default([getitem_1164, getitem_1165, getitem_1166, getitem_1167, getitem_1168, getitem_1169, getitem_1170, getitem_1171], 1); getitem_1164 = getitem_1165 = getitem_1166 = getitem_1167 = getitem_1168 = getitem_1169 = getitem_1170 = getitem_1171 = None + view_1459 = torch.ops.aten.view.default(cat_115, [16384, 4096]); cat_115 = None + permute_565 = torch.ops.aten.permute.default(view_1459, [1, 0]) + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7); reduce_scatter_tensor_7 = None + add_13 = torch.ops.aten.add.Tensor(add_11, wait_tensor_47); wait_tensor_47 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16); primals_36 = None + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 8, '0'); convert_element_type_119 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32); add_13 = None + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_48) + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_121, 8, '1'); convert_element_type_121 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + split_23 = torch.ops.aten.split.Tensor(wait_tensor_49, 2); wait_tensor_49 = None + getitem_220 = split_23[0] + getitem_221 = split_23[1] + getitem_222 = split_23[2] + getitem_223 = split_23[3] + getitem_224 = split_23[4] + getitem_225 = split_23[5] + getitem_226 = split_23[6] + getitem_227 = split_23[7]; split_23 = None + cat_15 = torch.ops.aten.cat.default([getitem_220, getitem_221, getitem_222, getitem_223, getitem_224, getitem_225, getitem_226, getitem_227], 1); getitem_220 = getitem_221 = getitem_222 = getitem_223 = getitem_224 = getitem_225 = getitem_226 = getitem_227 = None + view_276 = torch.ops.aten.view.default(cat_15, [16384, 4096]); cat_15 = None + view_277 = torch.ops.aten.view.default(mm_25, [2, 8192, 1792]); mm_25 = None + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_277, torch.float32); view_277 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16); primals_38 = None + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 8, '0'); convert_element_type_127 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_51, [1, 0]); wait_tensor_51 = None + mm_26 = torch.ops.aten.mm.default(view_276, permute_42) + view_284 = torch.ops.aten.view.default(mm_26, [2, 8192, 1792]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_284) + view_291 = torch.ops.aten.view.default(mul_31, [16384, 1792]); mul_31 = None + mm_283 = torch.ops.aten.mm.default(permute_565, view_291); permute_565 = view_291 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16); primals_39 = None + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 8, '0'); convert_element_type_130 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + permute_567 = torch.ops.aten.permute.default(permute_43, [1, 0]); permute_43 = None + mm_284 = torch.ops.aten.mm.default(view_1459, permute_567); view_1459 = permute_567 = None + view_1460 = torch.ops.aten.view.default(mm_284, [2, 8192, 1792]); mm_284 = None + convert_element_type_1198 = torch.ops.prims.convert_element_type.default(mm_283, torch.float32); mm_283 = None + reduce_scatter_tensor_168 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1198, 'avg', 8, '0'); convert_element_type_1198 = None + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_168); reduce_scatter_tensor_168 = None + mul_376 = torch.ops.aten.mul.Tensor(view_1460, convert_element_type_126); convert_element_type_126 = None + mul_377 = torch.ops.aten.mul.Tensor(view_1460, view_284); view_1460 = view_284 = None + view_1461 = torch.ops.aten.view.default(mul_376, [16384, 1792]); mul_376 = None + permute_569 = torch.ops.aten.permute.default(view_1461, [1, 0]) + mm_285 = torch.ops.aten.mm.default(permute_569, view_276); permute_569 = None + permute_571 = torch.ops.aten.permute.default(permute_42, [1, 0]); permute_42 = None + mm_286 = torch.ops.aten.mm.default(view_1461, permute_571); view_1461 = permute_571 = None + view_1462 = torch.ops.aten.view.default(mm_286, [2, 8192, 4096]); mm_286 = None + convert_element_type_1203 = torch.ops.prims.convert_element_type.default(mm_285, torch.float32); mm_285 = None + reduce_scatter_tensor_169 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1203, 'avg', 8, '0'); convert_element_type_1203 = None + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_169); reduce_scatter_tensor_169 = None + convert_element_type_1204 = torch.ops.prims.convert_element_type.default(mul_377, torch.float32); mul_377 = None + neg_12 = torch.ops.aten.neg.default(convert_element_type_125) + exp_12 = torch.ops.aten.exp.default(neg_12); neg_12 = None + add_149 = torch.ops.aten.add.Tensor(exp_12, 1); exp_12 = None + reciprocal_12 = torch.ops.aten.reciprocal.default(add_149); add_149 = None + mul_378 = torch.ops.aten.mul.Tensor(reciprocal_12, 1); reciprocal_12 = None + mul_379 = torch.ops.aten.mul.Tensor(convert_element_type_1204, mul_378); convert_element_type_1204 = None + sub_38 = torch.ops.aten.sub.Tensor(1, mul_378); mul_378 = None + mul_380 = torch.ops.aten.mul.Tensor(convert_element_type_125, sub_38); convert_element_type_125 = sub_38 = None + add_150 = torch.ops.aten.add.Tensor(mul_380, 1); mul_380 = None + mul_381 = torch.ops.aten.mul.Tensor(mul_379, add_150); mul_379 = add_150 = None + convert_element_type_1206 = torch.ops.prims.convert_element_type.default(mul_381, torch.bfloat16); mul_381 = None + view_1463 = torch.ops.aten.view.default(convert_element_type_1206, [16384, 1792]); convert_element_type_1206 = None + permute_573 = torch.ops.aten.permute.default(view_1463, [1, 0]) + mm_287 = torch.ops.aten.mm.default(permute_573, view_276); permute_573 = view_276 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16); primals_37 = None + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 8, '0'); convert_element_type_122 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + permute_575 = torch.ops.aten.permute.default(permute_41, [1, 0]); permute_41 = None + mm_288 = torch.ops.aten.mm.default(view_1463, permute_575); view_1463 = permute_575 = None + view_1464 = torch.ops.aten.view.default(mm_288, [2, 8192, 4096]); mm_288 = None + add_151 = torch.ops.aten.add.Tensor(view_1462, view_1464); view_1462 = view_1464 = None + convert_element_type_1211 = torch.ops.prims.convert_element_type.default(mm_287, torch.float32); mm_287 = None + reduce_scatter_tensor_170 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1211, 'avg', 8, '0'); convert_element_type_1211 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_170); reduce_scatter_tensor_170 = None + split_124 = torch.ops.aten.split.Tensor(add_151, 1024, 1); add_151 = None + getitem_1172 = split_124[0] + getitem_1173 = split_124[1] + getitem_1174 = split_124[2] + getitem_1175 = split_124[3] + getitem_1176 = split_124[4] + getitem_1177 = split_124[5] + getitem_1178 = split_124[6] + getitem_1179 = split_124[7]; split_124 = None + cat_116 = torch.ops.aten.cat.default([getitem_1172, getitem_1173, getitem_1174, getitem_1175, getitem_1176, getitem_1177, getitem_1178, getitem_1179]); getitem_1172 = getitem_1173 = getitem_1174 = getitem_1175 = getitem_1176 = getitem_1177 = getitem_1178 = getitem_1179 = None + reduce_scatter_tensor_171 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_116, 'sum', 8, '1'); cat_116 = None + wait_tensor_401 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_171); reduce_scatter_tensor_171 = None + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(wait_tensor_401, torch.float32); wait_tensor_401 = None + convert_element_type_1214 = torch.ops.prims.convert_element_type.default(wait_tensor_48, torch.float32); wait_tensor_48 = None + mul_382 = torch.ops.aten.mul.Tensor(convert_element_type_1212, convert_element_type_1214); convert_element_type_1214 = None + mul_384 = torch.ops.aten.mul.Tensor(mul_28, mul_382) + sum_75 = torch.ops.aten.sum.dim_IntList(mul_384, [2], True); mul_384 = None + div_25 = torch.ops.aten.div.Tensor(mul_28, 4096) + mul_385 = torch.ops.aten.mul.Tensor(div_25, sum_75); div_25 = sum_75 = None + sub_39 = torch.ops.aten.sub.Tensor(mul_382, mul_385); mul_382 = mul_385 = None + mul_386 = torch.ops.aten.mul.Tensor(sub_39, rsqrt_7); sub_39 = rsqrt_7 = None + mul_387 = torch.ops.aten.mul.Tensor(convert_element_type_1212, mul_28); convert_element_type_1212 = mul_28 = None + sum_76 = torch.ops.aten.sum.dim_IntList(mul_387, [0, 1]); mul_387 = None + convert_element_type_1215 = torch.ops.prims.convert_element_type.default(mul_386, torch.bfloat16); mul_386 = None + convert_element_type_1216 = torch.ops.prims.convert_element_type.default(sum_76, torch.bfloat16); sum_76 = None + all_reduce_25 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1216, 'sum', '1'); convert_element_type_1216 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_25); all_reduce_25 = None + convert_element_type_1217 = torch.ops.prims.convert_element_type.default(wait_tensor_402, torch.float32); wait_tensor_402 = None + reduce_scatter_tensor_172 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1217, 'avg', 8, '0'); convert_element_type_1217 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_172); reduce_scatter_tensor_172 = None + add_152 = torch.ops.aten.add.Tensor(add_148, convert_element_type_1215); add_148 = convert_element_type_1215 = None + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_152, 8, '1') + wait_tensor_404 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + split_125 = torch.ops.aten.split.Tensor(wait_tensor_404, 2); wait_tensor_404 = None + getitem_1180 = split_125[0] + getitem_1181 = split_125[1] + getitem_1182 = split_125[2] + getitem_1183 = split_125[3] + getitem_1184 = split_125[4] + getitem_1185 = split_125[5] + getitem_1186 = split_125[6] + getitem_1187 = split_125[7]; split_125 = None + cat_117 = torch.ops.aten.cat.default([getitem_1180, getitem_1181, getitem_1182, getitem_1183, getitem_1184, getitem_1185, getitem_1186, getitem_1187], 1); getitem_1180 = getitem_1181 = getitem_1182 = getitem_1183 = getitem_1184 = getitem_1185 = getitem_1186 = getitem_1187 = None + view_1465 = torch.ops.aten.view.default(cat_117, [16384, 4096]); cat_117 = None + permute_577 = torch.ops.aten.permute.default(view_1465, [1, 0]) + permute_39 = torch.ops.aten.permute.default(getitem_203, [0, 2, 1, 3]) + view_258 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + view_264 = torch.ops.aten.view.default(view_258, [16384, 512]); view_258 = None + mm_289 = torch.ops.aten.mm.default(permute_577, view_264); permute_577 = view_264 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16); primals_35 = None + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 8, '0'); convert_element_type_116 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_46, [1, 0]); wait_tensor_46 = None + permute_579 = torch.ops.aten.permute.default(permute_40, [1, 0]); permute_40 = None + mm_290 = torch.ops.aten.mm.default(view_1465, permute_579); view_1465 = permute_579 = None + view_1466 = torch.ops.aten.view.default(mm_290, [2, 8192, 512]); mm_290 = None + convert_element_type_1222 = torch.ops.prims.convert_element_type.default(mm_289, torch.float32); mm_289 = None + reduce_scatter_tensor_173 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1222, 'avg', 8, '0'); convert_element_type_1222 = None + wait_tensor_405 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_173); reduce_scatter_tensor_173 = None + view_1467 = torch.ops.aten.view.default(view_1466, [2, 8192, 4, 128]); view_1466 = None + permute_581 = torch.ops.aten.permute.default(view_1467, [0, 2, 1, 3]); view_1467 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16); primals_31 = None + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 8, '0'); convert_element_type_100 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32); add_11 = None + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_41) + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_102, 8, '1'); convert_element_type_102 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + split_21 = torch.ops.aten.split.Tensor(wait_tensor_42, 2); wait_tensor_42 = None + getitem_195 = split_21[0] + getitem_196 = split_21[1] + getitem_197 = split_21[2] + getitem_198 = split_21[3] + getitem_199 = split_21[4] + getitem_200 = split_21[5] + getitem_201 = split_21[6] + getitem_202 = split_21[7]; split_21 = None + cat_13 = torch.ops.aten.cat.default([getitem_195, getitem_196, getitem_197, getitem_198, getitem_199, getitem_200, getitem_201, getitem_202], 1); getitem_195 = getitem_196 = getitem_197 = getitem_198 = getitem_199 = getitem_200 = getitem_201 = getitem_202 = None + view_231 = torch.ops.aten.view.default(cat_13, [16384, 4096]); cat_13 = None + view_232 = torch.ops.aten.view.default(mm_21, [2, 8192, 512]); mm_21 = None + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16); primals_33 = None + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 8, '0'); convert_element_type_106 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_22 = torch.ops.aten.mm.default(view_231, permute_34) + view_239 = torch.ops.aten.view.default(mm_22, [2, 8192, 128]); mm_22 = None + view_246 = torch.ops.aten.view.default(mm_23, [2, 8192, 128]); mm_23 = None + view_248 = torch.ops.aten.view.default(view_232, [2, 8192, -1, 128]); view_232 = None + view_249 = torch.ops.aten.view.default(view_239, [2, 8192, -1, 128]); view_239 = None + view_250 = torch.ops.aten.view.default(view_246, [2, 8192, -1, 128]); view_246 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_248, torch.float32); view_248 = None + view_251 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 4, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_251); view_251 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 1, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_37); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_254 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 4, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_37); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_255 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 1, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_254, torch.bfloat16); view_254 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 1, 4, 128]); unsqueeze_6 = None + view_256 = torch.ops.aten.view.default(expand_6, [2, 8192, 4, 128]); expand_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_250, 3); view_250 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 1, 4, 128]); unsqueeze_7 = None + view_257 = torch.ops.aten.view.default(expand_7, [2, 8192, 4, 128]); expand_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_256, [0, 2, 1, 3]); view_256 = None + permute_38 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + _scaled_dot_product_cudnn_attention_backward_12 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_581, permute_36, permute_37, permute_38, getitem_203, getitem_204, getitem_209, getitem_210, None, None, None, 8192, 8192, 0.0, True); permute_581 = permute_36 = permute_37 = permute_38 = getitem_203 = getitem_204 = getitem_209 = getitem_210 = None + getitem_1188 = _scaled_dot_product_cudnn_attention_backward_12[0] + getitem_1189 = _scaled_dot_product_cudnn_attention_backward_12[1] + getitem_1190 = _scaled_dot_product_cudnn_attention_backward_12[2]; _scaled_dot_product_cudnn_attention_backward_12 = None + permute_582 = torch.ops.aten.permute.default(getitem_1190, [0, 2, 1, 3]); getitem_1190 = None + permute_583 = torch.ops.aten.permute.default(getitem_1189, [0, 2, 1, 3]); getitem_1189 = None + permute_584 = torch.ops.aten.permute.default(getitem_1188, [0, 2, 1, 3]); getitem_1188 = None + view_1468 = torch.ops.aten.view.default(permute_582, [2, 8192, 1, 4, 128]); permute_582 = None + sum_77 = torch.ops.aten.sum.dim_IntList(view_1468, [3], True); view_1468 = None + squeeze_24 = torch.ops.aten.squeeze.dim(sum_77, 3); sum_77 = None + view_1469 = torch.ops.aten.view.default(permute_583, [2, 8192, 1, 4, 128]); permute_583 = None + sum_78 = torch.ops.aten.sum.dim_IntList(view_1469, [3], True); view_1469 = None + squeeze_25 = torch.ops.aten.squeeze.dim(sum_78, 3); sum_78 = None + convert_element_type_1223 = torch.ops.prims.convert_element_type.default(squeeze_25, torch.float32); squeeze_25 = None + convert_element_type_1224 = torch.ops.prims.convert_element_type.default(permute_584, torch.float32); permute_584 = None + view_1470 = torch.ops.aten.view.default(convert_element_type_1223, [2, 8192, 1, 64, 2]); convert_element_type_1223 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_1470); view_1470 = None + mul_388 = torch.ops.aten.mul.Tensor(view_as_complex_56, _conj); view_as_complex_56 = None + view_1471 = torch.ops.aten.view.default(convert_element_type_1224, [2, 8192, 4, 64, 2]); convert_element_type_1224 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_1471); view_1471 = None + mul_389 = torch.ops.aten.mul.Tensor(view_as_complex_57, _conj); view_as_complex_57 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_388); mul_388 = None + view_1472 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 1, 128]); view_as_real_56 = None + convert_element_type_1225 = torch.ops.prims.convert_element_type.default(view_1472, torch.bfloat16); view_1472 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_389); mul_389 = None + view_1473 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 4, 128]); view_as_real_57 = None + convert_element_type_1226 = torch.ops.prims.convert_element_type.default(view_1473, torch.bfloat16); view_1473 = None + view_1474 = torch.ops.aten.view.default(squeeze_24, [2, 8192, 128]); squeeze_24 = None + view_1475 = torch.ops.aten.view.default(convert_element_type_1225, [2, 8192, 128]); convert_element_type_1225 = None + view_1476 = torch.ops.aten.view.default(convert_element_type_1226, [2, 8192, 512]); convert_element_type_1226 = None + view_1477 = torch.ops.aten.view.default(view_1474, [16384, 128]); view_1474 = None + permute_585 = torch.ops.aten.permute.default(view_1477, [1, 0]) + mm_291 = torch.ops.aten.mm.default(permute_585, view_231); permute_585 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16); primals_34 = None + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 8, '0'); convert_element_type_109 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + permute_587 = torch.ops.aten.permute.default(permute_35, [1, 0]); permute_35 = None + mm_292 = torch.ops.aten.mm.default(view_1477, permute_587); view_1477 = permute_587 = None + view_1478 = torch.ops.aten.view.default(mm_292, [2, 8192, 4096]); mm_292 = None + convert_element_type_1231 = torch.ops.prims.convert_element_type.default(mm_291, torch.float32); mm_291 = None + reduce_scatter_tensor_174 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1231, 'avg', 8, '0'); convert_element_type_1231 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_174); reduce_scatter_tensor_174 = None + view_1479 = torch.ops.aten.view.default(view_1475, [16384, 128]); view_1475 = None + permute_589 = torch.ops.aten.permute.default(view_1479, [1, 0]) + mm_293 = torch.ops.aten.mm.default(permute_589, view_231); permute_589 = None + permute_591 = torch.ops.aten.permute.default(permute_34, [1, 0]); permute_34 = None + mm_294 = torch.ops.aten.mm.default(view_1479, permute_591); view_1479 = permute_591 = None + view_1480 = torch.ops.aten.view.default(mm_294, [2, 8192, 4096]); mm_294 = None + add_153 = torch.ops.aten.add.Tensor(view_1478, view_1480); view_1478 = view_1480 = None + convert_element_type_1236 = torch.ops.prims.convert_element_type.default(mm_293, torch.float32); mm_293 = None + reduce_scatter_tensor_175 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1236, 'avg', 8, '0'); convert_element_type_1236 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_175); reduce_scatter_tensor_175 = None + view_1481 = torch.ops.aten.view.default(view_1476, [16384, 512]); view_1476 = None + permute_593 = torch.ops.aten.permute.default(view_1481, [1, 0]) + mm_295 = torch.ops.aten.mm.default(permute_593, view_231); permute_593 = view_231 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16); primals_32 = None + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 8, '0'); convert_element_type_103 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + permute_595 = torch.ops.aten.permute.default(permute_33, [1, 0]); permute_33 = None + mm_296 = torch.ops.aten.mm.default(view_1481, permute_595); view_1481 = permute_595 = None + view_1482 = torch.ops.aten.view.default(mm_296, [2, 8192, 4096]); mm_296 = None + add_154 = torch.ops.aten.add.Tensor(add_153, view_1482); add_153 = view_1482 = None + convert_element_type_1241 = torch.ops.prims.convert_element_type.default(mm_295, torch.float32); mm_295 = None + reduce_scatter_tensor_176 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1241, 'avg', 8, '0'); convert_element_type_1241 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_176); reduce_scatter_tensor_176 = None + split_126 = torch.ops.aten.split.Tensor(add_154, 1024, 1); add_154 = None + getitem_1191 = split_126[0] + getitem_1192 = split_126[1] + getitem_1193 = split_126[2] + getitem_1194 = split_126[3] + getitem_1195 = split_126[4] + getitem_1196 = split_126[5] + getitem_1197 = split_126[6] + getitem_1198 = split_126[7]; split_126 = None + cat_118 = torch.ops.aten.cat.default([getitem_1191, getitem_1192, getitem_1193, getitem_1194, getitem_1195, getitem_1196, getitem_1197, getitem_1198]); getitem_1191 = getitem_1192 = getitem_1193 = getitem_1194 = getitem_1195 = getitem_1196 = getitem_1197 = getitem_1198 = None + reduce_scatter_tensor_177 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_118, 'sum', 8, '1'); cat_118 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_177); reduce_scatter_tensor_177 = None + convert_element_type_1242 = torch.ops.prims.convert_element_type.default(wait_tensor_409, torch.float32); wait_tensor_409 = None + convert_element_type_1244 = torch.ops.prims.convert_element_type.default(wait_tensor_41, torch.float32); wait_tensor_41 = None + mul_390 = torch.ops.aten.mul.Tensor(convert_element_type_1242, convert_element_type_1244); convert_element_type_1244 = None + mul_392 = torch.ops.aten.mul.Tensor(mul_24, mul_390) + sum_79 = torch.ops.aten.sum.dim_IntList(mul_392, [2], True); mul_392 = None + div_26 = torch.ops.aten.div.Tensor(mul_24, 4096) + mul_393 = torch.ops.aten.mul.Tensor(div_26, sum_79); div_26 = sum_79 = None + sub_40 = torch.ops.aten.sub.Tensor(mul_390, mul_393); mul_390 = mul_393 = None + mul_394 = torch.ops.aten.mul.Tensor(sub_40, rsqrt_6); sub_40 = rsqrt_6 = None + mul_395 = torch.ops.aten.mul.Tensor(convert_element_type_1242, mul_24); convert_element_type_1242 = mul_24 = None + sum_80 = torch.ops.aten.sum.dim_IntList(mul_395, [0, 1]); mul_395 = None + convert_element_type_1245 = torch.ops.prims.convert_element_type.default(mul_394, torch.bfloat16); mul_394 = None + convert_element_type_1246 = torch.ops.prims.convert_element_type.default(sum_80, torch.bfloat16); sum_80 = None + all_reduce_26 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1246, 'sum', '1'); convert_element_type_1246 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_26); all_reduce_26 = None + convert_element_type_1247 = torch.ops.prims.convert_element_type.default(wait_tensor_410, torch.float32); wait_tensor_410 = None + reduce_scatter_tensor_178 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1247, 'avg', 8, '0'); convert_element_type_1247 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_178); reduce_scatter_tensor_178 = None + add_155 = torch.ops.aten.add.Tensor(add_152, convert_element_type_1245); add_152 = convert_element_type_1245 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_155, 8, '1') + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + split_127 = torch.ops.aten.split.Tensor(wait_tensor_412, 2); wait_tensor_412 = None + getitem_1199 = split_127[0] + getitem_1200 = split_127[1] + getitem_1201 = split_127[2] + getitem_1202 = split_127[3] + getitem_1203 = split_127[4] + getitem_1204 = split_127[5] + getitem_1205 = split_127[6] + getitem_1206 = split_127[7]; split_127 = None + cat_119 = torch.ops.aten.cat.default([getitem_1199, getitem_1200, getitem_1201, getitem_1202, getitem_1203, getitem_1204, getitem_1205, getitem_1206], 1); getitem_1199 = getitem_1200 = getitem_1201 = getitem_1202 = getitem_1203 = getitem_1204 = getitem_1205 = getitem_1206 = None + view_1483 = torch.ops.aten.view.default(cat_119, [16384, 4096]); cat_119 = None + permute_597 = torch.ops.aten.permute.default(view_1483, [1, 0]) + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5); reduce_scatter_tensor_5 = None + add_9 = torch.ops.aten.add.Tensor(add_7, wait_tensor_34); wait_tensor_34 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16); primals_27 = None + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 8, '0'); convert_element_type_86 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32); add_9 = None + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_35) + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_88, 8, '1'); convert_element_type_88 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + split_19 = torch.ops.aten.split.Tensor(wait_tensor_36, 2); wait_tensor_36 = None + getitem_179 = split_19[0] + getitem_180 = split_19[1] + getitem_181 = split_19[2] + getitem_182 = split_19[3] + getitem_183 = split_19[4] + getitem_184 = split_19[5] + getitem_185 = split_19[6] + getitem_186 = split_19[7]; split_19 = None + cat_11 = torch.ops.aten.cat.default([getitem_179, getitem_180, getitem_181, getitem_182, getitem_183, getitem_184, getitem_185, getitem_186], 1); getitem_179 = getitem_180 = getitem_181 = getitem_182 = getitem_183 = getitem_184 = getitem_185 = getitem_186 = None + view_204 = torch.ops.aten.view.default(cat_11, [16384, 4096]); cat_11 = None + view_205 = torch.ops.aten.view.default(mm_18, [2, 8192, 1792]); mm_18 = None + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_205, torch.float32); view_205 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16); primals_29 = None + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 8, '0'); convert_element_type_94 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + mm_19 = torch.ops.aten.mm.default(view_204, permute_31) + view_212 = torch.ops.aten.view.default(mm_19, [2, 8192, 1792]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_212) + view_219 = torch.ops.aten.view.default(mul_23, [16384, 1792]); mul_23 = None + mm_297 = torch.ops.aten.mm.default(permute_597, view_219); permute_597 = view_219 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16); primals_30 = None + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 8, '0'); convert_element_type_97 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + permute_599 = torch.ops.aten.permute.default(permute_32, [1, 0]); permute_32 = None + mm_298 = torch.ops.aten.mm.default(view_1483, permute_599); view_1483 = permute_599 = None + view_1484 = torch.ops.aten.view.default(mm_298, [2, 8192, 1792]); mm_298 = None + convert_element_type_1252 = torch.ops.prims.convert_element_type.default(mm_297, torch.float32); mm_297 = None + reduce_scatter_tensor_179 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1252, 'avg', 8, '0'); convert_element_type_1252 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_179); reduce_scatter_tensor_179 = None + mul_396 = torch.ops.aten.mul.Tensor(view_1484, convert_element_type_93); convert_element_type_93 = None + mul_397 = torch.ops.aten.mul.Tensor(view_1484, view_212); view_1484 = view_212 = None + view_1485 = torch.ops.aten.view.default(mul_396, [16384, 1792]); mul_396 = None + permute_601 = torch.ops.aten.permute.default(view_1485, [1, 0]) + mm_299 = torch.ops.aten.mm.default(permute_601, view_204); permute_601 = None + permute_603 = torch.ops.aten.permute.default(permute_31, [1, 0]); permute_31 = None + mm_300 = torch.ops.aten.mm.default(view_1485, permute_603); view_1485 = permute_603 = None + view_1486 = torch.ops.aten.view.default(mm_300, [2, 8192, 4096]); mm_300 = None + convert_element_type_1257 = torch.ops.prims.convert_element_type.default(mm_299, torch.float32); mm_299 = None + reduce_scatter_tensor_180 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1257, 'avg', 8, '0'); convert_element_type_1257 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_180); reduce_scatter_tensor_180 = None + convert_element_type_1258 = torch.ops.prims.convert_element_type.default(mul_397, torch.float32); mul_397 = None + neg_13 = torch.ops.aten.neg.default(convert_element_type_92) + exp_13 = torch.ops.aten.exp.default(neg_13); neg_13 = None + add_156 = torch.ops.aten.add.Tensor(exp_13, 1); exp_13 = None + reciprocal_13 = torch.ops.aten.reciprocal.default(add_156); add_156 = None + mul_398 = torch.ops.aten.mul.Tensor(reciprocal_13, 1); reciprocal_13 = None + mul_399 = torch.ops.aten.mul.Tensor(convert_element_type_1258, mul_398); convert_element_type_1258 = None + sub_41 = torch.ops.aten.sub.Tensor(1, mul_398); mul_398 = None + mul_400 = torch.ops.aten.mul.Tensor(convert_element_type_92, sub_41); convert_element_type_92 = sub_41 = None + add_157 = torch.ops.aten.add.Tensor(mul_400, 1); mul_400 = None + mul_401 = torch.ops.aten.mul.Tensor(mul_399, add_157); mul_399 = add_157 = None + convert_element_type_1260 = torch.ops.prims.convert_element_type.default(mul_401, torch.bfloat16); mul_401 = None + view_1487 = torch.ops.aten.view.default(convert_element_type_1260, [16384, 1792]); convert_element_type_1260 = None + permute_605 = torch.ops.aten.permute.default(view_1487, [1, 0]) + mm_301 = torch.ops.aten.mm.default(permute_605, view_204); permute_605 = view_204 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16); primals_28 = None + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 8, '0'); convert_element_type_89 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_37, [1, 0]); wait_tensor_37 = None + permute_607 = torch.ops.aten.permute.default(permute_30, [1, 0]); permute_30 = None + mm_302 = torch.ops.aten.mm.default(view_1487, permute_607); view_1487 = permute_607 = None + view_1488 = torch.ops.aten.view.default(mm_302, [2, 8192, 4096]); mm_302 = None + add_158 = torch.ops.aten.add.Tensor(view_1486, view_1488); view_1486 = view_1488 = None + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(mm_301, torch.float32); mm_301 = None + reduce_scatter_tensor_181 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1265, 'avg', 8, '0'); convert_element_type_1265 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_181); reduce_scatter_tensor_181 = None + split_128 = torch.ops.aten.split.Tensor(add_158, 1024, 1); add_158 = None + getitem_1207 = split_128[0] + getitem_1208 = split_128[1] + getitem_1209 = split_128[2] + getitem_1210 = split_128[3] + getitem_1211 = split_128[4] + getitem_1212 = split_128[5] + getitem_1213 = split_128[6] + getitem_1214 = split_128[7]; split_128 = None + cat_120 = torch.ops.aten.cat.default([getitem_1207, getitem_1208, getitem_1209, getitem_1210, getitem_1211, getitem_1212, getitem_1213, getitem_1214]); getitem_1207 = getitem_1208 = getitem_1209 = getitem_1210 = getitem_1211 = getitem_1212 = getitem_1213 = getitem_1214 = None + reduce_scatter_tensor_182 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_120, 'sum', 8, '1'); cat_120 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_182); reduce_scatter_tensor_182 = None + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(wait_tensor_416, torch.float32); wait_tensor_416 = None + convert_element_type_1268 = torch.ops.prims.convert_element_type.default(wait_tensor_35, torch.float32); wait_tensor_35 = None + mul_402 = torch.ops.aten.mul.Tensor(convert_element_type_1266, convert_element_type_1268); convert_element_type_1268 = None + mul_404 = torch.ops.aten.mul.Tensor(mul_20, mul_402) + sum_81 = torch.ops.aten.sum.dim_IntList(mul_404, [2], True); mul_404 = None + div_27 = torch.ops.aten.div.Tensor(mul_20, 4096) + mul_405 = torch.ops.aten.mul.Tensor(div_27, sum_81); div_27 = sum_81 = None + sub_42 = torch.ops.aten.sub.Tensor(mul_402, mul_405); mul_402 = mul_405 = None + mul_406 = torch.ops.aten.mul.Tensor(sub_42, rsqrt_5); sub_42 = rsqrt_5 = None + mul_407 = torch.ops.aten.mul.Tensor(convert_element_type_1266, mul_20); convert_element_type_1266 = mul_20 = None + sum_82 = torch.ops.aten.sum.dim_IntList(mul_407, [0, 1]); mul_407 = None + convert_element_type_1269 = torch.ops.prims.convert_element_type.default(mul_406, torch.bfloat16); mul_406 = None + convert_element_type_1270 = torch.ops.prims.convert_element_type.default(sum_82, torch.bfloat16); sum_82 = None + all_reduce_27 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1270, 'sum', '1'); convert_element_type_1270 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_27); all_reduce_27 = None + convert_element_type_1271 = torch.ops.prims.convert_element_type.default(wait_tensor_417, torch.float32); wait_tensor_417 = None + reduce_scatter_tensor_183 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1271, 'avg', 8, '0'); convert_element_type_1271 = None + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_183); reduce_scatter_tensor_183 = None + add_159 = torch.ops.aten.add.Tensor(add_155, convert_element_type_1269); add_155 = convert_element_type_1269 = None + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_159, 8, '1') + wait_tensor_419 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + split_129 = torch.ops.aten.split.Tensor(wait_tensor_419, 2); wait_tensor_419 = None + getitem_1215 = split_129[0] + getitem_1216 = split_129[1] + getitem_1217 = split_129[2] + getitem_1218 = split_129[3] + getitem_1219 = split_129[4] + getitem_1220 = split_129[5] + getitem_1221 = split_129[6] + getitem_1222 = split_129[7]; split_129 = None + cat_121 = torch.ops.aten.cat.default([getitem_1215, getitem_1216, getitem_1217, getitem_1218, getitem_1219, getitem_1220, getitem_1221, getitem_1222], 1); getitem_1215 = getitem_1216 = getitem_1217 = getitem_1218 = getitem_1219 = getitem_1220 = getitem_1221 = getitem_1222 = None + view_1489 = torch.ops.aten.view.default(cat_121, [16384, 4096]); cat_121 = None + permute_609 = torch.ops.aten.permute.default(view_1489, [1, 0]) + permute_28 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_186 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + view_192 = torch.ops.aten.view.default(view_186, [16384, 512]); view_186 = None + mm_303 = torch.ops.aten.mm.default(permute_609, view_192); permute_609 = view_192 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16); primals_26 = None + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 8, '0'); convert_element_type_83 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_33, [1, 0]); wait_tensor_33 = None + permute_611 = torch.ops.aten.permute.default(permute_29, [1, 0]); permute_29 = None + mm_304 = torch.ops.aten.mm.default(view_1489, permute_611); view_1489 = permute_611 = None + view_1490 = torch.ops.aten.view.default(mm_304, [2, 8192, 512]); mm_304 = None + convert_element_type_1276 = torch.ops.prims.convert_element_type.default(mm_303, torch.float32); mm_303 = None + reduce_scatter_tensor_184 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1276, 'avg', 8, '0'); convert_element_type_1276 = None + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_184); reduce_scatter_tensor_184 = None + view_1491 = torch.ops.aten.view.default(view_1490, [2, 8192, 4, 128]); view_1490 = None + permute_613 = torch.ops.aten.permute.default(view_1491, [0, 2, 1, 3]); view_1491 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16); primals_22 = None + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 8, '0'); convert_element_type_67 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32); add_7 = None + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_28) + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_69, 8, '1'); convert_element_type_69 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + split_17 = torch.ops.aten.split.Tensor(wait_tensor_29, 2); wait_tensor_29 = None + getitem_154 = split_17[0] + getitem_155 = split_17[1] + getitem_156 = split_17[2] + getitem_157 = split_17[3] + getitem_158 = split_17[4] + getitem_159 = split_17[5] + getitem_160 = split_17[6] + getitem_161 = split_17[7]; split_17 = None + cat_9 = torch.ops.aten.cat.default([getitem_154, getitem_155, getitem_156, getitem_157, getitem_158, getitem_159, getitem_160, getitem_161], 1); getitem_154 = getitem_155 = getitem_156 = getitem_157 = getitem_158 = getitem_159 = getitem_160 = getitem_161 = None + view_159 = torch.ops.aten.view.default(cat_9, [16384, 4096]); cat_9 = None + view_160 = torch.ops.aten.view.default(mm_14, [2, 8192, 512]); mm_14 = None + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16); primals_24 = None + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 8, '0'); convert_element_type_73 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_15 = torch.ops.aten.mm.default(view_159, permute_23) + view_167 = torch.ops.aten.view.default(mm_15, [2, 8192, 128]); mm_15 = None + view_174 = torch.ops.aten.view.default(mm_16, [2, 8192, 128]); mm_16 = None + view_176 = torch.ops.aten.view.default(view_160, [2, 8192, -1, 128]); view_160 = None + view_177 = torch.ops.aten.view.default(view_167, [2, 8192, -1, 128]); view_167 = None + view_178 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_176, torch.float32); view_176 = None + view_179 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 4, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_179); view_179 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_177, torch.float32); view_177 = None + view_180 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 1, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_180); view_180 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_37); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_182 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 4, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_37); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_183 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 1, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_182, torch.bfloat16); view_182 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_183, torch.bfloat16); view_183 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 1, 4, 128]); unsqueeze_4 = None + view_184 = torch.ops.aten.view.default(expand_4, [2, 8192, 4, 128]); expand_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_178, 3); view_178 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 1, 4, 128]); unsqueeze_5 = None + view_185 = torch.ops.aten.view.default(expand_5, [2, 8192, 4, 128]); expand_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_184, [0, 2, 1, 3]); view_184 = None + permute_27 = torch.ops.aten.permute.default(view_185, [0, 2, 1, 3]); view_185 = None + _scaled_dot_product_cudnn_attention_backward_13 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_613, permute_25, permute_26, permute_27, getitem_162, getitem_163, getitem_168, getitem_169, None, None, None, 8192, 8192, 0.0, True); permute_613 = permute_25 = permute_26 = permute_27 = getitem_162 = getitem_163 = getitem_168 = getitem_169 = None + getitem_1223 = _scaled_dot_product_cudnn_attention_backward_13[0] + getitem_1224 = _scaled_dot_product_cudnn_attention_backward_13[1] + getitem_1225 = _scaled_dot_product_cudnn_attention_backward_13[2]; _scaled_dot_product_cudnn_attention_backward_13 = None + permute_614 = torch.ops.aten.permute.default(getitem_1225, [0, 2, 1, 3]); getitem_1225 = None + permute_615 = torch.ops.aten.permute.default(getitem_1224, [0, 2, 1, 3]); getitem_1224 = None + permute_616 = torch.ops.aten.permute.default(getitem_1223, [0, 2, 1, 3]); getitem_1223 = None + view_1492 = torch.ops.aten.view.default(permute_614, [2, 8192, 1, 4, 128]); permute_614 = None + sum_83 = torch.ops.aten.sum.dim_IntList(view_1492, [3], True); view_1492 = None + squeeze_26 = torch.ops.aten.squeeze.dim(sum_83, 3); sum_83 = None + view_1493 = torch.ops.aten.view.default(permute_615, [2, 8192, 1, 4, 128]); permute_615 = None + sum_84 = torch.ops.aten.sum.dim_IntList(view_1493, [3], True); view_1493 = None + squeeze_27 = torch.ops.aten.squeeze.dim(sum_84, 3); sum_84 = None + convert_element_type_1277 = torch.ops.prims.convert_element_type.default(squeeze_27, torch.float32); squeeze_27 = None + convert_element_type_1278 = torch.ops.prims.convert_element_type.default(permute_616, torch.float32); permute_616 = None + view_1494 = torch.ops.aten.view.default(convert_element_type_1277, [2, 8192, 1, 64, 2]); convert_element_type_1277 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_1494); view_1494 = None + mul_408 = torch.ops.aten.mul.Tensor(view_as_complex_58, _conj); view_as_complex_58 = None + view_1495 = torch.ops.aten.view.default(convert_element_type_1278, [2, 8192, 4, 64, 2]); convert_element_type_1278 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_1495); view_1495 = None + mul_409 = torch.ops.aten.mul.Tensor(view_as_complex_59, _conj); view_as_complex_59 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_408); mul_408 = None + view_1496 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 1, 128]); view_as_real_58 = None + convert_element_type_1279 = torch.ops.prims.convert_element_type.default(view_1496, torch.bfloat16); view_1496 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_409); mul_409 = None + view_1497 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 4, 128]); view_as_real_59 = None + convert_element_type_1280 = torch.ops.prims.convert_element_type.default(view_1497, torch.bfloat16); view_1497 = None + view_1498 = torch.ops.aten.view.default(squeeze_26, [2, 8192, 128]); squeeze_26 = None + view_1499 = torch.ops.aten.view.default(convert_element_type_1279, [2, 8192, 128]); convert_element_type_1279 = None + view_1500 = torch.ops.aten.view.default(convert_element_type_1280, [2, 8192, 512]); convert_element_type_1280 = None + view_1501 = torch.ops.aten.view.default(view_1498, [16384, 128]); view_1498 = None + permute_617 = torch.ops.aten.permute.default(view_1501, [1, 0]) + mm_305 = torch.ops.aten.mm.default(permute_617, view_159); permute_617 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16); primals_25 = None + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 8, '0'); convert_element_type_76 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + permute_619 = torch.ops.aten.permute.default(permute_24, [1, 0]); permute_24 = None + mm_306 = torch.ops.aten.mm.default(view_1501, permute_619); view_1501 = permute_619 = None + view_1502 = torch.ops.aten.view.default(mm_306, [2, 8192, 4096]); mm_306 = None + convert_element_type_1285 = torch.ops.prims.convert_element_type.default(mm_305, torch.float32); mm_305 = None + reduce_scatter_tensor_185 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1285, 'avg', 8, '0'); convert_element_type_1285 = None + wait_tensor_421 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_185); reduce_scatter_tensor_185 = None + view_1503 = torch.ops.aten.view.default(view_1499, [16384, 128]); view_1499 = None + permute_621 = torch.ops.aten.permute.default(view_1503, [1, 0]) + mm_307 = torch.ops.aten.mm.default(permute_621, view_159); permute_621 = None + permute_623 = torch.ops.aten.permute.default(permute_23, [1, 0]); permute_23 = None + mm_308 = torch.ops.aten.mm.default(view_1503, permute_623); view_1503 = permute_623 = None + view_1504 = torch.ops.aten.view.default(mm_308, [2, 8192, 4096]); mm_308 = None + add_160 = torch.ops.aten.add.Tensor(view_1502, view_1504); view_1502 = view_1504 = None + convert_element_type_1290 = torch.ops.prims.convert_element_type.default(mm_307, torch.float32); mm_307 = None + reduce_scatter_tensor_186 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1290, 'avg', 8, '0'); convert_element_type_1290 = None + wait_tensor_422 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_186); reduce_scatter_tensor_186 = None + view_1505 = torch.ops.aten.view.default(view_1500, [16384, 512]); view_1500 = None + permute_625 = torch.ops.aten.permute.default(view_1505, [1, 0]) + mm_309 = torch.ops.aten.mm.default(permute_625, view_159); permute_625 = view_159 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16); primals_23 = None + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 8, '0'); convert_element_type_70 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + permute_627 = torch.ops.aten.permute.default(permute_22, [1, 0]); permute_22 = None + mm_310 = torch.ops.aten.mm.default(view_1505, permute_627); view_1505 = permute_627 = None + view_1506 = torch.ops.aten.view.default(mm_310, [2, 8192, 4096]); mm_310 = None + add_161 = torch.ops.aten.add.Tensor(add_160, view_1506); add_160 = view_1506 = None + convert_element_type_1295 = torch.ops.prims.convert_element_type.default(mm_309, torch.float32); mm_309 = None + reduce_scatter_tensor_187 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1295, 'avg', 8, '0'); convert_element_type_1295 = None + wait_tensor_423 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_187); reduce_scatter_tensor_187 = None + split_130 = torch.ops.aten.split.Tensor(add_161, 1024, 1); add_161 = None + getitem_1226 = split_130[0] + getitem_1227 = split_130[1] + getitem_1228 = split_130[2] + getitem_1229 = split_130[3] + getitem_1230 = split_130[4] + getitem_1231 = split_130[5] + getitem_1232 = split_130[6] + getitem_1233 = split_130[7]; split_130 = None + cat_122 = torch.ops.aten.cat.default([getitem_1226, getitem_1227, getitem_1228, getitem_1229, getitem_1230, getitem_1231, getitem_1232, getitem_1233]); getitem_1226 = getitem_1227 = getitem_1228 = getitem_1229 = getitem_1230 = getitem_1231 = getitem_1232 = getitem_1233 = None + reduce_scatter_tensor_188 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_122, 'sum', 8, '1'); cat_122 = None + wait_tensor_424 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_188); reduce_scatter_tensor_188 = None + convert_element_type_1296 = torch.ops.prims.convert_element_type.default(wait_tensor_424, torch.float32); wait_tensor_424 = None + convert_element_type_1298 = torch.ops.prims.convert_element_type.default(wait_tensor_28, torch.float32); wait_tensor_28 = None + mul_410 = torch.ops.aten.mul.Tensor(convert_element_type_1296, convert_element_type_1298); convert_element_type_1298 = None + mul_412 = torch.ops.aten.mul.Tensor(mul_16, mul_410) + sum_85 = torch.ops.aten.sum.dim_IntList(mul_412, [2], True); mul_412 = None + div_28 = torch.ops.aten.div.Tensor(mul_16, 4096) + mul_413 = torch.ops.aten.mul.Tensor(div_28, sum_85); div_28 = sum_85 = None + sub_43 = torch.ops.aten.sub.Tensor(mul_410, mul_413); mul_410 = mul_413 = None + mul_414 = torch.ops.aten.mul.Tensor(sub_43, rsqrt_4); sub_43 = rsqrt_4 = None + mul_415 = torch.ops.aten.mul.Tensor(convert_element_type_1296, mul_16); convert_element_type_1296 = mul_16 = None + sum_86 = torch.ops.aten.sum.dim_IntList(mul_415, [0, 1]); mul_415 = None + convert_element_type_1299 = torch.ops.prims.convert_element_type.default(mul_414, torch.bfloat16); mul_414 = None + convert_element_type_1300 = torch.ops.prims.convert_element_type.default(sum_86, torch.bfloat16); sum_86 = None + all_reduce_28 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1300, 'sum', '1'); convert_element_type_1300 = None + wait_tensor_425 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_28); all_reduce_28 = None + convert_element_type_1301 = torch.ops.prims.convert_element_type.default(wait_tensor_425, torch.float32); wait_tensor_425 = None + reduce_scatter_tensor_189 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1301, 'avg', 8, '0'); convert_element_type_1301 = None + wait_tensor_426 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_189); reduce_scatter_tensor_189 = None + add_162 = torch.ops.aten.add.Tensor(add_159, convert_element_type_1299); add_159 = convert_element_type_1299 = None + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_162, 8, '1') + wait_tensor_427 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + split_131 = torch.ops.aten.split.Tensor(wait_tensor_427, 2); wait_tensor_427 = None + getitem_1234 = split_131[0] + getitem_1235 = split_131[1] + getitem_1236 = split_131[2] + getitem_1237 = split_131[3] + getitem_1238 = split_131[4] + getitem_1239 = split_131[5] + getitem_1240 = split_131[6] + getitem_1241 = split_131[7]; split_131 = None + cat_123 = torch.ops.aten.cat.default([getitem_1234, getitem_1235, getitem_1236, getitem_1237, getitem_1238, getitem_1239, getitem_1240, getitem_1241], 1); getitem_1234 = getitem_1235 = getitem_1236 = getitem_1237 = getitem_1238 = getitem_1239 = getitem_1240 = getitem_1241 = None + view_1507 = torch.ops.aten.view.default(cat_123, [16384, 4096]); cat_123 = None + permute_629 = torch.ops.aten.permute.default(view_1507, [1, 0]) + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3); reduce_scatter_tensor_3 = None + add_5 = torch.ops.aten.add.Tensor(add_3, wait_tensor_21); wait_tensor_21 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16); primals_18 = None + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 8, '0'); convert_element_type_53 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32); add_5 = None + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_22) + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_55, 8, '1'); convert_element_type_55 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + split_15 = torch.ops.aten.split.Tensor(wait_tensor_23, 2); wait_tensor_23 = None + getitem_138 = split_15[0] + getitem_139 = split_15[1] + getitem_140 = split_15[2] + getitem_141 = split_15[3] + getitem_142 = split_15[4] + getitem_143 = split_15[5] + getitem_144 = split_15[6] + getitem_145 = split_15[7]; split_15 = None + cat_7 = torch.ops.aten.cat.default([getitem_138, getitem_139, getitem_140, getitem_141, getitem_142, getitem_143, getitem_144, getitem_145], 1); getitem_138 = getitem_139 = getitem_140 = getitem_141 = getitem_142 = getitem_143 = getitem_144 = getitem_145 = None + view_132 = torch.ops.aten.view.default(cat_7, [16384, 4096]); cat_7 = None + view_133 = torch.ops.aten.view.default(mm_11, [2, 8192, 1792]); mm_11 = None + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_133, torch.float32); view_133 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16); primals_20 = None + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 8, '0'); convert_element_type_61 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + mm_12 = torch.ops.aten.mm.default(view_132, permute_20) + view_140 = torch.ops.aten.view.default(mm_12, [2, 8192, 1792]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_140) + view_147 = torch.ops.aten.view.default(mul_15, [16384, 1792]); mul_15 = None + mm_311 = torch.ops.aten.mm.default(permute_629, view_147); permute_629 = view_147 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16); primals_21 = None + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 8, '0'); convert_element_type_64 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + permute_631 = torch.ops.aten.permute.default(permute_21, [1, 0]); permute_21 = None + mm_312 = torch.ops.aten.mm.default(view_1507, permute_631); view_1507 = permute_631 = None + view_1508 = torch.ops.aten.view.default(mm_312, [2, 8192, 1792]); mm_312 = None + convert_element_type_1306 = torch.ops.prims.convert_element_type.default(mm_311, torch.float32); mm_311 = None + reduce_scatter_tensor_190 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1306, 'avg', 8, '0'); convert_element_type_1306 = None + wait_tensor_428 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_190); reduce_scatter_tensor_190 = None + mul_416 = torch.ops.aten.mul.Tensor(view_1508, convert_element_type_60); convert_element_type_60 = None + mul_417 = torch.ops.aten.mul.Tensor(view_1508, view_140); view_1508 = view_140 = None + view_1509 = torch.ops.aten.view.default(mul_416, [16384, 1792]); mul_416 = None + permute_633 = torch.ops.aten.permute.default(view_1509, [1, 0]) + mm_313 = torch.ops.aten.mm.default(permute_633, view_132); permute_633 = None + permute_635 = torch.ops.aten.permute.default(permute_20, [1, 0]); permute_20 = None + mm_314 = torch.ops.aten.mm.default(view_1509, permute_635); view_1509 = permute_635 = None + view_1510 = torch.ops.aten.view.default(mm_314, [2, 8192, 4096]); mm_314 = None + convert_element_type_1311 = torch.ops.prims.convert_element_type.default(mm_313, torch.float32); mm_313 = None + reduce_scatter_tensor_191 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1311, 'avg', 8, '0'); convert_element_type_1311 = None + wait_tensor_429 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_191); reduce_scatter_tensor_191 = None + convert_element_type_1312 = torch.ops.prims.convert_element_type.default(mul_417, torch.float32); mul_417 = None + neg_14 = torch.ops.aten.neg.default(convert_element_type_59) + exp_14 = torch.ops.aten.exp.default(neg_14); neg_14 = None + add_163 = torch.ops.aten.add.Tensor(exp_14, 1); exp_14 = None + reciprocal_14 = torch.ops.aten.reciprocal.default(add_163); add_163 = None + mul_418 = torch.ops.aten.mul.Tensor(reciprocal_14, 1); reciprocal_14 = None + mul_419 = torch.ops.aten.mul.Tensor(convert_element_type_1312, mul_418); convert_element_type_1312 = None + sub_44 = torch.ops.aten.sub.Tensor(1, mul_418); mul_418 = None + mul_420 = torch.ops.aten.mul.Tensor(convert_element_type_59, sub_44); convert_element_type_59 = sub_44 = None + add_164 = torch.ops.aten.add.Tensor(mul_420, 1); mul_420 = None + mul_421 = torch.ops.aten.mul.Tensor(mul_419, add_164); mul_419 = add_164 = None + convert_element_type_1314 = torch.ops.prims.convert_element_type.default(mul_421, torch.bfloat16); mul_421 = None + view_1511 = torch.ops.aten.view.default(convert_element_type_1314, [16384, 1792]); convert_element_type_1314 = None + permute_637 = torch.ops.aten.permute.default(view_1511, [1, 0]) + mm_315 = torch.ops.aten.mm.default(permute_637, view_132); permute_637 = view_132 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16); primals_19 = None + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 8, '0'); convert_element_type_56 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_24, [1, 0]); wait_tensor_24 = None + permute_639 = torch.ops.aten.permute.default(permute_19, [1, 0]); permute_19 = None + mm_316 = torch.ops.aten.mm.default(view_1511, permute_639); view_1511 = permute_639 = None + view_1512 = torch.ops.aten.view.default(mm_316, [2, 8192, 4096]); mm_316 = None + add_165 = torch.ops.aten.add.Tensor(view_1510, view_1512); view_1510 = view_1512 = None + convert_element_type_1319 = torch.ops.prims.convert_element_type.default(mm_315, torch.float32); mm_315 = None + reduce_scatter_tensor_192 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1319, 'avg', 8, '0'); convert_element_type_1319 = None + wait_tensor_430 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_192); reduce_scatter_tensor_192 = None + split_132 = torch.ops.aten.split.Tensor(add_165, 1024, 1); add_165 = None + getitem_1242 = split_132[0] + getitem_1243 = split_132[1] + getitem_1244 = split_132[2] + getitem_1245 = split_132[3] + getitem_1246 = split_132[4] + getitem_1247 = split_132[5] + getitem_1248 = split_132[6] + getitem_1249 = split_132[7]; split_132 = None + cat_124 = torch.ops.aten.cat.default([getitem_1242, getitem_1243, getitem_1244, getitem_1245, getitem_1246, getitem_1247, getitem_1248, getitem_1249]); getitem_1242 = getitem_1243 = getitem_1244 = getitem_1245 = getitem_1246 = getitem_1247 = getitem_1248 = getitem_1249 = None + reduce_scatter_tensor_193 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_124, 'sum', 8, '1'); cat_124 = None + wait_tensor_431 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_193); reduce_scatter_tensor_193 = None + convert_element_type_1320 = torch.ops.prims.convert_element_type.default(wait_tensor_431, torch.float32); wait_tensor_431 = None + convert_element_type_1322 = torch.ops.prims.convert_element_type.default(wait_tensor_22, torch.float32); wait_tensor_22 = None + mul_422 = torch.ops.aten.mul.Tensor(convert_element_type_1320, convert_element_type_1322); convert_element_type_1322 = None + mul_424 = torch.ops.aten.mul.Tensor(mul_12, mul_422) + sum_87 = torch.ops.aten.sum.dim_IntList(mul_424, [2], True); mul_424 = None + div_29 = torch.ops.aten.div.Tensor(mul_12, 4096) + mul_425 = torch.ops.aten.mul.Tensor(div_29, sum_87); div_29 = sum_87 = None + sub_45 = torch.ops.aten.sub.Tensor(mul_422, mul_425); mul_422 = mul_425 = None + mul_426 = torch.ops.aten.mul.Tensor(sub_45, rsqrt_3); sub_45 = rsqrt_3 = None + mul_427 = torch.ops.aten.mul.Tensor(convert_element_type_1320, mul_12); convert_element_type_1320 = mul_12 = None + sum_88 = torch.ops.aten.sum.dim_IntList(mul_427, [0, 1]); mul_427 = None + convert_element_type_1323 = torch.ops.prims.convert_element_type.default(mul_426, torch.bfloat16); mul_426 = None + convert_element_type_1324 = torch.ops.prims.convert_element_type.default(sum_88, torch.bfloat16); sum_88 = None + all_reduce_29 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1324, 'sum', '1'); convert_element_type_1324 = None + wait_tensor_432 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_29); all_reduce_29 = None + convert_element_type_1325 = torch.ops.prims.convert_element_type.default(wait_tensor_432, torch.float32); wait_tensor_432 = None + reduce_scatter_tensor_194 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1325, 'avg', 8, '0'); convert_element_type_1325 = None + wait_tensor_433 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_194); reduce_scatter_tensor_194 = None + add_166 = torch.ops.aten.add.Tensor(add_162, convert_element_type_1323); add_162 = convert_element_type_1323 = None + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_166, 8, '1') + wait_tensor_434 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + split_133 = torch.ops.aten.split.Tensor(wait_tensor_434, 2); wait_tensor_434 = None + getitem_1250 = split_133[0] + getitem_1251 = split_133[1] + getitem_1252 = split_133[2] + getitem_1253 = split_133[3] + getitem_1254 = split_133[4] + getitem_1255 = split_133[5] + getitem_1256 = split_133[6] + getitem_1257 = split_133[7]; split_133 = None + cat_125 = torch.ops.aten.cat.default([getitem_1250, getitem_1251, getitem_1252, getitem_1253, getitem_1254, getitem_1255, getitem_1256, getitem_1257], 1); getitem_1250 = getitem_1251 = getitem_1252 = getitem_1253 = getitem_1254 = getitem_1255 = getitem_1256 = getitem_1257 = None + view_1513 = torch.ops.aten.view.default(cat_125, [16384, 4096]); cat_125 = None + permute_641 = torch.ops.aten.permute.default(view_1513, [1, 0]) + permute_17 = torch.ops.aten.permute.default(getitem_121, [0, 2, 1, 3]) + view_114 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + view_120 = torch.ops.aten.view.default(view_114, [16384, 512]); view_114 = None + mm_317 = torch.ops.aten.mm.default(permute_641, view_120); permute_641 = view_120 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16); primals_17 = None + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 8, '0'); convert_element_type_50 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + permute_643 = torch.ops.aten.permute.default(permute_18, [1, 0]); permute_18 = None + mm_318 = torch.ops.aten.mm.default(view_1513, permute_643); view_1513 = permute_643 = None + view_1514 = torch.ops.aten.view.default(mm_318, [2, 8192, 512]); mm_318 = None + convert_element_type_1330 = torch.ops.prims.convert_element_type.default(mm_317, torch.float32); mm_317 = None + reduce_scatter_tensor_195 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1330, 'avg', 8, '0'); convert_element_type_1330 = None + wait_tensor_435 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_195); reduce_scatter_tensor_195 = None + view_1515 = torch.ops.aten.view.default(view_1514, [2, 8192, 4, 128]); view_1514 = None + permute_645 = torch.ops.aten.permute.default(view_1515, [0, 2, 1, 3]); view_1515 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16); primals_13 = None + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 8, '0'); convert_element_type_34 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32); add_3 = None + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_15) + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_36, 8, '1'); convert_element_type_36 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + split_13 = torch.ops.aten.split.Tensor(wait_tensor_16, 2); wait_tensor_16 = None + getitem_113 = split_13[0] + getitem_114 = split_13[1] + getitem_115 = split_13[2] + getitem_116 = split_13[3] + getitem_117 = split_13[4] + getitem_118 = split_13[5] + getitem_119 = split_13[6] + getitem_120 = split_13[7]; split_13 = None + cat_5 = torch.ops.aten.cat.default([getitem_113, getitem_114, getitem_115, getitem_116, getitem_117, getitem_118, getitem_119, getitem_120], 1); getitem_113 = getitem_114 = getitem_115 = getitem_116 = getitem_117 = getitem_118 = getitem_119 = getitem_120 = None + view_87 = torch.ops.aten.view.default(cat_5, [16384, 4096]); cat_5 = None + view_88 = torch.ops.aten.view.default(mm_7, [2, 8192, 512]); mm_7 = None + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16); primals_15 = None + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 8, '0'); convert_element_type_40 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + mm_8 = torch.ops.aten.mm.default(view_87, permute_12) + view_95 = torch.ops.aten.view.default(mm_8, [2, 8192, 128]); mm_8 = None + view_102 = torch.ops.aten.view.default(mm_9, [2, 8192, 128]); mm_9 = None + view_104 = torch.ops.aten.view.default(view_88, [2, 8192, -1, 128]); view_88 = None + view_105 = torch.ops.aten.view.default(view_95, [2, 8192, -1, 128]); view_95 = None + view_106 = torch.ops.aten.view.default(view_102, [2, 8192, -1, 128]); view_102 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_104, torch.float32); view_104 = None + view_107 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 4, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_107); view_107 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_105, torch.float32); view_105 = None + view_108 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 1, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_108); view_108 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_37); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_110 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 4, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_37); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_111 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 1, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_110, torch.bfloat16); view_110 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_111, torch.bfloat16); view_111 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 1, 4, 128]); unsqueeze_2 = None + view_112 = torch.ops.aten.view.default(expand_2, [2, 8192, 4, 128]); expand_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_106, 3); view_106 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 1, 4, 128]); unsqueeze_3 = None + view_113 = torch.ops.aten.view.default(expand_3, [2, 8192, 4, 128]); expand_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_112, [0, 2, 1, 3]); view_112 = None + permute_16 = torch.ops.aten.permute.default(view_113, [0, 2, 1, 3]); view_113 = None + _scaled_dot_product_cudnn_attention_backward_14 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_645, permute_14, permute_15, permute_16, getitem_121, getitem_122, getitem_127, getitem_128, None, None, None, 8192, 8192, 0.0, True); permute_645 = permute_14 = permute_15 = permute_16 = getitem_121 = getitem_122 = getitem_127 = getitem_128 = None + getitem_1258 = _scaled_dot_product_cudnn_attention_backward_14[0] + getitem_1259 = _scaled_dot_product_cudnn_attention_backward_14[1] + getitem_1260 = _scaled_dot_product_cudnn_attention_backward_14[2]; _scaled_dot_product_cudnn_attention_backward_14 = None + permute_646 = torch.ops.aten.permute.default(getitem_1260, [0, 2, 1, 3]); getitem_1260 = None + permute_647 = torch.ops.aten.permute.default(getitem_1259, [0, 2, 1, 3]); getitem_1259 = None + permute_648 = torch.ops.aten.permute.default(getitem_1258, [0, 2, 1, 3]); getitem_1258 = None + view_1516 = torch.ops.aten.view.default(permute_646, [2, 8192, 1, 4, 128]); permute_646 = None + sum_89 = torch.ops.aten.sum.dim_IntList(view_1516, [3], True); view_1516 = None + squeeze_28 = torch.ops.aten.squeeze.dim(sum_89, 3); sum_89 = None + view_1517 = torch.ops.aten.view.default(permute_647, [2, 8192, 1, 4, 128]); permute_647 = None + sum_90 = torch.ops.aten.sum.dim_IntList(view_1517, [3], True); view_1517 = None + squeeze_29 = torch.ops.aten.squeeze.dim(sum_90, 3); sum_90 = None + convert_element_type_1331 = torch.ops.prims.convert_element_type.default(squeeze_29, torch.float32); squeeze_29 = None + convert_element_type_1332 = torch.ops.prims.convert_element_type.default(permute_648, torch.float32); permute_648 = None + view_1518 = torch.ops.aten.view.default(convert_element_type_1331, [2, 8192, 1, 64, 2]); convert_element_type_1331 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_1518); view_1518 = None + mul_428 = torch.ops.aten.mul.Tensor(view_as_complex_60, _conj); view_as_complex_60 = None + view_1519 = torch.ops.aten.view.default(convert_element_type_1332, [2, 8192, 4, 64, 2]); convert_element_type_1332 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_1519); view_1519 = None + mul_429 = torch.ops.aten.mul.Tensor(view_as_complex_61, _conj); view_as_complex_61 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_428); mul_428 = None + view_1520 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 1, 128]); view_as_real_60 = None + convert_element_type_1333 = torch.ops.prims.convert_element_type.default(view_1520, torch.bfloat16); view_1520 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_429); mul_429 = None + view_1521 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 4, 128]); view_as_real_61 = None + convert_element_type_1334 = torch.ops.prims.convert_element_type.default(view_1521, torch.bfloat16); view_1521 = None + view_1522 = torch.ops.aten.view.default(squeeze_28, [2, 8192, 128]); squeeze_28 = None + view_1523 = torch.ops.aten.view.default(convert_element_type_1333, [2, 8192, 128]); convert_element_type_1333 = None + view_1524 = torch.ops.aten.view.default(convert_element_type_1334, [2, 8192, 512]); convert_element_type_1334 = None + view_1525 = torch.ops.aten.view.default(view_1522, [16384, 128]); view_1522 = None + permute_649 = torch.ops.aten.permute.default(view_1525, [1, 0]) + mm_319 = torch.ops.aten.mm.default(permute_649, view_87); permute_649 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16); primals_16 = None + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 8, '0'); convert_element_type_43 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_19, [1, 0]); wait_tensor_19 = None + permute_651 = torch.ops.aten.permute.default(permute_13, [1, 0]); permute_13 = None + mm_320 = torch.ops.aten.mm.default(view_1525, permute_651); view_1525 = permute_651 = None + view_1526 = torch.ops.aten.view.default(mm_320, [2, 8192, 4096]); mm_320 = None + convert_element_type_1339 = torch.ops.prims.convert_element_type.default(mm_319, torch.float32); mm_319 = None + reduce_scatter_tensor_196 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1339, 'avg', 8, '0'); convert_element_type_1339 = None + wait_tensor_436 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_196); reduce_scatter_tensor_196 = None + view_1527 = torch.ops.aten.view.default(view_1523, [16384, 128]); view_1523 = None + permute_653 = torch.ops.aten.permute.default(view_1527, [1, 0]) + mm_321 = torch.ops.aten.mm.default(permute_653, view_87); permute_653 = None + permute_655 = torch.ops.aten.permute.default(permute_12, [1, 0]); permute_12 = None + mm_322 = torch.ops.aten.mm.default(view_1527, permute_655); view_1527 = permute_655 = None + view_1528 = torch.ops.aten.view.default(mm_322, [2, 8192, 4096]); mm_322 = None + add_167 = torch.ops.aten.add.Tensor(view_1526, view_1528); view_1526 = view_1528 = None + convert_element_type_1344 = torch.ops.prims.convert_element_type.default(mm_321, torch.float32); mm_321 = None + reduce_scatter_tensor_197 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1344, 'avg', 8, '0'); convert_element_type_1344 = None + wait_tensor_437 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_197); reduce_scatter_tensor_197 = None + view_1529 = torch.ops.aten.view.default(view_1524, [16384, 512]); view_1524 = None + permute_657 = torch.ops.aten.permute.default(view_1529, [1, 0]) + mm_323 = torch.ops.aten.mm.default(permute_657, view_87); permute_657 = view_87 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16); primals_14 = None + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 8, '0'); convert_element_type_37 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + permute_659 = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None + mm_324 = torch.ops.aten.mm.default(view_1529, permute_659); view_1529 = permute_659 = None + view_1530 = torch.ops.aten.view.default(mm_324, [2, 8192, 4096]); mm_324 = None + add_168 = torch.ops.aten.add.Tensor(add_167, view_1530); add_167 = view_1530 = None + convert_element_type_1349 = torch.ops.prims.convert_element_type.default(mm_323, torch.float32); mm_323 = None + reduce_scatter_tensor_198 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1349, 'avg', 8, '0'); convert_element_type_1349 = None + wait_tensor_438 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_198); reduce_scatter_tensor_198 = None + split_134 = torch.ops.aten.split.Tensor(add_168, 1024, 1); add_168 = None + getitem_1261 = split_134[0] + getitem_1262 = split_134[1] + getitem_1263 = split_134[2] + getitem_1264 = split_134[3] + getitem_1265 = split_134[4] + getitem_1266 = split_134[5] + getitem_1267 = split_134[6] + getitem_1268 = split_134[7]; split_134 = None + cat_126 = torch.ops.aten.cat.default([getitem_1261, getitem_1262, getitem_1263, getitem_1264, getitem_1265, getitem_1266, getitem_1267, getitem_1268]); getitem_1261 = getitem_1262 = getitem_1263 = getitem_1264 = getitem_1265 = getitem_1266 = getitem_1267 = getitem_1268 = None + reduce_scatter_tensor_199 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_126, 'sum', 8, '1'); cat_126 = None + wait_tensor_439 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_199); reduce_scatter_tensor_199 = None + convert_element_type_1350 = torch.ops.prims.convert_element_type.default(wait_tensor_439, torch.float32); wait_tensor_439 = None + convert_element_type_1352 = torch.ops.prims.convert_element_type.default(wait_tensor_15, torch.float32); wait_tensor_15 = None + mul_430 = torch.ops.aten.mul.Tensor(convert_element_type_1350, convert_element_type_1352); convert_element_type_1352 = None + mul_432 = torch.ops.aten.mul.Tensor(mul_8, mul_430) + sum_91 = torch.ops.aten.sum.dim_IntList(mul_432, [2], True); mul_432 = None + div_30 = torch.ops.aten.div.Tensor(mul_8, 4096) + mul_433 = torch.ops.aten.mul.Tensor(div_30, sum_91); div_30 = sum_91 = None + sub_46 = torch.ops.aten.sub.Tensor(mul_430, mul_433); mul_430 = mul_433 = None + mul_434 = torch.ops.aten.mul.Tensor(sub_46, rsqrt_2); sub_46 = rsqrt_2 = None + mul_435 = torch.ops.aten.mul.Tensor(convert_element_type_1350, mul_8); convert_element_type_1350 = mul_8 = None + sum_92 = torch.ops.aten.sum.dim_IntList(mul_435, [0, 1]); mul_435 = None + convert_element_type_1353 = torch.ops.prims.convert_element_type.default(mul_434, torch.bfloat16); mul_434 = None + convert_element_type_1354 = torch.ops.prims.convert_element_type.default(sum_92, torch.bfloat16); sum_92 = None + all_reduce_30 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1354, 'sum', '1'); convert_element_type_1354 = None + wait_tensor_440 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_30); all_reduce_30 = None + convert_element_type_1355 = torch.ops.prims.convert_element_type.default(wait_tensor_440, torch.float32); wait_tensor_440 = None + reduce_scatter_tensor_200 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1355, 'avg', 8, '0'); convert_element_type_1355 = None + wait_tensor_441 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_200); reduce_scatter_tensor_200 = None + add_169 = torch.ops.aten.add.Tensor(add_166, convert_element_type_1353); add_166 = convert_element_type_1353 = None + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_169, 8, '1') + wait_tensor_442 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + split_135 = torch.ops.aten.split.Tensor(wait_tensor_442, 2); wait_tensor_442 = None + getitem_1269 = split_135[0] + getitem_1270 = split_135[1] + getitem_1271 = split_135[2] + getitem_1272 = split_135[3] + getitem_1273 = split_135[4] + getitem_1274 = split_135[5] + getitem_1275 = split_135[6] + getitem_1276 = split_135[7]; split_135 = None + cat_127 = torch.ops.aten.cat.default([getitem_1269, getitem_1270, getitem_1271, getitem_1272, getitem_1273, getitem_1274, getitem_1275, getitem_1276], 1); getitem_1269 = getitem_1270 = getitem_1271 = getitem_1272 = getitem_1273 = getitem_1274 = getitem_1275 = getitem_1276 = None + view_1531 = torch.ops.aten.view.default(cat_127, [16384, 4096]); cat_127 = None + permute_661 = torch.ops.aten.permute.default(view_1531, [1, 0]) + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None + add_1 = torch.ops.aten.add.Tensor(wait_tensor_1, wait_tensor_8); wait_tensor_8 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16); primals_9 = None + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 8, '0'); convert_element_type_20 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32); add_1 = None + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_9) + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_22, 8, '1'); convert_element_type_22 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + split_11 = torch.ops.aten.split.Tensor(wait_tensor_10, 2); wait_tensor_10 = None + getitem_97 = split_11[0] + getitem_98 = split_11[1] + getitem_99 = split_11[2] + getitem_100 = split_11[3] + getitem_101 = split_11[4] + getitem_102 = split_11[5] + getitem_103 = split_11[6] + getitem_104 = split_11[7]; split_11 = None + cat_3 = torch.ops.aten.cat.default([getitem_97, getitem_98, getitem_99, getitem_100, getitem_101, getitem_102, getitem_103, getitem_104], 1); getitem_97 = getitem_98 = getitem_99 = getitem_100 = getitem_101 = getitem_102 = getitem_103 = getitem_104 = None + view_60 = torch.ops.aten.view.default(cat_3, [16384, 4096]); cat_3 = None + view_61 = torch.ops.aten.view.default(mm_4, [2, 8192, 1792]); mm_4 = None + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_61, torch.float32); view_61 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16); primals_11 = None + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 8, '0'); convert_element_type_28 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_5 = torch.ops.aten.mm.default(view_60, permute_9) + view_68 = torch.ops.aten.view.default(mm_5, [2, 8192, 1792]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_68) + view_75 = torch.ops.aten.view.default(mul_7, [16384, 1792]); mul_7 = None + mm_325 = torch.ops.aten.mm.default(permute_661, view_75); permute_661 = view_75 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16); primals_12 = None + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 8, '0'); convert_element_type_31 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + permute_663 = torch.ops.aten.permute.default(permute_10, [1, 0]); permute_10 = None + mm_326 = torch.ops.aten.mm.default(view_1531, permute_663); view_1531 = permute_663 = None + view_1532 = torch.ops.aten.view.default(mm_326, [2, 8192, 1792]); mm_326 = None + convert_element_type_1360 = torch.ops.prims.convert_element_type.default(mm_325, torch.float32); mm_325 = None + reduce_scatter_tensor_201 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1360, 'avg', 8, '0'); convert_element_type_1360 = None + wait_tensor_443 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_201); reduce_scatter_tensor_201 = None + mul_436 = torch.ops.aten.mul.Tensor(view_1532, convert_element_type_27); convert_element_type_27 = None + mul_437 = torch.ops.aten.mul.Tensor(view_1532, view_68); view_1532 = view_68 = None + view_1533 = torch.ops.aten.view.default(mul_436, [16384, 1792]); mul_436 = None + permute_665 = torch.ops.aten.permute.default(view_1533, [1, 0]) + mm_327 = torch.ops.aten.mm.default(permute_665, view_60); permute_665 = None + permute_667 = torch.ops.aten.permute.default(permute_9, [1, 0]); permute_9 = None + mm_328 = torch.ops.aten.mm.default(view_1533, permute_667); view_1533 = permute_667 = None + view_1534 = torch.ops.aten.view.default(mm_328, [2, 8192, 4096]); mm_328 = None + convert_element_type_1365 = torch.ops.prims.convert_element_type.default(mm_327, torch.float32); mm_327 = None + reduce_scatter_tensor_202 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1365, 'avg', 8, '0'); convert_element_type_1365 = None + wait_tensor_444 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_202); reduce_scatter_tensor_202 = None + convert_element_type_1366 = torch.ops.prims.convert_element_type.default(mul_437, torch.float32); mul_437 = None + neg_15 = torch.ops.aten.neg.default(convert_element_type_26) + exp_15 = torch.ops.aten.exp.default(neg_15); neg_15 = None + add_170 = torch.ops.aten.add.Tensor(exp_15, 1); exp_15 = None + reciprocal_15 = torch.ops.aten.reciprocal.default(add_170); add_170 = None + mul_438 = torch.ops.aten.mul.Tensor(reciprocal_15, 1); reciprocal_15 = None + mul_439 = torch.ops.aten.mul.Tensor(convert_element_type_1366, mul_438); convert_element_type_1366 = None + sub_47 = torch.ops.aten.sub.Tensor(1, mul_438); mul_438 = None + mul_440 = torch.ops.aten.mul.Tensor(convert_element_type_26, sub_47); convert_element_type_26 = sub_47 = None + add_171 = torch.ops.aten.add.Tensor(mul_440, 1); mul_440 = None + mul_441 = torch.ops.aten.mul.Tensor(mul_439, add_171); mul_439 = add_171 = None + convert_element_type_1368 = torch.ops.prims.convert_element_type.default(mul_441, torch.bfloat16); mul_441 = None + view_1535 = torch.ops.aten.view.default(convert_element_type_1368, [16384, 1792]); convert_element_type_1368 = None + permute_669 = torch.ops.aten.permute.default(view_1535, [1, 0]) + mm_329 = torch.ops.aten.mm.default(permute_669, view_60); permute_669 = view_60 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16); primals_10 = None + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 8, '0'); convert_element_type_23 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + permute_671 = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None + mm_330 = torch.ops.aten.mm.default(view_1535, permute_671); view_1535 = permute_671 = None + view_1536 = torch.ops.aten.view.default(mm_330, [2, 8192, 4096]); mm_330 = None + add_172 = torch.ops.aten.add.Tensor(view_1534, view_1536); view_1534 = view_1536 = None + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(mm_329, torch.float32); mm_329 = None + reduce_scatter_tensor_203 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1373, 'avg', 8, '0'); convert_element_type_1373 = None + wait_tensor_445 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_203); reduce_scatter_tensor_203 = None + split_136 = torch.ops.aten.split.Tensor(add_172, 1024, 1); add_172 = None + getitem_1277 = split_136[0] + getitem_1278 = split_136[1] + getitem_1279 = split_136[2] + getitem_1280 = split_136[3] + getitem_1281 = split_136[4] + getitem_1282 = split_136[5] + getitem_1283 = split_136[6] + getitem_1284 = split_136[7]; split_136 = None + cat_128 = torch.ops.aten.cat.default([getitem_1277, getitem_1278, getitem_1279, getitem_1280, getitem_1281, getitem_1282, getitem_1283, getitem_1284]); getitem_1277 = getitem_1278 = getitem_1279 = getitem_1280 = getitem_1281 = getitem_1282 = getitem_1283 = getitem_1284 = None + reduce_scatter_tensor_204 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_128, 'sum', 8, '1'); cat_128 = None + wait_tensor_446 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_204); reduce_scatter_tensor_204 = None + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(wait_tensor_446, torch.float32); wait_tensor_446 = None + convert_element_type_1376 = torch.ops.prims.convert_element_type.default(wait_tensor_9, torch.float32); wait_tensor_9 = None + mul_442 = torch.ops.aten.mul.Tensor(convert_element_type_1374, convert_element_type_1376); convert_element_type_1376 = None + mul_444 = torch.ops.aten.mul.Tensor(mul_4, mul_442) + sum_93 = torch.ops.aten.sum.dim_IntList(mul_444, [2], True); mul_444 = None + div_31 = torch.ops.aten.div.Tensor(mul_4, 4096) + mul_445 = torch.ops.aten.mul.Tensor(div_31, sum_93); div_31 = sum_93 = None + sub_48 = torch.ops.aten.sub.Tensor(mul_442, mul_445); mul_442 = mul_445 = None + mul_446 = torch.ops.aten.mul.Tensor(sub_48, rsqrt_1); sub_48 = rsqrt_1 = None + mul_447 = torch.ops.aten.mul.Tensor(convert_element_type_1374, mul_4); convert_element_type_1374 = mul_4 = None + sum_94 = torch.ops.aten.sum.dim_IntList(mul_447, [0, 1]); mul_447 = None + convert_element_type_1377 = torch.ops.prims.convert_element_type.default(mul_446, torch.bfloat16); mul_446 = None + convert_element_type_1378 = torch.ops.prims.convert_element_type.default(sum_94, torch.bfloat16); sum_94 = None + all_reduce_31 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1378, 'sum', '1'); convert_element_type_1378 = None + wait_tensor_447 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_31); all_reduce_31 = None + convert_element_type_1379 = torch.ops.prims.convert_element_type.default(wait_tensor_447, torch.float32); wait_tensor_447 = None + reduce_scatter_tensor_205 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1379, 'avg', 8, '0'); convert_element_type_1379 = None + wait_tensor_448 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_205); reduce_scatter_tensor_205 = None + add_173 = torch.ops.aten.add.Tensor(add_169, convert_element_type_1377); add_169 = convert_element_type_1377 = None + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_173, 8, '1') + wait_tensor_449 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + split_137 = torch.ops.aten.split.Tensor(wait_tensor_449, 2); wait_tensor_449 = None + getitem_1285 = split_137[0] + getitem_1286 = split_137[1] + getitem_1287 = split_137[2] + getitem_1288 = split_137[3] + getitem_1289 = split_137[4] + getitem_1290 = split_137[5] + getitem_1291 = split_137[6] + getitem_1292 = split_137[7]; split_137 = None + cat_129 = torch.ops.aten.cat.default([getitem_1285, getitem_1286, getitem_1287, getitem_1288, getitem_1289, getitem_1290, getitem_1291, getitem_1292], 1); getitem_1285 = getitem_1286 = getitem_1287 = getitem_1288 = getitem_1289 = getitem_1290 = getitem_1291 = getitem_1292 = None + view_1537 = torch.ops.aten.view.default(cat_129, [16384, 4096]); cat_129 = None + permute_673 = torch.ops.aten.permute.default(view_1537, [1, 0]) + permute_6 = torch.ops.aten.permute.default(getitem_80, [0, 2, 1, 3]) + view_42 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + view_48 = torch.ops.aten.view.default(view_42, [16384, 512]); view_42 = None + mm_331 = torch.ops.aten.mm.default(permute_673, view_48); permute_673 = view_48 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16); primals_8 = None + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 8, '0'); convert_element_type_17 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + permute_675 = torch.ops.aten.permute.default(permute_7, [1, 0]); permute_7 = None + mm_332 = torch.ops.aten.mm.default(view_1537, permute_675); view_1537 = permute_675 = None + view_1538 = torch.ops.aten.view.default(mm_332, [2, 8192, 512]); mm_332 = None + convert_element_type_1384 = torch.ops.prims.convert_element_type.default(mm_331, torch.float32); mm_331 = None + reduce_scatter_tensor_206 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1384, 'avg', 8, '0'); convert_element_type_1384 = None + wait_tensor_450 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_206); reduce_scatter_tensor_206 = None + view_1539 = torch.ops.aten.view.default(view_1538, [2, 8192, 4, 128]); view_1538 = None + permute_677 = torch.ops.aten.permute.default(view_1539, [0, 2, 1, 3]); view_1539 = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16); primals_4 = None + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 8, '0'); convert_element_type_1 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32); wait_tensor_1 = None + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_2) + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_3, 8, '1'); convert_element_type_3 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + split_9 = torch.ops.aten.split.Tensor(wait_tensor_3, 2); wait_tensor_3 = None + getitem_72 = split_9[0] + getitem_73 = split_9[1] + getitem_74 = split_9[2] + getitem_75 = split_9[3] + getitem_76 = split_9[4] + getitem_77 = split_9[5] + getitem_78 = split_9[6] + getitem_79 = split_9[7]; split_9 = None + cat_1 = torch.ops.aten.cat.default([getitem_72, getitem_73, getitem_74, getitem_75, getitem_76, getitem_77, getitem_78, getitem_79], 1); getitem_72 = getitem_73 = getitem_74 = getitem_75 = getitem_76 = getitem_77 = getitem_78 = getitem_79 = None + view_15 = torch.ops.aten.view.default(cat_1, [16384, 4096]); cat_1 = None + view_16 = torch.ops.aten.view.default(mm, [2, 8192, 512]); mm = None + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16); primals_6 = None + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 8, '0'); convert_element_type_7 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + mm_1 = torch.ops.aten.mm.default(view_15, permute_1) + view_23 = torch.ops.aten.view.default(mm_1, [2, 8192, 128]); mm_1 = None + view_30 = torch.ops.aten.view.default(mm_2, [2, 8192, 128]); mm_2 = None + view_32 = torch.ops.aten.view.default(view_16, [2, 8192, -1, 128]); view_16 = None + view_33 = torch.ops.aten.view.default(view_23, [2, 8192, -1, 128]); view_23 = None + view_34 = torch.ops.aten.view.default(view_30, [2, 8192, -1, 128]); view_30 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_32, torch.float32); view_32 = None + view_35 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 4, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_35); view_35 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_33, torch.float32); view_33 = None + view_36 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 1, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_36); view_36 = None + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_37); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_38 = torch.ops.aten.view.default(view_as_real, [2, 8192, 4, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_37); view_as_complex_1 = view_37 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_39 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 1, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_38, torch.bfloat16); view_38 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_39, torch.bfloat16); view_39 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 1, 4, 128]); unsqueeze = None + view_40 = torch.ops.aten.view.default(expand, [2, 8192, 4, 128]); expand = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_34, 3); view_34 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 1, 4, 128]); unsqueeze_1 = None + view_41 = torch.ops.aten.view.default(expand_1, [2, 8192, 4, 128]); expand_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_40, [0, 2, 1, 3]); view_40 = None + permute_5 = torch.ops.aten.permute.default(view_41, [0, 2, 1, 3]); view_41 = None + _scaled_dot_product_cudnn_attention_backward_15 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_677, permute_3, permute_4, permute_5, getitem_80, getitem_81, getitem_86, getitem_87, None, None, None, 8192, 8192, 0.0, True); permute_677 = permute_3 = permute_4 = permute_5 = getitem_80 = getitem_81 = getitem_86 = getitem_87 = None + getitem_1293 = _scaled_dot_product_cudnn_attention_backward_15[0] + getitem_1294 = _scaled_dot_product_cudnn_attention_backward_15[1] + getitem_1295 = _scaled_dot_product_cudnn_attention_backward_15[2]; _scaled_dot_product_cudnn_attention_backward_15 = None + permute_678 = torch.ops.aten.permute.default(getitem_1295, [0, 2, 1, 3]); getitem_1295 = None + permute_679 = torch.ops.aten.permute.default(getitem_1294, [0, 2, 1, 3]); getitem_1294 = None + permute_680 = torch.ops.aten.permute.default(getitem_1293, [0, 2, 1, 3]); getitem_1293 = None + view_1540 = torch.ops.aten.view.default(permute_678, [2, 8192, 1, 4, 128]); permute_678 = None + sum_95 = torch.ops.aten.sum.dim_IntList(view_1540, [3], True); view_1540 = None + squeeze_30 = torch.ops.aten.squeeze.dim(sum_95, 3); sum_95 = None + view_1541 = torch.ops.aten.view.default(permute_679, [2, 8192, 1, 4, 128]); permute_679 = None + sum_96 = torch.ops.aten.sum.dim_IntList(view_1541, [3], True); view_1541 = None + squeeze_31 = torch.ops.aten.squeeze.dim(sum_96, 3); sum_96 = None + convert_element_type_1385 = torch.ops.prims.convert_element_type.default(squeeze_31, torch.float32); squeeze_31 = None + convert_element_type_1386 = torch.ops.prims.convert_element_type.default(permute_680, torch.float32); permute_680 = None + view_1542 = torch.ops.aten.view.default(convert_element_type_1385, [2, 8192, 1, 64, 2]); convert_element_type_1385 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_1542); view_1542 = None + mul_448 = torch.ops.aten.mul.Tensor(view_as_complex_62, _conj); view_as_complex_62 = None + view_1543 = torch.ops.aten.view.default(convert_element_type_1386, [2, 8192, 4, 64, 2]); convert_element_type_1386 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_1543); view_1543 = None + mul_449 = torch.ops.aten.mul.Tensor(view_as_complex_63, _conj); view_as_complex_63 = _conj = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_448); mul_448 = None + view_1544 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 1, 128]); view_as_real_62 = None + convert_element_type_1387 = torch.ops.prims.convert_element_type.default(view_1544, torch.bfloat16); view_1544 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_449); mul_449 = None + view_1545 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 4, 128]); view_as_real_63 = None + convert_element_type_1388 = torch.ops.prims.convert_element_type.default(view_1545, torch.bfloat16); view_1545 = None + view_1546 = torch.ops.aten.view.default(squeeze_30, [2, 8192, 128]); squeeze_30 = None + view_1547 = torch.ops.aten.view.default(convert_element_type_1387, [2, 8192, 128]); convert_element_type_1387 = None + view_1548 = torch.ops.aten.view.default(convert_element_type_1388, [2, 8192, 512]); convert_element_type_1388 = None + view_1549 = torch.ops.aten.view.default(view_1546, [16384, 128]); view_1546 = None + permute_681 = torch.ops.aten.permute.default(view_1549, [1, 0]) + mm_333 = torch.ops.aten.mm.default(permute_681, view_15); permute_681 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16); primals_7 = None + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 8, '0'); convert_element_type_10 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_6, [1, 0]); wait_tensor_6 = None + permute_683 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + mm_334 = torch.ops.aten.mm.default(view_1549, permute_683); view_1549 = permute_683 = None + view_1550 = torch.ops.aten.view.default(mm_334, [2, 8192, 4096]); mm_334 = None + convert_element_type_1393 = torch.ops.prims.convert_element_type.default(mm_333, torch.float32); mm_333 = None + reduce_scatter_tensor_207 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1393, 'avg', 8, '0'); convert_element_type_1393 = None + wait_tensor_451 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_207); reduce_scatter_tensor_207 = None + view_1551 = torch.ops.aten.view.default(view_1547, [16384, 128]); view_1547 = None + permute_685 = torch.ops.aten.permute.default(view_1551, [1, 0]) + mm_335 = torch.ops.aten.mm.default(permute_685, view_15); permute_685 = None + permute_687 = torch.ops.aten.permute.default(permute_1, [1, 0]); permute_1 = None + mm_336 = torch.ops.aten.mm.default(view_1551, permute_687); view_1551 = permute_687 = None + view_1552 = torch.ops.aten.view.default(mm_336, [2, 8192, 4096]); mm_336 = None + add_174 = torch.ops.aten.add.Tensor(view_1550, view_1552); view_1550 = view_1552 = None + convert_element_type_1398 = torch.ops.prims.convert_element_type.default(mm_335, torch.float32); mm_335 = None + reduce_scatter_tensor_208 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1398, 'avg', 8, '0'); convert_element_type_1398 = None + wait_tensor_452 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_208); reduce_scatter_tensor_208 = None + view_1553 = torch.ops.aten.view.default(view_1548, [16384, 512]); view_1548 = None + permute_689 = torch.ops.aten.permute.default(view_1553, [1, 0]) + mm_337 = torch.ops.aten.mm.default(permute_689, view_15); permute_689 = view_15 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16); primals_5 = None + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 8, '0'); convert_element_type_4 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + permute_691 = torch.ops.aten.permute.default(permute, [1, 0]); permute = None + mm_338 = torch.ops.aten.mm.default(view_1553, permute_691); view_1553 = permute_691 = None + view_1554 = torch.ops.aten.view.default(mm_338, [2, 8192, 4096]); mm_338 = None + add_175 = torch.ops.aten.add.Tensor(add_174, view_1554); add_174 = view_1554 = None + convert_element_type_1403 = torch.ops.prims.convert_element_type.default(mm_337, torch.float32); mm_337 = None + reduce_scatter_tensor_209 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1403, 'avg', 8, '0'); convert_element_type_1403 = None + wait_tensor_453 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_209); reduce_scatter_tensor_209 = None + split_138 = torch.ops.aten.split.Tensor(add_175, 1024, 1); add_175 = None + getitem_1296 = split_138[0] + getitem_1297 = split_138[1] + getitem_1298 = split_138[2] + getitem_1299 = split_138[3] + getitem_1300 = split_138[4] + getitem_1301 = split_138[5] + getitem_1302 = split_138[6] + getitem_1303 = split_138[7]; split_138 = None + cat_130 = torch.ops.aten.cat.default([getitem_1296, getitem_1297, getitem_1298, getitem_1299, getitem_1300, getitem_1301, getitem_1302, getitem_1303]); getitem_1296 = getitem_1297 = getitem_1298 = getitem_1299 = getitem_1300 = getitem_1301 = getitem_1302 = getitem_1303 = None + reduce_scatter_tensor_210 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_130, 'sum', 8, '1'); cat_130 = None + wait_tensor_454 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_210); reduce_scatter_tensor_210 = None + convert_element_type_1404 = torch.ops.prims.convert_element_type.default(wait_tensor_454, torch.float32); wait_tensor_454 = None + convert_element_type_1406 = torch.ops.prims.convert_element_type.default(wait_tensor_2, torch.float32); wait_tensor_2 = None + mul_450 = torch.ops.aten.mul.Tensor(convert_element_type_1404, convert_element_type_1406); convert_element_type_1406 = None + mul_452 = torch.ops.aten.mul.Tensor(mul, mul_450) + sum_97 = torch.ops.aten.sum.dim_IntList(mul_452, [2], True); mul_452 = None + div_32 = torch.ops.aten.div.Tensor(mul, 4096) + mul_453 = torch.ops.aten.mul.Tensor(div_32, sum_97); div_32 = sum_97 = None + sub_49 = torch.ops.aten.sub.Tensor(mul_450, mul_453); mul_450 = mul_453 = None + mul_454 = torch.ops.aten.mul.Tensor(sub_49, rsqrt); sub_49 = rsqrt = None + mul_455 = torch.ops.aten.mul.Tensor(convert_element_type_1404, mul); convert_element_type_1404 = mul = None + sum_98 = torch.ops.aten.sum.dim_IntList(mul_455, [0, 1]); mul_455 = None + convert_element_type_1407 = torch.ops.prims.convert_element_type.default(mul_454, torch.bfloat16); mul_454 = None + convert_element_type_1408 = torch.ops.prims.convert_element_type.default(sum_98, torch.bfloat16); sum_98 = None + all_reduce_32 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1408, 'sum', '1'); convert_element_type_1408 = None + wait_tensor_455 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_32); all_reduce_32 = None + convert_element_type_1409 = torch.ops.prims.convert_element_type.default(wait_tensor_455, torch.float32); wait_tensor_455 = None + reduce_scatter_tensor_211 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1409, 'avg', 8, '0'); convert_element_type_1409 = None + wait_tensor_456 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_211); reduce_scatter_tensor_211 = None + add_176 = torch.ops.aten.add.Tensor(add_173, convert_element_type_1407); add_173 = convert_element_type_1407 = None + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_176, 8, '1'); add_176 = None + wait_tensor_457 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + split_139 = torch.ops.aten.split.Tensor(wait_tensor_457, 2); wait_tensor_457 = None + getitem_1304 = split_139[0] + getitem_1305 = split_139[1] + getitem_1306 = split_139[2] + getitem_1307 = split_139[3] + getitem_1308 = split_139[4] + getitem_1309 = split_139[5] + getitem_1310 = split_139[6] + getitem_1311 = split_139[7]; split_139 = None + cat_131 = torch.ops.aten.cat.default([getitem_1304, getitem_1305, getitem_1306, getitem_1307, getitem_1308, getitem_1309, getitem_1310, getitem_1311], 1); getitem_1304 = getitem_1305 = getitem_1306 = getitem_1307 = getitem_1308 = getitem_1309 = getitem_1310 = getitem_1311 = None + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(cat_131, torch.float32); cat_131 = None + eq = torch.ops.aten.eq.Scalar(primals_1, -1) + unsqueeze_32 = torch.ops.aten.unsqueeze.default(eq, -1); eq = None + full_default_2 = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + where = torch.ops.aten.where.self(unsqueeze_32, full_default_2, convert_element_type_1410); unsqueeze_32 = full_default_2 = convert_element_type_1410 = None + full_default_3 = torch.ops.aten.full.default([128256, 4096], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_2 = torch.ops.aten.index_put.default(full_default_3, [primals_1], where, True); full_default_3 = primals_1 = where = None + convert_element_type_1411 = torch.ops.prims.convert_element_type.default(index_put_2, torch.bfloat16); index_put_2 = None + split_140 = torch.ops.aten.split.Tensor(convert_element_type_1411, 16032); convert_element_type_1411 = None + getitem_1312 = split_140[0]; split_140 = None + convert_element_type_1412 = torch.ops.prims.convert_element_type.default(getitem_1312, torch.float32); getitem_1312 = None + reduce_scatter_tensor_212 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1412, 'avg', 8, '0'); convert_element_type_1412 = None + wait_tensor_458 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_212); reduce_scatter_tensor_212 = None + return (None, wait_tensor_458, None, wait_tensor_456, wait_tensor_453, wait_tensor_452, wait_tensor_451, wait_tensor_450, wait_tensor_448, wait_tensor_445, wait_tensor_444, wait_tensor_443, wait_tensor_441, wait_tensor_438, wait_tensor_437, wait_tensor_436, wait_tensor_435, wait_tensor_433, wait_tensor_430, wait_tensor_429, wait_tensor_428, wait_tensor_426, wait_tensor_423, wait_tensor_422, wait_tensor_421, wait_tensor_420, wait_tensor_418, wait_tensor_415, wait_tensor_414, wait_tensor_413, wait_tensor_411, wait_tensor_408, wait_tensor_407, wait_tensor_406, wait_tensor_405, wait_tensor_403, wait_tensor_400, wait_tensor_399, wait_tensor_398, wait_tensor_396, wait_tensor_393, wait_tensor_392, wait_tensor_391, wait_tensor_390, wait_tensor_388, wait_tensor_385, wait_tensor_384, wait_tensor_383, wait_tensor_381, wait_tensor_378, wait_tensor_377, wait_tensor_376, wait_tensor_375, wait_tensor_373, wait_tensor_370, wait_tensor_369, wait_tensor_368, wait_tensor_366, wait_tensor_363, wait_tensor_362, wait_tensor_361, wait_tensor_360, wait_tensor_358, wait_tensor_355, wait_tensor_354, wait_tensor_353, wait_tensor_351, wait_tensor_348, wait_tensor_347, wait_tensor_346, wait_tensor_345, wait_tensor_343, wait_tensor_340, wait_tensor_339, wait_tensor_338, wait_tensor_336, wait_tensor_333, wait_tensor_332, wait_tensor_331, wait_tensor_330, wait_tensor_328, wait_tensor_325, wait_tensor_324, wait_tensor_323, wait_tensor_321, wait_tensor_318, wait_tensor_317, wait_tensor_316, wait_tensor_315, wait_tensor_313, wait_tensor_310, wait_tensor_309, wait_tensor_308, wait_tensor_306, wait_tensor_303, wait_tensor_302, wait_tensor_301, wait_tensor_300, wait_tensor_298, wait_tensor_295, wait_tensor_294, wait_tensor_293, wait_tensor_291, wait_tensor_288, wait_tensor_287, wait_tensor_286, wait_tensor_285, wait_tensor_283, wait_tensor_280, wait_tensor_279, wait_tensor_278, wait_tensor_276, wait_tensor_273, wait_tensor_272, wait_tensor_271, wait_tensor_270, wait_tensor_268, wait_tensor_265, wait_tensor_264, wait_tensor_263, wait_tensor_261, wait_tensor_258, wait_tensor_257, wait_tensor_256, wait_tensor_255, wait_tensor_253, wait_tensor_250, wait_tensor_249, wait_tensor_248, wait_tensor_246, wait_tensor_243, wait_tensor_242, wait_tensor_241, wait_tensor_240, wait_tensor_238, wait_tensor_235, wait_tensor_234, wait_tensor_233, wait_tensor_231, wait_tensor_228, wait_tensor_227, wait_tensor_226, wait_tensor_225, wait_tensor_223, wait_tensor_220, wait_tensor_219, wait_tensor_218, wait_tensor_216, wait_tensor_213) + +def load_args(reader): + buf0 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf0, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_1 + buf1 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf1, (2004, 4096), is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf3, (512,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf4, (64, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf5, (16, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf6, (16, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf7, (512, 512), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf8, (512,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf9, (224, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf10, (224, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf11, (512, 1792), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf12, (512,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf13, (64, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf14, (16, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf15, (16, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf16, (512, 512), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf17, (512,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf18, (224, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf19, (224, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf20, (512, 1792), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf21, (512,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf22, (64, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf23, (16, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf24, (16, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf25, (512, 512), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf26, (512,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf27, (224, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf28, (224, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf29, (512, 1792), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf30, (512,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf31, (64, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf32, (16, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf33, (16, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf34, (512, 512), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf35, (512,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf36, (224, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf37, (224, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf38, (512, 1792), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf39, (512,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf40, (64, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf41, (16, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf42, (16, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf43, (512, 512), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf44, (512,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf45, (224, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf46, (224, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf47, (512, 1792), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf48, (512,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf49, (64, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf50, (16, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf51, (16, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf52, (512, 512), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf53, (512,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf54, (224, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf55, (224, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf56, (512, 1792), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf57, (512,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf58, (64, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf59, (16, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf60, (16, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf61, (512, 512), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf62, (512,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf63, (224, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf64, (224, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf65, (512, 1792), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf66, (512,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf67, (64, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf68, (16, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf69, (16, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf70, (512, 512), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf71, (512,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf72, (224, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf73, (224, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf74, (512, 1792), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf75, (512,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf76, (64, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf77, (16, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf78, (16, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf79, (512, 512), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf80, (512,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf81, (224, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf82, (224, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf83, (512, 1792), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf84, (512,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf85, (64, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf86, (16, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf87, (16, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf88, (512, 512), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf89, (512,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf90, (224, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf91, (224, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf92, (512, 1792), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf93, (512,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf94, (64, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf95, (16, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf96, (16, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf97, (512, 512), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf98, (512,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf99, (224, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf100, (224, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf101, (512, 1792), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf102, (512,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf103, (64, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf104, (16, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf105, (16, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf106, (512, 512), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf107, (512,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf108, (224, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf109, (224, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf110, (512, 1792), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf111, (512,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf112, (64, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf113, (16, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf114, (16, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf115, (512, 512), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf116, (512,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf117, (224, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf118, (224, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf119, (512, 1792), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf120, (512,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf121, (64, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf122, (16, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf123, (16, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf124, (512, 512), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf125, (512,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf126, (224, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf127, (224, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf128, (512, 1792), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf129, (512,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf130, (64, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf131, (16, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf132, (16, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf133, (512, 512), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf134, (512,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf135, (224, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf136, (224, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf137, (512, 1792), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf138, (512,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf139, (64, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf140, (16, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf141, (16, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf142, (512, 512), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf143, (512,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf144, (224, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf145, (224, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf146, (512, 1792), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf147, (512,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf148, (2004, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf149, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # wait_tensor_1 + buf150 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf150, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm + buf151 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf151, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_2 + buf152 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf152, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_80 + buf153 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf153, (2, 4, 8192, 1), is_leaf=True) # getitem_81 + buf154 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf154, (), dtype=torch.int64, is_leaf=True) # getitem_86 + buf155 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf155, (), dtype=torch.int64, is_leaf=True) # getitem_87 + buf156 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf156, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_1 + buf157 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf157, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_4 + buf158 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf158, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_3 + buf159 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf159, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_7 + buf160 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf160, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_9 + buf161 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf161, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_121 + buf162 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf162, (2, 4, 8192, 1), is_leaf=True) # getitem_122 + buf163 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf163, (), dtype=torch.int64, is_leaf=True) # getitem_127 + buf164 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf164, (), dtype=torch.int64, is_leaf=True) # getitem_128 + buf165 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf165, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_3 + buf166 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf166, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_11 + buf167 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf167, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_7 + buf168 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf168, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_14 + buf169 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf169, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_16 + buf170 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf170, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_162 + buf171 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf171, (2, 4, 8192, 1), is_leaf=True) # getitem_163 + buf172 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf172, (), dtype=torch.int64, is_leaf=True) # getitem_168 + buf173 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf173, (), dtype=torch.int64, is_leaf=True) # getitem_169 + buf174 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf174, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_5 + buf175 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf175, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_18 + buf176 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf176, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_11 + buf177 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf177, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_21 + buf178 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf178, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_23 + buf179 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf179, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_203 + buf180 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf180, (2, 4, 8192, 1), is_leaf=True) # getitem_204 + buf181 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf181, (), dtype=torch.int64, is_leaf=True) # getitem_209 + buf182 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf182, (), dtype=torch.int64, is_leaf=True) # getitem_210 + buf183 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf183, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_7 + buf184 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf184, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_25 + buf185 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf185, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_15 + buf186 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf186, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_28 + buf187 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf187, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_30 + buf188 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf188, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_244 + buf189 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf189, (2, 4, 8192, 1), is_leaf=True) # getitem_245 + buf190 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf190, (), dtype=torch.int64, is_leaf=True) # getitem_250 + buf191 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf191, (), dtype=torch.int64, is_leaf=True) # getitem_251 + buf192 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf192, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_9 + buf193 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf193, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_32 + buf194 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf194, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_19 + buf195 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf195, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_35 + buf196 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf196, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_37 + buf197 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf197, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_285 + buf198 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf198, (2, 4, 8192, 1), is_leaf=True) # getitem_286 + buf199 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf199, (), dtype=torch.int64, is_leaf=True) # getitem_291 + buf200 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf200, (), dtype=torch.int64, is_leaf=True) # getitem_292 + buf201 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf201, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_11 + buf202 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf202, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_39 + buf203 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf203, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_23 + buf204 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf204, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_42 + buf205 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf205, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_44 + buf206 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf206, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_326 + buf207 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf207, (2, 4, 8192, 1), is_leaf=True) # getitem_327 + buf208 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf208, (), dtype=torch.int64, is_leaf=True) # getitem_332 + buf209 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf209, (), dtype=torch.int64, is_leaf=True) # getitem_333 + buf210 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf210, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_13 + buf211 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf211, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_46 + buf212 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf212, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_27 + buf213 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf213, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_49 + buf214 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf214, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_51 + buf215 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf215, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_367 + buf216 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf216, (2, 4, 8192, 1), is_leaf=True) # getitem_368 + buf217 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf217, (), dtype=torch.int64, is_leaf=True) # getitem_373 + buf218 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf218, (), dtype=torch.int64, is_leaf=True) # getitem_374 + buf219 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf219, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_15 + buf220 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf220, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_53 + buf221 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf221, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_31 + buf222 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf222, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_56 + buf223 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf223, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_58 + buf224 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf224, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_408 + buf225 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf225, (2, 4, 8192, 1), is_leaf=True) # getitem_409 + buf226 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf226, (), dtype=torch.int64, is_leaf=True) # getitem_414 + buf227 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf227, (), dtype=torch.int64, is_leaf=True) # getitem_415 + buf228 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf228, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_17 + buf229 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf229, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_60 + buf230 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf230, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_35 + buf231 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf231, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_63 + buf232 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf232, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_65 + buf233 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf233, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_449 + buf234 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf234, (2, 4, 8192, 1), is_leaf=True) # getitem_450 + buf235 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf235, (), dtype=torch.int64, is_leaf=True) # getitem_455 + buf236 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf236, (), dtype=torch.int64, is_leaf=True) # getitem_456 + buf237 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf237, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_19 + buf238 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf238, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_67 + buf239 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf239, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_39 + buf240 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf240, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_70 + buf241 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf241, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_72 + buf242 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf242, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_490 + buf243 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf243, (2, 4, 8192, 1), is_leaf=True) # getitem_491 + buf244 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf244, (), dtype=torch.int64, is_leaf=True) # getitem_496 + buf245 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf245, (), dtype=torch.int64, is_leaf=True) # getitem_497 + buf246 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf246, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_21 + buf247 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf247, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_74 + buf248 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf248, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_43 + buf249 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf249, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_77 + buf250 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf250, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_79 + buf251 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf251, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_531 + buf252 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf252, (2, 4, 8192, 1), is_leaf=True) # getitem_532 + buf253 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf253, (), dtype=torch.int64, is_leaf=True) # getitem_537 + buf254 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf254, (), dtype=torch.int64, is_leaf=True) # getitem_538 + buf255 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf255, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_23 + buf256 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf256, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_81 + buf257 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf257, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_47 + buf258 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf258, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_84 + buf259 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf259, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_86 + buf260 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf260, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_572 + buf261 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf261, (2, 4, 8192, 1), is_leaf=True) # getitem_573 + buf262 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf262, (), dtype=torch.int64, is_leaf=True) # getitem_578 + buf263 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf263, (), dtype=torch.int64, is_leaf=True) # getitem_579 + buf264 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf264, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_25 + buf265 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf265, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_88 + buf266 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf266, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_51 + buf267 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf267, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_91 + buf268 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf268, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_93 + buf269 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf269, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_613 + buf270 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf270, (2, 4, 8192, 1), is_leaf=True) # getitem_614 + buf271 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf271, (), dtype=torch.int64, is_leaf=True) # getitem_619 + buf272 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf272, (), dtype=torch.int64, is_leaf=True) # getitem_620 + buf273 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf273, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_27 + buf274 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf274, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_95 + buf275 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf275, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_55 + buf276 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf276, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_98 + buf277 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf277, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_100 + buf278 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf278, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_654 + buf279 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf279, (2, 4, 8192, 1), is_leaf=True) # getitem_655 + buf280 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf280, (), dtype=torch.int64, is_leaf=True) # getitem_660 + buf281 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf281, (), dtype=torch.int64, is_leaf=True) # getitem_661 + buf282 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf282, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_29 + buf283 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf283, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_102 + buf284 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf284, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_59 + buf285 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf285, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_105 + buf286 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf286, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_107 + buf287 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf287, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_695 + buf288 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf288, (2, 4, 8192, 1), is_leaf=True) # getitem_696 + buf289 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf289, (), dtype=torch.int64, is_leaf=True) # getitem_701 + buf290 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf290, (), dtype=torch.int64, is_leaf=True) # getitem_702 + buf291 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf291, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_31 + buf292 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf292, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_109 + buf293 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf293, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_32 + buf294 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf294, (2, 1024, 1), is_leaf=True) # rsqrt_32 + buf295 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf295, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # view_1167 + buf296 = reader.storage(None, 525336576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf296, (2, 8192, 16032), dtype=torch.bfloat16, is_leaf=True) # tangents_1 + +load_args._version = 0 + +def get_mesh_sizes(): + return 8, 8 + +def get_colls_estimations_file(): + return "colls8_8.table" + +def get_pg_names(): + return "0", "1" diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_1d_32layers.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_1d_32layers.py new file mode 100644 index 00000000..21bc2d7b --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_1d_32layers.py @@ -0,0 +1,4153 @@ +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293): + convert_element_type = torch.ops.prims.convert_element_type.default(primals_1, torch.bfloat16) + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type, 256, '0'); convert_element_type = None + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None + embedding = torch.ops.aten.embedding.default(wait_tensor, primals_2); wait_tensor = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16) + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 256, '0'); convert_element_type_1 = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(embedding, torch.float32) + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = rsqrt = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_1); mul = wait_tensor_1 = None + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16) + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 256, '0'); convert_element_type_4 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + permute = torch.ops.aten.permute.default(wait_tensor_2, [1, 0]); wait_tensor_2 = None + view_3 = torch.ops.aten.view.default(convert_element_type_3, [16384, 4096]); convert_element_type_3 = None + mm = torch.ops.aten.mm.default(view_3, permute); permute = None + view_4 = torch.ops.aten.view.default(mm, [2, 8192, 4096]) + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16) + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 256, '0'); convert_element_type_7 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_3, [1, 0]); wait_tensor_3 = None + mm_1 = torch.ops.aten.mm.default(view_3, permute_1); permute_1 = None + view_7 = torch.ops.aten.view.default(mm_1, [2, 8192, 1024]); mm_1 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16) + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 256, '0'); convert_element_type_10 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + mm_2 = torch.ops.aten.mm.default(view_3, permute_2); view_3 = permute_2 = None + view_10 = torch.ops.aten.view.default(mm_2, [2, 8192, 1024]) + view_11 = torch.ops.aten.view.default(view_4, [2, 8192, -1, 128]); view_4 = None + view_12 = torch.ops.aten.view.default(view_7, [2, 8192, -1, 128]); view_7 = None + view_13 = torch.ops.aten.view.default(view_10, [2, 8192, -1, 128]); view_10 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_11, torch.float32); view_11 = None + view_14 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 32, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_14); view_14 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_12, torch.float32); view_12 = None + view_15 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 8, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_15); view_15 = None + view_16 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]) + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_16); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_17 = torch.ops.aten.view.default(view_as_real, [2, 8192, 32, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_16); view_as_complex_1 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_18 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 8, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_17, torch.bfloat16); view_17 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_18, torch.bfloat16); view_18 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 8, 4, 128]); unsqueeze = None + clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None + view_19 = torch.ops.aten.view.default(clone, [2, 8192, 32, 128]); clone = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_13, 3); view_13 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 8, 4, 128]); unsqueeze_1 = None + clone_1 = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None + view_20 = torch.ops.aten.view.default(clone_1, [2, 8192, 32, 128]); clone_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3]); view_19 = None + permute_5 = torch.ops.aten.permute.default(view_20, [0, 2, 1, 3]); view_20 = None + _scaled_dot_product_cudnn_attention = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_3, permute_4, permute_5, None, True, 0.0, True); permute_3 = permute_4 = permute_5 = None + getitem = _scaled_dot_product_cudnn_attention[0] + getitem_1 = _scaled_dot_product_cudnn_attention[1] + getitem_6 = _scaled_dot_product_cudnn_attention[6] + getitem_7 = _scaled_dot_product_cudnn_attention[7]; _scaled_dot_product_cudnn_attention = None + permute_6 = torch.ops.aten.permute.default(getitem, [0, 2, 1, 3]) + view_21 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16) + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 256, '0'); convert_element_type_17 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + view_23 = torch.ops.aten.view.default(view_21, [16384, 4096]); view_21 = None + mm_3 = torch.ops.aten.mm.default(view_23, permute_7); view_23 = permute_7 = None + view_24 = torch.ops.aten.view.default(mm_3, [2, 8192, 4096]); mm_3 = None + add_1 = torch.ops.aten.add.Tensor(embedding, view_24); view_24 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16) + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 256, '0'); convert_element_type_20 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32) + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = rsqrt_1 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_6); mul_4 = wait_tensor_6 = None + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16) + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 256, '0'); convert_element_type_23 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + view_27 = torch.ops.aten.view.default(convert_element_type_22, [16384, 4096]); convert_element_type_22 = None + mm_4 = torch.ops.aten.mm.default(view_27, permute_8); permute_8 = None + view_28 = torch.ops.aten.view.default(mm_4, [2, 8192, 14336]) + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_28, torch.float32); view_28 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); convert_element_type_26 = sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16) + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 256, '0'); convert_element_type_28 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_8, [1, 0]); wait_tensor_8 = None + mm_5 = torch.ops.aten.mm.default(view_27, permute_9); view_27 = permute_9 = None + view_31 = torch.ops.aten.view.default(mm_5, [2, 8192, 14336]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_31); convert_element_type_27 = view_31 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16) + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 256, '0'); convert_element_type_31 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_9, [1, 0]); wait_tensor_9 = None + view_33 = torch.ops.aten.view.default(mul_7, [16384, 14336]); mul_7 = None + mm_6 = torch.ops.aten.mm.default(view_33, permute_10); view_33 = permute_10 = None + view_34 = torch.ops.aten.view.default(mm_6, [2, 8192, 4096]); mm_6 = None + add_3 = torch.ops.aten.add.Tensor(add_1, view_34); add_1 = view_34 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16) + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 256, '0'); convert_element_type_34 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32) + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = rsqrt_2 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_10); mul_8 = wait_tensor_10 = None + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16) + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 256, '0'); convert_element_type_37 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + view_37 = torch.ops.aten.view.default(convert_element_type_36, [16384, 4096]); convert_element_type_36 = None + mm_7 = torch.ops.aten.mm.default(view_37, permute_11); permute_11 = None + view_38 = torch.ops.aten.view.default(mm_7, [2, 8192, 4096]) + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16) + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 256, '0'); convert_element_type_40 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_8 = torch.ops.aten.mm.default(view_37, permute_12); permute_12 = None + view_41 = torch.ops.aten.view.default(mm_8, [2, 8192, 1024]); mm_8 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16) + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 256, '0'); convert_element_type_43 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + mm_9 = torch.ops.aten.mm.default(view_37, permute_13); view_37 = permute_13 = None + view_44 = torch.ops.aten.view.default(mm_9, [2, 8192, 1024]) + view_45 = torch.ops.aten.view.default(view_38, [2, 8192, -1, 128]); view_38 = None + view_46 = torch.ops.aten.view.default(view_41, [2, 8192, -1, 128]); view_41 = None + view_47 = torch.ops.aten.view.default(view_44, [2, 8192, -1, 128]); view_44 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_45, torch.float32); view_45 = None + view_48 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 32, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_48); view_48 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_46, torch.float32); view_46 = None + view_49 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 8, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_49); view_49 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_16); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_51 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 32, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_16); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_52 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 8, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_51, torch.bfloat16); view_51 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_52, torch.bfloat16); view_52 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 8, 4, 128]); unsqueeze_2 = None + clone_2 = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format); expand_2 = None + view_53 = torch.ops.aten.view.default(clone_2, [2, 8192, 32, 128]); clone_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_47, 3); view_47 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 8, 4, 128]); unsqueeze_3 = None + clone_3 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format); expand_3 = None + view_54 = torch.ops.aten.view.default(clone_3, [2, 8192, 32, 128]); clone_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_53, [0, 2, 1, 3]); view_53 = None + permute_16 = torch.ops.aten.permute.default(view_54, [0, 2, 1, 3]); view_54 = None + _scaled_dot_product_cudnn_attention_1 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_14, permute_15, permute_16, None, True, 0.0, True); permute_14 = permute_15 = permute_16 = None + getitem_9 = _scaled_dot_product_cudnn_attention_1[0] + getitem_10 = _scaled_dot_product_cudnn_attention_1[1] + getitem_15 = _scaled_dot_product_cudnn_attention_1[6] + getitem_16 = _scaled_dot_product_cudnn_attention_1[7]; _scaled_dot_product_cudnn_attention_1 = None + permute_17 = torch.ops.aten.permute.default(getitem_9, [0, 2, 1, 3]) + view_55 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16) + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 256, '0'); convert_element_type_50 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_14, [1, 0]); wait_tensor_14 = None + view_57 = torch.ops.aten.view.default(view_55, [16384, 4096]); view_55 = None + mm_10 = torch.ops.aten.mm.default(view_57, permute_18); view_57 = permute_18 = None + view_58 = torch.ops.aten.view.default(mm_10, [2, 8192, 4096]); mm_10 = None + add_5 = torch.ops.aten.add.Tensor(add_3, view_58); view_58 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16) + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 256, '0'); convert_element_type_53 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32) + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = rsqrt_3 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_15); mul_12 = wait_tensor_15 = None + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16) + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 256, '0'); convert_element_type_56 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_16, [1, 0]); wait_tensor_16 = None + view_61 = torch.ops.aten.view.default(convert_element_type_55, [16384, 4096]); convert_element_type_55 = None + mm_11 = torch.ops.aten.mm.default(view_61, permute_19); permute_19 = None + view_62 = torch.ops.aten.view.default(mm_11, [2, 8192, 14336]) + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_62, torch.float32); view_62 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); convert_element_type_59 = sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16) + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 256, '0'); convert_element_type_61 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + mm_12 = torch.ops.aten.mm.default(view_61, permute_20); view_61 = permute_20 = None + view_65 = torch.ops.aten.view.default(mm_12, [2, 8192, 14336]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_65); convert_element_type_60 = view_65 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16) + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 256, '0'); convert_element_type_64 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + view_67 = torch.ops.aten.view.default(mul_15, [16384, 14336]); mul_15 = None + mm_13 = torch.ops.aten.mm.default(view_67, permute_21); view_67 = permute_21 = None + view_68 = torch.ops.aten.view.default(mm_13, [2, 8192, 4096]); mm_13 = None + add_7 = torch.ops.aten.add.Tensor(add_5, view_68); add_5 = view_68 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16) + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 256, '0'); convert_element_type_67 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32) + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = rsqrt_4 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_19); mul_16 = wait_tensor_19 = None + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16) + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 256, '0'); convert_element_type_70 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + view_71 = torch.ops.aten.view.default(convert_element_type_69, [16384, 4096]); convert_element_type_69 = None + mm_14 = torch.ops.aten.mm.default(view_71, permute_22); permute_22 = None + view_72 = torch.ops.aten.view.default(mm_14, [2, 8192, 4096]) + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16) + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 256, '0'); convert_element_type_73 = None + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_21, [1, 0]); wait_tensor_21 = None + mm_15 = torch.ops.aten.mm.default(view_71, permute_23); permute_23 = None + view_75 = torch.ops.aten.view.default(mm_15, [2, 8192, 1024]); mm_15 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16) + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 256, '0'); convert_element_type_76 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_22, [1, 0]); wait_tensor_22 = None + mm_16 = torch.ops.aten.mm.default(view_71, permute_24); view_71 = permute_24 = None + view_78 = torch.ops.aten.view.default(mm_16, [2, 8192, 1024]) + view_79 = torch.ops.aten.view.default(view_72, [2, 8192, -1, 128]); view_72 = None + view_80 = torch.ops.aten.view.default(view_75, [2, 8192, -1, 128]); view_75 = None + view_81 = torch.ops.aten.view.default(view_78, [2, 8192, -1, 128]); view_78 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_79, torch.float32); view_79 = None + view_82 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 32, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_82); view_82 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_80, torch.float32); view_80 = None + view_83 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 8, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_83); view_83 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_16); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_85 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 32, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_16); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_86 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 8, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_85, torch.bfloat16); view_85 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_86, torch.bfloat16); view_86 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 8, 4, 128]); unsqueeze_4 = None + clone_4 = torch.ops.aten.clone.default(expand_4, memory_format = torch.contiguous_format); expand_4 = None + view_87 = torch.ops.aten.view.default(clone_4, [2, 8192, 32, 128]); clone_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_81, 3); view_81 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 8, 4, 128]); unsqueeze_5 = None + clone_5 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format); expand_5 = None + view_88 = torch.ops.aten.view.default(clone_5, [2, 8192, 32, 128]); clone_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_87, [0, 2, 1, 3]); view_87 = None + permute_27 = torch.ops.aten.permute.default(view_88, [0, 2, 1, 3]); view_88 = None + _scaled_dot_product_cudnn_attention_2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_25, permute_26, permute_27, None, True, 0.0, True); permute_25 = permute_26 = permute_27 = None + getitem_18 = _scaled_dot_product_cudnn_attention_2[0] + getitem_19 = _scaled_dot_product_cudnn_attention_2[1] + getitem_24 = _scaled_dot_product_cudnn_attention_2[6] + getitem_25 = _scaled_dot_product_cudnn_attention_2[7]; _scaled_dot_product_cudnn_attention_2 = None + permute_28 = torch.ops.aten.permute.default(getitem_18, [0, 2, 1, 3]) + view_89 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16) + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 256, '0'); convert_element_type_83 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_23, [1, 0]); wait_tensor_23 = None + view_91 = torch.ops.aten.view.default(view_89, [16384, 4096]); view_89 = None + mm_17 = torch.ops.aten.mm.default(view_91, permute_29); view_91 = permute_29 = None + view_92 = torch.ops.aten.view.default(mm_17, [2, 8192, 4096]); mm_17 = None + add_9 = torch.ops.aten.add.Tensor(add_7, view_92); view_92 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16) + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 256, '0'); convert_element_type_86 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32) + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = rsqrt_5 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_24); mul_20 = wait_tensor_24 = None + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16) + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 256, '0'); convert_element_type_89 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + view_95 = torch.ops.aten.view.default(convert_element_type_88, [16384, 4096]); convert_element_type_88 = None + mm_18 = torch.ops.aten.mm.default(view_95, permute_30); permute_30 = None + view_96 = torch.ops.aten.view.default(mm_18, [2, 8192, 14336]) + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_96, torch.float32); view_96 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); convert_element_type_92 = sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16) + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 256, '0'); convert_element_type_94 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + mm_19 = torch.ops.aten.mm.default(view_95, permute_31); view_95 = permute_31 = None + view_99 = torch.ops.aten.view.default(mm_19, [2, 8192, 14336]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_99); convert_element_type_93 = view_99 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16) + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 256, '0'); convert_element_type_97 = None + wait_tensor_27 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_27, [1, 0]); wait_tensor_27 = None + view_101 = torch.ops.aten.view.default(mul_23, [16384, 14336]); mul_23 = None + mm_20 = torch.ops.aten.mm.default(view_101, permute_32); view_101 = permute_32 = None + view_102 = torch.ops.aten.view.default(mm_20, [2, 8192, 4096]); mm_20 = None + add_11 = torch.ops.aten.add.Tensor(add_9, view_102); add_9 = view_102 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16) + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 256, '0'); convert_element_type_100 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32) + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = rsqrt_6 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_28); mul_24 = wait_tensor_28 = None + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16) + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 256, '0'); convert_element_type_103 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_29, [1, 0]); wait_tensor_29 = None + view_105 = torch.ops.aten.view.default(convert_element_type_102, [16384, 4096]); convert_element_type_102 = None + mm_21 = torch.ops.aten.mm.default(view_105, permute_33); permute_33 = None + view_106 = torch.ops.aten.view.default(mm_21, [2, 8192, 4096]) + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16) + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 256, '0'); convert_element_type_106 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + mm_22 = torch.ops.aten.mm.default(view_105, permute_34); permute_34 = None + view_109 = torch.ops.aten.view.default(mm_22, [2, 8192, 1024]); mm_22 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16) + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 256, '0'); convert_element_type_109 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_23 = torch.ops.aten.mm.default(view_105, permute_35); view_105 = permute_35 = None + view_112 = torch.ops.aten.view.default(mm_23, [2, 8192, 1024]) + view_113 = torch.ops.aten.view.default(view_106, [2, 8192, -1, 128]); view_106 = None + view_114 = torch.ops.aten.view.default(view_109, [2, 8192, -1, 128]); view_109 = None + view_115 = torch.ops.aten.view.default(view_112, [2, 8192, -1, 128]); view_112 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_113, torch.float32); view_113 = None + view_116 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 32, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_116); view_116 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_114, torch.float32); view_114 = None + view_117 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 8, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_117); view_117 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_16); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_119 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 32, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_16); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_120 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 8, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_119, torch.bfloat16); view_119 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_120, torch.bfloat16); view_120 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 8, 4, 128]); unsqueeze_6 = None + clone_6 = torch.ops.aten.clone.default(expand_6, memory_format = torch.contiguous_format); expand_6 = None + view_121 = torch.ops.aten.view.default(clone_6, [2, 8192, 32, 128]); clone_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_115, 3); view_115 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 8, 4, 128]); unsqueeze_7 = None + clone_7 = torch.ops.aten.clone.default(expand_7, memory_format = torch.contiguous_format); expand_7 = None + view_122 = torch.ops.aten.view.default(clone_7, [2, 8192, 32, 128]); clone_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_121, [0, 2, 1, 3]); view_121 = None + permute_38 = torch.ops.aten.permute.default(view_122, [0, 2, 1, 3]); view_122 = None + _scaled_dot_product_cudnn_attention_3 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_36, permute_37, permute_38, None, True, 0.0, True); permute_36 = permute_37 = permute_38 = None + getitem_27 = _scaled_dot_product_cudnn_attention_3[0] + getitem_28 = _scaled_dot_product_cudnn_attention_3[1] + getitem_33 = _scaled_dot_product_cudnn_attention_3[6] + getitem_34 = _scaled_dot_product_cudnn_attention_3[7]; _scaled_dot_product_cudnn_attention_3 = None + permute_39 = torch.ops.aten.permute.default(getitem_27, [0, 2, 1, 3]) + view_123 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16) + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 256, '0'); convert_element_type_116 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + view_125 = torch.ops.aten.view.default(view_123, [16384, 4096]); view_123 = None + mm_24 = torch.ops.aten.mm.default(view_125, permute_40); view_125 = permute_40 = None + view_126 = torch.ops.aten.view.default(mm_24, [2, 8192, 4096]); mm_24 = None + add_13 = torch.ops.aten.add.Tensor(add_11, view_126); view_126 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16) + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 256, '0'); convert_element_type_119 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32) + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = rsqrt_7 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_33); mul_28 = wait_tensor_33 = None + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16) + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 256, '0'); convert_element_type_122 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_34, [1, 0]); wait_tensor_34 = None + view_129 = torch.ops.aten.view.default(convert_element_type_121, [16384, 4096]); convert_element_type_121 = None + mm_25 = torch.ops.aten.mm.default(view_129, permute_41); permute_41 = None + view_130 = torch.ops.aten.view.default(mm_25, [2, 8192, 14336]) + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_130, torch.float32); view_130 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); convert_element_type_125 = sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16) + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 256, '0'); convert_element_type_127 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_35, [1, 0]); wait_tensor_35 = None + mm_26 = torch.ops.aten.mm.default(view_129, permute_42); view_129 = permute_42 = None + view_133 = torch.ops.aten.view.default(mm_26, [2, 8192, 14336]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_133); convert_element_type_126 = view_133 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16) + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 256, '0'); convert_element_type_130 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_36, [1, 0]); wait_tensor_36 = None + view_135 = torch.ops.aten.view.default(mul_31, [16384, 14336]); mul_31 = None + mm_27 = torch.ops.aten.mm.default(view_135, permute_43); view_135 = permute_43 = None + view_136 = torch.ops.aten.view.default(mm_27, [2, 8192, 4096]); mm_27 = None + add_15 = torch.ops.aten.add.Tensor(add_13, view_136); add_13 = view_136 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16) + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 256, '0'); convert_element_type_133 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32) + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = rsqrt_8 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_37); mul_32 = wait_tensor_37 = None + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16) + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 256, '0'); convert_element_type_136 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + view_139 = torch.ops.aten.view.default(convert_element_type_135, [16384, 4096]); convert_element_type_135 = None + mm_28 = torch.ops.aten.mm.default(view_139, permute_44); permute_44 = None + view_140 = torch.ops.aten.view.default(mm_28, [2, 8192, 4096]) + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16) + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 256, '0'); convert_element_type_139 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + mm_29 = torch.ops.aten.mm.default(view_139, permute_45); permute_45 = None + view_143 = torch.ops.aten.view.default(mm_29, [2, 8192, 1024]); mm_29 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16) + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 256, '0'); convert_element_type_142 = None + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_40, [1, 0]); wait_tensor_40 = None + mm_30 = torch.ops.aten.mm.default(view_139, permute_46); view_139 = permute_46 = None + view_146 = torch.ops.aten.view.default(mm_30, [2, 8192, 1024]) + view_147 = torch.ops.aten.view.default(view_140, [2, 8192, -1, 128]); view_140 = None + view_148 = torch.ops.aten.view.default(view_143, [2, 8192, -1, 128]); view_143 = None + view_149 = torch.ops.aten.view.default(view_146, [2, 8192, -1, 128]); view_146 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_147, torch.float32); view_147 = None + view_150 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 32, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_150); view_150 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_148, torch.float32); view_148 = None + view_151 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 8, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_151); view_151 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_16); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_153 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 32, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_16); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_154 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 8, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_153, torch.bfloat16); view_153 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_154, torch.bfloat16); view_154 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 8, 4, 128]); unsqueeze_8 = None + clone_8 = torch.ops.aten.clone.default(expand_8, memory_format = torch.contiguous_format); expand_8 = None + view_155 = torch.ops.aten.view.default(clone_8, [2, 8192, 32, 128]); clone_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_149, 3); view_149 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 8, 4, 128]); unsqueeze_9 = None + clone_9 = torch.ops.aten.clone.default(expand_9, memory_format = torch.contiguous_format); expand_9 = None + view_156 = torch.ops.aten.view.default(clone_9, [2, 8192, 32, 128]); clone_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_155, [0, 2, 1, 3]); view_155 = None + permute_49 = torch.ops.aten.permute.default(view_156, [0, 2, 1, 3]); view_156 = None + _scaled_dot_product_cudnn_attention_4 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_47, permute_48, permute_49, None, True, 0.0, True); permute_47 = permute_48 = permute_49 = None + getitem_36 = _scaled_dot_product_cudnn_attention_4[0] + getitem_37 = _scaled_dot_product_cudnn_attention_4[1] + getitem_42 = _scaled_dot_product_cudnn_attention_4[6] + getitem_43 = _scaled_dot_product_cudnn_attention_4[7]; _scaled_dot_product_cudnn_attention_4 = None + permute_50 = torch.ops.aten.permute.default(getitem_36, [0, 2, 1, 3]) + view_157 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16) + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 256, '0'); convert_element_type_149 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_41, [1, 0]); wait_tensor_41 = None + view_159 = torch.ops.aten.view.default(view_157, [16384, 4096]); view_157 = None + mm_31 = torch.ops.aten.mm.default(view_159, permute_51); view_159 = permute_51 = None + view_160 = torch.ops.aten.view.default(mm_31, [2, 8192, 4096]); mm_31 = None + add_17 = torch.ops.aten.add.Tensor(add_15, view_160); view_160 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16) + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 256, '0'); convert_element_type_152 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32) + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = rsqrt_9 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_42); mul_36 = wait_tensor_42 = None + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16) + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 256, '0'); convert_element_type_155 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + view_163 = torch.ops.aten.view.default(convert_element_type_154, [16384, 4096]); convert_element_type_154 = None + mm_32 = torch.ops.aten.mm.default(view_163, permute_52); permute_52 = None + view_164 = torch.ops.aten.view.default(mm_32, [2, 8192, 14336]) + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_164, torch.float32); view_164 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); convert_element_type_158 = sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16) + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 256, '0'); convert_element_type_160 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_33 = torch.ops.aten.mm.default(view_163, permute_53); view_163 = permute_53 = None + view_167 = torch.ops.aten.view.default(mm_33, [2, 8192, 14336]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_167); convert_element_type_159 = view_167 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16) + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 256, '0'); convert_element_type_163 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + view_169 = torch.ops.aten.view.default(mul_39, [16384, 14336]); mul_39 = None + mm_34 = torch.ops.aten.mm.default(view_169, permute_54); view_169 = permute_54 = None + view_170 = torch.ops.aten.view.default(mm_34, [2, 8192, 4096]); mm_34 = None + add_19 = torch.ops.aten.add.Tensor(add_17, view_170); add_17 = view_170 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16) + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 256, '0'); convert_element_type_166 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32) + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = rsqrt_10 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_46); mul_40 = wait_tensor_46 = None + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16) + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 256, '0'); convert_element_type_169 = None + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_47, [1, 0]); wait_tensor_47 = None + view_173 = torch.ops.aten.view.default(convert_element_type_168, [16384, 4096]); convert_element_type_168 = None + mm_35 = torch.ops.aten.mm.default(view_173, permute_55); permute_55 = None + view_174 = torch.ops.aten.view.default(mm_35, [2, 8192, 4096]) + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16) + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 256, '0'); convert_element_type_172 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_48, [1, 0]); wait_tensor_48 = None + mm_36 = torch.ops.aten.mm.default(view_173, permute_56); permute_56 = None + view_177 = torch.ops.aten.view.default(mm_36, [2, 8192, 1024]); mm_36 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16) + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 256, '0'); convert_element_type_175 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_49, [1, 0]); wait_tensor_49 = None + mm_37 = torch.ops.aten.mm.default(view_173, permute_57); view_173 = permute_57 = None + view_180 = torch.ops.aten.view.default(mm_37, [2, 8192, 1024]) + view_181 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + view_182 = torch.ops.aten.view.default(view_177, [2, 8192, -1, 128]); view_177 = None + view_183 = torch.ops.aten.view.default(view_180, [2, 8192, -1, 128]); view_180 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_181, torch.float32); view_181 = None + view_184 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 32, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_184); view_184 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_182, torch.float32); view_182 = None + view_185 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 8, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_185); view_185 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_16); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_187 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 32, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_16); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_188 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 8, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_187, torch.bfloat16); view_187 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_188, torch.bfloat16); view_188 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 8, 4, 128]); unsqueeze_10 = None + clone_10 = torch.ops.aten.clone.default(expand_10, memory_format = torch.contiguous_format); expand_10 = None + view_189 = torch.ops.aten.view.default(clone_10, [2, 8192, 32, 128]); clone_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_183, 3); view_183 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 8, 4, 128]); unsqueeze_11 = None + clone_11 = torch.ops.aten.clone.default(expand_11, memory_format = torch.contiguous_format); expand_11 = None + view_190 = torch.ops.aten.view.default(clone_11, [2, 8192, 32, 128]); clone_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_189, [0, 2, 1, 3]); view_189 = None + permute_60 = torch.ops.aten.permute.default(view_190, [0, 2, 1, 3]); view_190 = None + _scaled_dot_product_cudnn_attention_5 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_58, permute_59, permute_60, None, True, 0.0, True); permute_58 = permute_59 = permute_60 = None + getitem_45 = _scaled_dot_product_cudnn_attention_5[0] + getitem_46 = _scaled_dot_product_cudnn_attention_5[1] + getitem_51 = _scaled_dot_product_cudnn_attention_5[6] + getitem_52 = _scaled_dot_product_cudnn_attention_5[7]; _scaled_dot_product_cudnn_attention_5 = None + permute_61 = torch.ops.aten.permute.default(getitem_45, [0, 2, 1, 3]) + view_191 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16) + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 256, '0'); convert_element_type_182 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + view_193 = torch.ops.aten.view.default(view_191, [16384, 4096]); view_191 = None + mm_38 = torch.ops.aten.mm.default(view_193, permute_62); view_193 = permute_62 = None + view_194 = torch.ops.aten.view.default(mm_38, [2, 8192, 4096]); mm_38 = None + add_21 = torch.ops.aten.add.Tensor(add_19, view_194); view_194 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16) + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 256, '0'); convert_element_type_185 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32) + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = rsqrt_11 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_51); mul_44 = wait_tensor_51 = None + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16) + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 256, '0'); convert_element_type_188 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + view_197 = torch.ops.aten.view.default(convert_element_type_187, [16384, 4096]); convert_element_type_187 = None + mm_39 = torch.ops.aten.mm.default(view_197, permute_63); permute_63 = None + view_198 = torch.ops.aten.view.default(mm_39, [2, 8192, 14336]) + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_198, torch.float32); view_198 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); convert_element_type_191 = sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16) + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 256, '0'); convert_element_type_193 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_53, [1, 0]); wait_tensor_53 = None + mm_40 = torch.ops.aten.mm.default(view_197, permute_64); view_197 = permute_64 = None + view_201 = torch.ops.aten.view.default(mm_40, [2, 8192, 14336]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_201); convert_element_type_192 = view_201 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16) + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 256, '0'); convert_element_type_196 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_54, [1, 0]); wait_tensor_54 = None + view_203 = torch.ops.aten.view.default(mul_47, [16384, 14336]); mul_47 = None + mm_41 = torch.ops.aten.mm.default(view_203, permute_65); view_203 = permute_65 = None + view_204 = torch.ops.aten.view.default(mm_41, [2, 8192, 4096]); mm_41 = None + add_23 = torch.ops.aten.add.Tensor(add_21, view_204); add_21 = view_204 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16) + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 256, '0'); convert_element_type_199 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32) + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = rsqrt_12 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_55); mul_48 = wait_tensor_55 = None + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16) + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 256, '0'); convert_element_type_202 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + view_207 = torch.ops.aten.view.default(convert_element_type_201, [16384, 4096]); convert_element_type_201 = None + mm_42 = torch.ops.aten.mm.default(view_207, permute_66); permute_66 = None + view_208 = torch.ops.aten.view.default(mm_42, [2, 8192, 4096]) + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16) + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 256, '0'); convert_element_type_205 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_43 = torch.ops.aten.mm.default(view_207, permute_67); permute_67 = None + view_211 = torch.ops.aten.view.default(mm_43, [2, 8192, 1024]); mm_43 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16) + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 256, '0'); convert_element_type_208 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + mm_44 = torch.ops.aten.mm.default(view_207, permute_68); view_207 = permute_68 = None + view_214 = torch.ops.aten.view.default(mm_44, [2, 8192, 1024]) + view_215 = torch.ops.aten.view.default(view_208, [2, 8192, -1, 128]); view_208 = None + view_216 = torch.ops.aten.view.default(view_211, [2, 8192, -1, 128]); view_211 = None + view_217 = torch.ops.aten.view.default(view_214, [2, 8192, -1, 128]); view_214 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_215, torch.float32); view_215 = None + view_218 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 32, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_218); view_218 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_216, torch.float32); view_216 = None + view_219 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 8, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_219); view_219 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_16); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_221 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 32, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_16); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_222 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 8, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_221, torch.bfloat16); view_221 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_222, torch.bfloat16); view_222 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 8, 4, 128]); unsqueeze_12 = None + clone_12 = torch.ops.aten.clone.default(expand_12, memory_format = torch.contiguous_format); expand_12 = None + view_223 = torch.ops.aten.view.default(clone_12, [2, 8192, 32, 128]); clone_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_217, 3); view_217 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 8, 4, 128]); unsqueeze_13 = None + clone_13 = torch.ops.aten.clone.default(expand_13, memory_format = torch.contiguous_format); expand_13 = None + view_224 = torch.ops.aten.view.default(clone_13, [2, 8192, 32, 128]); clone_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_223, [0, 2, 1, 3]); view_223 = None + permute_71 = torch.ops.aten.permute.default(view_224, [0, 2, 1, 3]); view_224 = None + _scaled_dot_product_cudnn_attention_6 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_69, permute_70, permute_71, None, True, 0.0, True); permute_69 = permute_70 = permute_71 = None + getitem_54 = _scaled_dot_product_cudnn_attention_6[0] + getitem_55 = _scaled_dot_product_cudnn_attention_6[1] + getitem_60 = _scaled_dot_product_cudnn_attention_6[6] + getitem_61 = _scaled_dot_product_cudnn_attention_6[7]; _scaled_dot_product_cudnn_attention_6 = None + permute_72 = torch.ops.aten.permute.default(getitem_54, [0, 2, 1, 3]) + view_225 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16) + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 256, '0'); convert_element_type_215 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + view_227 = torch.ops.aten.view.default(view_225, [16384, 4096]); view_225 = None + mm_45 = torch.ops.aten.mm.default(view_227, permute_73); view_227 = permute_73 = None + view_228 = torch.ops.aten.view.default(mm_45, [2, 8192, 4096]); mm_45 = None + add_25 = torch.ops.aten.add.Tensor(add_23, view_228); view_228 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16) + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 256, '0'); convert_element_type_218 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32) + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = rsqrt_13 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_60); mul_52 = wait_tensor_60 = None + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16) + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 256, '0'); convert_element_type_221 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_61, [1, 0]); wait_tensor_61 = None + view_231 = torch.ops.aten.view.default(convert_element_type_220, [16384, 4096]); convert_element_type_220 = None + mm_46 = torch.ops.aten.mm.default(view_231, permute_74); permute_74 = None + view_232 = torch.ops.aten.view.default(mm_46, [2, 8192, 14336]) + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_232, torch.float32); view_232 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); convert_element_type_224 = sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16) + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 256, '0'); convert_element_type_226 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_62, [1, 0]); wait_tensor_62 = None + mm_47 = torch.ops.aten.mm.default(view_231, permute_75); view_231 = permute_75 = None + view_235 = torch.ops.aten.view.default(mm_47, [2, 8192, 14336]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_235); convert_element_type_225 = view_235 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16) + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 256, '0'); convert_element_type_229 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + view_237 = torch.ops.aten.view.default(mul_55, [16384, 14336]); mul_55 = None + mm_48 = torch.ops.aten.mm.default(view_237, permute_76); view_237 = permute_76 = None + view_238 = torch.ops.aten.view.default(mm_48, [2, 8192, 4096]); mm_48 = None + add_27 = torch.ops.aten.add.Tensor(add_25, view_238); add_25 = view_238 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16) + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 256, '0'); convert_element_type_232 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32) + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = rsqrt_14 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_64); mul_56 = wait_tensor_64 = None + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16) + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 256, '0'); convert_element_type_235 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + view_241 = torch.ops.aten.view.default(convert_element_type_234, [16384, 4096]); convert_element_type_234 = None + mm_49 = torch.ops.aten.mm.default(view_241, permute_77); permute_77 = None + view_242 = torch.ops.aten.view.default(mm_49, [2, 8192, 4096]) + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16) + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 256, '0'); convert_element_type_238 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_66, [1, 0]); wait_tensor_66 = None + mm_50 = torch.ops.aten.mm.default(view_241, permute_78); permute_78 = None + view_245 = torch.ops.aten.view.default(mm_50, [2, 8192, 1024]); mm_50 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16) + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 256, '0'); convert_element_type_241 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_67, [1, 0]); wait_tensor_67 = None + mm_51 = torch.ops.aten.mm.default(view_241, permute_79); view_241 = permute_79 = None + view_248 = torch.ops.aten.view.default(mm_51, [2, 8192, 1024]) + view_249 = torch.ops.aten.view.default(view_242, [2, 8192, -1, 128]); view_242 = None + view_250 = torch.ops.aten.view.default(view_245, [2, 8192, -1, 128]); view_245 = None + view_251 = torch.ops.aten.view.default(view_248, [2, 8192, -1, 128]); view_248 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 32, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_250, torch.float32); view_250 = None + view_253 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 8, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_253); view_253 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_16); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_255 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 32, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_16); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_256 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 8, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_256, torch.bfloat16); view_256 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 8, 4, 128]); unsqueeze_14 = None + clone_14 = torch.ops.aten.clone.default(expand_14, memory_format = torch.contiguous_format); expand_14 = None + view_257 = torch.ops.aten.view.default(clone_14, [2, 8192, 32, 128]); clone_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_251, 3); view_251 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 8, 4, 128]); unsqueeze_15 = None + clone_15 = torch.ops.aten.clone.default(expand_15, memory_format = torch.contiguous_format); expand_15 = None + view_258 = torch.ops.aten.view.default(clone_15, [2, 8192, 32, 128]); clone_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + permute_82 = torch.ops.aten.permute.default(view_258, [0, 2, 1, 3]); view_258 = None + _scaled_dot_product_cudnn_attention_7 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_80, permute_81, permute_82, None, True, 0.0, True); permute_80 = permute_81 = permute_82 = None + getitem_63 = _scaled_dot_product_cudnn_attention_7[0] + getitem_64 = _scaled_dot_product_cudnn_attention_7[1] + getitem_69 = _scaled_dot_product_cudnn_attention_7[6] + getitem_70 = _scaled_dot_product_cudnn_attention_7[7]; _scaled_dot_product_cudnn_attention_7 = None + permute_83 = torch.ops.aten.permute.default(getitem_63, [0, 2, 1, 3]) + view_259 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16) + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 256, '0'); convert_element_type_248 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_68, [1, 0]); wait_tensor_68 = None + view_261 = torch.ops.aten.view.default(view_259, [16384, 4096]); view_259 = None + mm_52 = torch.ops.aten.mm.default(view_261, permute_84); view_261 = permute_84 = None + view_262 = torch.ops.aten.view.default(mm_52, [2, 8192, 4096]); mm_52 = None + add_29 = torch.ops.aten.add.Tensor(add_27, view_262); view_262 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16) + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 256, '0'); convert_element_type_251 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32) + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = rsqrt_15 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_69); mul_60 = wait_tensor_69 = None + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16) + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 256, '0'); convert_element_type_254 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + view_265 = torch.ops.aten.view.default(convert_element_type_253, [16384, 4096]); convert_element_type_253 = None + mm_53 = torch.ops.aten.mm.default(view_265, permute_85); permute_85 = None + view_266 = torch.ops.aten.view.default(mm_53, [2, 8192, 14336]) + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_266, torch.float32); view_266 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); convert_element_type_257 = sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16) + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 256, '0'); convert_element_type_259 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_54 = torch.ops.aten.mm.default(view_265, permute_86); view_265 = permute_86 = None + view_269 = torch.ops.aten.view.default(mm_54, [2, 8192, 14336]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_269); convert_element_type_258 = view_269 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16) + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 256, '0'); convert_element_type_262 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + view_271 = torch.ops.aten.view.default(mul_63, [16384, 14336]); mul_63 = None + mm_55 = torch.ops.aten.mm.default(view_271, permute_87); view_271 = permute_87 = None + view_272 = torch.ops.aten.view.default(mm_55, [2, 8192, 4096]); mm_55 = None + add_31 = torch.ops.aten.add.Tensor(add_29, view_272); add_29 = view_272 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16) + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 256, '0'); convert_element_type_265 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32) + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = rsqrt_16 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_73); mul_64 = wait_tensor_73 = None + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16) + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 256, '0'); convert_element_type_268 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_74, [1, 0]); wait_tensor_74 = None + view_275 = torch.ops.aten.view.default(convert_element_type_267, [16384, 4096]); convert_element_type_267 = None + mm_56 = torch.ops.aten.mm.default(view_275, permute_88); permute_88 = None + view_276 = torch.ops.aten.view.default(mm_56, [2, 8192, 4096]) + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16) + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 256, '0'); convert_element_type_271 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_75, [1, 0]); wait_tensor_75 = None + mm_57 = torch.ops.aten.mm.default(view_275, permute_89); permute_89 = None + view_279 = torch.ops.aten.view.default(mm_57, [2, 8192, 1024]); mm_57 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16) + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 256, '0'); convert_element_type_274 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + mm_58 = torch.ops.aten.mm.default(view_275, permute_90); view_275 = permute_90 = None + view_282 = torch.ops.aten.view.default(mm_58, [2, 8192, 1024]) + view_283 = torch.ops.aten.view.default(view_276, [2, 8192, -1, 128]); view_276 = None + view_284 = torch.ops.aten.view.default(view_279, [2, 8192, -1, 128]); view_279 = None + view_285 = torch.ops.aten.view.default(view_282, [2, 8192, -1, 128]); view_282 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_283, torch.float32); view_283 = None + view_286 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 32, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_286); view_286 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_284, torch.float32); view_284 = None + view_287 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 8, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_287); view_287 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_16); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_289 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 32, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_16); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_290 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 8, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_289, torch.bfloat16); view_289 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_290, torch.bfloat16); view_290 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 8, 4, 128]); unsqueeze_16 = None + clone_16 = torch.ops.aten.clone.default(expand_16, memory_format = torch.contiguous_format); expand_16 = None + view_291 = torch.ops.aten.view.default(clone_16, [2, 8192, 32, 128]); clone_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_285, 3); view_285 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 8, 4, 128]); unsqueeze_17 = None + clone_17 = torch.ops.aten.clone.default(expand_17, memory_format = torch.contiguous_format); expand_17 = None + view_292 = torch.ops.aten.view.default(clone_17, [2, 8192, 32, 128]); clone_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_291, [0, 2, 1, 3]); view_291 = None + permute_93 = torch.ops.aten.permute.default(view_292, [0, 2, 1, 3]); view_292 = None + _scaled_dot_product_cudnn_attention_8 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_91, permute_92, permute_93, None, True, 0.0, True); permute_91 = permute_92 = permute_93 = None + getitem_72 = _scaled_dot_product_cudnn_attention_8[0] + getitem_73 = _scaled_dot_product_cudnn_attention_8[1] + getitem_78 = _scaled_dot_product_cudnn_attention_8[6] + getitem_79 = _scaled_dot_product_cudnn_attention_8[7]; _scaled_dot_product_cudnn_attention_8 = None + permute_94 = torch.ops.aten.permute.default(getitem_72, [0, 2, 1, 3]) + view_293 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16) + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 256, '0'); convert_element_type_281 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + view_295 = torch.ops.aten.view.default(view_293, [16384, 4096]); view_293 = None + mm_59 = torch.ops.aten.mm.default(view_295, permute_95); view_295 = permute_95 = None + view_296 = torch.ops.aten.view.default(mm_59, [2, 8192, 4096]); mm_59 = None + add_33 = torch.ops.aten.add.Tensor(add_31, view_296); view_296 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16) + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 256, '0'); convert_element_type_284 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32) + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = rsqrt_17 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_78); mul_68 = wait_tensor_78 = None + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16) + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 256, '0'); convert_element_type_287 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_79, [1, 0]); wait_tensor_79 = None + view_299 = torch.ops.aten.view.default(convert_element_type_286, [16384, 4096]); convert_element_type_286 = None + mm_60 = torch.ops.aten.mm.default(view_299, permute_96); permute_96 = None + view_300 = torch.ops.aten.view.default(mm_60, [2, 8192, 14336]) + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); convert_element_type_290 = sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16) + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 256, '0'); convert_element_type_292 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_80, [1, 0]); wait_tensor_80 = None + mm_61 = torch.ops.aten.mm.default(view_299, permute_97); view_299 = permute_97 = None + view_303 = torch.ops.aten.view.default(mm_61, [2, 8192, 14336]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_303); convert_element_type_291 = view_303 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16) + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 256, '0'); convert_element_type_295 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_81, [1, 0]); wait_tensor_81 = None + view_305 = torch.ops.aten.view.default(mul_71, [16384, 14336]); mul_71 = None + mm_62 = torch.ops.aten.mm.default(view_305, permute_98); view_305 = permute_98 = None + view_306 = torch.ops.aten.view.default(mm_62, [2, 8192, 4096]); mm_62 = None + add_35 = torch.ops.aten.add.Tensor(add_33, view_306); add_33 = view_306 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16) + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 256, '0'); convert_element_type_298 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32) + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = rsqrt_18 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_82); mul_72 = wait_tensor_82 = None + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16) + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 256, '0'); convert_element_type_301 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + view_309 = torch.ops.aten.view.default(convert_element_type_300, [16384, 4096]); convert_element_type_300 = None + mm_63 = torch.ops.aten.mm.default(view_309, permute_99); permute_99 = None + view_310 = torch.ops.aten.view.default(mm_63, [2, 8192, 4096]) + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16) + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 256, '0'); convert_element_type_304 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + mm_64 = torch.ops.aten.mm.default(view_309, permute_100); permute_100 = None + view_313 = torch.ops.aten.view.default(mm_64, [2, 8192, 1024]); mm_64 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16) + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 256, '0'); convert_element_type_307 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + mm_65 = torch.ops.aten.mm.default(view_309, permute_101); view_309 = permute_101 = None + view_316 = torch.ops.aten.view.default(mm_65, [2, 8192, 1024]) + view_317 = torch.ops.aten.view.default(view_310, [2, 8192, -1, 128]); view_310 = None + view_318 = torch.ops.aten.view.default(view_313, [2, 8192, -1, 128]); view_313 = None + view_319 = torch.ops.aten.view.default(view_316, [2, 8192, -1, 128]); view_316 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_317, torch.float32); view_317 = None + view_320 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 32, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_320); view_320 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_318, torch.float32); view_318 = None + view_321 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 8, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_321); view_321 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_16); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_323 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 32, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_16); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_324 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 8, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_323, torch.bfloat16); view_323 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_324, torch.bfloat16); view_324 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 8, 4, 128]); unsqueeze_18 = None + clone_18 = torch.ops.aten.clone.default(expand_18, memory_format = torch.contiguous_format); expand_18 = None + view_325 = torch.ops.aten.view.default(clone_18, [2, 8192, 32, 128]); clone_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_319, 3); view_319 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 8, 4, 128]); unsqueeze_19 = None + clone_19 = torch.ops.aten.clone.default(expand_19, memory_format = torch.contiguous_format); expand_19 = None + view_326 = torch.ops.aten.view.default(clone_19, [2, 8192, 32, 128]); clone_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_325, [0, 2, 1, 3]); view_325 = None + permute_104 = torch.ops.aten.permute.default(view_326, [0, 2, 1, 3]); view_326 = None + _scaled_dot_product_cudnn_attention_9 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_102, permute_103, permute_104, None, True, 0.0, True); permute_102 = permute_103 = permute_104 = None + getitem_81 = _scaled_dot_product_cudnn_attention_9[0] + getitem_82 = _scaled_dot_product_cudnn_attention_9[1] + getitem_87 = _scaled_dot_product_cudnn_attention_9[6] + getitem_88 = _scaled_dot_product_cudnn_attention_9[7]; _scaled_dot_product_cudnn_attention_9 = None + permute_105 = torch.ops.aten.permute.default(getitem_81, [0, 2, 1, 3]) + view_327 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16) + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 256, '0'); convert_element_type_314 = None + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_86, [1, 0]); wait_tensor_86 = None + view_329 = torch.ops.aten.view.default(view_327, [16384, 4096]); view_327 = None + mm_66 = torch.ops.aten.mm.default(view_329, permute_106); view_329 = permute_106 = None + view_330 = torch.ops.aten.view.default(mm_66, [2, 8192, 4096]); mm_66 = None + add_37 = torch.ops.aten.add.Tensor(add_35, view_330); view_330 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16) + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 256, '0'); convert_element_type_317 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32) + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = rsqrt_19 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_87); mul_76 = wait_tensor_87 = None + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16) + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 256, '0'); convert_element_type_320 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_88, [1, 0]); wait_tensor_88 = None + view_333 = torch.ops.aten.view.default(convert_element_type_319, [16384, 4096]); convert_element_type_319 = None + mm_67 = torch.ops.aten.mm.default(view_333, permute_107); permute_107 = None + view_334 = torch.ops.aten.view.default(mm_67, [2, 8192, 14336]) + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_334, torch.float32); view_334 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); convert_element_type_323 = sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16) + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 256, '0'); convert_element_type_325 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + mm_68 = torch.ops.aten.mm.default(view_333, permute_108); view_333 = permute_108 = None + view_337 = torch.ops.aten.view.default(mm_68, [2, 8192, 14336]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_337); convert_element_type_324 = view_337 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16) + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 256, '0'); convert_element_type_328 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + view_339 = torch.ops.aten.view.default(mul_79, [16384, 14336]); mul_79 = None + mm_69 = torch.ops.aten.mm.default(view_339, permute_109); view_339 = permute_109 = None + view_340 = torch.ops.aten.view.default(mm_69, [2, 8192, 4096]); mm_69 = None + add_39 = torch.ops.aten.add.Tensor(add_37, view_340); add_37 = view_340 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16) + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 256, '0'); convert_element_type_331 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32) + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = rsqrt_20 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_91); mul_80 = wait_tensor_91 = None + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16) + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 256, '0'); convert_element_type_334 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_92, [1, 0]); wait_tensor_92 = None + view_343 = torch.ops.aten.view.default(convert_element_type_333, [16384, 4096]); convert_element_type_333 = None + mm_70 = torch.ops.aten.mm.default(view_343, permute_110); permute_110 = None + view_344 = torch.ops.aten.view.default(mm_70, [2, 8192, 4096]) + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16) + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 256, '0'); convert_element_type_337 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_93, [1, 0]); wait_tensor_93 = None + mm_71 = torch.ops.aten.mm.default(view_343, permute_111); permute_111 = None + view_347 = torch.ops.aten.view.default(mm_71, [2, 8192, 1024]); mm_71 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16) + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 256, '0'); convert_element_type_340 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_94, [1, 0]); wait_tensor_94 = None + mm_72 = torch.ops.aten.mm.default(view_343, permute_112); view_343 = permute_112 = None + view_350 = torch.ops.aten.view.default(mm_72, [2, 8192, 1024]) + view_351 = torch.ops.aten.view.default(view_344, [2, 8192, -1, 128]); view_344 = None + view_352 = torch.ops.aten.view.default(view_347, [2, 8192, -1, 128]); view_347 = None + view_353 = torch.ops.aten.view.default(view_350, [2, 8192, -1, 128]); view_350 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_351, torch.float32); view_351 = None + view_354 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 32, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_354); view_354 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_352, torch.float32); view_352 = None + view_355 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 8, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_355); view_355 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_16); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_357 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 32, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_16); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_358 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 8, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_357, torch.bfloat16); view_357 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_358, torch.bfloat16); view_358 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 8, 4, 128]); unsqueeze_20 = None + clone_20 = torch.ops.aten.clone.default(expand_20, memory_format = torch.contiguous_format); expand_20 = None + view_359 = torch.ops.aten.view.default(clone_20, [2, 8192, 32, 128]); clone_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_353, 3); view_353 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 8, 4, 128]); unsqueeze_21 = None + clone_21 = torch.ops.aten.clone.default(expand_21, memory_format = torch.contiguous_format); expand_21 = None + view_360 = torch.ops.aten.view.default(clone_21, [2, 8192, 32, 128]); clone_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_359, [0, 2, 1, 3]); view_359 = None + permute_115 = torch.ops.aten.permute.default(view_360, [0, 2, 1, 3]); view_360 = None + _scaled_dot_product_cudnn_attention_10 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_113, permute_114, permute_115, None, True, 0.0, True); permute_113 = permute_114 = permute_115 = None + getitem_90 = _scaled_dot_product_cudnn_attention_10[0] + getitem_91 = _scaled_dot_product_cudnn_attention_10[1] + getitem_96 = _scaled_dot_product_cudnn_attention_10[6] + getitem_97 = _scaled_dot_product_cudnn_attention_10[7]; _scaled_dot_product_cudnn_attention_10 = None + permute_116 = torch.ops.aten.permute.default(getitem_90, [0, 2, 1, 3]) + view_361 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16) + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 256, '0'); convert_element_type_347 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + view_363 = torch.ops.aten.view.default(view_361, [16384, 4096]); view_361 = None + mm_73 = torch.ops.aten.mm.default(view_363, permute_117); view_363 = permute_117 = None + view_364 = torch.ops.aten.view.default(mm_73, [2, 8192, 4096]); mm_73 = None + add_41 = torch.ops.aten.add.Tensor(add_39, view_364); view_364 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16) + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 256, '0'); convert_element_type_350 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32) + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = rsqrt_21 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_96); mul_84 = wait_tensor_96 = None + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16) + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 256, '0'); convert_element_type_353 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + view_367 = torch.ops.aten.view.default(convert_element_type_352, [16384, 4096]); convert_element_type_352 = None + mm_74 = torch.ops.aten.mm.default(view_367, permute_118); permute_118 = None + view_368 = torch.ops.aten.view.default(mm_74, [2, 8192, 14336]) + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_368, torch.float32); view_368 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); convert_element_type_356 = sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16) + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 256, '0'); convert_element_type_358 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + mm_75 = torch.ops.aten.mm.default(view_367, permute_119); view_367 = permute_119 = None + view_371 = torch.ops.aten.view.default(mm_75, [2, 8192, 14336]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_371); convert_element_type_357 = view_371 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16) + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 256, '0'); convert_element_type_361 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_99, [1, 0]); wait_tensor_99 = None + view_373 = torch.ops.aten.view.default(mul_87, [16384, 14336]); mul_87 = None + mm_76 = torch.ops.aten.mm.default(view_373, permute_120); view_373 = permute_120 = None + view_374 = torch.ops.aten.view.default(mm_76, [2, 8192, 4096]); mm_76 = None + add_43 = torch.ops.aten.add.Tensor(add_41, view_374); add_41 = view_374 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16) + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 256, '0'); convert_element_type_364 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32) + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = rsqrt_22 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_100); mul_88 = wait_tensor_100 = None + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16) + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 256, '0'); convert_element_type_367 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_101, [1, 0]); wait_tensor_101 = None + view_377 = torch.ops.aten.view.default(convert_element_type_366, [16384, 4096]); convert_element_type_366 = None + mm_77 = torch.ops.aten.mm.default(view_377, permute_121); permute_121 = None + view_378 = torch.ops.aten.view.default(mm_77, [2, 8192, 4096]) + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16) + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 256, '0'); convert_element_type_370 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + mm_78 = torch.ops.aten.mm.default(view_377, permute_122); permute_122 = None + view_381 = torch.ops.aten.view.default(mm_78, [2, 8192, 1024]); mm_78 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16) + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 256, '0'); convert_element_type_373 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + mm_79 = torch.ops.aten.mm.default(view_377, permute_123); view_377 = permute_123 = None + view_384 = torch.ops.aten.view.default(mm_79, [2, 8192, 1024]) + view_385 = torch.ops.aten.view.default(view_378, [2, 8192, -1, 128]); view_378 = None + view_386 = torch.ops.aten.view.default(view_381, [2, 8192, -1, 128]); view_381 = None + view_387 = torch.ops.aten.view.default(view_384, [2, 8192, -1, 128]); view_384 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_385, torch.float32); view_385 = None + view_388 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 32, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_388); view_388 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_386, torch.float32); view_386 = None + view_389 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 8, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_389); view_389 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_16); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_391 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 32, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_16); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_392 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 8, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_391, torch.bfloat16); view_391 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_392, torch.bfloat16); view_392 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 8, 4, 128]); unsqueeze_22 = None + clone_22 = torch.ops.aten.clone.default(expand_22, memory_format = torch.contiguous_format); expand_22 = None + view_393 = torch.ops.aten.view.default(clone_22, [2, 8192, 32, 128]); clone_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_387, 3); view_387 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 8, 4, 128]); unsqueeze_23 = None + clone_23 = torch.ops.aten.clone.default(expand_23, memory_format = torch.contiguous_format); expand_23 = None + view_394 = torch.ops.aten.view.default(clone_23, [2, 8192, 32, 128]); clone_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_393, [0, 2, 1, 3]); view_393 = None + permute_126 = torch.ops.aten.permute.default(view_394, [0, 2, 1, 3]); view_394 = None + _scaled_dot_product_cudnn_attention_11 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_124, permute_125, permute_126, None, True, 0.0, True); permute_124 = permute_125 = permute_126 = None + getitem_99 = _scaled_dot_product_cudnn_attention_11[0] + getitem_100 = _scaled_dot_product_cudnn_attention_11[1] + getitem_105 = _scaled_dot_product_cudnn_attention_11[6] + getitem_106 = _scaled_dot_product_cudnn_attention_11[7]; _scaled_dot_product_cudnn_attention_11 = None + permute_127 = torch.ops.aten.permute.default(getitem_99, [0, 2, 1, 3]) + view_395 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16) + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 256, '0'); convert_element_type_380 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + view_397 = torch.ops.aten.view.default(view_395, [16384, 4096]); view_395 = None + mm_80 = torch.ops.aten.mm.default(view_397, permute_128); view_397 = permute_128 = None + view_398 = torch.ops.aten.view.default(mm_80, [2, 8192, 4096]); mm_80 = None + add_45 = torch.ops.aten.add.Tensor(add_43, view_398); view_398 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16) + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 256, '0'); convert_element_type_383 = None + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32) + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = rsqrt_23 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_105); mul_92 = wait_tensor_105 = None + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16) + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 256, '0'); convert_element_type_386 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_106, [1, 0]); wait_tensor_106 = None + view_401 = torch.ops.aten.view.default(convert_element_type_385, [16384, 4096]); convert_element_type_385 = None + mm_81 = torch.ops.aten.mm.default(view_401, permute_129); permute_129 = None + view_402 = torch.ops.aten.view.default(mm_81, [2, 8192, 14336]) + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_402, torch.float32); view_402 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); convert_element_type_389 = sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16) + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 256, '0'); convert_element_type_391 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_107, [1, 0]); wait_tensor_107 = None + mm_82 = torch.ops.aten.mm.default(view_401, permute_130); view_401 = permute_130 = None + view_405 = torch.ops.aten.view.default(mm_82, [2, 8192, 14336]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_405); convert_element_type_390 = view_405 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16) + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 256, '0'); convert_element_type_394 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + view_407 = torch.ops.aten.view.default(mul_95, [16384, 14336]); mul_95 = None + mm_83 = torch.ops.aten.mm.default(view_407, permute_131); view_407 = permute_131 = None + view_408 = torch.ops.aten.view.default(mm_83, [2, 8192, 4096]); mm_83 = None + add_47 = torch.ops.aten.add.Tensor(add_45, view_408); add_45 = view_408 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16) + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 256, '0'); convert_element_type_397 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32) + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = rsqrt_24 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_109); mul_96 = wait_tensor_109 = None + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16) + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 256, '0'); convert_element_type_400 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + view_411 = torch.ops.aten.view.default(convert_element_type_399, [16384, 4096]); convert_element_type_399 = None + mm_84 = torch.ops.aten.mm.default(view_411, permute_132); permute_132 = None + view_412 = torch.ops.aten.view.default(mm_84, [2, 8192, 4096]) + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16) + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 256, '0'); convert_element_type_403 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + mm_85 = torch.ops.aten.mm.default(view_411, permute_133); permute_133 = None + view_415 = torch.ops.aten.view.default(mm_85, [2, 8192, 1024]); mm_85 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16) + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 256, '0'); convert_element_type_406 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_112, [1, 0]); wait_tensor_112 = None + mm_86 = torch.ops.aten.mm.default(view_411, permute_134); view_411 = permute_134 = None + view_418 = torch.ops.aten.view.default(mm_86, [2, 8192, 1024]) + view_419 = torch.ops.aten.view.default(view_412, [2, 8192, -1, 128]); view_412 = None + view_420 = torch.ops.aten.view.default(view_415, [2, 8192, -1, 128]); view_415 = None + view_421 = torch.ops.aten.view.default(view_418, [2, 8192, -1, 128]); view_418 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_419, torch.float32); view_419 = None + view_422 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 32, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_422); view_422 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_420, torch.float32); view_420 = None + view_423 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 8, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_423); view_423 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_16); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_425 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 32, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_16); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_426 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 8, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_425, torch.bfloat16); view_425 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_426, torch.bfloat16); view_426 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 8, 4, 128]); unsqueeze_24 = None + clone_24 = torch.ops.aten.clone.default(expand_24, memory_format = torch.contiguous_format); expand_24 = None + view_427 = torch.ops.aten.view.default(clone_24, [2, 8192, 32, 128]); clone_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_421, 3); view_421 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 8, 4, 128]); unsqueeze_25 = None + clone_25 = torch.ops.aten.clone.default(expand_25, memory_format = torch.contiguous_format); expand_25 = None + view_428 = torch.ops.aten.view.default(clone_25, [2, 8192, 32, 128]); clone_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_427, [0, 2, 1, 3]); view_427 = None + permute_137 = torch.ops.aten.permute.default(view_428, [0, 2, 1, 3]); view_428 = None + _scaled_dot_product_cudnn_attention_12 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_135, permute_136, permute_137, None, True, 0.0, True); permute_135 = permute_136 = permute_137 = None + getitem_108 = _scaled_dot_product_cudnn_attention_12[0] + getitem_109 = _scaled_dot_product_cudnn_attention_12[1] + getitem_114 = _scaled_dot_product_cudnn_attention_12[6] + getitem_115 = _scaled_dot_product_cudnn_attention_12[7]; _scaled_dot_product_cudnn_attention_12 = None + permute_138 = torch.ops.aten.permute.default(getitem_108, [0, 2, 1, 3]) + view_429 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16) + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 256, '0'); convert_element_type_413 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_113, [1, 0]); wait_tensor_113 = None + view_431 = torch.ops.aten.view.default(view_429, [16384, 4096]); view_429 = None + mm_87 = torch.ops.aten.mm.default(view_431, permute_139); view_431 = permute_139 = None + view_432 = torch.ops.aten.view.default(mm_87, [2, 8192, 4096]); mm_87 = None + add_49 = torch.ops.aten.add.Tensor(add_47, view_432); view_432 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16) + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 256, '0'); convert_element_type_416 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32) + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = rsqrt_25 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_114); mul_100 = wait_tensor_114 = None + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16) + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 256, '0'); convert_element_type_419 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + view_435 = torch.ops.aten.view.default(convert_element_type_418, [16384, 4096]); convert_element_type_418 = None + mm_88 = torch.ops.aten.mm.default(view_435, permute_140); permute_140 = None + view_436 = torch.ops.aten.view.default(mm_88, [2, 8192, 14336]) + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_436, torch.float32); view_436 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); convert_element_type_422 = sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16) + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 256, '0'); convert_element_type_424 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_89 = torch.ops.aten.mm.default(view_435, permute_141); view_435 = permute_141 = None + view_439 = torch.ops.aten.view.default(mm_89, [2, 8192, 14336]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_439); convert_element_type_423 = view_439 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16) + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 256, '0'); convert_element_type_427 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + view_441 = torch.ops.aten.view.default(mul_103, [16384, 14336]); mul_103 = None + mm_90 = torch.ops.aten.mm.default(view_441, permute_142); view_441 = permute_142 = None + view_442 = torch.ops.aten.view.default(mm_90, [2, 8192, 4096]); mm_90 = None + add_51 = torch.ops.aten.add.Tensor(add_49, view_442); add_49 = view_442 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16) + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 256, '0'); convert_element_type_430 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32) + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = rsqrt_26 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_118); mul_104 = wait_tensor_118 = None + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16) + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 256, '0'); convert_element_type_433 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_119, [1, 0]); wait_tensor_119 = None + view_445 = torch.ops.aten.view.default(convert_element_type_432, [16384, 4096]); convert_element_type_432 = None + mm_91 = torch.ops.aten.mm.default(view_445, permute_143); permute_143 = None + view_446 = torch.ops.aten.view.default(mm_91, [2, 8192, 4096]) + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16) + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 256, '0'); convert_element_type_436 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_120, [1, 0]); wait_tensor_120 = None + mm_92 = torch.ops.aten.mm.default(view_445, permute_144); permute_144 = None + view_449 = torch.ops.aten.view.default(mm_92, [2, 8192, 1024]); mm_92 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16) + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 256, '0'); convert_element_type_439 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + mm_93 = torch.ops.aten.mm.default(view_445, permute_145); view_445 = permute_145 = None + view_452 = torch.ops.aten.view.default(mm_93, [2, 8192, 1024]) + view_453 = torch.ops.aten.view.default(view_446, [2, 8192, -1, 128]); view_446 = None + view_454 = torch.ops.aten.view.default(view_449, [2, 8192, -1, 128]); view_449 = None + view_455 = torch.ops.aten.view.default(view_452, [2, 8192, -1, 128]); view_452 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_453, torch.float32); view_453 = None + view_456 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 32, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_456); view_456 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_454, torch.float32); view_454 = None + view_457 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 8, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_457); view_457 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_16); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_459 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 32, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_16); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_460 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 8, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_459, torch.bfloat16); view_459 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_460, torch.bfloat16); view_460 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 8, 4, 128]); unsqueeze_26 = None + clone_26 = torch.ops.aten.clone.default(expand_26, memory_format = torch.contiguous_format); expand_26 = None + view_461 = torch.ops.aten.view.default(clone_26, [2, 8192, 32, 128]); clone_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_455, 3); view_455 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 8, 4, 128]); unsqueeze_27 = None + clone_27 = torch.ops.aten.clone.default(expand_27, memory_format = torch.contiguous_format); expand_27 = None + view_462 = torch.ops.aten.view.default(clone_27, [2, 8192, 32, 128]); clone_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_461, [0, 2, 1, 3]); view_461 = None + permute_148 = torch.ops.aten.permute.default(view_462, [0, 2, 1, 3]); view_462 = None + _scaled_dot_product_cudnn_attention_13 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_146, permute_147, permute_148, None, True, 0.0, True); permute_146 = permute_147 = permute_148 = None + getitem_117 = _scaled_dot_product_cudnn_attention_13[0] + getitem_118 = _scaled_dot_product_cudnn_attention_13[1] + getitem_123 = _scaled_dot_product_cudnn_attention_13[6] + getitem_124 = _scaled_dot_product_cudnn_attention_13[7]; _scaled_dot_product_cudnn_attention_13 = None + permute_149 = torch.ops.aten.permute.default(getitem_117, [0, 2, 1, 3]) + view_463 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16) + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 256, '0'); convert_element_type_446 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + view_465 = torch.ops.aten.view.default(view_463, [16384, 4096]); view_463 = None + mm_94 = torch.ops.aten.mm.default(view_465, permute_150); view_465 = permute_150 = None + view_466 = torch.ops.aten.view.default(mm_94, [2, 8192, 4096]); mm_94 = None + add_53 = torch.ops.aten.add.Tensor(add_51, view_466); view_466 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16) + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 256, '0'); convert_element_type_449 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32) + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = rsqrt_27 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_123); mul_108 = wait_tensor_123 = None + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16) + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 256, '0'); convert_element_type_452 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + view_469 = torch.ops.aten.view.default(convert_element_type_451, [16384, 4096]); convert_element_type_451 = None + mm_95 = torch.ops.aten.mm.default(view_469, permute_151); permute_151 = None + view_470 = torch.ops.aten.view.default(mm_95, [2, 8192, 14336]) + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_470, torch.float32); view_470 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); convert_element_type_455 = sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16) + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 256, '0'); convert_element_type_457 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_125, [1, 0]); wait_tensor_125 = None + mm_96 = torch.ops.aten.mm.default(view_469, permute_152); view_469 = permute_152 = None + view_473 = torch.ops.aten.view.default(mm_96, [2, 8192, 14336]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_473); convert_element_type_456 = view_473 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16) + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 256, '0'); convert_element_type_460 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_126, [1, 0]); wait_tensor_126 = None + view_475 = torch.ops.aten.view.default(mul_111, [16384, 14336]); mul_111 = None + mm_97 = torch.ops.aten.mm.default(view_475, permute_153); view_475 = permute_153 = None + view_476 = torch.ops.aten.view.default(mm_97, [2, 8192, 4096]); mm_97 = None + add_55 = torch.ops.aten.add.Tensor(add_53, view_476); add_53 = view_476 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16) + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 256, '0'); convert_element_type_463 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32) + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = rsqrt_28 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_127); mul_112 = wait_tensor_127 = None + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16) + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 256, '0'); convert_element_type_466 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + view_479 = torch.ops.aten.view.default(convert_element_type_465, [16384, 4096]); convert_element_type_465 = None + mm_98 = torch.ops.aten.mm.default(view_479, permute_154); permute_154 = None + view_480 = torch.ops.aten.view.default(mm_98, [2, 8192, 4096]) + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16) + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 256, '0'); convert_element_type_469 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_99 = torch.ops.aten.mm.default(view_479, permute_155); permute_155 = None + view_483 = torch.ops.aten.view.default(mm_99, [2, 8192, 1024]); mm_99 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16) + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 256, '0'); convert_element_type_472 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + mm_100 = torch.ops.aten.mm.default(view_479, permute_156); view_479 = permute_156 = None + view_486 = torch.ops.aten.view.default(mm_100, [2, 8192, 1024]) + view_487 = torch.ops.aten.view.default(view_480, [2, 8192, -1, 128]); view_480 = None + view_488 = torch.ops.aten.view.default(view_483, [2, 8192, -1, 128]); view_483 = None + view_489 = torch.ops.aten.view.default(view_486, [2, 8192, -1, 128]); view_486 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_487, torch.float32); view_487 = None + view_490 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 32, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_490); view_490 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_488, torch.float32); view_488 = None + view_491 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 8, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_491); view_491 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_16); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_493 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 32, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_16); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_494 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 8, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_493, torch.bfloat16); view_493 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_494, torch.bfloat16); view_494 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 8, 4, 128]); unsqueeze_28 = None + clone_28 = torch.ops.aten.clone.default(expand_28, memory_format = torch.contiguous_format); expand_28 = None + view_495 = torch.ops.aten.view.default(clone_28, [2, 8192, 32, 128]); clone_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_489, 3); view_489 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 8, 4, 128]); unsqueeze_29 = None + clone_29 = torch.ops.aten.clone.default(expand_29, memory_format = torch.contiguous_format); expand_29 = None + view_496 = torch.ops.aten.view.default(clone_29, [2, 8192, 32, 128]); clone_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_495, [0, 2, 1, 3]); view_495 = None + permute_159 = torch.ops.aten.permute.default(view_496, [0, 2, 1, 3]); view_496 = None + _scaled_dot_product_cudnn_attention_14 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_157, permute_158, permute_159, None, True, 0.0, True); permute_157 = permute_158 = permute_159 = None + getitem_126 = _scaled_dot_product_cudnn_attention_14[0] + getitem_127 = _scaled_dot_product_cudnn_attention_14[1] + getitem_132 = _scaled_dot_product_cudnn_attention_14[6] + getitem_133 = _scaled_dot_product_cudnn_attention_14[7]; _scaled_dot_product_cudnn_attention_14 = None + permute_160 = torch.ops.aten.permute.default(getitem_126, [0, 2, 1, 3]) + view_497 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16) + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 256, '0'); convert_element_type_479 = None + wait_tensor_131 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_131, [1, 0]); wait_tensor_131 = None + view_499 = torch.ops.aten.view.default(view_497, [16384, 4096]); view_497 = None + mm_101 = torch.ops.aten.mm.default(view_499, permute_161); view_499 = permute_161 = None + view_500 = torch.ops.aten.view.default(mm_101, [2, 8192, 4096]); mm_101 = None + add_57 = torch.ops.aten.add.Tensor(add_55, view_500); view_500 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16) + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 256, '0'); convert_element_type_482 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32) + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = rsqrt_29 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_132); mul_116 = wait_tensor_132 = None + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16) + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 256, '0'); convert_element_type_485 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_133, [1, 0]); wait_tensor_133 = None + view_503 = torch.ops.aten.view.default(convert_element_type_484, [16384, 4096]); convert_element_type_484 = None + mm_102 = torch.ops.aten.mm.default(view_503, permute_162); permute_162 = None + view_504 = torch.ops.aten.view.default(mm_102, [2, 8192, 14336]) + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_504, torch.float32); view_504 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); convert_element_type_488 = sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16) + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 256, '0'); convert_element_type_490 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + mm_103 = torch.ops.aten.mm.default(view_503, permute_163); view_503 = permute_163 = None + view_507 = torch.ops.aten.view.default(mm_103, [2, 8192, 14336]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_507); convert_element_type_489 = view_507 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16) + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 256, '0'); convert_element_type_493 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + view_509 = torch.ops.aten.view.default(mul_119, [16384, 14336]); mul_119 = None + mm_104 = torch.ops.aten.mm.default(view_509, permute_164); view_509 = permute_164 = None + view_510 = torch.ops.aten.view.default(mm_104, [2, 8192, 4096]); mm_104 = None + add_59 = torch.ops.aten.add.Tensor(add_57, view_510); add_57 = view_510 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16) + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 256, '0'); convert_element_type_496 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32) + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = rsqrt_30 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_136); mul_120 = wait_tensor_136 = None + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16) + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 256, '0'); convert_element_type_499 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + view_513 = torch.ops.aten.view.default(convert_element_type_498, [16384, 4096]); convert_element_type_498 = None + mm_105 = torch.ops.aten.mm.default(view_513, permute_165); permute_165 = None + view_514 = torch.ops.aten.view.default(mm_105, [2, 8192, 4096]) + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16) + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 256, '0'); convert_element_type_502 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_138, [1, 0]); wait_tensor_138 = None + mm_106 = torch.ops.aten.mm.default(view_513, permute_166); permute_166 = None + view_517 = torch.ops.aten.view.default(mm_106, [2, 8192, 1024]); mm_106 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16) + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 256, '0'); convert_element_type_505 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_139, [1, 0]); wait_tensor_139 = None + mm_107 = torch.ops.aten.mm.default(view_513, permute_167); view_513 = permute_167 = None + view_520 = torch.ops.aten.view.default(mm_107, [2, 8192, 1024]) + view_521 = torch.ops.aten.view.default(view_514, [2, 8192, -1, 128]); view_514 = None + view_522 = torch.ops.aten.view.default(view_517, [2, 8192, -1, 128]); view_517 = None + view_523 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_521, torch.float32); view_521 = None + view_524 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 32, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_524); view_524 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_522, torch.float32); view_522 = None + view_525 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 8, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_525); view_525 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_16); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_527 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 32, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_16); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_528 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 8, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_527, torch.bfloat16); view_527 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_528, torch.bfloat16); view_528 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 8, 4, 128]); unsqueeze_30 = None + clone_30 = torch.ops.aten.clone.default(expand_30, memory_format = torch.contiguous_format); expand_30 = None + view_529 = torch.ops.aten.view.default(clone_30, [2, 8192, 32, 128]); clone_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_523, 3); view_523 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 8, 4, 128]); unsqueeze_31 = None + clone_31 = torch.ops.aten.clone.default(expand_31, memory_format = torch.contiguous_format); expand_31 = None + view_530 = torch.ops.aten.view.default(clone_31, [2, 8192, 32, 128]); clone_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_529, [0, 2, 1, 3]); view_529 = None + permute_170 = torch.ops.aten.permute.default(view_530, [0, 2, 1, 3]); view_530 = None + _scaled_dot_product_cudnn_attention_15 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_168, permute_169, permute_170, None, True, 0.0, True); permute_168 = permute_169 = permute_170 = None + getitem_135 = _scaled_dot_product_cudnn_attention_15[0] + getitem_136 = _scaled_dot_product_cudnn_attention_15[1] + getitem_141 = _scaled_dot_product_cudnn_attention_15[6] + getitem_142 = _scaled_dot_product_cudnn_attention_15[7]; _scaled_dot_product_cudnn_attention_15 = None + permute_171 = torch.ops.aten.permute.default(getitem_135, [0, 2, 1, 3]) + view_531 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16) + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 256, '0'); convert_element_type_512 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_140, [1, 0]); wait_tensor_140 = None + view_533 = torch.ops.aten.view.default(view_531, [16384, 4096]); view_531 = None + mm_108 = torch.ops.aten.mm.default(view_533, permute_172); view_533 = permute_172 = None + view_534 = torch.ops.aten.view.default(mm_108, [2, 8192, 4096]); mm_108 = None + add_61 = torch.ops.aten.add.Tensor(add_59, view_534); view_534 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16) + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 256, '0'); convert_element_type_515 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32) + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = rsqrt_31 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_141); mul_124 = wait_tensor_141 = None + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16) + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 256, '0'); convert_element_type_518 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + view_537 = torch.ops.aten.view.default(convert_element_type_517, [16384, 4096]); convert_element_type_517 = None + mm_109 = torch.ops.aten.mm.default(view_537, permute_173); permute_173 = None + view_538 = torch.ops.aten.view.default(mm_109, [2, 8192, 14336]) + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_538, torch.float32); view_538 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); convert_element_type_521 = sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16) + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 256, '0'); convert_element_type_523 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + mm_110 = torch.ops.aten.mm.default(view_537, permute_174); view_537 = permute_174 = None + view_541 = torch.ops.aten.view.default(mm_110, [2, 8192, 14336]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_541); convert_element_type_522 = view_541 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16) + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 256, '0'); convert_element_type_526 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_144, [1, 0]); wait_tensor_144 = None + view_543 = torch.ops.aten.view.default(mul_127, [16384, 14336]); mul_127 = None + mm_111 = torch.ops.aten.mm.default(view_543, permute_175); view_543 = permute_175 = None + view_544 = torch.ops.aten.view.default(mm_111, [2, 8192, 4096]); mm_111 = None + add_63 = torch.ops.aten.add.Tensor(add_61, view_544); add_61 = view_544 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16) + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 256, '0'); convert_element_type_529 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32) + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = rsqrt_32 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_145); mul_128 = wait_tensor_145 = None + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16) + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 256, '0'); convert_element_type_532 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_146, [1, 0]); wait_tensor_146 = None + view_547 = torch.ops.aten.view.default(convert_element_type_531, [16384, 4096]); convert_element_type_531 = None + mm_112 = torch.ops.aten.mm.default(view_547, permute_176); permute_176 = None + view_548 = torch.ops.aten.view.default(mm_112, [2, 8192, 4096]) + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16) + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 256, '0'); convert_element_type_535 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + mm_113 = torch.ops.aten.mm.default(view_547, permute_177); permute_177 = None + view_551 = torch.ops.aten.view.default(mm_113, [2, 8192, 1024]); mm_113 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16) + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 256, '0'); convert_element_type_538 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + mm_114 = torch.ops.aten.mm.default(view_547, permute_178); view_547 = permute_178 = None + view_554 = torch.ops.aten.view.default(mm_114, [2, 8192, 1024]) + view_555 = torch.ops.aten.view.default(view_548, [2, 8192, -1, 128]); view_548 = None + view_556 = torch.ops.aten.view.default(view_551, [2, 8192, -1, 128]); view_551 = None + view_557 = torch.ops.aten.view.default(view_554, [2, 8192, -1, 128]); view_554 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_555, torch.float32); view_555 = None + view_558 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 32, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_558); view_558 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_556, torch.float32); view_556 = None + view_559 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 8, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_559); view_559 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_16); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_561 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 32, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_16); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_562 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 8, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_561, torch.bfloat16); view_561 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_562, torch.bfloat16); view_562 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 8, 4, 128]); unsqueeze_32 = None + clone_32 = torch.ops.aten.clone.default(expand_32, memory_format = torch.contiguous_format); expand_32 = None + view_563 = torch.ops.aten.view.default(clone_32, [2, 8192, 32, 128]); clone_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_557, 3); view_557 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 8, 4, 128]); unsqueeze_33 = None + clone_33 = torch.ops.aten.clone.default(expand_33, memory_format = torch.contiguous_format); expand_33 = None + view_564 = torch.ops.aten.view.default(clone_33, [2, 8192, 32, 128]); clone_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_563, [0, 2, 1, 3]); view_563 = None + permute_181 = torch.ops.aten.permute.default(view_564, [0, 2, 1, 3]); view_564 = None + _scaled_dot_product_cudnn_attention_16 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_179, permute_180, permute_181, None, True, 0.0, True); permute_179 = permute_180 = permute_181 = None + getitem_144 = _scaled_dot_product_cudnn_attention_16[0] + getitem_145 = _scaled_dot_product_cudnn_attention_16[1] + getitem_150 = _scaled_dot_product_cudnn_attention_16[6] + getitem_151 = _scaled_dot_product_cudnn_attention_16[7]; _scaled_dot_product_cudnn_attention_16 = None + permute_182 = torch.ops.aten.permute.default(getitem_144, [0, 2, 1, 3]) + view_565 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16) + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 256, '0'); convert_element_type_545 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + view_567 = torch.ops.aten.view.default(view_565, [16384, 4096]); view_565 = None + mm_115 = torch.ops.aten.mm.default(view_567, permute_183); view_567 = permute_183 = None + view_568 = torch.ops.aten.view.default(mm_115, [2, 8192, 4096]); mm_115 = None + add_65 = torch.ops.aten.add.Tensor(add_63, view_568); view_568 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16) + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 256, '0'); convert_element_type_548 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32) + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = rsqrt_33 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_150); mul_132 = wait_tensor_150 = None + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16) + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 256, '0'); convert_element_type_551 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_151, [1, 0]); wait_tensor_151 = None + view_571 = torch.ops.aten.view.default(convert_element_type_550, [16384, 4096]); convert_element_type_550 = None + mm_116 = torch.ops.aten.mm.default(view_571, permute_184); permute_184 = None + view_572 = torch.ops.aten.view.default(mm_116, [2, 8192, 14336]) + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_572, torch.float32); view_572 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); convert_element_type_554 = sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16) + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 256, '0'); convert_element_type_556 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_152, [1, 0]); wait_tensor_152 = None + mm_117 = torch.ops.aten.mm.default(view_571, permute_185); view_571 = permute_185 = None + view_575 = torch.ops.aten.view.default(mm_117, [2, 8192, 14336]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_575); convert_element_type_555 = view_575 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16) + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 256, '0'); convert_element_type_559 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_153, [1, 0]); wait_tensor_153 = None + view_577 = torch.ops.aten.view.default(mul_135, [16384, 14336]); mul_135 = None + mm_118 = torch.ops.aten.mm.default(view_577, permute_186); view_577 = permute_186 = None + view_578 = torch.ops.aten.view.default(mm_118, [2, 8192, 4096]); mm_118 = None + add_67 = torch.ops.aten.add.Tensor(add_65, view_578); add_65 = view_578 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16) + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 256, '0'); convert_element_type_562 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32) + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = rsqrt_34 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_154); mul_136 = wait_tensor_154 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16) + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 256, '0'); convert_element_type_565 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + view_581 = torch.ops.aten.view.default(convert_element_type_564, [16384, 4096]); convert_element_type_564 = None + mm_119 = torch.ops.aten.mm.default(view_581, permute_187); permute_187 = None + view_582 = torch.ops.aten.view.default(mm_119, [2, 8192, 4096]) + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16) + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 256, '0'); convert_element_type_568 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + mm_120 = torch.ops.aten.mm.default(view_581, permute_188); permute_188 = None + view_585 = torch.ops.aten.view.default(mm_120, [2, 8192, 1024]); mm_120 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16) + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 256, '0'); convert_element_type_571 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_157, [1, 0]); wait_tensor_157 = None + mm_121 = torch.ops.aten.mm.default(view_581, permute_189); view_581 = permute_189 = None + view_588 = torch.ops.aten.view.default(mm_121, [2, 8192, 1024]) + view_589 = torch.ops.aten.view.default(view_582, [2, 8192, -1, 128]); view_582 = None + view_590 = torch.ops.aten.view.default(view_585, [2, 8192, -1, 128]); view_585 = None + view_591 = torch.ops.aten.view.default(view_588, [2, 8192, -1, 128]); view_588 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_589, torch.float32); view_589 = None + view_592 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 32, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_592); view_592 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_590, torch.float32); view_590 = None + view_593 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 8, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_593); view_593 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_16); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_595 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 32, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_16); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_596 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 8, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_595, torch.bfloat16); view_595 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_596, torch.bfloat16); view_596 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 8, 4, 128]); unsqueeze_34 = None + clone_34 = torch.ops.aten.clone.default(expand_34, memory_format = torch.contiguous_format); expand_34 = None + view_597 = torch.ops.aten.view.default(clone_34, [2, 8192, 32, 128]); clone_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_591, 3); view_591 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 8, 4, 128]); unsqueeze_35 = None + clone_35 = torch.ops.aten.clone.default(expand_35, memory_format = torch.contiguous_format); expand_35 = None + view_598 = torch.ops.aten.view.default(clone_35, [2, 8192, 32, 128]); clone_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_597, [0, 2, 1, 3]); view_597 = None + permute_192 = torch.ops.aten.permute.default(view_598, [0, 2, 1, 3]); view_598 = None + _scaled_dot_product_cudnn_attention_17 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_190, permute_191, permute_192, None, True, 0.0, True); permute_190 = permute_191 = permute_192 = None + getitem_153 = _scaled_dot_product_cudnn_attention_17[0] + getitem_154 = _scaled_dot_product_cudnn_attention_17[1] + getitem_159 = _scaled_dot_product_cudnn_attention_17[6] + getitem_160 = _scaled_dot_product_cudnn_attention_17[7]; _scaled_dot_product_cudnn_attention_17 = None + permute_193 = torch.ops.aten.permute.default(getitem_153, [0, 2, 1, 3]) + view_599 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16) + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 256, '0'); convert_element_type_578 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_158, [1, 0]); wait_tensor_158 = None + view_601 = torch.ops.aten.view.default(view_599, [16384, 4096]); view_599 = None + mm_122 = torch.ops.aten.mm.default(view_601, permute_194); view_601 = permute_194 = None + view_602 = torch.ops.aten.view.default(mm_122, [2, 8192, 4096]); mm_122 = None + add_69 = torch.ops.aten.add.Tensor(add_67, view_602); view_602 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16) + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 256, '0'); convert_element_type_581 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32) + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = rsqrt_35 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_159); mul_140 = wait_tensor_159 = None + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16) + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 256, '0'); convert_element_type_584 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + view_605 = torch.ops.aten.view.default(convert_element_type_583, [16384, 4096]); convert_element_type_583 = None + mm_123 = torch.ops.aten.mm.default(view_605, permute_195); permute_195 = None + view_606 = torch.ops.aten.view.default(mm_123, [2, 8192, 14336]) + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_606, torch.float32); view_606 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); convert_element_type_587 = sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16) + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 256, '0'); convert_element_type_589 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_124 = torch.ops.aten.mm.default(view_605, permute_196); view_605 = permute_196 = None + view_609 = torch.ops.aten.view.default(mm_124, [2, 8192, 14336]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_609); convert_element_type_588 = view_609 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16) + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 256, '0'); convert_element_type_592 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + view_611 = torch.ops.aten.view.default(mul_143, [16384, 14336]); mul_143 = None + mm_125 = torch.ops.aten.mm.default(view_611, permute_197); view_611 = permute_197 = None + view_612 = torch.ops.aten.view.default(mm_125, [2, 8192, 4096]); mm_125 = None + add_71 = torch.ops.aten.add.Tensor(add_69, view_612); add_69 = view_612 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16) + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 256, '0'); convert_element_type_595 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32) + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = rsqrt_36 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_163); mul_144 = wait_tensor_163 = None + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16) + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 256, '0'); convert_element_type_598 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_164, [1, 0]); wait_tensor_164 = None + view_615 = torch.ops.aten.view.default(convert_element_type_597, [16384, 4096]); convert_element_type_597 = None + mm_126 = torch.ops.aten.mm.default(view_615, permute_198); permute_198 = None + view_616 = torch.ops.aten.view.default(mm_126, [2, 8192, 4096]) + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16) + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 256, '0'); convert_element_type_601 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_165, [1, 0]); wait_tensor_165 = None + mm_127 = torch.ops.aten.mm.default(view_615, permute_199); permute_199 = None + view_619 = torch.ops.aten.view.default(mm_127, [2, 8192, 1024]); mm_127 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16) + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 256, '0'); convert_element_type_604 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_166, [1, 0]); wait_tensor_166 = None + mm_128 = torch.ops.aten.mm.default(view_615, permute_200); view_615 = permute_200 = None + view_622 = torch.ops.aten.view.default(mm_128, [2, 8192, 1024]) + view_623 = torch.ops.aten.view.default(view_616, [2, 8192, -1, 128]); view_616 = None + view_624 = torch.ops.aten.view.default(view_619, [2, 8192, -1, 128]); view_619 = None + view_625 = torch.ops.aten.view.default(view_622, [2, 8192, -1, 128]); view_622 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_623, torch.float32); view_623 = None + view_626 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 32, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_626); view_626 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_624, torch.float32); view_624 = None + view_627 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 8, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_627); view_627 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_16); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_629 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 32, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_16); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_630 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 8, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_629, torch.bfloat16); view_629 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_630, torch.bfloat16); view_630 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 8, 4, 128]); unsqueeze_36 = None + clone_36 = torch.ops.aten.clone.default(expand_36, memory_format = torch.contiguous_format); expand_36 = None + view_631 = torch.ops.aten.view.default(clone_36, [2, 8192, 32, 128]); clone_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_625, 3); view_625 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 8, 4, 128]); unsqueeze_37 = None + clone_37 = torch.ops.aten.clone.default(expand_37, memory_format = torch.contiguous_format); expand_37 = None + view_632 = torch.ops.aten.view.default(clone_37, [2, 8192, 32, 128]); clone_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_631, [0, 2, 1, 3]); view_631 = None + permute_203 = torch.ops.aten.permute.default(view_632, [0, 2, 1, 3]); view_632 = None + _scaled_dot_product_cudnn_attention_18 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_201, permute_202, permute_203, None, True, 0.0, True); permute_201 = permute_202 = permute_203 = None + getitem_162 = _scaled_dot_product_cudnn_attention_18[0] + getitem_163 = _scaled_dot_product_cudnn_attention_18[1] + getitem_168 = _scaled_dot_product_cudnn_attention_18[6] + getitem_169 = _scaled_dot_product_cudnn_attention_18[7]; _scaled_dot_product_cudnn_attention_18 = None + permute_204 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_633 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16) + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 256, '0'); convert_element_type_611 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + view_635 = torch.ops.aten.view.default(view_633, [16384, 4096]); view_633 = None + mm_129 = torch.ops.aten.mm.default(view_635, permute_205); view_635 = permute_205 = None + view_636 = torch.ops.aten.view.default(mm_129, [2, 8192, 4096]); mm_129 = None + add_73 = torch.ops.aten.add.Tensor(add_71, view_636); view_636 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16) + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 256, '0'); convert_element_type_614 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32) + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = rsqrt_37 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_168); mul_148 = wait_tensor_168 = None + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16) + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 256, '0'); convert_element_type_617 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + view_639 = torch.ops.aten.view.default(convert_element_type_616, [16384, 4096]); convert_element_type_616 = None + mm_130 = torch.ops.aten.mm.default(view_639, permute_206); permute_206 = None + view_640 = torch.ops.aten.view.default(mm_130, [2, 8192, 14336]) + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_640, torch.float32); view_640 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); convert_element_type_620 = sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16) + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 256, '0'); convert_element_type_622 = None + wait_tensor_170 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_170, [1, 0]); wait_tensor_170 = None + mm_131 = torch.ops.aten.mm.default(view_639, permute_207); view_639 = permute_207 = None + view_643 = torch.ops.aten.view.default(mm_131, [2, 8192, 14336]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_643); convert_element_type_621 = view_643 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16) + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 256, '0'); convert_element_type_625 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_171, [1, 0]); wait_tensor_171 = None + view_645 = torch.ops.aten.view.default(mul_151, [16384, 14336]); mul_151 = None + mm_132 = torch.ops.aten.mm.default(view_645, permute_208); view_645 = permute_208 = None + view_646 = torch.ops.aten.view.default(mm_132, [2, 8192, 4096]); mm_132 = None + add_75 = torch.ops.aten.add.Tensor(add_73, view_646); add_73 = view_646 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16) + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 256, '0'); convert_element_type_628 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32) + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = rsqrt_38 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_172); mul_152 = wait_tensor_172 = None + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16) + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 256, '0'); convert_element_type_631 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + view_649 = torch.ops.aten.view.default(convert_element_type_630, [16384, 4096]); convert_element_type_630 = None + mm_133 = torch.ops.aten.mm.default(view_649, permute_209); permute_209 = None + view_650 = torch.ops.aten.view.default(mm_133, [2, 8192, 4096]) + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16) + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 256, '0'); convert_element_type_634 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_134 = torch.ops.aten.mm.default(view_649, permute_210); permute_210 = None + view_653 = torch.ops.aten.view.default(mm_134, [2, 8192, 1024]); mm_134 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16) + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 256, '0'); convert_element_type_637 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + mm_135 = torch.ops.aten.mm.default(view_649, permute_211); view_649 = permute_211 = None + view_656 = torch.ops.aten.view.default(mm_135, [2, 8192, 1024]) + view_657 = torch.ops.aten.view.default(view_650, [2, 8192, -1, 128]); view_650 = None + view_658 = torch.ops.aten.view.default(view_653, [2, 8192, -1, 128]); view_653 = None + view_659 = torch.ops.aten.view.default(view_656, [2, 8192, -1, 128]); view_656 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_657, torch.float32); view_657 = None + view_660 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 32, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_660); view_660 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_658, torch.float32); view_658 = None + view_661 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 8, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_661); view_661 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_16); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_663 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 32, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_16); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_664 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 8, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_663, torch.bfloat16); view_663 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_664, torch.bfloat16); view_664 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 8, 4, 128]); unsqueeze_38 = None + clone_38 = torch.ops.aten.clone.default(expand_38, memory_format = torch.contiguous_format); expand_38 = None + view_665 = torch.ops.aten.view.default(clone_38, [2, 8192, 32, 128]); clone_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_659, 3); view_659 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 8, 4, 128]); unsqueeze_39 = None + clone_39 = torch.ops.aten.clone.default(expand_39, memory_format = torch.contiguous_format); expand_39 = None + view_666 = torch.ops.aten.view.default(clone_39, [2, 8192, 32, 128]); clone_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_665, [0, 2, 1, 3]); view_665 = None + permute_214 = torch.ops.aten.permute.default(view_666, [0, 2, 1, 3]); view_666 = None + _scaled_dot_product_cudnn_attention_19 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_212, permute_213, permute_214, None, True, 0.0, True); permute_212 = permute_213 = permute_214 = None + getitem_171 = _scaled_dot_product_cudnn_attention_19[0] + getitem_172 = _scaled_dot_product_cudnn_attention_19[1] + getitem_177 = _scaled_dot_product_cudnn_attention_19[6] + getitem_178 = _scaled_dot_product_cudnn_attention_19[7]; _scaled_dot_product_cudnn_attention_19 = None + permute_215 = torch.ops.aten.permute.default(getitem_171, [0, 2, 1, 3]) + view_667 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16) + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 256, '0'); convert_element_type_644 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + view_669 = torch.ops.aten.view.default(view_667, [16384, 4096]); view_667 = None + mm_136 = torch.ops.aten.mm.default(view_669, permute_216); view_669 = permute_216 = None + view_670 = torch.ops.aten.view.default(mm_136, [2, 8192, 4096]); mm_136 = None + add_77 = torch.ops.aten.add.Tensor(add_75, view_670); view_670 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16) + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 256, '0'); convert_element_type_647 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32) + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = rsqrt_39 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_177); mul_156 = wait_tensor_177 = None + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16) + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 256, '0'); convert_element_type_650 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_178, [1, 0]); wait_tensor_178 = None + view_673 = torch.ops.aten.view.default(convert_element_type_649, [16384, 4096]); convert_element_type_649 = None + mm_137 = torch.ops.aten.mm.default(view_673, permute_217); permute_217 = None + view_674 = torch.ops.aten.view.default(mm_137, [2, 8192, 14336]) + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_674, torch.float32); view_674 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); convert_element_type_653 = sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16) + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 256, '0'); convert_element_type_655 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_179, [1, 0]); wait_tensor_179 = None + mm_138 = torch.ops.aten.mm.default(view_673, permute_218); view_673 = permute_218 = None + view_677 = torch.ops.aten.view.default(mm_138, [2, 8192, 14336]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_677); convert_element_type_654 = view_677 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16) + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 256, '0'); convert_element_type_658 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + view_679 = torch.ops.aten.view.default(mul_159, [16384, 14336]); mul_159 = None + mm_139 = torch.ops.aten.mm.default(view_679, permute_219); view_679 = permute_219 = None + view_680 = torch.ops.aten.view.default(mm_139, [2, 8192, 4096]); mm_139 = None + add_79 = torch.ops.aten.add.Tensor(add_77, view_680); add_77 = view_680 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16) + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 256, '0'); convert_element_type_661 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32) + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = rsqrt_40 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_181); mul_160 = wait_tensor_181 = None + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16) + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 256, '0'); convert_element_type_664 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + view_683 = torch.ops.aten.view.default(convert_element_type_663, [16384, 4096]); convert_element_type_663 = None + mm_140 = torch.ops.aten.mm.default(view_683, permute_220); permute_220 = None + view_684 = torch.ops.aten.view.default(mm_140, [2, 8192, 4096]) + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16) + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 256, '0'); convert_element_type_667 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_183, [1, 0]); wait_tensor_183 = None + mm_141 = torch.ops.aten.mm.default(view_683, permute_221); permute_221 = None + view_687 = torch.ops.aten.view.default(mm_141, [2, 8192, 1024]); mm_141 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16) + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 256, '0'); convert_element_type_670 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_184, [1, 0]); wait_tensor_184 = None + mm_142 = torch.ops.aten.mm.default(view_683, permute_222); view_683 = permute_222 = None + view_690 = torch.ops.aten.view.default(mm_142, [2, 8192, 1024]) + view_691 = torch.ops.aten.view.default(view_684, [2, 8192, -1, 128]); view_684 = None + view_692 = torch.ops.aten.view.default(view_687, [2, 8192, -1, 128]); view_687 = None + view_693 = torch.ops.aten.view.default(view_690, [2, 8192, -1, 128]); view_690 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_691, torch.float32); view_691 = None + view_694 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 32, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_694); view_694 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_692, torch.float32); view_692 = None + view_695 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 8, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_695); view_695 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_16); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_697 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 32, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_16); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_698 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 8, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_697, torch.bfloat16); view_697 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_698, torch.bfloat16); view_698 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 8, 4, 128]); unsqueeze_40 = None + clone_40 = torch.ops.aten.clone.default(expand_40, memory_format = torch.contiguous_format); expand_40 = None + view_699 = torch.ops.aten.view.default(clone_40, [2, 8192, 32, 128]); clone_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_693, 3); view_693 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 8, 4, 128]); unsqueeze_41 = None + clone_41 = torch.ops.aten.clone.default(expand_41, memory_format = torch.contiguous_format); expand_41 = None + view_700 = torch.ops.aten.view.default(clone_41, [2, 8192, 32, 128]); clone_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_699, [0, 2, 1, 3]); view_699 = None + permute_225 = torch.ops.aten.permute.default(view_700, [0, 2, 1, 3]); view_700 = None + _scaled_dot_product_cudnn_attention_20 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_223, permute_224, permute_225, None, True, 0.0, True); permute_223 = permute_224 = permute_225 = None + getitem_180 = _scaled_dot_product_cudnn_attention_20[0] + getitem_181 = _scaled_dot_product_cudnn_attention_20[1] + getitem_186 = _scaled_dot_product_cudnn_attention_20[6] + getitem_187 = _scaled_dot_product_cudnn_attention_20[7]; _scaled_dot_product_cudnn_attention_20 = None + permute_226 = torch.ops.aten.permute.default(getitem_180, [0, 2, 1, 3]) + view_701 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16) + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 256, '0'); convert_element_type_677 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_185, [1, 0]); wait_tensor_185 = None + view_703 = torch.ops.aten.view.default(view_701, [16384, 4096]); view_701 = None + mm_143 = torch.ops.aten.mm.default(view_703, permute_227); view_703 = permute_227 = None + view_704 = torch.ops.aten.view.default(mm_143, [2, 8192, 4096]); mm_143 = None + add_81 = torch.ops.aten.add.Tensor(add_79, view_704); view_704 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16) + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 256, '0'); convert_element_type_680 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32) + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = rsqrt_41 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_186); mul_164 = wait_tensor_186 = None + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16) + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 256, '0'); convert_element_type_683 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + view_707 = torch.ops.aten.view.default(convert_element_type_682, [16384, 4096]); convert_element_type_682 = None + mm_144 = torch.ops.aten.mm.default(view_707, permute_228); permute_228 = None + view_708 = torch.ops.aten.view.default(mm_144, [2, 8192, 14336]) + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_708, torch.float32); view_708 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); convert_element_type_686 = sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16) + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 256, '0'); convert_element_type_688 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + mm_145 = torch.ops.aten.mm.default(view_707, permute_229); view_707 = permute_229 = None + view_711 = torch.ops.aten.view.default(mm_145, [2, 8192, 14336]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_711); convert_element_type_687 = view_711 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16) + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 256, '0'); convert_element_type_691 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + view_713 = torch.ops.aten.view.default(mul_167, [16384, 14336]); mul_167 = None + mm_146 = torch.ops.aten.mm.default(view_713, permute_230); view_713 = permute_230 = None + view_714 = torch.ops.aten.view.default(mm_146, [2, 8192, 4096]); mm_146 = None + add_83 = torch.ops.aten.add.Tensor(add_81, view_714); add_81 = view_714 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16) + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 256, '0'); convert_element_type_694 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32) + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = rsqrt_42 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_190); mul_168 = wait_tensor_190 = None + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16) + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 256, '0'); convert_element_type_697 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_191, [1, 0]); wait_tensor_191 = None + view_717 = torch.ops.aten.view.default(convert_element_type_696, [16384, 4096]); convert_element_type_696 = None + mm_147 = torch.ops.aten.mm.default(view_717, permute_231); permute_231 = None + view_718 = torch.ops.aten.view.default(mm_147, [2, 8192, 4096]) + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16) + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 256, '0'); convert_element_type_700 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_192, [1, 0]); wait_tensor_192 = None + mm_148 = torch.ops.aten.mm.default(view_717, permute_232); permute_232 = None + view_721 = torch.ops.aten.view.default(mm_148, [2, 8192, 1024]); mm_148 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16) + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 256, '0'); convert_element_type_703 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + mm_149 = torch.ops.aten.mm.default(view_717, permute_233); view_717 = permute_233 = None + view_724 = torch.ops.aten.view.default(mm_149, [2, 8192, 1024]) + view_725 = torch.ops.aten.view.default(view_718, [2, 8192, -1, 128]); view_718 = None + view_726 = torch.ops.aten.view.default(view_721, [2, 8192, -1, 128]); view_721 = None + view_727 = torch.ops.aten.view.default(view_724, [2, 8192, -1, 128]); view_724 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_725, torch.float32); view_725 = None + view_728 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 32, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_728); view_728 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_726, torch.float32); view_726 = None + view_729 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 8, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_729); view_729 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_16); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_731 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 32, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_16); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_732 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 8, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_731, torch.bfloat16); view_731 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_732, torch.bfloat16); view_732 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 8, 4, 128]); unsqueeze_42 = None + clone_42 = torch.ops.aten.clone.default(expand_42, memory_format = torch.contiguous_format); expand_42 = None + view_733 = torch.ops.aten.view.default(clone_42, [2, 8192, 32, 128]); clone_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_727, 3); view_727 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 8, 4, 128]); unsqueeze_43 = None + clone_43 = torch.ops.aten.clone.default(expand_43, memory_format = torch.contiguous_format); expand_43 = None + view_734 = torch.ops.aten.view.default(clone_43, [2, 8192, 32, 128]); clone_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_733, [0, 2, 1, 3]); view_733 = None + permute_236 = torch.ops.aten.permute.default(view_734, [0, 2, 1, 3]); view_734 = None + _scaled_dot_product_cudnn_attention_21 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_234, permute_235, permute_236, None, True, 0.0, True); permute_234 = permute_235 = permute_236 = None + getitem_189 = _scaled_dot_product_cudnn_attention_21[0] + getitem_190 = _scaled_dot_product_cudnn_attention_21[1] + getitem_195 = _scaled_dot_product_cudnn_attention_21[6] + getitem_196 = _scaled_dot_product_cudnn_attention_21[7]; _scaled_dot_product_cudnn_attention_21 = None + permute_237 = torch.ops.aten.permute.default(getitem_189, [0, 2, 1, 3]) + view_735 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16) + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 256, '0'); convert_element_type_710 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + view_737 = torch.ops.aten.view.default(view_735, [16384, 4096]); view_735 = None + mm_150 = torch.ops.aten.mm.default(view_737, permute_238); view_737 = permute_238 = None + view_738 = torch.ops.aten.view.default(mm_150, [2, 8192, 4096]); mm_150 = None + add_85 = torch.ops.aten.add.Tensor(add_83, view_738); view_738 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16) + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 256, '0'); convert_element_type_713 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32) + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = rsqrt_43 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_195); mul_172 = wait_tensor_195 = None + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16) + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 256, '0'); convert_element_type_716 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_196, [1, 0]); wait_tensor_196 = None + view_741 = torch.ops.aten.view.default(convert_element_type_715, [16384, 4096]); convert_element_type_715 = None + mm_151 = torch.ops.aten.mm.default(view_741, permute_239); permute_239 = None + view_742 = torch.ops.aten.view.default(mm_151, [2, 8192, 14336]) + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_742, torch.float32); view_742 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); convert_element_type_719 = sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16) + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 256, '0'); convert_element_type_721 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_197, [1, 0]); wait_tensor_197 = None + mm_152 = torch.ops.aten.mm.default(view_741, permute_240); view_741 = permute_240 = None + view_745 = torch.ops.aten.view.default(mm_152, [2, 8192, 14336]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_745); convert_element_type_720 = view_745 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16) + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 256, '0'); convert_element_type_724 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_198, [1, 0]); wait_tensor_198 = None + view_747 = torch.ops.aten.view.default(mul_175, [16384, 14336]); mul_175 = None + mm_153 = torch.ops.aten.mm.default(view_747, permute_241); view_747 = permute_241 = None + view_748 = torch.ops.aten.view.default(mm_153, [2, 8192, 4096]); mm_153 = None + add_87 = torch.ops.aten.add.Tensor(add_85, view_748); add_85 = view_748 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16) + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 256, '0'); convert_element_type_727 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32) + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = rsqrt_44 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_199); mul_176 = wait_tensor_199 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16) + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 256, '0'); convert_element_type_730 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + view_751 = torch.ops.aten.view.default(convert_element_type_729, [16384, 4096]); convert_element_type_729 = None + mm_154 = torch.ops.aten.mm.default(view_751, permute_242); permute_242 = None + view_752 = torch.ops.aten.view.default(mm_154, [2, 8192, 4096]) + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16) + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 256, '0'); convert_element_type_733 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + mm_155 = torch.ops.aten.mm.default(view_751, permute_243); permute_243 = None + view_755 = torch.ops.aten.view.default(mm_155, [2, 8192, 1024]); mm_155 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16) + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 256, '0'); convert_element_type_736 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + mm_156 = torch.ops.aten.mm.default(view_751, permute_244); view_751 = permute_244 = None + view_758 = torch.ops.aten.view.default(mm_156, [2, 8192, 1024]) + view_759 = torch.ops.aten.view.default(view_752, [2, 8192, -1, 128]); view_752 = None + view_760 = torch.ops.aten.view.default(view_755, [2, 8192, -1, 128]); view_755 = None + view_761 = torch.ops.aten.view.default(view_758, [2, 8192, -1, 128]); view_758 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_759, torch.float32); view_759 = None + view_762 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 32, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_762); view_762 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_760, torch.float32); view_760 = None + view_763 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 8, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_763); view_763 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_16); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_765 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 32, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_16); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_766 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 8, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_765, torch.bfloat16); view_765 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_766, torch.bfloat16); view_766 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 8, 4, 128]); unsqueeze_44 = None + clone_44 = torch.ops.aten.clone.default(expand_44, memory_format = torch.contiguous_format); expand_44 = None + view_767 = torch.ops.aten.view.default(clone_44, [2, 8192, 32, 128]); clone_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_761, 3); view_761 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 8, 4, 128]); unsqueeze_45 = None + clone_45 = torch.ops.aten.clone.default(expand_45, memory_format = torch.contiguous_format); expand_45 = None + view_768 = torch.ops.aten.view.default(clone_45, [2, 8192, 32, 128]); clone_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_767, [0, 2, 1, 3]); view_767 = None + permute_247 = torch.ops.aten.permute.default(view_768, [0, 2, 1, 3]); view_768 = None + _scaled_dot_product_cudnn_attention_22 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_245, permute_246, permute_247, None, True, 0.0, True); permute_245 = permute_246 = permute_247 = None + getitem_198 = _scaled_dot_product_cudnn_attention_22[0] + getitem_199 = _scaled_dot_product_cudnn_attention_22[1] + getitem_204 = _scaled_dot_product_cudnn_attention_22[6] + getitem_205 = _scaled_dot_product_cudnn_attention_22[7]; _scaled_dot_product_cudnn_attention_22 = None + permute_248 = torch.ops.aten.permute.default(getitem_198, [0, 2, 1, 3]) + view_769 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16) + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 256, '0'); convert_element_type_743 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_203, [1, 0]); wait_tensor_203 = None + view_771 = torch.ops.aten.view.default(view_769, [16384, 4096]); view_769 = None + mm_157 = torch.ops.aten.mm.default(view_771, permute_249); view_771 = permute_249 = None + view_772 = torch.ops.aten.view.default(mm_157, [2, 8192, 4096]); mm_157 = None + add_89 = torch.ops.aten.add.Tensor(add_87, view_772); view_772 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16) + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 256, '0'); convert_element_type_746 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32) + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = rsqrt_45 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_204); mul_180 = wait_tensor_204 = None + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16) + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 256, '0'); convert_element_type_749 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_205, [1, 0]); wait_tensor_205 = None + view_775 = torch.ops.aten.view.default(convert_element_type_748, [16384, 4096]); convert_element_type_748 = None + mm_158 = torch.ops.aten.mm.default(view_775, permute_250); permute_250 = None + view_776 = torch.ops.aten.view.default(mm_158, [2, 8192, 14336]) + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_776, torch.float32); view_776 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); convert_element_type_752 = sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16) + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 256, '0'); convert_element_type_754 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + mm_159 = torch.ops.aten.mm.default(view_775, permute_251); view_775 = permute_251 = None + view_779 = torch.ops.aten.view.default(mm_159, [2, 8192, 14336]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_779); convert_element_type_753 = view_779 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16) + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 256, '0'); convert_element_type_757 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + view_781 = torch.ops.aten.view.default(mul_183, [16384, 14336]); mul_183 = None + mm_160 = torch.ops.aten.mm.default(view_781, permute_252); view_781 = permute_252 = None + view_782 = torch.ops.aten.view.default(mm_160, [2, 8192, 4096]); mm_160 = None + add_91 = torch.ops.aten.add.Tensor(add_89, view_782); add_89 = view_782 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16) + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 256, '0'); convert_element_type_760 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32) + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = rsqrt_46 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_208); mul_184 = wait_tensor_208 = None + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16) + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 256, '0'); convert_element_type_763 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_209, [1, 0]); wait_tensor_209 = None + view_785 = torch.ops.aten.view.default(convert_element_type_762, [16384, 4096]); convert_element_type_762 = None + mm_161 = torch.ops.aten.mm.default(view_785, permute_253); permute_253 = None + view_786 = torch.ops.aten.view.default(mm_161, [2, 8192, 4096]) + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16) + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 256, '0'); convert_element_type_766 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_210, [1, 0]); wait_tensor_210 = None + mm_162 = torch.ops.aten.mm.default(view_785, permute_254); permute_254 = None + view_789 = torch.ops.aten.view.default(mm_162, [2, 8192, 1024]); mm_162 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16) + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 256, '0'); convert_element_type_769 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_211, [1, 0]); wait_tensor_211 = None + mm_163 = torch.ops.aten.mm.default(view_785, permute_255); view_785 = permute_255 = None + view_792 = torch.ops.aten.view.default(mm_163, [2, 8192, 1024]) + view_793 = torch.ops.aten.view.default(view_786, [2, 8192, -1, 128]); view_786 = None + view_794 = torch.ops.aten.view.default(view_789, [2, 8192, -1, 128]); view_789 = None + view_795 = torch.ops.aten.view.default(view_792, [2, 8192, -1, 128]); view_792 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_793, torch.float32); view_793 = None + view_796 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 32, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_796); view_796 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_794, torch.float32); view_794 = None + view_797 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 8, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_797); view_797 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_16); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_799 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 32, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_16); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_800 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 8, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_799, torch.bfloat16); view_799 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_800, torch.bfloat16); view_800 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 8, 4, 128]); unsqueeze_46 = None + clone_46 = torch.ops.aten.clone.default(expand_46, memory_format = torch.contiguous_format); expand_46 = None + view_801 = torch.ops.aten.view.default(clone_46, [2, 8192, 32, 128]); clone_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_795, 3); view_795 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 8, 4, 128]); unsqueeze_47 = None + clone_47 = torch.ops.aten.clone.default(expand_47, memory_format = torch.contiguous_format); expand_47 = None + view_802 = torch.ops.aten.view.default(clone_47, [2, 8192, 32, 128]); clone_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_801, [0, 2, 1, 3]); view_801 = None + permute_258 = torch.ops.aten.permute.default(view_802, [0, 2, 1, 3]); view_802 = None + _scaled_dot_product_cudnn_attention_23 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_256, permute_257, permute_258, None, True, 0.0, True); permute_256 = permute_257 = permute_258 = None + getitem_207 = _scaled_dot_product_cudnn_attention_23[0] + getitem_208 = _scaled_dot_product_cudnn_attention_23[1] + getitem_213 = _scaled_dot_product_cudnn_attention_23[6] + getitem_214 = _scaled_dot_product_cudnn_attention_23[7]; _scaled_dot_product_cudnn_attention_23 = None + permute_259 = torch.ops.aten.permute.default(getitem_207, [0, 2, 1, 3]) + view_803 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16) + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 256, '0'); convert_element_type_776 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + view_805 = torch.ops.aten.view.default(view_803, [16384, 4096]); view_803 = None + mm_164 = torch.ops.aten.mm.default(view_805, permute_260); view_805 = permute_260 = None + view_806 = torch.ops.aten.view.default(mm_164, [2, 8192, 4096]); mm_164 = None + add_93 = torch.ops.aten.add.Tensor(add_91, view_806); view_806 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16) + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 256, '0'); convert_element_type_779 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32) + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = rsqrt_47 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_213); mul_188 = wait_tensor_213 = None + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16) + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 256, '0'); convert_element_type_782 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + view_809 = torch.ops.aten.view.default(convert_element_type_781, [16384, 4096]); convert_element_type_781 = None + mm_165 = torch.ops.aten.mm.default(view_809, permute_261); permute_261 = None + view_810 = torch.ops.aten.view.default(mm_165, [2, 8192, 14336]) + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_810, torch.float32); view_810 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); convert_element_type_785 = sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16) + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 256, '0'); convert_element_type_787 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + mm_166 = torch.ops.aten.mm.default(view_809, permute_262); view_809 = permute_262 = None + view_813 = torch.ops.aten.view.default(mm_166, [2, 8192, 14336]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_813); convert_element_type_786 = view_813 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16) + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 256, '0'); convert_element_type_790 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_216, [1, 0]); wait_tensor_216 = None + view_815 = torch.ops.aten.view.default(mul_191, [16384, 14336]); mul_191 = None + mm_167 = torch.ops.aten.mm.default(view_815, permute_263); view_815 = permute_263 = None + view_816 = torch.ops.aten.view.default(mm_167, [2, 8192, 4096]); mm_167 = None + add_95 = torch.ops.aten.add.Tensor(add_93, view_816); add_93 = view_816 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16) + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 256, '0'); convert_element_type_793 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32) + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = rsqrt_48 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_217); mul_192 = wait_tensor_217 = None + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16) + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 256, '0'); convert_element_type_796 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_218, [1, 0]); wait_tensor_218 = None + view_819 = torch.ops.aten.view.default(convert_element_type_795, [16384, 4096]); convert_element_type_795 = None + mm_168 = torch.ops.aten.mm.default(view_819, permute_264); permute_264 = None + view_820 = torch.ops.aten.view.default(mm_168, [2, 8192, 4096]) + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16) + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 256, '0'); convert_element_type_799 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + mm_169 = torch.ops.aten.mm.default(view_819, permute_265); permute_265 = None + view_823 = torch.ops.aten.view.default(mm_169, [2, 8192, 1024]); mm_169 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16) + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 256, '0'); convert_element_type_802 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + mm_170 = torch.ops.aten.mm.default(view_819, permute_266); view_819 = permute_266 = None + view_826 = torch.ops.aten.view.default(mm_170, [2, 8192, 1024]) + view_827 = torch.ops.aten.view.default(view_820, [2, 8192, -1, 128]); view_820 = None + view_828 = torch.ops.aten.view.default(view_823, [2, 8192, -1, 128]); view_823 = None + view_829 = torch.ops.aten.view.default(view_826, [2, 8192, -1, 128]); view_826 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_827, torch.float32); view_827 = None + view_830 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 32, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_830); view_830 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_828, torch.float32); view_828 = None + view_831 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 8, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_831); view_831 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_16); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_833 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 32, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_16); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_834 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 8, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_833, torch.bfloat16); view_833 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_834, torch.bfloat16); view_834 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 8, 4, 128]); unsqueeze_48 = None + clone_48 = torch.ops.aten.clone.default(expand_48, memory_format = torch.contiguous_format); expand_48 = None + view_835 = torch.ops.aten.view.default(clone_48, [2, 8192, 32, 128]); clone_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_829, 3); view_829 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 8, 4, 128]); unsqueeze_49 = None + clone_49 = torch.ops.aten.clone.default(expand_49, memory_format = torch.contiguous_format); expand_49 = None + view_836 = torch.ops.aten.view.default(clone_49, [2, 8192, 32, 128]); clone_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_835, [0, 2, 1, 3]); view_835 = None + permute_269 = torch.ops.aten.permute.default(view_836, [0, 2, 1, 3]); view_836 = None + _scaled_dot_product_cudnn_attention_24 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_267, permute_268, permute_269, None, True, 0.0, True); permute_267 = permute_268 = permute_269 = None + getitem_216 = _scaled_dot_product_cudnn_attention_24[0] + getitem_217 = _scaled_dot_product_cudnn_attention_24[1] + getitem_222 = _scaled_dot_product_cudnn_attention_24[6] + getitem_223 = _scaled_dot_product_cudnn_attention_24[7]; _scaled_dot_product_cudnn_attention_24 = None + permute_270 = torch.ops.aten.permute.default(getitem_216, [0, 2, 1, 3]) + view_837 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16) + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 256, '0'); convert_element_type_809 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + view_839 = torch.ops.aten.view.default(view_837, [16384, 4096]); view_837 = None + mm_171 = torch.ops.aten.mm.default(view_839, permute_271); view_839 = permute_271 = None + view_840 = torch.ops.aten.view.default(mm_171, [2, 8192, 4096]); mm_171 = None + add_97 = torch.ops.aten.add.Tensor(add_95, view_840); view_840 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16) + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 256, '0'); convert_element_type_812 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32) + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = rsqrt_49 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_222); mul_196 = wait_tensor_222 = None + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16) + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 256, '0'); convert_element_type_815 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_223, [1, 0]); wait_tensor_223 = None + view_843 = torch.ops.aten.view.default(convert_element_type_814, [16384, 4096]); convert_element_type_814 = None + mm_172 = torch.ops.aten.mm.default(view_843, permute_272); permute_272 = None + view_844 = torch.ops.aten.view.default(mm_172, [2, 8192, 14336]) + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_844, torch.float32); view_844 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); convert_element_type_818 = sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16) + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 256, '0'); convert_element_type_820 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_224, [1, 0]); wait_tensor_224 = None + mm_173 = torch.ops.aten.mm.default(view_843, permute_273); view_843 = permute_273 = None + view_847 = torch.ops.aten.view.default(mm_173, [2, 8192, 14336]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_847); convert_element_type_819 = view_847 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16) + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 256, '0'); convert_element_type_823 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + view_849 = torch.ops.aten.view.default(mul_199, [16384, 14336]); mul_199 = None + mm_174 = torch.ops.aten.mm.default(view_849, permute_274); view_849 = permute_274 = None + view_850 = torch.ops.aten.view.default(mm_174, [2, 8192, 4096]); mm_174 = None + add_99 = torch.ops.aten.add.Tensor(add_97, view_850); add_97 = view_850 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16) + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 256, '0'); convert_element_type_826 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32) + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = rsqrt_50 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_226); mul_200 = wait_tensor_226 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16) + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 256, '0'); convert_element_type_829 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + view_853 = torch.ops.aten.view.default(convert_element_type_828, [16384, 4096]); convert_element_type_828 = None + mm_175 = torch.ops.aten.mm.default(view_853, permute_275); permute_275 = None + view_854 = torch.ops.aten.view.default(mm_175, [2, 8192, 4096]) + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16) + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 256, '0'); convert_element_type_832 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + mm_176 = torch.ops.aten.mm.default(view_853, permute_276); permute_276 = None + view_857 = torch.ops.aten.view.default(mm_176, [2, 8192, 1024]); mm_176 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16) + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 256, '0'); convert_element_type_835 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_229, [1, 0]); wait_tensor_229 = None + mm_177 = torch.ops.aten.mm.default(view_853, permute_277); view_853 = permute_277 = None + view_860 = torch.ops.aten.view.default(mm_177, [2, 8192, 1024]) + view_861 = torch.ops.aten.view.default(view_854, [2, 8192, -1, 128]); view_854 = None + view_862 = torch.ops.aten.view.default(view_857, [2, 8192, -1, 128]); view_857 = None + view_863 = torch.ops.aten.view.default(view_860, [2, 8192, -1, 128]); view_860 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_861, torch.float32); view_861 = None + view_864 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 32, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_864); view_864 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_862, torch.float32); view_862 = None + view_865 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 8, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_865); view_865 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_16); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_867 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 32, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_16); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_868 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 8, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_867, torch.bfloat16); view_867 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_868, torch.bfloat16); view_868 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 8, 4, 128]); unsqueeze_50 = None + clone_50 = torch.ops.aten.clone.default(expand_50, memory_format = torch.contiguous_format); expand_50 = None + view_869 = torch.ops.aten.view.default(clone_50, [2, 8192, 32, 128]); clone_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_863, 3); view_863 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 8, 4, 128]); unsqueeze_51 = None + clone_51 = torch.ops.aten.clone.default(expand_51, memory_format = torch.contiguous_format); expand_51 = None + view_870 = torch.ops.aten.view.default(clone_51, [2, 8192, 32, 128]); clone_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_869, [0, 2, 1, 3]); view_869 = None + permute_280 = torch.ops.aten.permute.default(view_870, [0, 2, 1, 3]); view_870 = None + _scaled_dot_product_cudnn_attention_25 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_278, permute_279, permute_280, None, True, 0.0, True); permute_278 = permute_279 = permute_280 = None + getitem_225 = _scaled_dot_product_cudnn_attention_25[0] + getitem_226 = _scaled_dot_product_cudnn_attention_25[1] + getitem_231 = _scaled_dot_product_cudnn_attention_25[6] + getitem_232 = _scaled_dot_product_cudnn_attention_25[7]; _scaled_dot_product_cudnn_attention_25 = None + permute_281 = torch.ops.aten.permute.default(getitem_225, [0, 2, 1, 3]) + view_871 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16) + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 256, '0'); convert_element_type_842 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_230, [1, 0]); wait_tensor_230 = None + view_873 = torch.ops.aten.view.default(view_871, [16384, 4096]); view_871 = None + mm_178 = torch.ops.aten.mm.default(view_873, permute_282); view_873 = permute_282 = None + view_874 = torch.ops.aten.view.default(mm_178, [2, 8192, 4096]); mm_178 = None + add_101 = torch.ops.aten.add.Tensor(add_99, view_874); view_874 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16) + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 256, '0'); convert_element_type_845 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32) + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = rsqrt_51 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_231); mul_204 = wait_tensor_231 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16) + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 256, '0'); convert_element_type_848 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + view_877 = torch.ops.aten.view.default(convert_element_type_847, [16384, 4096]); convert_element_type_847 = None + mm_179 = torch.ops.aten.mm.default(view_877, permute_283); permute_283 = None + view_878 = torch.ops.aten.view.default(mm_179, [2, 8192, 14336]) + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_878, torch.float32); view_878 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); convert_element_type_851 = sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16) + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 256, '0'); convert_element_type_853 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_180 = torch.ops.aten.mm.default(view_877, permute_284); view_877 = permute_284 = None + view_881 = torch.ops.aten.view.default(mm_180, [2, 8192, 14336]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_881); convert_element_type_852 = view_881 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16) + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 256, '0'); convert_element_type_856 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + view_883 = torch.ops.aten.view.default(mul_207, [16384, 14336]); mul_207 = None + mm_181 = torch.ops.aten.mm.default(view_883, permute_285); view_883 = permute_285 = None + view_884 = torch.ops.aten.view.default(mm_181, [2, 8192, 4096]); mm_181 = None + add_103 = torch.ops.aten.add.Tensor(add_101, view_884); add_101 = view_884 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16) + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 256, '0'); convert_element_type_859 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32) + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = rsqrt_52 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_235); mul_208 = wait_tensor_235 = None + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16) + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 256, '0'); convert_element_type_862 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_236, [1, 0]); wait_tensor_236 = None + view_887 = torch.ops.aten.view.default(convert_element_type_861, [16384, 4096]); convert_element_type_861 = None + mm_182 = torch.ops.aten.mm.default(view_887, permute_286); permute_286 = None + view_888 = torch.ops.aten.view.default(mm_182, [2, 8192, 4096]) + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16) + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 256, '0'); convert_element_type_865 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_237, [1, 0]); wait_tensor_237 = None + mm_183 = torch.ops.aten.mm.default(view_887, permute_287); permute_287 = None + view_891 = torch.ops.aten.view.default(mm_183, [2, 8192, 1024]); mm_183 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16) + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 256, '0'); convert_element_type_868 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + mm_184 = torch.ops.aten.mm.default(view_887, permute_288); view_887 = permute_288 = None + view_894 = torch.ops.aten.view.default(mm_184, [2, 8192, 1024]) + view_895 = torch.ops.aten.view.default(view_888, [2, 8192, -1, 128]); view_888 = None + view_896 = torch.ops.aten.view.default(view_891, [2, 8192, -1, 128]); view_891 = None + view_897 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_895, torch.float32); view_895 = None + view_898 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 32, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_898); view_898 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 8, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_16); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_901 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 32, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_16); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_902 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 8, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_901, torch.bfloat16); view_901 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 8, 4, 128]); unsqueeze_52 = None + clone_52 = torch.ops.aten.clone.default(expand_52, memory_format = torch.contiguous_format); expand_52 = None + view_903 = torch.ops.aten.view.default(clone_52, [2, 8192, 32, 128]); clone_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_897, 3); view_897 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 8, 4, 128]); unsqueeze_53 = None + clone_53 = torch.ops.aten.clone.default(expand_53, memory_format = torch.contiguous_format); expand_53 = None + view_904 = torch.ops.aten.view.default(clone_53, [2, 8192, 32, 128]); clone_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_903, [0, 2, 1, 3]); view_903 = None + permute_291 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + _scaled_dot_product_cudnn_attention_26 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_289, permute_290, permute_291, None, True, 0.0, True); permute_289 = permute_290 = permute_291 = None + getitem_234 = _scaled_dot_product_cudnn_attention_26[0] + getitem_235 = _scaled_dot_product_cudnn_attention_26[1] + getitem_240 = _scaled_dot_product_cudnn_attention_26[6] + getitem_241 = _scaled_dot_product_cudnn_attention_26[7]; _scaled_dot_product_cudnn_attention_26 = None + permute_292 = torch.ops.aten.permute.default(getitem_234, [0, 2, 1, 3]) + view_905 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16) + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 256, '0'); convert_element_type_875 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + view_907 = torch.ops.aten.view.default(view_905, [16384, 4096]); view_905 = None + mm_185 = torch.ops.aten.mm.default(view_907, permute_293); view_907 = permute_293 = None + view_908 = torch.ops.aten.view.default(mm_185, [2, 8192, 4096]); mm_185 = None + add_105 = torch.ops.aten.add.Tensor(add_103, view_908); view_908 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16) + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 256, '0'); convert_element_type_878 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32) + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = rsqrt_53 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_240); mul_212 = wait_tensor_240 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16) + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 256, '0'); convert_element_type_881 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + view_911 = torch.ops.aten.view.default(convert_element_type_880, [16384, 4096]); convert_element_type_880 = None + mm_186 = torch.ops.aten.mm.default(view_911, permute_294); permute_294 = None + view_912 = torch.ops.aten.view.default(mm_186, [2, 8192, 14336]) + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_912, torch.float32); view_912 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); convert_element_type_884 = sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16) + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 256, '0'); convert_element_type_886 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_242, [1, 0]); wait_tensor_242 = None + mm_187 = torch.ops.aten.mm.default(view_911, permute_295); view_911 = permute_295 = None + view_915 = torch.ops.aten.view.default(mm_187, [2, 8192, 14336]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_915); convert_element_type_885 = view_915 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16) + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 256, '0'); convert_element_type_889 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_243, [1, 0]); wait_tensor_243 = None + view_917 = torch.ops.aten.view.default(mul_215, [16384, 14336]); mul_215 = None + mm_188 = torch.ops.aten.mm.default(view_917, permute_296); view_917 = permute_296 = None + view_918 = torch.ops.aten.view.default(mm_188, [2, 8192, 4096]); mm_188 = None + add_107 = torch.ops.aten.add.Tensor(add_105, view_918); add_105 = view_918 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16) + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 256, '0'); convert_element_type_892 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32) + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = rsqrt_54 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_244); mul_216 = wait_tensor_244 = None + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16) + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 256, '0'); convert_element_type_895 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + view_921 = torch.ops.aten.view.default(convert_element_type_894, [16384, 4096]); convert_element_type_894 = None + mm_189 = torch.ops.aten.mm.default(view_921, permute_297); permute_297 = None + view_922 = torch.ops.aten.view.default(mm_189, [2, 8192, 4096]) + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16) + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 256, '0'); convert_element_type_898 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_190 = torch.ops.aten.mm.default(view_921, permute_298); permute_298 = None + view_925 = torch.ops.aten.view.default(mm_190, [2, 8192, 1024]); mm_190 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16) + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 256, '0'); convert_element_type_901 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + mm_191 = torch.ops.aten.mm.default(view_921, permute_299); view_921 = permute_299 = None + view_928 = torch.ops.aten.view.default(mm_191, [2, 8192, 1024]) + view_929 = torch.ops.aten.view.default(view_922, [2, 8192, -1, 128]); view_922 = None + view_930 = torch.ops.aten.view.default(view_925, [2, 8192, -1, 128]); view_925 = None + view_931 = torch.ops.aten.view.default(view_928, [2, 8192, -1, 128]); view_928 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_929, torch.float32); view_929 = None + view_932 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 32, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_932); view_932 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_930, torch.float32); view_930 = None + view_933 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 8, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_933); view_933 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_16); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_935 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 32, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_16); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_936 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 8, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_935, torch.bfloat16); view_935 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_936, torch.bfloat16); view_936 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 8, 4, 128]); unsqueeze_54 = None + clone_54 = torch.ops.aten.clone.default(expand_54, memory_format = torch.contiguous_format); expand_54 = None + view_937 = torch.ops.aten.view.default(clone_54, [2, 8192, 32, 128]); clone_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_931, 3); view_931 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 8, 4, 128]); unsqueeze_55 = None + clone_55 = torch.ops.aten.clone.default(expand_55, memory_format = torch.contiguous_format); expand_55 = None + view_938 = torch.ops.aten.view.default(clone_55, [2, 8192, 32, 128]); clone_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_937, [0, 2, 1, 3]); view_937 = None + permute_302 = torch.ops.aten.permute.default(view_938, [0, 2, 1, 3]); view_938 = None + _scaled_dot_product_cudnn_attention_27 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_300, permute_301, permute_302, None, True, 0.0, True); permute_300 = permute_301 = permute_302 = None + getitem_243 = _scaled_dot_product_cudnn_attention_27[0] + getitem_244 = _scaled_dot_product_cudnn_attention_27[1] + getitem_249 = _scaled_dot_product_cudnn_attention_27[6] + getitem_250 = _scaled_dot_product_cudnn_attention_27[7]; _scaled_dot_product_cudnn_attention_27 = None + permute_303 = torch.ops.aten.permute.default(getitem_243, [0, 2, 1, 3]) + view_939 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16) + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 256, '0'); convert_element_type_908 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_248, [1, 0]); wait_tensor_248 = None + view_941 = torch.ops.aten.view.default(view_939, [16384, 4096]); view_939 = None + mm_192 = torch.ops.aten.mm.default(view_941, permute_304); view_941 = permute_304 = None + view_942 = torch.ops.aten.view.default(mm_192, [2, 8192, 4096]); mm_192 = None + add_109 = torch.ops.aten.add.Tensor(add_107, view_942); view_942 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16) + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 256, '0'); convert_element_type_911 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32) + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = rsqrt_55 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_249); mul_220 = wait_tensor_249 = None + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16) + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 256, '0'); convert_element_type_914 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_250, [1, 0]); wait_tensor_250 = None + view_945 = torch.ops.aten.view.default(convert_element_type_913, [16384, 4096]); convert_element_type_913 = None + mm_193 = torch.ops.aten.mm.default(view_945, permute_305); permute_305 = None + view_946 = torch.ops.aten.view.default(mm_193, [2, 8192, 14336]) + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_946, torch.float32); view_946 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); convert_element_type_917 = sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16) + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 256, '0'); convert_element_type_919 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + mm_194 = torch.ops.aten.mm.default(view_945, permute_306); view_945 = permute_306 = None + view_949 = torch.ops.aten.view.default(mm_194, [2, 8192, 14336]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_949); convert_element_type_918 = view_949 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16) + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 256, '0'); convert_element_type_922 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + view_951 = torch.ops.aten.view.default(mul_223, [16384, 14336]); mul_223 = None + mm_195 = torch.ops.aten.mm.default(view_951, permute_307); view_951 = permute_307 = None + view_952 = torch.ops.aten.view.default(mm_195, [2, 8192, 4096]); mm_195 = None + add_111 = torch.ops.aten.add.Tensor(add_109, view_952); add_109 = view_952 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16) + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 256, '0'); convert_element_type_925 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32) + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = rsqrt_56 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_253); mul_224 = wait_tensor_253 = None + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16) + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 256, '0'); convert_element_type_928 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + view_955 = torch.ops.aten.view.default(convert_element_type_927, [16384, 4096]); convert_element_type_927 = None + mm_196 = torch.ops.aten.mm.default(view_955, permute_308); permute_308 = None + view_956 = torch.ops.aten.view.default(mm_196, [2, 8192, 4096]) + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16) + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 256, '0'); convert_element_type_931 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_255, [1, 0]); wait_tensor_255 = None + mm_197 = torch.ops.aten.mm.default(view_955, permute_309); permute_309 = None + view_959 = torch.ops.aten.view.default(mm_197, [2, 8192, 1024]); mm_197 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16) + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 256, '0'); convert_element_type_934 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_256, [1, 0]); wait_tensor_256 = None + mm_198 = torch.ops.aten.mm.default(view_955, permute_310); view_955 = permute_310 = None + view_962 = torch.ops.aten.view.default(mm_198, [2, 8192, 1024]) + view_963 = torch.ops.aten.view.default(view_956, [2, 8192, -1, 128]); view_956 = None + view_964 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_965 = torch.ops.aten.view.default(view_962, [2, 8192, -1, 128]); view_962 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_963, torch.float32); view_963 = None + view_966 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 32, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_966); view_966 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_964, torch.float32); view_964 = None + view_967 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 8, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_967); view_967 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_16); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_969 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 32, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_16); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_970 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 8, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_969, torch.bfloat16); view_969 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_970, torch.bfloat16); view_970 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 8, 4, 128]); unsqueeze_56 = None + clone_56 = torch.ops.aten.clone.default(expand_56, memory_format = torch.contiguous_format); expand_56 = None + view_971 = torch.ops.aten.view.default(clone_56, [2, 8192, 32, 128]); clone_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_965, 3); view_965 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 8, 4, 128]); unsqueeze_57 = None + clone_57 = torch.ops.aten.clone.default(expand_57, memory_format = torch.contiguous_format); expand_57 = None + view_972 = torch.ops.aten.view.default(clone_57, [2, 8192, 32, 128]); clone_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_971, [0, 2, 1, 3]); view_971 = None + permute_313 = torch.ops.aten.permute.default(view_972, [0, 2, 1, 3]); view_972 = None + _scaled_dot_product_cudnn_attention_28 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_311, permute_312, permute_313, None, True, 0.0, True); permute_311 = permute_312 = permute_313 = None + getitem_252 = _scaled_dot_product_cudnn_attention_28[0] + getitem_253 = _scaled_dot_product_cudnn_attention_28[1] + getitem_258 = _scaled_dot_product_cudnn_attention_28[6] + getitem_259 = _scaled_dot_product_cudnn_attention_28[7]; _scaled_dot_product_cudnn_attention_28 = None + permute_314 = torch.ops.aten.permute.default(getitem_252, [0, 2, 1, 3]) + view_973 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16) + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 256, '0'); convert_element_type_941 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_257, [1, 0]); wait_tensor_257 = None + view_975 = torch.ops.aten.view.default(view_973, [16384, 4096]); view_973 = None + mm_199 = torch.ops.aten.mm.default(view_975, permute_315); view_975 = permute_315 = None + view_976 = torch.ops.aten.view.default(mm_199, [2, 8192, 4096]); mm_199 = None + add_113 = torch.ops.aten.add.Tensor(add_111, view_976); view_976 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16) + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 256, '0'); convert_element_type_944 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32) + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = rsqrt_57 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_258); mul_228 = wait_tensor_258 = None + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16) + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 256, '0'); convert_element_type_947 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + view_979 = torch.ops.aten.view.default(convert_element_type_946, [16384, 4096]); convert_element_type_946 = None + mm_200 = torch.ops.aten.mm.default(view_979, permute_316); permute_316 = None + view_980 = torch.ops.aten.view.default(mm_200, [2, 8192, 14336]) + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_980, torch.float32); view_980 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); convert_element_type_950 = sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16) + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 256, '0'); convert_element_type_952 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + mm_201 = torch.ops.aten.mm.default(view_979, permute_317); view_979 = permute_317 = None + view_983 = torch.ops.aten.view.default(mm_201, [2, 8192, 14336]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_983); convert_element_type_951 = view_983 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16) + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 256, '0'); convert_element_type_955 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_261, [1, 0]); wait_tensor_261 = None + view_985 = torch.ops.aten.view.default(mul_231, [16384, 14336]); mul_231 = None + mm_202 = torch.ops.aten.mm.default(view_985, permute_318); view_985 = permute_318 = None + view_986 = torch.ops.aten.view.default(mm_202, [2, 8192, 4096]); mm_202 = None + add_115 = torch.ops.aten.add.Tensor(add_113, view_986); add_113 = view_986 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16) + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 256, '0'); convert_element_type_958 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32) + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = rsqrt_58 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_262); mul_232 = wait_tensor_262 = None + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16) + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 256, '0'); convert_element_type_961 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_263, [1, 0]); wait_tensor_263 = None + view_989 = torch.ops.aten.view.default(convert_element_type_960, [16384, 4096]); convert_element_type_960 = None + mm_203 = torch.ops.aten.mm.default(view_989, permute_319); permute_319 = None + view_990 = torch.ops.aten.view.default(mm_203, [2, 8192, 4096]) + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16) + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 256, '0'); convert_element_type_964 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + mm_204 = torch.ops.aten.mm.default(view_989, permute_320); permute_320 = None + view_993 = torch.ops.aten.view.default(mm_204, [2, 8192, 1024]); mm_204 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16) + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 256, '0'); convert_element_type_967 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + mm_205 = torch.ops.aten.mm.default(view_989, permute_321); view_989 = permute_321 = None + view_996 = torch.ops.aten.view.default(mm_205, [2, 8192, 1024]) + view_997 = torch.ops.aten.view.default(view_990, [2, 8192, -1, 128]); view_990 = None + view_998 = torch.ops.aten.view.default(view_993, [2, 8192, -1, 128]); view_993 = None + view_999 = torch.ops.aten.view.default(view_996, [2, 8192, -1, 128]); view_996 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + view_1000 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 32, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_1000); view_1000 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_998, torch.float32); view_998 = None + view_1001 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 8, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_1001); view_1001 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_16); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_1003 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 32, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_16); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_1004 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 8, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_1003, torch.bfloat16); view_1003 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_1004, torch.bfloat16); view_1004 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 8, 4, 128]); unsqueeze_58 = None + clone_58 = torch.ops.aten.clone.default(expand_58, memory_format = torch.contiguous_format); expand_58 = None + view_1005 = torch.ops.aten.view.default(clone_58, [2, 8192, 32, 128]); clone_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_999, 3); view_999 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 8, 4, 128]); unsqueeze_59 = None + clone_59 = torch.ops.aten.clone.default(expand_59, memory_format = torch.contiguous_format); expand_59 = None + view_1006 = torch.ops.aten.view.default(clone_59, [2, 8192, 32, 128]); clone_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_1005, [0, 2, 1, 3]); view_1005 = None + permute_324 = torch.ops.aten.permute.default(view_1006, [0, 2, 1, 3]); view_1006 = None + _scaled_dot_product_cudnn_attention_29 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_322, permute_323, permute_324, None, True, 0.0, True); permute_322 = permute_323 = permute_324 = None + getitem_261 = _scaled_dot_product_cudnn_attention_29[0] + getitem_262 = _scaled_dot_product_cudnn_attention_29[1] + getitem_267 = _scaled_dot_product_cudnn_attention_29[6] + getitem_268 = _scaled_dot_product_cudnn_attention_29[7]; _scaled_dot_product_cudnn_attention_29 = None + permute_325 = torch.ops.aten.permute.default(getitem_261, [0, 2, 1, 3]) + view_1007 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16) + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 256, '0'); convert_element_type_974 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + view_1009 = torch.ops.aten.view.default(view_1007, [16384, 4096]); view_1007 = None + mm_206 = torch.ops.aten.mm.default(view_1009, permute_326); view_1009 = permute_326 = None + view_1010 = torch.ops.aten.view.default(mm_206, [2, 8192, 4096]); mm_206 = None + add_117 = torch.ops.aten.add.Tensor(add_115, view_1010); view_1010 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16) + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 256, '0'); convert_element_type_977 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32) + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = rsqrt_59 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_267); mul_236 = wait_tensor_267 = None + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16) + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 256, '0'); convert_element_type_980 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_268, [1, 0]); wait_tensor_268 = None + view_1013 = torch.ops.aten.view.default(convert_element_type_979, [16384, 4096]); convert_element_type_979 = None + mm_207 = torch.ops.aten.mm.default(view_1013, permute_327); permute_327 = None + view_1014 = torch.ops.aten.view.default(mm_207, [2, 8192, 14336]) + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_1014, torch.float32); view_1014 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); convert_element_type_983 = sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16) + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 256, '0'); convert_element_type_985 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_269, [1, 0]); wait_tensor_269 = None + mm_208 = torch.ops.aten.mm.default(view_1013, permute_328); view_1013 = permute_328 = None + view_1017 = torch.ops.aten.view.default(mm_208, [2, 8192, 14336]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_1017); convert_element_type_984 = view_1017 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16) + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 256, '0'); convert_element_type_988 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_270, [1, 0]); wait_tensor_270 = None + view_1019 = torch.ops.aten.view.default(mul_239, [16384, 14336]); mul_239 = None + mm_209 = torch.ops.aten.mm.default(view_1019, permute_329); view_1019 = permute_329 = None + view_1020 = torch.ops.aten.view.default(mm_209, [2, 8192, 4096]); mm_209 = None + add_119 = torch.ops.aten.add.Tensor(add_117, view_1020); add_117 = view_1020 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16) + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 256, '0'); convert_element_type_991 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32) + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = rsqrt_60 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_271); mul_240 = wait_tensor_271 = None + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16) + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 256, '0'); convert_element_type_994 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + view_1023 = torch.ops.aten.view.default(convert_element_type_993, [16384, 4096]); convert_element_type_993 = None + mm_210 = torch.ops.aten.mm.default(view_1023, permute_330); permute_330 = None + view_1024 = torch.ops.aten.view.default(mm_210, [2, 8192, 4096]) + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16) + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 256, '0'); convert_element_type_997 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + mm_211 = torch.ops.aten.mm.default(view_1023, permute_331); permute_331 = None + view_1027 = torch.ops.aten.view.default(mm_211, [2, 8192, 1024]); mm_211 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16) + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 256, '0'); convert_element_type_1000 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_274, [1, 0]); wait_tensor_274 = None + mm_212 = torch.ops.aten.mm.default(view_1023, permute_332); view_1023 = permute_332 = None + view_1030 = torch.ops.aten.view.default(mm_212, [2, 8192, 1024]) + view_1031 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1032 = torch.ops.aten.view.default(view_1027, [2, 8192, -1, 128]); view_1027 = None + view_1033 = torch.ops.aten.view.default(view_1030, [2, 8192, -1, 128]); view_1030 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_1031, torch.float32); view_1031 = None + view_1034 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 32, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_1034); view_1034 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_1032, torch.float32); view_1032 = None + view_1035 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 8, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_1035); view_1035 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_16); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_1037 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 32, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_16); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_1038 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 8, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_1037, torch.bfloat16); view_1037 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_1038, torch.bfloat16); view_1038 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 8, 4, 128]); unsqueeze_60 = None + clone_60 = torch.ops.aten.clone.default(expand_60, memory_format = torch.contiguous_format); expand_60 = None + view_1039 = torch.ops.aten.view.default(clone_60, [2, 8192, 32, 128]); clone_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_1033, 3); view_1033 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 8, 4, 128]); unsqueeze_61 = None + clone_61 = torch.ops.aten.clone.default(expand_61, memory_format = torch.contiguous_format); expand_61 = None + view_1040 = torch.ops.aten.view.default(clone_61, [2, 8192, 32, 128]); clone_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_1039, [0, 2, 1, 3]); view_1039 = None + permute_335 = torch.ops.aten.permute.default(view_1040, [0, 2, 1, 3]); view_1040 = None + _scaled_dot_product_cudnn_attention_30 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_333, permute_334, permute_335, None, True, 0.0, True); permute_333 = permute_334 = permute_335 = None + getitem_270 = _scaled_dot_product_cudnn_attention_30[0] + getitem_271 = _scaled_dot_product_cudnn_attention_30[1] + getitem_276 = _scaled_dot_product_cudnn_attention_30[6] + getitem_277 = _scaled_dot_product_cudnn_attention_30[7]; _scaled_dot_product_cudnn_attention_30 = None + permute_336 = torch.ops.aten.permute.default(getitem_270, [0, 2, 1, 3]) + view_1041 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16) + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 256, '0'); convert_element_type_1007 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_275, [1, 0]); wait_tensor_275 = None + view_1043 = torch.ops.aten.view.default(view_1041, [16384, 4096]); view_1041 = None + mm_213 = torch.ops.aten.mm.default(view_1043, permute_337); view_1043 = permute_337 = None + view_1044 = torch.ops.aten.view.default(mm_213, [2, 8192, 4096]); mm_213 = None + add_121 = torch.ops.aten.add.Tensor(add_119, view_1044); view_1044 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16) + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 256, '0'); convert_element_type_1010 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32) + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = rsqrt_61 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_276); mul_244 = wait_tensor_276 = None + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16) + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 256, '0'); convert_element_type_1013 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + view_1047 = torch.ops.aten.view.default(convert_element_type_1012, [16384, 4096]); convert_element_type_1012 = None + mm_214 = torch.ops.aten.mm.default(view_1047, permute_338); permute_338 = None + view_1048 = torch.ops.aten.view.default(mm_214, [2, 8192, 14336]) + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_1048, torch.float32); view_1048 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); convert_element_type_1016 = sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16) + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 256, '0'); convert_element_type_1018 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_215 = torch.ops.aten.mm.default(view_1047, permute_339); view_1047 = permute_339 = None + view_1051 = torch.ops.aten.view.default(mm_215, [2, 8192, 14336]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_1051); convert_element_type_1017 = view_1051 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16) + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 256, '0'); convert_element_type_1021 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + view_1053 = torch.ops.aten.view.default(mul_247, [16384, 14336]); mul_247 = None + mm_216 = torch.ops.aten.mm.default(view_1053, permute_340); view_1053 = permute_340 = None + view_1054 = torch.ops.aten.view.default(mm_216, [2, 8192, 4096]); mm_216 = None + add_123 = torch.ops.aten.add.Tensor(add_121, view_1054); add_121 = view_1054 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16) + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 256, '0'); convert_element_type_1024 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32) + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = rsqrt_62 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_280); mul_248 = wait_tensor_280 = None + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16) + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 256, '0'); convert_element_type_1027 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_281, [1, 0]); wait_tensor_281 = None + view_1057 = torch.ops.aten.view.default(convert_element_type_1026, [16384, 4096]); convert_element_type_1026 = None + mm_217 = torch.ops.aten.mm.default(view_1057, permute_341); permute_341 = None + view_1058 = torch.ops.aten.view.default(mm_217, [2, 8192, 4096]) + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16) + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 256, '0'); convert_element_type_1030 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_282, [1, 0]); wait_tensor_282 = None + mm_218 = torch.ops.aten.mm.default(view_1057, permute_342); permute_342 = None + view_1061 = torch.ops.aten.view.default(mm_218, [2, 8192, 1024]); mm_218 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16) + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 256, '0'); convert_element_type_1033 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_283, [1, 0]); wait_tensor_283 = None + mm_219 = torch.ops.aten.mm.default(view_1057, permute_343); view_1057 = permute_343 = None + view_1064 = torch.ops.aten.view.default(mm_219, [2, 8192, 1024]) + view_1065 = torch.ops.aten.view.default(view_1058, [2, 8192, -1, 128]); view_1058 = None + view_1066 = torch.ops.aten.view.default(view_1061, [2, 8192, -1, 128]); view_1061 = None + view_1067 = torch.ops.aten.view.default(view_1064, [2, 8192, -1, 128]); view_1064 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_1065, torch.float32); view_1065 = None + view_1068 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 32, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_1068); view_1068 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_1066, torch.float32); view_1066 = None + view_1069 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 8, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_1069); view_1069 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_16); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_1071 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 32, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_16); view_as_complex_63 = view_16 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_1072 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 8, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_1071, torch.bfloat16); view_1071 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_1072, torch.bfloat16); view_1072 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 8, 4, 128]); unsqueeze_62 = None + clone_62 = torch.ops.aten.clone.default(expand_62, memory_format = torch.contiguous_format); expand_62 = None + view_1073 = torch.ops.aten.view.default(clone_62, [2, 8192, 32, 128]); clone_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_1067, 3); view_1067 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 8, 4, 128]); unsqueeze_63 = None + clone_63 = torch.ops.aten.clone.default(expand_63, memory_format = torch.contiguous_format); expand_63 = None + view_1074 = torch.ops.aten.view.default(clone_63, [2, 8192, 32, 128]); clone_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_1073, [0, 2, 1, 3]); view_1073 = None + permute_346 = torch.ops.aten.permute.default(view_1074, [0, 2, 1, 3]); view_1074 = None + _scaled_dot_product_cudnn_attention_31 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_344, permute_345, permute_346, None, True, 0.0, True); permute_344 = permute_345 = permute_346 = None + getitem_279 = _scaled_dot_product_cudnn_attention_31[0] + getitem_280 = _scaled_dot_product_cudnn_attention_31[1] + getitem_285 = _scaled_dot_product_cudnn_attention_31[6] + getitem_286 = _scaled_dot_product_cudnn_attention_31[7]; _scaled_dot_product_cudnn_attention_31 = None + permute_347 = torch.ops.aten.permute.default(getitem_279, [0, 2, 1, 3]) + view_1075 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16) + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 256, '0'); convert_element_type_1040 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + view_1077 = torch.ops.aten.view.default(view_1075, [16384, 4096]); view_1075 = None + mm_220 = torch.ops.aten.mm.default(view_1077, permute_348); view_1077 = permute_348 = None + view_1078 = torch.ops.aten.view.default(mm_220, [2, 8192, 4096]); mm_220 = None + add_125 = torch.ops.aten.add.Tensor(add_123, view_1078); view_1078 = None + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16) + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 256, '0'); convert_element_type_1043 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32) + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = rsqrt_63 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_285); mul_252 = wait_tensor_285 = None + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16) + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 256, '0'); convert_element_type_1046 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + view_1081 = torch.ops.aten.view.default(convert_element_type_1045, [16384, 4096]); convert_element_type_1045 = None + mm_221 = torch.ops.aten.mm.default(view_1081, permute_349); permute_349 = None + view_1082 = torch.ops.aten.view.default(mm_221, [2, 8192, 14336]) + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_1082, torch.float32); view_1082 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); convert_element_type_1049 = sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16) + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 256, '0'); convert_element_type_1051 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_287, [1, 0]); wait_tensor_287 = None + mm_222 = torch.ops.aten.mm.default(view_1081, permute_350); view_1081 = permute_350 = None + view_1085 = torch.ops.aten.view.default(mm_222, [2, 8192, 14336]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_1085); convert_element_type_1050 = view_1085 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16) + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 256, '0'); convert_element_type_1054 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_288, [1, 0]); wait_tensor_288 = None + view_1087 = torch.ops.aten.view.default(mul_255, [16384, 14336]); mul_255 = None + mm_223 = torch.ops.aten.mm.default(view_1087, permute_351); view_1087 = permute_351 = None + view_1088 = torch.ops.aten.view.default(mm_223, [2, 8192, 4096]) + add_127 = torch.ops.aten.add.Tensor(add_125, view_1088); add_125 = view_1088 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16) + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 256, '0'); convert_element_type_1057 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + pow_65 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1058, 2) + mean_64 = torch.ops.aten.mean.dim(pow_65, [2], True); pow_65 = None + add_128 = torch.ops.aten.add.Scalar(mean_64, 1e-05); mean_64 = None + rsqrt_64 = torch.ops.aten.rsqrt.default(add_128); add_128 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_257 = torch.ops.aten.mul.Tensor(mul_256, wait_tensor_289); mul_256 = wait_tensor_289 = None + convert_element_type_1059 = torch.ops.prims.convert_element_type.default(mul_257, torch.bfloat16); mul_257 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16) + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 256, '0'); convert_element_type_1060 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + view_1091 = torch.ops.aten.view.default(convert_element_type_1059, [16384, 4096]); convert_element_type_1059 = None + mm_224 = torch.ops.aten.mm.default(view_1091, permute_352); permute_352 = None + view_1092 = torch.ops.aten.view.default(mm_224, [2, 8192, 128256]); mm_224 = None + return (view_1092, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, embedding, mm, mm_2, getitem, getitem_1, getitem_6, getitem_7, mm_4, add_3, mm_7, mm_9, getitem_9, getitem_10, getitem_15, getitem_16, mm_11, add_7, mm_14, mm_16, getitem_18, getitem_19, getitem_24, getitem_25, mm_18, add_11, mm_21, mm_23, getitem_27, getitem_28, getitem_33, getitem_34, mm_25, add_15, mm_28, mm_30, getitem_36, getitem_37, getitem_42, getitem_43, mm_32, add_19, mm_35, mm_37, getitem_45, getitem_46, getitem_51, getitem_52, mm_39, add_23, mm_42, mm_44, getitem_54, getitem_55, getitem_60, getitem_61, mm_46, add_27, mm_49, mm_51, getitem_63, getitem_64, getitem_69, getitem_70, mm_53, add_31, mm_56, mm_58, getitem_72, getitem_73, getitem_78, getitem_79, mm_60, add_35, mm_63, mm_65, getitem_81, getitem_82, getitem_87, getitem_88, mm_67, add_39, mm_70, mm_72, getitem_90, getitem_91, getitem_96, getitem_97, mm_74, add_43, mm_77, mm_79, getitem_99, getitem_100, getitem_105, getitem_106, mm_81, add_47, mm_84, mm_86, getitem_108, getitem_109, getitem_114, getitem_115, mm_88, add_51, mm_91, mm_93, getitem_117, getitem_118, getitem_123, getitem_124, mm_95, add_55, mm_98, mm_100, getitem_126, getitem_127, getitem_132, getitem_133, mm_102, add_59, mm_105, mm_107, getitem_135, getitem_136, getitem_141, getitem_142, mm_109, add_63, mm_112, mm_114, getitem_144, getitem_145, getitem_150, getitem_151, mm_116, add_67, mm_119, mm_121, getitem_153, getitem_154, getitem_159, getitem_160, mm_123, add_71, mm_126, mm_128, getitem_162, getitem_163, getitem_168, getitem_169, mm_130, add_75, mm_133, mm_135, getitem_171, getitem_172, getitem_177, getitem_178, mm_137, add_79, mm_140, mm_142, getitem_180, getitem_181, getitem_186, getitem_187, mm_144, add_83, mm_147, mm_149, getitem_189, getitem_190, getitem_195, getitem_196, mm_151, add_87, mm_154, mm_156, getitem_198, getitem_199, getitem_204, getitem_205, mm_158, add_91, mm_161, mm_163, getitem_207, getitem_208, getitem_213, getitem_214, mm_165, add_95, mm_168, mm_170, getitem_216, getitem_217, getitem_222, getitem_223, mm_172, add_99, mm_175, mm_177, getitem_225, getitem_226, getitem_231, getitem_232, mm_179, add_103, mm_182, mm_184, getitem_234, getitem_235, getitem_240, getitem_241, mm_186, add_107, mm_189, mm_191, getitem_243, getitem_244, getitem_249, getitem_250, mm_193, add_111, mm_196, mm_198, getitem_252, getitem_253, getitem_258, getitem_259, mm_200, add_115, mm_203, mm_205, getitem_261, getitem_262, getitem_267, getitem_268, mm_207, add_119, mm_210, mm_212, getitem_270, getitem_271, getitem_276, getitem_277, mm_214, add_123, mm_217, mm_219, getitem_279, getitem_280, getitem_285, getitem_286, mm_221, mm_223, rsqrt_64, view_1091) + +def load_args(reader): + buf0 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf0, (501, 4096), is_leaf=True) # primals_1 + buf1 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf3, (16,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf4, (16, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf5, (4, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf6, (4, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf7, (16, 4096), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf8, (16,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf9, (56, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf10, (56, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf11, (16, 14336), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf12, (16,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf13, (16, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf14, (4, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf15, (4, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf16, (16, 4096), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf17, (16,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf18, (56, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf19, (56, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf20, (16, 14336), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf21, (16,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf22, (16, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf23, (4, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf24, (4, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf25, (16, 4096), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf26, (16,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf27, (56, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf28, (56, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf29, (16, 14336), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf30, (16,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf31, (16, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf32, (4, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf33, (4, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf34, (16, 4096), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf35, (16,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf36, (56, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf37, (56, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf38, (16, 14336), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf39, (16,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf40, (16, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf41, (4, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf42, (4, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf43, (16, 4096), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf44, (16,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf45, (56, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf46, (56, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf47, (16, 14336), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf48, (16,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf49, (16, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf50, (4, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf51, (4, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf52, (16, 4096), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf53, (16,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf54, (56, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf55, (56, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf56, (16, 14336), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf57, (16,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf58, (16, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf59, (4, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf60, (4, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf61, (16, 4096), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf62, (16,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf63, (56, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf64, (56, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf65, (16, 14336), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf66, (16,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf67, (16, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf68, (4, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf69, (4, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf70, (16, 4096), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf71, (16,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf72, (56, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf73, (56, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf74, (16, 14336), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf75, (16,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf76, (16, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf77, (4, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf78, (4, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf79, (16, 4096), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf80, (16,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf81, (56, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf82, (56, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf83, (16, 14336), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf84, (16,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf85, (16, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf86, (4, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf87, (4, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf88, (16, 4096), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf89, (16,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf90, (56, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf91, (56, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf92, (16, 14336), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf93, (16,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf94, (16, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf95, (4, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf96, (4, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf97, (16, 4096), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf98, (16,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf99, (56, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf100, (56, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf101, (16, 14336), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf102, (16,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf103, (16, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf104, (4, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf105, (4, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf106, (16, 4096), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf107, (16,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf108, (56, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf109, (56, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf110, (16, 14336), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf111, (16,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf112, (16, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf113, (4, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf114, (4, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf115, (16, 4096), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf116, (16,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf117, (56, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf118, (56, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf119, (16, 14336), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf120, (16,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf121, (16, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf122, (4, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf123, (4, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf124, (16, 4096), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf125, (16,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf126, (56, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf127, (56, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf128, (16, 14336), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf129, (16,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf130, (16, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf131, (4, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf132, (4, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf133, (16, 4096), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf134, (16,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf135, (56, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf136, (56, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf137, (16, 14336), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf138, (16,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf139, (16, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf140, (4, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf141, (4, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf142, (16, 4096), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf143, (16,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf144, (56, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf145, (56, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf146, (16, 14336), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf147, (16,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf148, (16, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf149, (4, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf150, (4, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf151, (16, 4096), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf152, (16,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf153, (56, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf154, (56, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf155, (16, 14336), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf156, (16,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf157, (16, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf158, (4, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf159, (4, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf160, (16, 4096), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf161, (16,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf162, (56, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf163, (56, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf164, (16, 14336), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf165, (16,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf166, (16, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf167, (4, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf168, (4, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf169, (16, 4096), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf170, (16,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf171, (56, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf172, (56, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf173, (16, 14336), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf174, (16,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf175, (16, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf176, (4, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf177, (4, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf178, (16, 4096), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf179, (16,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf180, (56, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf181, (56, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf182, (16, 14336), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf183, (16,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf184, (16, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf185, (4, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf186, (4, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf187, (16, 4096), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf188, (16,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf189, (56, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf190, (56, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf191, (16, 14336), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf192, (16,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf193, (16, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf194, (4, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf195, (4, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf196, (16, 4096), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf197, (16,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf198, (56, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf199, (56, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf200, (16, 14336), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf201, (16,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf202, (16, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf203, (4, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf204, (4, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf205, (16, 4096), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf206, (16,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf207, (56, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf208, (56, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf209, (16, 14336), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf210, (16,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf211, (16, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf212, (4, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf213, (4, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf214, (16, 4096), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf215, (16,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf216, (56, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf217, (56, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf218, (16, 14336), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf219, (16,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf220, (16, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf221, (4, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf222, (4, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf223, (16, 4096), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf224, (16,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf225, (56, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf226, (56, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf227, (16, 14336), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf228, (16,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf229, (16, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf230, (4, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf231, (4, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf232, (16, 4096), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf233, (16,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf234, (56, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf235, (56, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf236, (16, 14336), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf237, (16,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf238, (16, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf239, (4, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf240, (4, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf241, (16, 4096), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf242, (16,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf243, (56, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf244, (56, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf245, (16, 14336), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf246, (16,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf247, (16, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf248, (4, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf249, (4, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf250, (16, 4096), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf251, (16,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf252, (56, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf253, (56, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf254, (16, 14336), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf255, (16,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf256, (16, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf257, (4, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf258, (4, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf259, (16, 4096), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf260, (16,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf261, (56, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf262, (56, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf263, (16, 14336), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf264, (16,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf265, (16, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf266, (4, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf267, (4, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf268, (16, 4096), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf269, (16,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf270, (56, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf271, (56, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf272, (16, 14336), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf273, (16,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf274, (16, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf275, (4, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf276, (4, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf277, (16, 4096), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf278, (16,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf279, (56, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf280, (56, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf281, (16, 14336), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf282, (16,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf283, (16, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf284, (4, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf285, (4, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf286, (16, 4096), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf287, (16,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf288, (56, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf289, (56, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf290, (16, 14336), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf291, (16,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf292, (501, 4096), is_leaf=True) # primals_293 + +load_args._version = 0 + +def get_mesh_sizes(): + return 256, + +def get_colls_estimations_file(): + return "colls32_8.table" + +def get_pg_names(): + return "0", diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_2d_32layers.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_2d_32layers.py new file mode 100644 index 00000000..b581794d --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_2d_32layers.py @@ -0,0 +1,5658 @@ +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293): + convert_element_type = torch.ops.prims.convert_element_type.default(primals_2, torch.bfloat16) + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type, 32, '0'); convert_element_type = None + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None + lt = torch.ops.aten.lt.Scalar(primals_1, 0) + ge = torch.ops.aten.ge.Scalar(primals_1, 16032) + bitwise_or = torch.ops.aten.bitwise_or.Tensor(lt, ge); lt = ge = None + sub = torch.ops.aten.sub.Tensor(primals_1, 0) + full_default = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + index_put = torch.ops.aten.index_put.default(sub, [bitwise_or], full_default); sub = full_default = None + embedding = torch.ops.aten.embedding.default(wait_tensor, index_put); wait_tensor = index_put = None + full_default_1 = torch.ops.aten.full.default([], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + index_put_1 = torch.ops.aten.index_put.default(embedding, [bitwise_or], full_default_1); embedding = bitwise_or = full_default_1 = None + split_1 = torch.ops.aten.split.Tensor(index_put_1, 1024, 1); index_put_1 = None + getitem_8 = split_1[0] + getitem_17 = split_1[1] + getitem_26 = split_1[2] + getitem_35 = split_1[3] + getitem_44 = split_1[4] + getitem_53 = split_1[5] + getitem_62 = split_1[6] + getitem_71 = split_1[7]; split_1 = None + cat = torch.ops.aten.cat.default([getitem_8, getitem_17, getitem_26, getitem_35, getitem_44, getitem_53, getitem_62, getitem_71]); getitem_8 = getitem_17 = getitem_26 = getitem_35 = getitem_44 = getitem_53 = getitem_62 = getitem_71 = None + reduce_scatter_tensor = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat, 'sum', 8, '1'); cat = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16) + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 32, '0'); convert_element_type_1 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32) + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = rsqrt = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_2); mul = wait_tensor_2 = None + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_3, 8, '1'); convert_element_type_3 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + split_9 = torch.ops.aten.split.Tensor(wait_tensor_3, 2); wait_tensor_3 = None + getitem_72 = split_9[0] + getitem_73 = split_9[1] + getitem_74 = split_9[2] + getitem_75 = split_9[3] + getitem_76 = split_9[4] + getitem_77 = split_9[5] + getitem_78 = split_9[6] + getitem_79 = split_9[7]; split_9 = None + cat_1 = torch.ops.aten.cat.default([getitem_72, getitem_73, getitem_74, getitem_75, getitem_76, getitem_77, getitem_78, getitem_79], 1); getitem_72 = getitem_73 = getitem_74 = getitem_75 = getitem_76 = getitem_77 = getitem_78 = getitem_79 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16) + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 32, '0'); convert_element_type_4 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + view_15 = torch.ops.aten.view.default(cat_1, [16384, 4096]); cat_1 = None + mm = torch.ops.aten.mm.default(view_15, permute); permute = None + view_16 = torch.ops.aten.view.default(mm, [2, 8192, 512]) + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16) + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 32, '0'); convert_element_type_7 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + mm_1 = torch.ops.aten.mm.default(view_15, permute_1); permute_1 = None + view_23 = torch.ops.aten.view.default(mm_1, [2, 8192, 128]); mm_1 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16) + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 32, '0'); convert_element_type_10 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_6, [1, 0]); wait_tensor_6 = None + mm_2 = torch.ops.aten.mm.default(view_15, permute_2); view_15 = permute_2 = None + view_30 = torch.ops.aten.view.default(mm_2, [2, 8192, 128]) + view_32 = torch.ops.aten.view.default(view_16, [2, 8192, -1, 128]); view_16 = None + view_33 = torch.ops.aten.view.default(view_23, [2, 8192, -1, 128]); view_23 = None + view_34 = torch.ops.aten.view.default(view_30, [2, 8192, -1, 128]); view_30 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_32, torch.float32); view_32 = None + view_35 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 4, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_35); view_35 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_33, torch.float32); view_33 = None + view_36 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 1, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_36); view_36 = None + view_37 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]) + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_37); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_38 = torch.ops.aten.view.default(view_as_real, [2, 8192, 4, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_37); view_as_complex_1 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_39 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 1, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_38, torch.bfloat16); view_38 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_39, torch.bfloat16); view_39 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 1, 4, 128]); unsqueeze = None + view_40 = torch.ops.aten.view.default(expand, [2, 8192, 4, 128]); expand = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_34, 3); view_34 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 1, 4, 128]); unsqueeze_1 = None + view_41 = torch.ops.aten.view.default(expand_1, [2, 8192, 4, 128]); expand_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_40, [0, 2, 1, 3]); view_40 = None + permute_5 = torch.ops.aten.permute.default(view_41, [0, 2, 1, 3]); view_41 = None + _scaled_dot_product_cudnn_attention = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_3, permute_4, permute_5, None, True, 0.0, True); permute_3 = permute_4 = permute_5 = None + getitem_80 = _scaled_dot_product_cudnn_attention[0] + getitem_81 = _scaled_dot_product_cudnn_attention[1] + getitem_86 = _scaled_dot_product_cudnn_attention[6] + getitem_87 = _scaled_dot_product_cudnn_attention[7]; _scaled_dot_product_cudnn_attention = None + permute_6 = torch.ops.aten.permute.default(getitem_80, [0, 2, 1, 3]) + view_42 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16) + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 32, '0'); convert_element_type_17 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + view_48 = torch.ops.aten.view.default(view_42, [16384, 512]); view_42 = None + mm_3 = torch.ops.aten.mm.default(view_48, permute_7); view_48 = permute_7 = None + view_49 = torch.ops.aten.view.default(mm_3, [2, 8192, 4096]); mm_3 = None + split_10 = torch.ops.aten.split.Tensor(view_49, 1024, 1); view_49 = None + getitem_89 = split_10[0] + getitem_90 = split_10[1] + getitem_91 = split_10[2] + getitem_92 = split_10[3] + getitem_93 = split_10[4] + getitem_94 = split_10[5] + getitem_95 = split_10[6] + getitem_96 = split_10[7]; split_10 = None + cat_2 = torch.ops.aten.cat.default([getitem_89, getitem_90, getitem_91, getitem_92, getitem_93, getitem_94, getitem_95, getitem_96]); getitem_89 = getitem_90 = getitem_91 = getitem_92 = getitem_93 = getitem_94 = getitem_95 = getitem_96 = None + reduce_scatter_tensor_1 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_2, 'sum', 8, '1'); cat_2 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1) + add_1 = torch.ops.aten.add.Tensor(wait_tensor_1, wait_tensor_8); wait_tensor_8 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16) + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 32, '0'); convert_element_type_20 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32) + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = rsqrt_1 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_9); mul_4 = wait_tensor_9 = None + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_22, 8, '1'); convert_element_type_22 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + split_11 = torch.ops.aten.split.Tensor(wait_tensor_10, 2); wait_tensor_10 = None + getitem_97 = split_11[0] + getitem_98 = split_11[1] + getitem_99 = split_11[2] + getitem_100 = split_11[3] + getitem_101 = split_11[4] + getitem_102 = split_11[5] + getitem_103 = split_11[6] + getitem_104 = split_11[7]; split_11 = None + cat_3 = torch.ops.aten.cat.default([getitem_97, getitem_98, getitem_99, getitem_100, getitem_101, getitem_102, getitem_103, getitem_104], 1); getitem_97 = getitem_98 = getitem_99 = getitem_100 = getitem_101 = getitem_102 = getitem_103 = getitem_104 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16) + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 32, '0'); convert_element_type_23 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + view_60 = torch.ops.aten.view.default(cat_3, [16384, 4096]); cat_3 = None + mm_4 = torch.ops.aten.mm.default(view_60, permute_8); permute_8 = None + view_61 = torch.ops.aten.view.default(mm_4, [2, 8192, 1792]) + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_61, torch.float32); view_61 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); convert_element_type_26 = sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16) + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 32, '0'); convert_element_type_28 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_5 = torch.ops.aten.mm.default(view_60, permute_9); view_60 = permute_9 = None + view_68 = torch.ops.aten.view.default(mm_5, [2, 8192, 1792]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_68); convert_element_type_27 = view_68 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16) + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 32, '0'); convert_element_type_31 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + view_75 = torch.ops.aten.view.default(mul_7, [16384, 1792]); mul_7 = None + mm_6 = torch.ops.aten.mm.default(view_75, permute_10); view_75 = permute_10 = None + view_76 = torch.ops.aten.view.default(mm_6, [2, 8192, 4096]); mm_6 = None + split_12 = torch.ops.aten.split.Tensor(view_76, 1024, 1); view_76 = None + getitem_105 = split_12[0] + getitem_106 = split_12[1] + getitem_107 = split_12[2] + getitem_108 = split_12[3] + getitem_109 = split_12[4] + getitem_110 = split_12[5] + getitem_111 = split_12[6] + getitem_112 = split_12[7]; split_12 = None + cat_4 = torch.ops.aten.cat.default([getitem_105, getitem_106, getitem_107, getitem_108, getitem_109, getitem_110, getitem_111, getitem_112]); getitem_105 = getitem_106 = getitem_107 = getitem_108 = getitem_109 = getitem_110 = getitem_111 = getitem_112 = None + reduce_scatter_tensor_2 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_4, 'sum', 8, '1'); cat_4 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None + add_3 = torch.ops.aten.add.Tensor(add_1, wait_tensor_14); add_1 = wait_tensor_14 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16) + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 32, '0'); convert_element_type_34 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32) + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = rsqrt_2 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_15); mul_8 = wait_tensor_15 = None + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_36, 8, '1'); convert_element_type_36 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + split_13 = torch.ops.aten.split.Tensor(wait_tensor_16, 2); wait_tensor_16 = None + getitem_113 = split_13[0] + getitem_114 = split_13[1] + getitem_115 = split_13[2] + getitem_116 = split_13[3] + getitem_117 = split_13[4] + getitem_118 = split_13[5] + getitem_119 = split_13[6] + getitem_120 = split_13[7]; split_13 = None + cat_5 = torch.ops.aten.cat.default([getitem_113, getitem_114, getitem_115, getitem_116, getitem_117, getitem_118, getitem_119, getitem_120], 1); getitem_113 = getitem_114 = getitem_115 = getitem_116 = getitem_117 = getitem_118 = getitem_119 = getitem_120 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16) + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 32, '0'); convert_element_type_37 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + view_87 = torch.ops.aten.view.default(cat_5, [16384, 4096]); cat_5 = None + mm_7 = torch.ops.aten.mm.default(view_87, permute_11); permute_11 = None + view_88 = torch.ops.aten.view.default(mm_7, [2, 8192, 512]) + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16) + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 32, '0'); convert_element_type_40 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + mm_8 = torch.ops.aten.mm.default(view_87, permute_12); permute_12 = None + view_95 = torch.ops.aten.view.default(mm_8, [2, 8192, 128]); mm_8 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16) + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 32, '0'); convert_element_type_43 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_19, [1, 0]); wait_tensor_19 = None + mm_9 = torch.ops.aten.mm.default(view_87, permute_13); view_87 = permute_13 = None + view_102 = torch.ops.aten.view.default(mm_9, [2, 8192, 128]) + view_104 = torch.ops.aten.view.default(view_88, [2, 8192, -1, 128]); view_88 = None + view_105 = torch.ops.aten.view.default(view_95, [2, 8192, -1, 128]); view_95 = None + view_106 = torch.ops.aten.view.default(view_102, [2, 8192, -1, 128]); view_102 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_104, torch.float32); view_104 = None + view_107 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 4, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_107); view_107 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_105, torch.float32); view_105 = None + view_108 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 1, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_108); view_108 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_37); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_110 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 4, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_37); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_111 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 1, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_110, torch.bfloat16); view_110 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_111, torch.bfloat16); view_111 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 1, 4, 128]); unsqueeze_2 = None + view_112 = torch.ops.aten.view.default(expand_2, [2, 8192, 4, 128]); expand_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_106, 3); view_106 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 1, 4, 128]); unsqueeze_3 = None + view_113 = torch.ops.aten.view.default(expand_3, [2, 8192, 4, 128]); expand_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_112, [0, 2, 1, 3]); view_112 = None + permute_16 = torch.ops.aten.permute.default(view_113, [0, 2, 1, 3]); view_113 = None + _scaled_dot_product_cudnn_attention_1 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_14, permute_15, permute_16, None, True, 0.0, True); permute_14 = permute_15 = permute_16 = None + getitem_121 = _scaled_dot_product_cudnn_attention_1[0] + getitem_122 = _scaled_dot_product_cudnn_attention_1[1] + getitem_127 = _scaled_dot_product_cudnn_attention_1[6] + getitem_128 = _scaled_dot_product_cudnn_attention_1[7]; _scaled_dot_product_cudnn_attention_1 = None + permute_17 = torch.ops.aten.permute.default(getitem_121, [0, 2, 1, 3]) + view_114 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16) + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 32, '0'); convert_element_type_50 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + view_120 = torch.ops.aten.view.default(view_114, [16384, 512]); view_114 = None + mm_10 = torch.ops.aten.mm.default(view_120, permute_18); view_120 = permute_18 = None + view_121 = torch.ops.aten.view.default(mm_10, [2, 8192, 4096]); mm_10 = None + split_14 = torch.ops.aten.split.Tensor(view_121, 1024, 1); view_121 = None + getitem_130 = split_14[0] + getitem_131 = split_14[1] + getitem_132 = split_14[2] + getitem_133 = split_14[3] + getitem_134 = split_14[4] + getitem_135 = split_14[5] + getitem_136 = split_14[6] + getitem_137 = split_14[7]; split_14 = None + cat_6 = torch.ops.aten.cat.default([getitem_130, getitem_131, getitem_132, getitem_133, getitem_134, getitem_135, getitem_136, getitem_137]); getitem_130 = getitem_131 = getitem_132 = getitem_133 = getitem_134 = getitem_135 = getitem_136 = getitem_137 = None + reduce_scatter_tensor_3 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_6, 'sum', 8, '1'); cat_6 = None + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3) + add_5 = torch.ops.aten.add.Tensor(add_3, wait_tensor_21); wait_tensor_21 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16) + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 32, '0'); convert_element_type_53 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32) + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = rsqrt_3 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_22); mul_12 = wait_tensor_22 = None + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_55, 8, '1'); convert_element_type_55 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + split_15 = torch.ops.aten.split.Tensor(wait_tensor_23, 2); wait_tensor_23 = None + getitem_138 = split_15[0] + getitem_139 = split_15[1] + getitem_140 = split_15[2] + getitem_141 = split_15[3] + getitem_142 = split_15[4] + getitem_143 = split_15[5] + getitem_144 = split_15[6] + getitem_145 = split_15[7]; split_15 = None + cat_7 = torch.ops.aten.cat.default([getitem_138, getitem_139, getitem_140, getitem_141, getitem_142, getitem_143, getitem_144, getitem_145], 1); getitem_138 = getitem_139 = getitem_140 = getitem_141 = getitem_142 = getitem_143 = getitem_144 = getitem_145 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16) + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 32, '0'); convert_element_type_56 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_24, [1, 0]); wait_tensor_24 = None + view_132 = torch.ops.aten.view.default(cat_7, [16384, 4096]); cat_7 = None + mm_11 = torch.ops.aten.mm.default(view_132, permute_19); permute_19 = None + view_133 = torch.ops.aten.view.default(mm_11, [2, 8192, 1792]) + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_133, torch.float32); view_133 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); convert_element_type_59 = sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16) + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 32, '0'); convert_element_type_61 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + mm_12 = torch.ops.aten.mm.default(view_132, permute_20); view_132 = permute_20 = None + view_140 = torch.ops.aten.view.default(mm_12, [2, 8192, 1792]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_140); convert_element_type_60 = view_140 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16) + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 32, '0'); convert_element_type_64 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + view_147 = torch.ops.aten.view.default(mul_15, [16384, 1792]); mul_15 = None + mm_13 = torch.ops.aten.mm.default(view_147, permute_21); view_147 = permute_21 = None + view_148 = torch.ops.aten.view.default(mm_13, [2, 8192, 4096]); mm_13 = None + split_16 = torch.ops.aten.split.Tensor(view_148, 1024, 1); view_148 = None + getitem_146 = split_16[0] + getitem_147 = split_16[1] + getitem_148 = split_16[2] + getitem_149 = split_16[3] + getitem_150 = split_16[4] + getitem_151 = split_16[5] + getitem_152 = split_16[6] + getitem_153 = split_16[7]; split_16 = None + cat_8 = torch.ops.aten.cat.default([getitem_146, getitem_147, getitem_148, getitem_149, getitem_150, getitem_151, getitem_152, getitem_153]); getitem_146 = getitem_147 = getitem_148 = getitem_149 = getitem_150 = getitem_151 = getitem_152 = getitem_153 = None + reduce_scatter_tensor_4 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_8, 'sum', 8, '1'); cat_8 = None + wait_tensor_27 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_4); reduce_scatter_tensor_4 = None + add_7 = torch.ops.aten.add.Tensor(add_5, wait_tensor_27); add_5 = wait_tensor_27 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16) + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 32, '0'); convert_element_type_67 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32) + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = rsqrt_4 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_28); mul_16 = wait_tensor_28 = None + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_69, 8, '1'); convert_element_type_69 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + split_17 = torch.ops.aten.split.Tensor(wait_tensor_29, 2); wait_tensor_29 = None + getitem_154 = split_17[0] + getitem_155 = split_17[1] + getitem_156 = split_17[2] + getitem_157 = split_17[3] + getitem_158 = split_17[4] + getitem_159 = split_17[5] + getitem_160 = split_17[6] + getitem_161 = split_17[7]; split_17 = None + cat_9 = torch.ops.aten.cat.default([getitem_154, getitem_155, getitem_156, getitem_157, getitem_158, getitem_159, getitem_160, getitem_161], 1); getitem_154 = getitem_155 = getitem_156 = getitem_157 = getitem_158 = getitem_159 = getitem_160 = getitem_161 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16) + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 32, '0'); convert_element_type_70 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + view_159 = torch.ops.aten.view.default(cat_9, [16384, 4096]); cat_9 = None + mm_14 = torch.ops.aten.mm.default(view_159, permute_22); permute_22 = None + view_160 = torch.ops.aten.view.default(mm_14, [2, 8192, 512]) + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16) + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 32, '0'); convert_element_type_73 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_15 = torch.ops.aten.mm.default(view_159, permute_23); permute_23 = None + view_167 = torch.ops.aten.view.default(mm_15, [2, 8192, 128]); mm_15 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16) + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 32, '0'); convert_element_type_76 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + mm_16 = torch.ops.aten.mm.default(view_159, permute_24); view_159 = permute_24 = None + view_174 = torch.ops.aten.view.default(mm_16, [2, 8192, 128]) + view_176 = torch.ops.aten.view.default(view_160, [2, 8192, -1, 128]); view_160 = None + view_177 = torch.ops.aten.view.default(view_167, [2, 8192, -1, 128]); view_167 = None + view_178 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_176, torch.float32); view_176 = None + view_179 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 4, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_179); view_179 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_177, torch.float32); view_177 = None + view_180 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 1, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_180); view_180 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_37); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_182 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 4, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_37); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_183 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 1, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_182, torch.bfloat16); view_182 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_183, torch.bfloat16); view_183 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 1, 4, 128]); unsqueeze_4 = None + view_184 = torch.ops.aten.view.default(expand_4, [2, 8192, 4, 128]); expand_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_178, 3); view_178 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 1, 4, 128]); unsqueeze_5 = None + view_185 = torch.ops.aten.view.default(expand_5, [2, 8192, 4, 128]); expand_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_184, [0, 2, 1, 3]); view_184 = None + permute_27 = torch.ops.aten.permute.default(view_185, [0, 2, 1, 3]); view_185 = None + _scaled_dot_product_cudnn_attention_2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_25, permute_26, permute_27, None, True, 0.0, True); permute_25 = permute_26 = permute_27 = None + getitem_162 = _scaled_dot_product_cudnn_attention_2[0] + getitem_163 = _scaled_dot_product_cudnn_attention_2[1] + getitem_168 = _scaled_dot_product_cudnn_attention_2[6] + getitem_169 = _scaled_dot_product_cudnn_attention_2[7]; _scaled_dot_product_cudnn_attention_2 = None + permute_28 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_186 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16) + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 32, '0'); convert_element_type_83 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_33, [1, 0]); wait_tensor_33 = None + view_192 = torch.ops.aten.view.default(view_186, [16384, 512]); view_186 = None + mm_17 = torch.ops.aten.mm.default(view_192, permute_29); view_192 = permute_29 = None + view_193 = torch.ops.aten.view.default(mm_17, [2, 8192, 4096]); mm_17 = None + split_18 = torch.ops.aten.split.Tensor(view_193, 1024, 1); view_193 = None + getitem_171 = split_18[0] + getitem_172 = split_18[1] + getitem_173 = split_18[2] + getitem_174 = split_18[3] + getitem_175 = split_18[4] + getitem_176 = split_18[5] + getitem_177 = split_18[6] + getitem_178 = split_18[7]; split_18 = None + cat_10 = torch.ops.aten.cat.default([getitem_171, getitem_172, getitem_173, getitem_174, getitem_175, getitem_176, getitem_177, getitem_178]); getitem_171 = getitem_172 = getitem_173 = getitem_174 = getitem_175 = getitem_176 = getitem_177 = getitem_178 = None + reduce_scatter_tensor_5 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_10, 'sum', 8, '1'); cat_10 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5) + add_9 = torch.ops.aten.add.Tensor(add_7, wait_tensor_34); wait_tensor_34 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16) + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 32, '0'); convert_element_type_86 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32) + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = rsqrt_5 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_35); mul_20 = wait_tensor_35 = None + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_88, 8, '1'); convert_element_type_88 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + split_19 = torch.ops.aten.split.Tensor(wait_tensor_36, 2); wait_tensor_36 = None + getitem_179 = split_19[0] + getitem_180 = split_19[1] + getitem_181 = split_19[2] + getitem_182 = split_19[3] + getitem_183 = split_19[4] + getitem_184 = split_19[5] + getitem_185 = split_19[6] + getitem_186 = split_19[7]; split_19 = None + cat_11 = torch.ops.aten.cat.default([getitem_179, getitem_180, getitem_181, getitem_182, getitem_183, getitem_184, getitem_185, getitem_186], 1); getitem_179 = getitem_180 = getitem_181 = getitem_182 = getitem_183 = getitem_184 = getitem_185 = getitem_186 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16) + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 32, '0'); convert_element_type_89 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_37, [1, 0]); wait_tensor_37 = None + view_204 = torch.ops.aten.view.default(cat_11, [16384, 4096]); cat_11 = None + mm_18 = torch.ops.aten.mm.default(view_204, permute_30); permute_30 = None + view_205 = torch.ops.aten.view.default(mm_18, [2, 8192, 1792]) + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_205, torch.float32); view_205 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); convert_element_type_92 = sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16) + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 32, '0'); convert_element_type_94 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + mm_19 = torch.ops.aten.mm.default(view_204, permute_31); view_204 = permute_31 = None + view_212 = torch.ops.aten.view.default(mm_19, [2, 8192, 1792]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_212); convert_element_type_93 = view_212 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16) + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 32, '0'); convert_element_type_97 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + view_219 = torch.ops.aten.view.default(mul_23, [16384, 1792]); mul_23 = None + mm_20 = torch.ops.aten.mm.default(view_219, permute_32); view_219 = permute_32 = None + view_220 = torch.ops.aten.view.default(mm_20, [2, 8192, 4096]); mm_20 = None + split_20 = torch.ops.aten.split.Tensor(view_220, 1024, 1); view_220 = None + getitem_187 = split_20[0] + getitem_188 = split_20[1] + getitem_189 = split_20[2] + getitem_190 = split_20[3] + getitem_191 = split_20[4] + getitem_192 = split_20[5] + getitem_193 = split_20[6] + getitem_194 = split_20[7]; split_20 = None + cat_12 = torch.ops.aten.cat.default([getitem_187, getitem_188, getitem_189, getitem_190, getitem_191, getitem_192, getitem_193, getitem_194]); getitem_187 = getitem_188 = getitem_189 = getitem_190 = getitem_191 = getitem_192 = getitem_193 = getitem_194 = None + reduce_scatter_tensor_6 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_12, 'sum', 8, '1'); cat_12 = None + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_6); reduce_scatter_tensor_6 = None + add_11 = torch.ops.aten.add.Tensor(add_9, wait_tensor_40); add_9 = wait_tensor_40 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16) + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 32, '0'); convert_element_type_100 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32) + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = rsqrt_6 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_41); mul_24 = wait_tensor_41 = None + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_102, 8, '1'); convert_element_type_102 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + split_21 = torch.ops.aten.split.Tensor(wait_tensor_42, 2); wait_tensor_42 = None + getitem_195 = split_21[0] + getitem_196 = split_21[1] + getitem_197 = split_21[2] + getitem_198 = split_21[3] + getitem_199 = split_21[4] + getitem_200 = split_21[5] + getitem_201 = split_21[6] + getitem_202 = split_21[7]; split_21 = None + cat_13 = torch.ops.aten.cat.default([getitem_195, getitem_196, getitem_197, getitem_198, getitem_199, getitem_200, getitem_201, getitem_202], 1); getitem_195 = getitem_196 = getitem_197 = getitem_198 = getitem_199 = getitem_200 = getitem_201 = getitem_202 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16) + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 32, '0'); convert_element_type_103 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + view_231 = torch.ops.aten.view.default(cat_13, [16384, 4096]); cat_13 = None + mm_21 = torch.ops.aten.mm.default(view_231, permute_33); permute_33 = None + view_232 = torch.ops.aten.view.default(mm_21, [2, 8192, 512]) + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16) + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 32, '0'); convert_element_type_106 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_22 = torch.ops.aten.mm.default(view_231, permute_34); permute_34 = None + view_239 = torch.ops.aten.view.default(mm_22, [2, 8192, 128]); mm_22 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16) + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 32, '0'); convert_element_type_109 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + mm_23 = torch.ops.aten.mm.default(view_231, permute_35); view_231 = permute_35 = None + view_246 = torch.ops.aten.view.default(mm_23, [2, 8192, 128]) + view_248 = torch.ops.aten.view.default(view_232, [2, 8192, -1, 128]); view_232 = None + view_249 = torch.ops.aten.view.default(view_239, [2, 8192, -1, 128]); view_239 = None + view_250 = torch.ops.aten.view.default(view_246, [2, 8192, -1, 128]); view_246 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_248, torch.float32); view_248 = None + view_251 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 4, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_251); view_251 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 1, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_37); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_254 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 4, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_37); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_255 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 1, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_254, torch.bfloat16); view_254 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 1, 4, 128]); unsqueeze_6 = None + view_256 = torch.ops.aten.view.default(expand_6, [2, 8192, 4, 128]); expand_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_250, 3); view_250 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 1, 4, 128]); unsqueeze_7 = None + view_257 = torch.ops.aten.view.default(expand_7, [2, 8192, 4, 128]); expand_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_256, [0, 2, 1, 3]); view_256 = None + permute_38 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + _scaled_dot_product_cudnn_attention_3 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_36, permute_37, permute_38, None, True, 0.0, True); permute_36 = permute_37 = permute_38 = None + getitem_203 = _scaled_dot_product_cudnn_attention_3[0] + getitem_204 = _scaled_dot_product_cudnn_attention_3[1] + getitem_209 = _scaled_dot_product_cudnn_attention_3[6] + getitem_210 = _scaled_dot_product_cudnn_attention_3[7]; _scaled_dot_product_cudnn_attention_3 = None + permute_39 = torch.ops.aten.permute.default(getitem_203, [0, 2, 1, 3]) + view_258 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16) + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 32, '0'); convert_element_type_116 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_46, [1, 0]); wait_tensor_46 = None + view_264 = torch.ops.aten.view.default(view_258, [16384, 512]); view_258 = None + mm_24 = torch.ops.aten.mm.default(view_264, permute_40); view_264 = permute_40 = None + view_265 = torch.ops.aten.view.default(mm_24, [2, 8192, 4096]); mm_24 = None + split_22 = torch.ops.aten.split.Tensor(view_265, 1024, 1); view_265 = None + getitem_212 = split_22[0] + getitem_213 = split_22[1] + getitem_214 = split_22[2] + getitem_215 = split_22[3] + getitem_216 = split_22[4] + getitem_217 = split_22[5] + getitem_218 = split_22[6] + getitem_219 = split_22[7]; split_22 = None + cat_14 = torch.ops.aten.cat.default([getitem_212, getitem_213, getitem_214, getitem_215, getitem_216, getitem_217, getitem_218, getitem_219]); getitem_212 = getitem_213 = getitem_214 = getitem_215 = getitem_216 = getitem_217 = getitem_218 = getitem_219 = None + reduce_scatter_tensor_7 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_14, 'sum', 8, '1'); cat_14 = None + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7) + add_13 = torch.ops.aten.add.Tensor(add_11, wait_tensor_47); wait_tensor_47 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16) + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 32, '0'); convert_element_type_119 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32) + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = rsqrt_7 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_48); mul_28 = wait_tensor_48 = None + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_121, 8, '1'); convert_element_type_121 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + split_23 = torch.ops.aten.split.Tensor(wait_tensor_49, 2); wait_tensor_49 = None + getitem_220 = split_23[0] + getitem_221 = split_23[1] + getitem_222 = split_23[2] + getitem_223 = split_23[3] + getitem_224 = split_23[4] + getitem_225 = split_23[5] + getitem_226 = split_23[6] + getitem_227 = split_23[7]; split_23 = None + cat_15 = torch.ops.aten.cat.default([getitem_220, getitem_221, getitem_222, getitem_223, getitem_224, getitem_225, getitem_226, getitem_227], 1); getitem_220 = getitem_221 = getitem_222 = getitem_223 = getitem_224 = getitem_225 = getitem_226 = getitem_227 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16) + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 32, '0'); convert_element_type_122 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + view_276 = torch.ops.aten.view.default(cat_15, [16384, 4096]); cat_15 = None + mm_25 = torch.ops.aten.mm.default(view_276, permute_41); permute_41 = None + view_277 = torch.ops.aten.view.default(mm_25, [2, 8192, 1792]) + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_277, torch.float32); view_277 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); convert_element_type_125 = sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16) + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 32, '0'); convert_element_type_127 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_51, [1, 0]); wait_tensor_51 = None + mm_26 = torch.ops.aten.mm.default(view_276, permute_42); view_276 = permute_42 = None + view_284 = torch.ops.aten.view.default(mm_26, [2, 8192, 1792]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_284); convert_element_type_126 = view_284 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16) + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 32, '0'); convert_element_type_130 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + view_291 = torch.ops.aten.view.default(mul_31, [16384, 1792]); mul_31 = None + mm_27 = torch.ops.aten.mm.default(view_291, permute_43); view_291 = permute_43 = None + view_292 = torch.ops.aten.view.default(mm_27, [2, 8192, 4096]); mm_27 = None + split_24 = torch.ops.aten.split.Tensor(view_292, 1024, 1); view_292 = None + getitem_228 = split_24[0] + getitem_229 = split_24[1] + getitem_230 = split_24[2] + getitem_231 = split_24[3] + getitem_232 = split_24[4] + getitem_233 = split_24[5] + getitem_234 = split_24[6] + getitem_235 = split_24[7]; split_24 = None + cat_16 = torch.ops.aten.cat.default([getitem_228, getitem_229, getitem_230, getitem_231, getitem_232, getitem_233, getitem_234, getitem_235]); getitem_228 = getitem_229 = getitem_230 = getitem_231 = getitem_232 = getitem_233 = getitem_234 = getitem_235 = None + reduce_scatter_tensor_8 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_16, 'sum', 8, '1'); cat_16 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_8); reduce_scatter_tensor_8 = None + add_15 = torch.ops.aten.add.Tensor(add_13, wait_tensor_53); add_13 = wait_tensor_53 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16) + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 32, '0'); convert_element_type_133 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32) + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = rsqrt_8 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_54); mul_32 = wait_tensor_54 = None + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_135, 8, '1'); convert_element_type_135 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + split_25 = torch.ops.aten.split.Tensor(wait_tensor_55, 2); wait_tensor_55 = None + getitem_236 = split_25[0] + getitem_237 = split_25[1] + getitem_238 = split_25[2] + getitem_239 = split_25[3] + getitem_240 = split_25[4] + getitem_241 = split_25[5] + getitem_242 = split_25[6] + getitem_243 = split_25[7]; split_25 = None + cat_17 = torch.ops.aten.cat.default([getitem_236, getitem_237, getitem_238, getitem_239, getitem_240, getitem_241, getitem_242, getitem_243], 1); getitem_236 = getitem_237 = getitem_238 = getitem_239 = getitem_240 = getitem_241 = getitem_242 = getitem_243 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16) + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 32, '0'); convert_element_type_136 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + view_303 = torch.ops.aten.view.default(cat_17, [16384, 4096]); cat_17 = None + mm_28 = torch.ops.aten.mm.default(view_303, permute_44); permute_44 = None + view_304 = torch.ops.aten.view.default(mm_28, [2, 8192, 512]) + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16) + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 32, '0'); convert_element_type_139 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_29 = torch.ops.aten.mm.default(view_303, permute_45); permute_45 = None + view_311 = torch.ops.aten.view.default(mm_29, [2, 8192, 128]); mm_29 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16) + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 32, '0'); convert_element_type_142 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + mm_30 = torch.ops.aten.mm.default(view_303, permute_46); view_303 = permute_46 = None + view_318 = torch.ops.aten.view.default(mm_30, [2, 8192, 128]) + view_320 = torch.ops.aten.view.default(view_304, [2, 8192, -1, 128]); view_304 = None + view_321 = torch.ops.aten.view.default(view_311, [2, 8192, -1, 128]); view_311 = None + view_322 = torch.ops.aten.view.default(view_318, [2, 8192, -1, 128]); view_318 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_320, torch.float32); view_320 = None + view_323 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 4, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_323); view_323 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_321, torch.float32); view_321 = None + view_324 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 1, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_324); view_324 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_37); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_326 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 4, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_37); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_327 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 1, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_326, torch.bfloat16); view_326 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_327, torch.bfloat16); view_327 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 1, 4, 128]); unsqueeze_8 = None + view_328 = torch.ops.aten.view.default(expand_8, [2, 8192, 4, 128]); expand_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_322, 3); view_322 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 1, 4, 128]); unsqueeze_9 = None + view_329 = torch.ops.aten.view.default(expand_9, [2, 8192, 4, 128]); expand_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_328, [0, 2, 1, 3]); view_328 = None + permute_49 = torch.ops.aten.permute.default(view_329, [0, 2, 1, 3]); view_329 = None + _scaled_dot_product_cudnn_attention_4 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_47, permute_48, permute_49, None, True, 0.0, True); permute_47 = permute_48 = permute_49 = None + getitem_244 = _scaled_dot_product_cudnn_attention_4[0] + getitem_245 = _scaled_dot_product_cudnn_attention_4[1] + getitem_250 = _scaled_dot_product_cudnn_attention_4[6] + getitem_251 = _scaled_dot_product_cudnn_attention_4[7]; _scaled_dot_product_cudnn_attention_4 = None + permute_50 = torch.ops.aten.permute.default(getitem_244, [0, 2, 1, 3]) + view_330 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16) + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 32, '0'); convert_element_type_149 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + view_336 = torch.ops.aten.view.default(view_330, [16384, 512]); view_330 = None + mm_31 = torch.ops.aten.mm.default(view_336, permute_51); view_336 = permute_51 = None + view_337 = torch.ops.aten.view.default(mm_31, [2, 8192, 4096]); mm_31 = None + split_26 = torch.ops.aten.split.Tensor(view_337, 1024, 1); view_337 = None + getitem_253 = split_26[0] + getitem_254 = split_26[1] + getitem_255 = split_26[2] + getitem_256 = split_26[3] + getitem_257 = split_26[4] + getitem_258 = split_26[5] + getitem_259 = split_26[6] + getitem_260 = split_26[7]; split_26 = None + cat_18 = torch.ops.aten.cat.default([getitem_253, getitem_254, getitem_255, getitem_256, getitem_257, getitem_258, getitem_259, getitem_260]); getitem_253 = getitem_254 = getitem_255 = getitem_256 = getitem_257 = getitem_258 = getitem_259 = getitem_260 = None + reduce_scatter_tensor_9 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_18, 'sum', 8, '1'); cat_18 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9) + add_17 = torch.ops.aten.add.Tensor(add_15, wait_tensor_60); wait_tensor_60 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16) + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 32, '0'); convert_element_type_152 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32) + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = rsqrt_9 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_61); mul_36 = wait_tensor_61 = None + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_154, 8, '1'); convert_element_type_154 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + split_27 = torch.ops.aten.split.Tensor(wait_tensor_62, 2); wait_tensor_62 = None + getitem_261 = split_27[0] + getitem_262 = split_27[1] + getitem_263 = split_27[2] + getitem_264 = split_27[3] + getitem_265 = split_27[4] + getitem_266 = split_27[5] + getitem_267 = split_27[6] + getitem_268 = split_27[7]; split_27 = None + cat_19 = torch.ops.aten.cat.default([getitem_261, getitem_262, getitem_263, getitem_264, getitem_265, getitem_266, getitem_267, getitem_268], 1); getitem_261 = getitem_262 = getitem_263 = getitem_264 = getitem_265 = getitem_266 = getitem_267 = getitem_268 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16) + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 32, '0'); convert_element_type_155 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + view_348 = torch.ops.aten.view.default(cat_19, [16384, 4096]); cat_19 = None + mm_32 = torch.ops.aten.mm.default(view_348, permute_52); permute_52 = None + view_349 = torch.ops.aten.view.default(mm_32, [2, 8192, 1792]) + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_349, torch.float32); view_349 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); convert_element_type_158 = sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16) + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 32, '0'); convert_element_type_160 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_64, [1, 0]); wait_tensor_64 = None + mm_33 = torch.ops.aten.mm.default(view_348, permute_53); view_348 = permute_53 = None + view_356 = torch.ops.aten.view.default(mm_33, [2, 8192, 1792]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_356); convert_element_type_159 = view_356 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16) + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 32, '0'); convert_element_type_163 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + view_363 = torch.ops.aten.view.default(mul_39, [16384, 1792]); mul_39 = None + mm_34 = torch.ops.aten.mm.default(view_363, permute_54); view_363 = permute_54 = None + view_364 = torch.ops.aten.view.default(mm_34, [2, 8192, 4096]); mm_34 = None + split_28 = torch.ops.aten.split.Tensor(view_364, 1024, 1); view_364 = None + getitem_269 = split_28[0] + getitem_270 = split_28[1] + getitem_271 = split_28[2] + getitem_272 = split_28[3] + getitem_273 = split_28[4] + getitem_274 = split_28[5] + getitem_275 = split_28[6] + getitem_276 = split_28[7]; split_28 = None + cat_20 = torch.ops.aten.cat.default([getitem_269, getitem_270, getitem_271, getitem_272, getitem_273, getitem_274, getitem_275, getitem_276]); getitem_269 = getitem_270 = getitem_271 = getitem_272 = getitem_273 = getitem_274 = getitem_275 = getitem_276 = None + reduce_scatter_tensor_10 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_20, 'sum', 8, '1'); cat_20 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_10); reduce_scatter_tensor_10 = None + add_19 = torch.ops.aten.add.Tensor(add_17, wait_tensor_66); add_17 = wait_tensor_66 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16) + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 32, '0'); convert_element_type_166 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32) + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = rsqrt_10 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_67); mul_40 = wait_tensor_67 = None + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_168, 8, '1'); convert_element_type_168 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + split_29 = torch.ops.aten.split.Tensor(wait_tensor_68, 2); wait_tensor_68 = None + getitem_277 = split_29[0] + getitem_278 = split_29[1] + getitem_279 = split_29[2] + getitem_280 = split_29[3] + getitem_281 = split_29[4] + getitem_282 = split_29[5] + getitem_283 = split_29[6] + getitem_284 = split_29[7]; split_29 = None + cat_21 = torch.ops.aten.cat.default([getitem_277, getitem_278, getitem_279, getitem_280, getitem_281, getitem_282, getitem_283, getitem_284], 1); getitem_277 = getitem_278 = getitem_279 = getitem_280 = getitem_281 = getitem_282 = getitem_283 = getitem_284 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16) + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 32, '0'); convert_element_type_169 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_69, [1, 0]); wait_tensor_69 = None + view_375 = torch.ops.aten.view.default(cat_21, [16384, 4096]); cat_21 = None + mm_35 = torch.ops.aten.mm.default(view_375, permute_55); permute_55 = None + view_376 = torch.ops.aten.view.default(mm_35, [2, 8192, 512]) + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16) + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 32, '0'); convert_element_type_172 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + mm_36 = torch.ops.aten.mm.default(view_375, permute_56); permute_56 = None + view_383 = torch.ops.aten.view.default(mm_36, [2, 8192, 128]); mm_36 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16) + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 32, '0'); convert_element_type_175 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_37 = torch.ops.aten.mm.default(view_375, permute_57); view_375 = permute_57 = None + view_390 = torch.ops.aten.view.default(mm_37, [2, 8192, 128]) + view_392 = torch.ops.aten.view.default(view_376, [2, 8192, -1, 128]); view_376 = None + view_393 = torch.ops.aten.view.default(view_383, [2, 8192, -1, 128]); view_383 = None + view_394 = torch.ops.aten.view.default(view_390, [2, 8192, -1, 128]); view_390 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_392, torch.float32); view_392 = None + view_395 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 4, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_395); view_395 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_393, torch.float32); view_393 = None + view_396 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 1, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_396); view_396 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_37); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_398 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 4, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_37); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_399 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 1, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_398, torch.bfloat16); view_398 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_399, torch.bfloat16); view_399 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 1, 4, 128]); unsqueeze_10 = None + view_400 = torch.ops.aten.view.default(expand_10, [2, 8192, 4, 128]); expand_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_394, 3); view_394 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 1, 4, 128]); unsqueeze_11 = None + view_401 = torch.ops.aten.view.default(expand_11, [2, 8192, 4, 128]); expand_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_400, [0, 2, 1, 3]); view_400 = None + permute_60 = torch.ops.aten.permute.default(view_401, [0, 2, 1, 3]); view_401 = None + _scaled_dot_product_cudnn_attention_5 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_58, permute_59, permute_60, None, True, 0.0, True); permute_58 = permute_59 = permute_60 = None + getitem_285 = _scaled_dot_product_cudnn_attention_5[0] + getitem_286 = _scaled_dot_product_cudnn_attention_5[1] + getitem_291 = _scaled_dot_product_cudnn_attention_5[6] + getitem_292 = _scaled_dot_product_cudnn_attention_5[7]; _scaled_dot_product_cudnn_attention_5 = None + permute_61 = torch.ops.aten.permute.default(getitem_285, [0, 2, 1, 3]) + view_402 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16) + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 32, '0'); convert_element_type_182 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + view_408 = torch.ops.aten.view.default(view_402, [16384, 512]); view_402 = None + mm_38 = torch.ops.aten.mm.default(view_408, permute_62); view_408 = permute_62 = None + view_409 = torch.ops.aten.view.default(mm_38, [2, 8192, 4096]); mm_38 = None + split_30 = torch.ops.aten.split.Tensor(view_409, 1024, 1); view_409 = None + getitem_294 = split_30[0] + getitem_295 = split_30[1] + getitem_296 = split_30[2] + getitem_297 = split_30[3] + getitem_298 = split_30[4] + getitem_299 = split_30[5] + getitem_300 = split_30[6] + getitem_301 = split_30[7]; split_30 = None + cat_22 = torch.ops.aten.cat.default([getitem_294, getitem_295, getitem_296, getitem_297, getitem_298, getitem_299, getitem_300, getitem_301]); getitem_294 = getitem_295 = getitem_296 = getitem_297 = getitem_298 = getitem_299 = getitem_300 = getitem_301 = None + reduce_scatter_tensor_11 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_22, 'sum', 8, '1'); cat_22 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11) + add_21 = torch.ops.aten.add.Tensor(add_19, wait_tensor_73); wait_tensor_73 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16) + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 32, '0'); convert_element_type_185 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32) + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = rsqrt_11 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_74); mul_44 = wait_tensor_74 = None + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_187, 8, '1'); convert_element_type_187 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + split_31 = torch.ops.aten.split.Tensor(wait_tensor_75, 2); wait_tensor_75 = None + getitem_302 = split_31[0] + getitem_303 = split_31[1] + getitem_304 = split_31[2] + getitem_305 = split_31[3] + getitem_306 = split_31[4] + getitem_307 = split_31[5] + getitem_308 = split_31[6] + getitem_309 = split_31[7]; split_31 = None + cat_23 = torch.ops.aten.cat.default([getitem_302, getitem_303, getitem_304, getitem_305, getitem_306, getitem_307, getitem_308, getitem_309], 1); getitem_302 = getitem_303 = getitem_304 = getitem_305 = getitem_306 = getitem_307 = getitem_308 = getitem_309 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16) + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 32, '0'); convert_element_type_188 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + view_420 = torch.ops.aten.view.default(cat_23, [16384, 4096]); cat_23 = None + mm_39 = torch.ops.aten.mm.default(view_420, permute_63); permute_63 = None + view_421 = torch.ops.aten.view.default(mm_39, [2, 8192, 1792]) + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_421, torch.float32); view_421 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); convert_element_type_191 = sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16) + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 32, '0'); convert_element_type_193 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + mm_40 = torch.ops.aten.mm.default(view_420, permute_64); view_420 = permute_64 = None + view_428 = torch.ops.aten.view.default(mm_40, [2, 8192, 1792]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_428); convert_element_type_192 = view_428 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16) + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 32, '0'); convert_element_type_196 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_78, [1, 0]); wait_tensor_78 = None + view_435 = torch.ops.aten.view.default(mul_47, [16384, 1792]); mul_47 = None + mm_41 = torch.ops.aten.mm.default(view_435, permute_65); view_435 = permute_65 = None + view_436 = torch.ops.aten.view.default(mm_41, [2, 8192, 4096]); mm_41 = None + split_32 = torch.ops.aten.split.Tensor(view_436, 1024, 1); view_436 = None + getitem_310 = split_32[0] + getitem_311 = split_32[1] + getitem_312 = split_32[2] + getitem_313 = split_32[3] + getitem_314 = split_32[4] + getitem_315 = split_32[5] + getitem_316 = split_32[6] + getitem_317 = split_32[7]; split_32 = None + cat_24 = torch.ops.aten.cat.default([getitem_310, getitem_311, getitem_312, getitem_313, getitem_314, getitem_315, getitem_316, getitem_317]); getitem_310 = getitem_311 = getitem_312 = getitem_313 = getitem_314 = getitem_315 = getitem_316 = getitem_317 = None + reduce_scatter_tensor_12 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_24, 'sum', 8, '1'); cat_24 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_12); reduce_scatter_tensor_12 = None + add_23 = torch.ops.aten.add.Tensor(add_21, wait_tensor_79); add_21 = wait_tensor_79 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16) + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 32, '0'); convert_element_type_199 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32) + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = rsqrt_12 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_80); mul_48 = wait_tensor_80 = None + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_201, 8, '1'); convert_element_type_201 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + split_33 = torch.ops.aten.split.Tensor(wait_tensor_81, 2); wait_tensor_81 = None + getitem_318 = split_33[0] + getitem_319 = split_33[1] + getitem_320 = split_33[2] + getitem_321 = split_33[3] + getitem_322 = split_33[4] + getitem_323 = split_33[5] + getitem_324 = split_33[6] + getitem_325 = split_33[7]; split_33 = None + cat_25 = torch.ops.aten.cat.default([getitem_318, getitem_319, getitem_320, getitem_321, getitem_322, getitem_323, getitem_324, getitem_325], 1); getitem_318 = getitem_319 = getitem_320 = getitem_321 = getitem_322 = getitem_323 = getitem_324 = getitem_325 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16) + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 32, '0'); convert_element_type_202 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_82, [1, 0]); wait_tensor_82 = None + view_447 = torch.ops.aten.view.default(cat_25, [16384, 4096]); cat_25 = None + mm_42 = torch.ops.aten.mm.default(view_447, permute_66); permute_66 = None + view_448 = torch.ops.aten.view.default(mm_42, [2, 8192, 512]) + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16) + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 32, '0'); convert_element_type_205 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + mm_43 = torch.ops.aten.mm.default(view_447, permute_67); permute_67 = None + view_455 = torch.ops.aten.view.default(mm_43, [2, 8192, 128]); mm_43 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16) + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 32, '0'); convert_element_type_208 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + mm_44 = torch.ops.aten.mm.default(view_447, permute_68); view_447 = permute_68 = None + view_462 = torch.ops.aten.view.default(mm_44, [2, 8192, 128]) + view_464 = torch.ops.aten.view.default(view_448, [2, 8192, -1, 128]); view_448 = None + view_465 = torch.ops.aten.view.default(view_455, [2, 8192, -1, 128]); view_455 = None + view_466 = torch.ops.aten.view.default(view_462, [2, 8192, -1, 128]); view_462 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_464, torch.float32); view_464 = None + view_467 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 4, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_467); view_467 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_465, torch.float32); view_465 = None + view_468 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 1, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_468); view_468 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_37); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_470 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 4, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_37); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_471 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 1, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_470, torch.bfloat16); view_470 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_471, torch.bfloat16); view_471 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 1, 4, 128]); unsqueeze_12 = None + view_472 = torch.ops.aten.view.default(expand_12, [2, 8192, 4, 128]); expand_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_466, 3); view_466 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 1, 4, 128]); unsqueeze_13 = None + view_473 = torch.ops.aten.view.default(expand_13, [2, 8192, 4, 128]); expand_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_472, [0, 2, 1, 3]); view_472 = None + permute_71 = torch.ops.aten.permute.default(view_473, [0, 2, 1, 3]); view_473 = None + _scaled_dot_product_cudnn_attention_6 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_69, permute_70, permute_71, None, True, 0.0, True); permute_69 = permute_70 = permute_71 = None + getitem_326 = _scaled_dot_product_cudnn_attention_6[0] + getitem_327 = _scaled_dot_product_cudnn_attention_6[1] + getitem_332 = _scaled_dot_product_cudnn_attention_6[6] + getitem_333 = _scaled_dot_product_cudnn_attention_6[7]; _scaled_dot_product_cudnn_attention_6 = None + permute_72 = torch.ops.aten.permute.default(getitem_326, [0, 2, 1, 3]) + view_474 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16) + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 32, '0'); convert_element_type_215 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + view_480 = torch.ops.aten.view.default(view_474, [16384, 512]); view_474 = None + mm_45 = torch.ops.aten.mm.default(view_480, permute_73); view_480 = permute_73 = None + view_481 = torch.ops.aten.view.default(mm_45, [2, 8192, 4096]); mm_45 = None + split_34 = torch.ops.aten.split.Tensor(view_481, 1024, 1); view_481 = None + getitem_335 = split_34[0] + getitem_336 = split_34[1] + getitem_337 = split_34[2] + getitem_338 = split_34[3] + getitem_339 = split_34[4] + getitem_340 = split_34[5] + getitem_341 = split_34[6] + getitem_342 = split_34[7]; split_34 = None + cat_26 = torch.ops.aten.cat.default([getitem_335, getitem_336, getitem_337, getitem_338, getitem_339, getitem_340, getitem_341, getitem_342]); getitem_335 = getitem_336 = getitem_337 = getitem_338 = getitem_339 = getitem_340 = getitem_341 = getitem_342 = None + reduce_scatter_tensor_13 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_26, 'sum', 8, '1'); cat_26 = None + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13) + add_25 = torch.ops.aten.add.Tensor(add_23, wait_tensor_86); wait_tensor_86 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16) + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 32, '0'); convert_element_type_218 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32) + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = rsqrt_13 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_87); mul_52 = wait_tensor_87 = None + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_220, 8, '1'); convert_element_type_220 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + split_35 = torch.ops.aten.split.Tensor(wait_tensor_88, 2); wait_tensor_88 = None + getitem_343 = split_35[0] + getitem_344 = split_35[1] + getitem_345 = split_35[2] + getitem_346 = split_35[3] + getitem_347 = split_35[4] + getitem_348 = split_35[5] + getitem_349 = split_35[6] + getitem_350 = split_35[7]; split_35 = None + cat_27 = torch.ops.aten.cat.default([getitem_343, getitem_344, getitem_345, getitem_346, getitem_347, getitem_348, getitem_349, getitem_350], 1); getitem_343 = getitem_344 = getitem_345 = getitem_346 = getitem_347 = getitem_348 = getitem_349 = getitem_350 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16) + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 32, '0'); convert_element_type_221 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + view_492 = torch.ops.aten.view.default(cat_27, [16384, 4096]); cat_27 = None + mm_46 = torch.ops.aten.mm.default(view_492, permute_74); permute_74 = None + view_493 = torch.ops.aten.view.default(mm_46, [2, 8192, 1792]) + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_493, torch.float32); view_493 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); convert_element_type_224 = sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16) + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 32, '0'); convert_element_type_226 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + mm_47 = torch.ops.aten.mm.default(view_492, permute_75); view_492 = permute_75 = None + view_500 = torch.ops.aten.view.default(mm_47, [2, 8192, 1792]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_500); convert_element_type_225 = view_500 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16) + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 32, '0'); convert_element_type_229 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_91, [1, 0]); wait_tensor_91 = None + view_507 = torch.ops.aten.view.default(mul_55, [16384, 1792]); mul_55 = None + mm_48 = torch.ops.aten.mm.default(view_507, permute_76); view_507 = permute_76 = None + view_508 = torch.ops.aten.view.default(mm_48, [2, 8192, 4096]); mm_48 = None + split_36 = torch.ops.aten.split.Tensor(view_508, 1024, 1); view_508 = None + getitem_351 = split_36[0] + getitem_352 = split_36[1] + getitem_353 = split_36[2] + getitem_354 = split_36[3] + getitem_355 = split_36[4] + getitem_356 = split_36[5] + getitem_357 = split_36[6] + getitem_358 = split_36[7]; split_36 = None + cat_28 = torch.ops.aten.cat.default([getitem_351, getitem_352, getitem_353, getitem_354, getitem_355, getitem_356, getitem_357, getitem_358]); getitem_351 = getitem_352 = getitem_353 = getitem_354 = getitem_355 = getitem_356 = getitem_357 = getitem_358 = None + reduce_scatter_tensor_14 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_28, 'sum', 8, '1'); cat_28 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_14); reduce_scatter_tensor_14 = None + add_27 = torch.ops.aten.add.Tensor(add_25, wait_tensor_92); add_25 = wait_tensor_92 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16) + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 32, '0'); convert_element_type_232 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32) + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = rsqrt_14 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_93); mul_56 = wait_tensor_93 = None + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_234, 8, '1'); convert_element_type_234 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + split_37 = torch.ops.aten.split.Tensor(wait_tensor_94, 2); wait_tensor_94 = None + getitem_359 = split_37[0] + getitem_360 = split_37[1] + getitem_361 = split_37[2] + getitem_362 = split_37[3] + getitem_363 = split_37[4] + getitem_364 = split_37[5] + getitem_365 = split_37[6] + getitem_366 = split_37[7]; split_37 = None + cat_29 = torch.ops.aten.cat.default([getitem_359, getitem_360, getitem_361, getitem_362, getitem_363, getitem_364, getitem_365, getitem_366], 1); getitem_359 = getitem_360 = getitem_361 = getitem_362 = getitem_363 = getitem_364 = getitem_365 = getitem_366 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16) + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 32, '0'); convert_element_type_235 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + view_519 = torch.ops.aten.view.default(cat_29, [16384, 4096]); cat_29 = None + mm_49 = torch.ops.aten.mm.default(view_519, permute_77); permute_77 = None + view_520 = torch.ops.aten.view.default(mm_49, [2, 8192, 512]) + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16) + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 32, '0'); convert_element_type_238 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_96, [1, 0]); wait_tensor_96 = None + mm_50 = torch.ops.aten.mm.default(view_519, permute_78); permute_78 = None + view_527 = torch.ops.aten.view.default(mm_50, [2, 8192, 128]); mm_50 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16) + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 32, '0'); convert_element_type_241 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + mm_51 = torch.ops.aten.mm.default(view_519, permute_79); view_519 = permute_79 = None + view_534 = torch.ops.aten.view.default(mm_51, [2, 8192, 128]) + view_536 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + view_537 = torch.ops.aten.view.default(view_527, [2, 8192, -1, 128]); view_527 = None + view_538 = torch.ops.aten.view.default(view_534, [2, 8192, -1, 128]); view_534 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_536, torch.float32); view_536 = None + view_539 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 4, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_539); view_539 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_537, torch.float32); view_537 = None + view_540 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 1, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_540); view_540 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_37); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_542 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 4, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_37); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_543 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 1, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_542, torch.bfloat16); view_542 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_543, torch.bfloat16); view_543 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 1, 4, 128]); unsqueeze_14 = None + view_544 = torch.ops.aten.view.default(expand_14, [2, 8192, 4, 128]); expand_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_538, 3); view_538 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 1, 4, 128]); unsqueeze_15 = None + view_545 = torch.ops.aten.view.default(expand_15, [2, 8192, 4, 128]); expand_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_544, [0, 2, 1, 3]); view_544 = None + permute_82 = torch.ops.aten.permute.default(view_545, [0, 2, 1, 3]); view_545 = None + _scaled_dot_product_cudnn_attention_7 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_80, permute_81, permute_82, None, True, 0.0, True); permute_80 = permute_81 = permute_82 = None + getitem_367 = _scaled_dot_product_cudnn_attention_7[0] + getitem_368 = _scaled_dot_product_cudnn_attention_7[1] + getitem_373 = _scaled_dot_product_cudnn_attention_7[6] + getitem_374 = _scaled_dot_product_cudnn_attention_7[7]; _scaled_dot_product_cudnn_attention_7 = None + permute_83 = torch.ops.aten.permute.default(getitem_367, [0, 2, 1, 3]) + view_546 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16) + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 32, '0'); convert_element_type_248 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + view_552 = torch.ops.aten.view.default(view_546, [16384, 512]); view_546 = None + mm_52 = torch.ops.aten.mm.default(view_552, permute_84); view_552 = permute_84 = None + view_553 = torch.ops.aten.view.default(mm_52, [2, 8192, 4096]); mm_52 = None + split_38 = torch.ops.aten.split.Tensor(view_553, 1024, 1); view_553 = None + getitem_376 = split_38[0] + getitem_377 = split_38[1] + getitem_378 = split_38[2] + getitem_379 = split_38[3] + getitem_380 = split_38[4] + getitem_381 = split_38[5] + getitem_382 = split_38[6] + getitem_383 = split_38[7]; split_38 = None + cat_30 = torch.ops.aten.cat.default([getitem_376, getitem_377, getitem_378, getitem_379, getitem_380, getitem_381, getitem_382, getitem_383]); getitem_376 = getitem_377 = getitem_378 = getitem_379 = getitem_380 = getitem_381 = getitem_382 = getitem_383 = None + reduce_scatter_tensor_15 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_30, 'sum', 8, '1'); cat_30 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15) + add_29 = torch.ops.aten.add.Tensor(add_27, wait_tensor_99); wait_tensor_99 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16) + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 32, '0'); convert_element_type_251 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32) + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = rsqrt_15 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_100); mul_60 = wait_tensor_100 = None + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_253, 8, '1'); convert_element_type_253 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + split_39 = torch.ops.aten.split.Tensor(wait_tensor_101, 2); wait_tensor_101 = None + getitem_384 = split_39[0] + getitem_385 = split_39[1] + getitem_386 = split_39[2] + getitem_387 = split_39[3] + getitem_388 = split_39[4] + getitem_389 = split_39[5] + getitem_390 = split_39[6] + getitem_391 = split_39[7]; split_39 = None + cat_31 = torch.ops.aten.cat.default([getitem_384, getitem_385, getitem_386, getitem_387, getitem_388, getitem_389, getitem_390, getitem_391], 1); getitem_384 = getitem_385 = getitem_386 = getitem_387 = getitem_388 = getitem_389 = getitem_390 = getitem_391 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16) + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 32, '0'); convert_element_type_254 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + view_564 = torch.ops.aten.view.default(cat_31, [16384, 4096]); cat_31 = None + mm_53 = torch.ops.aten.mm.default(view_564, permute_85); permute_85 = None + view_565 = torch.ops.aten.view.default(mm_53, [2, 8192, 1792]) + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_565, torch.float32); view_565 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); convert_element_type_257 = sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16) + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 32, '0'); convert_element_type_259 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + mm_54 = torch.ops.aten.mm.default(view_564, permute_86); view_564 = permute_86 = None + view_572 = torch.ops.aten.view.default(mm_54, [2, 8192, 1792]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_572); convert_element_type_258 = view_572 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16) + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 32, '0'); convert_element_type_262 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + view_579 = torch.ops.aten.view.default(mul_63, [16384, 1792]); mul_63 = None + mm_55 = torch.ops.aten.mm.default(view_579, permute_87); view_579 = permute_87 = None + view_580 = torch.ops.aten.view.default(mm_55, [2, 8192, 4096]); mm_55 = None + split_40 = torch.ops.aten.split.Tensor(view_580, 1024, 1); view_580 = None + getitem_392 = split_40[0] + getitem_393 = split_40[1] + getitem_394 = split_40[2] + getitem_395 = split_40[3] + getitem_396 = split_40[4] + getitem_397 = split_40[5] + getitem_398 = split_40[6] + getitem_399 = split_40[7]; split_40 = None + cat_32 = torch.ops.aten.cat.default([getitem_392, getitem_393, getitem_394, getitem_395, getitem_396, getitem_397, getitem_398, getitem_399]); getitem_392 = getitem_393 = getitem_394 = getitem_395 = getitem_396 = getitem_397 = getitem_398 = getitem_399 = None + reduce_scatter_tensor_16 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_32, 'sum', 8, '1'); cat_32 = None + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_16); reduce_scatter_tensor_16 = None + add_31 = torch.ops.aten.add.Tensor(add_29, wait_tensor_105); add_29 = wait_tensor_105 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16) + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 32, '0'); convert_element_type_265 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32) + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = rsqrt_16 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_106); mul_64 = wait_tensor_106 = None + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_267, 8, '1'); convert_element_type_267 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + split_41 = torch.ops.aten.split.Tensor(wait_tensor_107, 2); wait_tensor_107 = None + getitem_400 = split_41[0] + getitem_401 = split_41[1] + getitem_402 = split_41[2] + getitem_403 = split_41[3] + getitem_404 = split_41[4] + getitem_405 = split_41[5] + getitem_406 = split_41[6] + getitem_407 = split_41[7]; split_41 = None + cat_33 = torch.ops.aten.cat.default([getitem_400, getitem_401, getitem_402, getitem_403, getitem_404, getitem_405, getitem_406, getitem_407], 1); getitem_400 = getitem_401 = getitem_402 = getitem_403 = getitem_404 = getitem_405 = getitem_406 = getitem_407 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16) + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 32, '0'); convert_element_type_268 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + view_591 = torch.ops.aten.view.default(cat_33, [16384, 4096]); cat_33 = None + mm_56 = torch.ops.aten.mm.default(view_591, permute_88); permute_88 = None + view_592 = torch.ops.aten.view.default(mm_56, [2, 8192, 512]) + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16) + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 32, '0'); convert_element_type_271 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_109, [1, 0]); wait_tensor_109 = None + mm_57 = torch.ops.aten.mm.default(view_591, permute_89); permute_89 = None + view_599 = torch.ops.aten.view.default(mm_57, [2, 8192, 128]); mm_57 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16) + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 32, '0'); convert_element_type_274 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + mm_58 = torch.ops.aten.mm.default(view_591, permute_90); view_591 = permute_90 = None + view_606 = torch.ops.aten.view.default(mm_58, [2, 8192, 128]) + view_608 = torch.ops.aten.view.default(view_592, [2, 8192, -1, 128]); view_592 = None + view_609 = torch.ops.aten.view.default(view_599, [2, 8192, -1, 128]); view_599 = None + view_610 = torch.ops.aten.view.default(view_606, [2, 8192, -1, 128]); view_606 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_608, torch.float32); view_608 = None + view_611 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 4, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_611); view_611 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_609, torch.float32); view_609 = None + view_612 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 1, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_612); view_612 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_37); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_614 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 4, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_37); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_615 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 1, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_614, torch.bfloat16); view_614 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_615, torch.bfloat16); view_615 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 1, 4, 128]); unsqueeze_16 = None + view_616 = torch.ops.aten.view.default(expand_16, [2, 8192, 4, 128]); expand_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_610, 3); view_610 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 1, 4, 128]); unsqueeze_17 = None + view_617 = torch.ops.aten.view.default(expand_17, [2, 8192, 4, 128]); expand_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_616, [0, 2, 1, 3]); view_616 = None + permute_93 = torch.ops.aten.permute.default(view_617, [0, 2, 1, 3]); view_617 = None + _scaled_dot_product_cudnn_attention_8 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_91, permute_92, permute_93, None, True, 0.0, True); permute_91 = permute_92 = permute_93 = None + getitem_408 = _scaled_dot_product_cudnn_attention_8[0] + getitem_409 = _scaled_dot_product_cudnn_attention_8[1] + getitem_414 = _scaled_dot_product_cudnn_attention_8[6] + getitem_415 = _scaled_dot_product_cudnn_attention_8[7]; _scaled_dot_product_cudnn_attention_8 = None + permute_94 = torch.ops.aten.permute.default(getitem_408, [0, 2, 1, 3]) + view_618 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16) + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 32, '0'); convert_element_type_281 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + view_624 = torch.ops.aten.view.default(view_618, [16384, 512]); view_618 = None + mm_59 = torch.ops.aten.mm.default(view_624, permute_95); view_624 = permute_95 = None + view_625 = torch.ops.aten.view.default(mm_59, [2, 8192, 4096]); mm_59 = None + split_42 = torch.ops.aten.split.Tensor(view_625, 1024, 1); view_625 = None + getitem_417 = split_42[0] + getitem_418 = split_42[1] + getitem_419 = split_42[2] + getitem_420 = split_42[3] + getitem_421 = split_42[4] + getitem_422 = split_42[5] + getitem_423 = split_42[6] + getitem_424 = split_42[7]; split_42 = None + cat_34 = torch.ops.aten.cat.default([getitem_417, getitem_418, getitem_419, getitem_420, getitem_421, getitem_422, getitem_423, getitem_424]); getitem_417 = getitem_418 = getitem_419 = getitem_420 = getitem_421 = getitem_422 = getitem_423 = getitem_424 = None + reduce_scatter_tensor_17 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_34, 'sum', 8, '1'); cat_34 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17) + add_33 = torch.ops.aten.add.Tensor(add_31, wait_tensor_112); wait_tensor_112 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16) + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 32, '0'); convert_element_type_284 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32) + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = rsqrt_17 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_113); mul_68 = wait_tensor_113 = None + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_286, 8, '1'); convert_element_type_286 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + split_43 = torch.ops.aten.split.Tensor(wait_tensor_114, 2); wait_tensor_114 = None + getitem_425 = split_43[0] + getitem_426 = split_43[1] + getitem_427 = split_43[2] + getitem_428 = split_43[3] + getitem_429 = split_43[4] + getitem_430 = split_43[5] + getitem_431 = split_43[6] + getitem_432 = split_43[7]; split_43 = None + cat_35 = torch.ops.aten.cat.default([getitem_425, getitem_426, getitem_427, getitem_428, getitem_429, getitem_430, getitem_431, getitem_432], 1); getitem_425 = getitem_426 = getitem_427 = getitem_428 = getitem_429 = getitem_430 = getitem_431 = getitem_432 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16) + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 32, '0'); convert_element_type_287 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + view_636 = torch.ops.aten.view.default(cat_35, [16384, 4096]); cat_35 = None + mm_60 = torch.ops.aten.mm.default(view_636, permute_96); permute_96 = None + view_637 = torch.ops.aten.view.default(mm_60, [2, 8192, 1792]) + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_637, torch.float32); view_637 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); convert_element_type_290 = sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16) + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 32, '0'); convert_element_type_292 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_61 = torch.ops.aten.mm.default(view_636, permute_97); view_636 = permute_97 = None + view_644 = torch.ops.aten.view.default(mm_61, [2, 8192, 1792]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_644); convert_element_type_291 = view_644 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16) + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 32, '0'); convert_element_type_295 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + view_651 = torch.ops.aten.view.default(mul_71, [16384, 1792]); mul_71 = None + mm_62 = torch.ops.aten.mm.default(view_651, permute_98); view_651 = permute_98 = None + view_652 = torch.ops.aten.view.default(mm_62, [2, 8192, 4096]); mm_62 = None + split_44 = torch.ops.aten.split.Tensor(view_652, 1024, 1); view_652 = None + getitem_433 = split_44[0] + getitem_434 = split_44[1] + getitem_435 = split_44[2] + getitem_436 = split_44[3] + getitem_437 = split_44[4] + getitem_438 = split_44[5] + getitem_439 = split_44[6] + getitem_440 = split_44[7]; split_44 = None + cat_36 = torch.ops.aten.cat.default([getitem_433, getitem_434, getitem_435, getitem_436, getitem_437, getitem_438, getitem_439, getitem_440]); getitem_433 = getitem_434 = getitem_435 = getitem_436 = getitem_437 = getitem_438 = getitem_439 = getitem_440 = None + reduce_scatter_tensor_18 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_36, 'sum', 8, '1'); cat_36 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_18); reduce_scatter_tensor_18 = None + add_35 = torch.ops.aten.add.Tensor(add_33, wait_tensor_118); add_33 = wait_tensor_118 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16) + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 32, '0'); convert_element_type_298 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32) + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = rsqrt_18 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_119); mul_72 = wait_tensor_119 = None + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_300, 8, '1'); convert_element_type_300 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + split_45 = torch.ops.aten.split.Tensor(wait_tensor_120, 2); wait_tensor_120 = None + getitem_441 = split_45[0] + getitem_442 = split_45[1] + getitem_443 = split_45[2] + getitem_444 = split_45[3] + getitem_445 = split_45[4] + getitem_446 = split_45[5] + getitem_447 = split_45[6] + getitem_448 = split_45[7]; split_45 = None + cat_37 = torch.ops.aten.cat.default([getitem_441, getitem_442, getitem_443, getitem_444, getitem_445, getitem_446, getitem_447, getitem_448], 1); getitem_441 = getitem_442 = getitem_443 = getitem_444 = getitem_445 = getitem_446 = getitem_447 = getitem_448 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16) + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 32, '0'); convert_element_type_301 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + view_663 = torch.ops.aten.view.default(cat_37, [16384, 4096]); cat_37 = None + mm_63 = torch.ops.aten.mm.default(view_663, permute_99); permute_99 = None + view_664 = torch.ops.aten.view.default(mm_63, [2, 8192, 512]) + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16) + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 32, '0'); convert_element_type_304 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + mm_64 = torch.ops.aten.mm.default(view_663, permute_100); permute_100 = None + view_671 = torch.ops.aten.view.default(mm_64, [2, 8192, 128]); mm_64 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16) + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 32, '0'); convert_element_type_307 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_123, [1, 0]); wait_tensor_123 = None + mm_65 = torch.ops.aten.mm.default(view_663, permute_101); view_663 = permute_101 = None + view_678 = torch.ops.aten.view.default(mm_65, [2, 8192, 128]) + view_680 = torch.ops.aten.view.default(view_664, [2, 8192, -1, 128]); view_664 = None + view_681 = torch.ops.aten.view.default(view_671, [2, 8192, -1, 128]); view_671 = None + view_682 = torch.ops.aten.view.default(view_678, [2, 8192, -1, 128]); view_678 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_680, torch.float32); view_680 = None + view_683 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 4, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_683); view_683 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_681, torch.float32); view_681 = None + view_684 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 1, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_684); view_684 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_37); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_686 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 4, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_37); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_687 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 1, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_686, torch.bfloat16); view_686 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_687, torch.bfloat16); view_687 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 1, 4, 128]); unsqueeze_18 = None + view_688 = torch.ops.aten.view.default(expand_18, [2, 8192, 4, 128]); expand_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_682, 3); view_682 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 1, 4, 128]); unsqueeze_19 = None + view_689 = torch.ops.aten.view.default(expand_19, [2, 8192, 4, 128]); expand_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_688, [0, 2, 1, 3]); view_688 = None + permute_104 = torch.ops.aten.permute.default(view_689, [0, 2, 1, 3]); view_689 = None + _scaled_dot_product_cudnn_attention_9 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_102, permute_103, permute_104, None, True, 0.0, True); permute_102 = permute_103 = permute_104 = None + getitem_449 = _scaled_dot_product_cudnn_attention_9[0] + getitem_450 = _scaled_dot_product_cudnn_attention_9[1] + getitem_455 = _scaled_dot_product_cudnn_attention_9[6] + getitem_456 = _scaled_dot_product_cudnn_attention_9[7]; _scaled_dot_product_cudnn_attention_9 = None + permute_105 = torch.ops.aten.permute.default(getitem_449, [0, 2, 1, 3]) + view_690 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16) + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 32, '0'); convert_element_type_314 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + view_696 = torch.ops.aten.view.default(view_690, [16384, 512]); view_690 = None + mm_66 = torch.ops.aten.mm.default(view_696, permute_106); view_696 = permute_106 = None + view_697 = torch.ops.aten.view.default(mm_66, [2, 8192, 4096]); mm_66 = None + split_46 = torch.ops.aten.split.Tensor(view_697, 1024, 1); view_697 = None + getitem_458 = split_46[0] + getitem_459 = split_46[1] + getitem_460 = split_46[2] + getitem_461 = split_46[3] + getitem_462 = split_46[4] + getitem_463 = split_46[5] + getitem_464 = split_46[6] + getitem_465 = split_46[7]; split_46 = None + cat_38 = torch.ops.aten.cat.default([getitem_458, getitem_459, getitem_460, getitem_461, getitem_462, getitem_463, getitem_464, getitem_465]); getitem_458 = getitem_459 = getitem_460 = getitem_461 = getitem_462 = getitem_463 = getitem_464 = getitem_465 = None + reduce_scatter_tensor_19 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_38, 'sum', 8, '1'); cat_38 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19) + add_37 = torch.ops.aten.add.Tensor(add_35, wait_tensor_125); wait_tensor_125 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16) + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 32, '0'); convert_element_type_317 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32) + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = rsqrt_19 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_126); mul_76 = wait_tensor_126 = None + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_319, 8, '1'); convert_element_type_319 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + split_47 = torch.ops.aten.split.Tensor(wait_tensor_127, 2); wait_tensor_127 = None + getitem_466 = split_47[0] + getitem_467 = split_47[1] + getitem_468 = split_47[2] + getitem_469 = split_47[3] + getitem_470 = split_47[4] + getitem_471 = split_47[5] + getitem_472 = split_47[6] + getitem_473 = split_47[7]; split_47 = None + cat_39 = torch.ops.aten.cat.default([getitem_466, getitem_467, getitem_468, getitem_469, getitem_470, getitem_471, getitem_472, getitem_473], 1); getitem_466 = getitem_467 = getitem_468 = getitem_469 = getitem_470 = getitem_471 = getitem_472 = getitem_473 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16) + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 32, '0'); convert_element_type_320 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + view_708 = torch.ops.aten.view.default(cat_39, [16384, 4096]); cat_39 = None + mm_67 = torch.ops.aten.mm.default(view_708, permute_107); permute_107 = None + view_709 = torch.ops.aten.view.default(mm_67, [2, 8192, 1792]) + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_709, torch.float32); view_709 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); convert_element_type_323 = sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16) + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 32, '0'); convert_element_type_325 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_68 = torch.ops.aten.mm.default(view_708, permute_108); view_708 = permute_108 = None + view_716 = torch.ops.aten.view.default(mm_68, [2, 8192, 1792]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_716); convert_element_type_324 = view_716 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16) + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 32, '0'); convert_element_type_328 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + view_723 = torch.ops.aten.view.default(mul_79, [16384, 1792]); mul_79 = None + mm_69 = torch.ops.aten.mm.default(view_723, permute_109); view_723 = permute_109 = None + view_724 = torch.ops.aten.view.default(mm_69, [2, 8192, 4096]); mm_69 = None + split_48 = torch.ops.aten.split.Tensor(view_724, 1024, 1); view_724 = None + getitem_474 = split_48[0] + getitem_475 = split_48[1] + getitem_476 = split_48[2] + getitem_477 = split_48[3] + getitem_478 = split_48[4] + getitem_479 = split_48[5] + getitem_480 = split_48[6] + getitem_481 = split_48[7]; split_48 = None + cat_40 = torch.ops.aten.cat.default([getitem_474, getitem_475, getitem_476, getitem_477, getitem_478, getitem_479, getitem_480, getitem_481]); getitem_474 = getitem_475 = getitem_476 = getitem_477 = getitem_478 = getitem_479 = getitem_480 = getitem_481 = None + reduce_scatter_tensor_20 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_40, 'sum', 8, '1'); cat_40 = None + wait_tensor_131 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_20); reduce_scatter_tensor_20 = None + add_39 = torch.ops.aten.add.Tensor(add_37, wait_tensor_131); add_37 = wait_tensor_131 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16) + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 32, '0'); convert_element_type_331 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32) + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = rsqrt_20 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_132); mul_80 = wait_tensor_132 = None + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_333, 8, '1'); convert_element_type_333 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + split_49 = torch.ops.aten.split.Tensor(wait_tensor_133, 2); wait_tensor_133 = None + getitem_482 = split_49[0] + getitem_483 = split_49[1] + getitem_484 = split_49[2] + getitem_485 = split_49[3] + getitem_486 = split_49[4] + getitem_487 = split_49[5] + getitem_488 = split_49[6] + getitem_489 = split_49[7]; split_49 = None + cat_41 = torch.ops.aten.cat.default([getitem_482, getitem_483, getitem_484, getitem_485, getitem_486, getitem_487, getitem_488, getitem_489], 1); getitem_482 = getitem_483 = getitem_484 = getitem_485 = getitem_486 = getitem_487 = getitem_488 = getitem_489 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16) + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 32, '0'); convert_element_type_334 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + view_735 = torch.ops.aten.view.default(cat_41, [16384, 4096]); cat_41 = None + mm_70 = torch.ops.aten.mm.default(view_735, permute_110); permute_110 = None + view_736 = torch.ops.aten.view.default(mm_70, [2, 8192, 512]) + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16) + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 32, '0'); convert_element_type_337 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + mm_71 = torch.ops.aten.mm.default(view_735, permute_111); permute_111 = None + view_743 = torch.ops.aten.view.default(mm_71, [2, 8192, 128]); mm_71 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16) + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 32, '0'); convert_element_type_340 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_136, [1, 0]); wait_tensor_136 = None + mm_72 = torch.ops.aten.mm.default(view_735, permute_112); view_735 = permute_112 = None + view_750 = torch.ops.aten.view.default(mm_72, [2, 8192, 128]) + view_752 = torch.ops.aten.view.default(view_736, [2, 8192, -1, 128]); view_736 = None + view_753 = torch.ops.aten.view.default(view_743, [2, 8192, -1, 128]); view_743 = None + view_754 = torch.ops.aten.view.default(view_750, [2, 8192, -1, 128]); view_750 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_752, torch.float32); view_752 = None + view_755 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 4, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_755); view_755 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_753, torch.float32); view_753 = None + view_756 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 1, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_756); view_756 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_37); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_758 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 4, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_37); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_759 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 1, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_758, torch.bfloat16); view_758 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_759, torch.bfloat16); view_759 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 1, 4, 128]); unsqueeze_20 = None + view_760 = torch.ops.aten.view.default(expand_20, [2, 8192, 4, 128]); expand_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_754, 3); view_754 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 1, 4, 128]); unsqueeze_21 = None + view_761 = torch.ops.aten.view.default(expand_21, [2, 8192, 4, 128]); expand_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_760, [0, 2, 1, 3]); view_760 = None + permute_115 = torch.ops.aten.permute.default(view_761, [0, 2, 1, 3]); view_761 = None + _scaled_dot_product_cudnn_attention_10 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_113, permute_114, permute_115, None, True, 0.0, True); permute_113 = permute_114 = permute_115 = None + getitem_490 = _scaled_dot_product_cudnn_attention_10[0] + getitem_491 = _scaled_dot_product_cudnn_attention_10[1] + getitem_496 = _scaled_dot_product_cudnn_attention_10[6] + getitem_497 = _scaled_dot_product_cudnn_attention_10[7]; _scaled_dot_product_cudnn_attention_10 = None + permute_116 = torch.ops.aten.permute.default(getitem_490, [0, 2, 1, 3]) + view_762 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16) + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 32, '0'); convert_element_type_347 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + view_768 = torch.ops.aten.view.default(view_762, [16384, 512]); view_762 = None + mm_73 = torch.ops.aten.mm.default(view_768, permute_117); view_768 = permute_117 = None + view_769 = torch.ops.aten.view.default(mm_73, [2, 8192, 4096]); mm_73 = None + split_50 = torch.ops.aten.split.Tensor(view_769, 1024, 1); view_769 = None + getitem_499 = split_50[0] + getitem_500 = split_50[1] + getitem_501 = split_50[2] + getitem_502 = split_50[3] + getitem_503 = split_50[4] + getitem_504 = split_50[5] + getitem_505 = split_50[6] + getitem_506 = split_50[7]; split_50 = None + cat_42 = torch.ops.aten.cat.default([getitem_499, getitem_500, getitem_501, getitem_502, getitem_503, getitem_504, getitem_505, getitem_506]); getitem_499 = getitem_500 = getitem_501 = getitem_502 = getitem_503 = getitem_504 = getitem_505 = getitem_506 = None + reduce_scatter_tensor_21 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_42, 'sum', 8, '1'); cat_42 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21) + add_41 = torch.ops.aten.add.Tensor(add_39, wait_tensor_138); wait_tensor_138 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16) + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 32, '0'); convert_element_type_350 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32) + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = rsqrt_21 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_139); mul_84 = wait_tensor_139 = None + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_352, 8, '1'); convert_element_type_352 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + split_51 = torch.ops.aten.split.Tensor(wait_tensor_140, 2); wait_tensor_140 = None + getitem_507 = split_51[0] + getitem_508 = split_51[1] + getitem_509 = split_51[2] + getitem_510 = split_51[3] + getitem_511 = split_51[4] + getitem_512 = split_51[5] + getitem_513 = split_51[6] + getitem_514 = split_51[7]; split_51 = None + cat_43 = torch.ops.aten.cat.default([getitem_507, getitem_508, getitem_509, getitem_510, getitem_511, getitem_512, getitem_513, getitem_514], 1); getitem_507 = getitem_508 = getitem_509 = getitem_510 = getitem_511 = getitem_512 = getitem_513 = getitem_514 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16) + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 32, '0'); convert_element_type_353 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_141, [1, 0]); wait_tensor_141 = None + view_780 = torch.ops.aten.view.default(cat_43, [16384, 4096]); cat_43 = None + mm_74 = torch.ops.aten.mm.default(view_780, permute_118); permute_118 = None + view_781 = torch.ops.aten.view.default(mm_74, [2, 8192, 1792]) + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_781, torch.float32); view_781 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); convert_element_type_356 = sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16) + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 32, '0'); convert_element_type_358 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + mm_75 = torch.ops.aten.mm.default(view_780, permute_119); view_780 = permute_119 = None + view_788 = torch.ops.aten.view.default(mm_75, [2, 8192, 1792]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_788); convert_element_type_357 = view_788 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16) + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 32, '0'); convert_element_type_361 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + view_795 = torch.ops.aten.view.default(mul_87, [16384, 1792]); mul_87 = None + mm_76 = torch.ops.aten.mm.default(view_795, permute_120); view_795 = permute_120 = None + view_796 = torch.ops.aten.view.default(mm_76, [2, 8192, 4096]); mm_76 = None + split_52 = torch.ops.aten.split.Tensor(view_796, 1024, 1); view_796 = None + getitem_515 = split_52[0] + getitem_516 = split_52[1] + getitem_517 = split_52[2] + getitem_518 = split_52[3] + getitem_519 = split_52[4] + getitem_520 = split_52[5] + getitem_521 = split_52[6] + getitem_522 = split_52[7]; split_52 = None + cat_44 = torch.ops.aten.cat.default([getitem_515, getitem_516, getitem_517, getitem_518, getitem_519, getitem_520, getitem_521, getitem_522]); getitem_515 = getitem_516 = getitem_517 = getitem_518 = getitem_519 = getitem_520 = getitem_521 = getitem_522 = None + reduce_scatter_tensor_22 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_44, 'sum', 8, '1'); cat_44 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_22); reduce_scatter_tensor_22 = None + add_43 = torch.ops.aten.add.Tensor(add_41, wait_tensor_144); add_41 = wait_tensor_144 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16) + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 32, '0'); convert_element_type_364 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32) + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = rsqrt_22 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_145); mul_88 = wait_tensor_145 = None + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_366, 8, '1'); convert_element_type_366 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + split_53 = torch.ops.aten.split.Tensor(wait_tensor_146, 2); wait_tensor_146 = None + getitem_523 = split_53[0] + getitem_524 = split_53[1] + getitem_525 = split_53[2] + getitem_526 = split_53[3] + getitem_527 = split_53[4] + getitem_528 = split_53[5] + getitem_529 = split_53[6] + getitem_530 = split_53[7]; split_53 = None + cat_45 = torch.ops.aten.cat.default([getitem_523, getitem_524, getitem_525, getitem_526, getitem_527, getitem_528, getitem_529, getitem_530], 1); getitem_523 = getitem_524 = getitem_525 = getitem_526 = getitem_527 = getitem_528 = getitem_529 = getitem_530 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16) + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 32, '0'); convert_element_type_367 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + view_807 = torch.ops.aten.view.default(cat_45, [16384, 4096]); cat_45 = None + mm_77 = torch.ops.aten.mm.default(view_807, permute_121); permute_121 = None + view_808 = torch.ops.aten.view.default(mm_77, [2, 8192, 512]) + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16) + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 32, '0'); convert_element_type_370 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + mm_78 = torch.ops.aten.mm.default(view_807, permute_122); permute_122 = None + view_815 = torch.ops.aten.view.default(mm_78, [2, 8192, 128]); mm_78 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16) + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 32, '0'); convert_element_type_373 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + mm_79 = torch.ops.aten.mm.default(view_807, permute_123); view_807 = permute_123 = None + view_822 = torch.ops.aten.view.default(mm_79, [2, 8192, 128]) + view_824 = torch.ops.aten.view.default(view_808, [2, 8192, -1, 128]); view_808 = None + view_825 = torch.ops.aten.view.default(view_815, [2, 8192, -1, 128]); view_815 = None + view_826 = torch.ops.aten.view.default(view_822, [2, 8192, -1, 128]); view_822 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_824, torch.float32); view_824 = None + view_827 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 4, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_827); view_827 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_825, torch.float32); view_825 = None + view_828 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 1, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_828); view_828 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_37); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_830 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 4, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_37); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_831 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 1, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_830, torch.bfloat16); view_830 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_831, torch.bfloat16); view_831 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 1, 4, 128]); unsqueeze_22 = None + view_832 = torch.ops.aten.view.default(expand_22, [2, 8192, 4, 128]); expand_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_826, 3); view_826 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 1, 4, 128]); unsqueeze_23 = None + view_833 = torch.ops.aten.view.default(expand_23, [2, 8192, 4, 128]); expand_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_832, [0, 2, 1, 3]); view_832 = None + permute_126 = torch.ops.aten.permute.default(view_833, [0, 2, 1, 3]); view_833 = None + _scaled_dot_product_cudnn_attention_11 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_124, permute_125, permute_126, None, True, 0.0, True); permute_124 = permute_125 = permute_126 = None + getitem_531 = _scaled_dot_product_cudnn_attention_11[0] + getitem_532 = _scaled_dot_product_cudnn_attention_11[1] + getitem_537 = _scaled_dot_product_cudnn_attention_11[6] + getitem_538 = _scaled_dot_product_cudnn_attention_11[7]; _scaled_dot_product_cudnn_attention_11 = None + permute_127 = torch.ops.aten.permute.default(getitem_531, [0, 2, 1, 3]) + view_834 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16) + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 32, '0'); convert_element_type_380 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_150, [1, 0]); wait_tensor_150 = None + view_840 = torch.ops.aten.view.default(view_834, [16384, 512]); view_834 = None + mm_80 = torch.ops.aten.mm.default(view_840, permute_128); view_840 = permute_128 = None + view_841 = torch.ops.aten.view.default(mm_80, [2, 8192, 4096]); mm_80 = None + split_54 = torch.ops.aten.split.Tensor(view_841, 1024, 1); view_841 = None + getitem_540 = split_54[0] + getitem_541 = split_54[1] + getitem_542 = split_54[2] + getitem_543 = split_54[3] + getitem_544 = split_54[4] + getitem_545 = split_54[5] + getitem_546 = split_54[6] + getitem_547 = split_54[7]; split_54 = None + cat_46 = torch.ops.aten.cat.default([getitem_540, getitem_541, getitem_542, getitem_543, getitem_544, getitem_545, getitem_546, getitem_547]); getitem_540 = getitem_541 = getitem_542 = getitem_543 = getitem_544 = getitem_545 = getitem_546 = getitem_547 = None + reduce_scatter_tensor_23 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_46, 'sum', 8, '1'); cat_46 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23) + add_45 = torch.ops.aten.add.Tensor(add_43, wait_tensor_151); wait_tensor_151 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16) + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 32, '0'); convert_element_type_383 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32) + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = rsqrt_23 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_152); mul_92 = wait_tensor_152 = None + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_385, 8, '1'); convert_element_type_385 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + split_55 = torch.ops.aten.split.Tensor(wait_tensor_153, 2); wait_tensor_153 = None + getitem_548 = split_55[0] + getitem_549 = split_55[1] + getitem_550 = split_55[2] + getitem_551 = split_55[3] + getitem_552 = split_55[4] + getitem_553 = split_55[5] + getitem_554 = split_55[6] + getitem_555 = split_55[7]; split_55 = None + cat_47 = torch.ops.aten.cat.default([getitem_548, getitem_549, getitem_550, getitem_551, getitem_552, getitem_553, getitem_554, getitem_555], 1); getitem_548 = getitem_549 = getitem_550 = getitem_551 = getitem_552 = getitem_553 = getitem_554 = getitem_555 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16) + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 32, '0'); convert_element_type_386 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_154, [1, 0]); wait_tensor_154 = None + view_852 = torch.ops.aten.view.default(cat_47, [16384, 4096]); cat_47 = None + mm_81 = torch.ops.aten.mm.default(view_852, permute_129); permute_129 = None + view_853 = torch.ops.aten.view.default(mm_81, [2, 8192, 1792]) + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_853, torch.float32); view_853 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); convert_element_type_389 = sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16) + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 32, '0'); convert_element_type_391 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + mm_82 = torch.ops.aten.mm.default(view_852, permute_130); view_852 = permute_130 = None + view_860 = torch.ops.aten.view.default(mm_82, [2, 8192, 1792]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_860); convert_element_type_390 = view_860 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16) + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 32, '0'); convert_element_type_394 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + view_867 = torch.ops.aten.view.default(mul_95, [16384, 1792]); mul_95 = None + mm_83 = torch.ops.aten.mm.default(view_867, permute_131); view_867 = permute_131 = None + view_868 = torch.ops.aten.view.default(mm_83, [2, 8192, 4096]); mm_83 = None + split_56 = torch.ops.aten.split.Tensor(view_868, 1024, 1); view_868 = None + getitem_556 = split_56[0] + getitem_557 = split_56[1] + getitem_558 = split_56[2] + getitem_559 = split_56[3] + getitem_560 = split_56[4] + getitem_561 = split_56[5] + getitem_562 = split_56[6] + getitem_563 = split_56[7]; split_56 = None + cat_48 = torch.ops.aten.cat.default([getitem_556, getitem_557, getitem_558, getitem_559, getitem_560, getitem_561, getitem_562, getitem_563]); getitem_556 = getitem_557 = getitem_558 = getitem_559 = getitem_560 = getitem_561 = getitem_562 = getitem_563 = None + reduce_scatter_tensor_24 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_48, 'sum', 8, '1'); cat_48 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_24); reduce_scatter_tensor_24 = None + add_47 = torch.ops.aten.add.Tensor(add_45, wait_tensor_157); add_45 = wait_tensor_157 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16) + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 32, '0'); convert_element_type_397 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32) + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = rsqrt_24 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_158); mul_96 = wait_tensor_158 = None + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_399, 8, '1'); convert_element_type_399 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + split_57 = torch.ops.aten.split.Tensor(wait_tensor_159, 2); wait_tensor_159 = None + getitem_564 = split_57[0] + getitem_565 = split_57[1] + getitem_566 = split_57[2] + getitem_567 = split_57[3] + getitem_568 = split_57[4] + getitem_569 = split_57[5] + getitem_570 = split_57[6] + getitem_571 = split_57[7]; split_57 = None + cat_49 = torch.ops.aten.cat.default([getitem_564, getitem_565, getitem_566, getitem_567, getitem_568, getitem_569, getitem_570, getitem_571], 1); getitem_564 = getitem_565 = getitem_566 = getitem_567 = getitem_568 = getitem_569 = getitem_570 = getitem_571 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16) + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 32, '0'); convert_element_type_400 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + view_879 = torch.ops.aten.view.default(cat_49, [16384, 4096]); cat_49 = None + mm_84 = torch.ops.aten.mm.default(view_879, permute_132); permute_132 = None + view_880 = torch.ops.aten.view.default(mm_84, [2, 8192, 512]) + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16) + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 32, '0'); convert_element_type_403 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_85 = torch.ops.aten.mm.default(view_879, permute_133); permute_133 = None + view_887 = torch.ops.aten.view.default(mm_85, [2, 8192, 128]); mm_85 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16) + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 32, '0'); convert_element_type_406 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + mm_86 = torch.ops.aten.mm.default(view_879, permute_134); view_879 = permute_134 = None + view_894 = torch.ops.aten.view.default(mm_86, [2, 8192, 128]) + view_896 = torch.ops.aten.view.default(view_880, [2, 8192, -1, 128]); view_880 = None + view_897 = torch.ops.aten.view.default(view_887, [2, 8192, -1, 128]); view_887 = None + view_898 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 4, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_897, torch.float32); view_897 = None + view_900 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 1, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_900); view_900 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_37); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_902 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 4, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_37); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_903 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 1, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_903, torch.bfloat16); view_903 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 1, 4, 128]); unsqueeze_24 = None + view_904 = torch.ops.aten.view.default(expand_24, [2, 8192, 4, 128]); expand_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_898, 3); view_898 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 1, 4, 128]); unsqueeze_25 = None + view_905 = torch.ops.aten.view.default(expand_25, [2, 8192, 4, 128]); expand_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + permute_137 = torch.ops.aten.permute.default(view_905, [0, 2, 1, 3]); view_905 = None + _scaled_dot_product_cudnn_attention_12 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_135, permute_136, permute_137, None, True, 0.0, True); permute_135 = permute_136 = permute_137 = None + getitem_572 = _scaled_dot_product_cudnn_attention_12[0] + getitem_573 = _scaled_dot_product_cudnn_attention_12[1] + getitem_578 = _scaled_dot_product_cudnn_attention_12[6] + getitem_579 = _scaled_dot_product_cudnn_attention_12[7]; _scaled_dot_product_cudnn_attention_12 = None + permute_138 = torch.ops.aten.permute.default(getitem_572, [0, 2, 1, 3]) + view_906 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16) + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 32, '0'); convert_element_type_413 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_163, [1, 0]); wait_tensor_163 = None + view_912 = torch.ops.aten.view.default(view_906, [16384, 512]); view_906 = None + mm_87 = torch.ops.aten.mm.default(view_912, permute_139); view_912 = permute_139 = None + view_913 = torch.ops.aten.view.default(mm_87, [2, 8192, 4096]); mm_87 = None + split_58 = torch.ops.aten.split.Tensor(view_913, 1024, 1); view_913 = None + getitem_581 = split_58[0] + getitem_582 = split_58[1] + getitem_583 = split_58[2] + getitem_584 = split_58[3] + getitem_585 = split_58[4] + getitem_586 = split_58[5] + getitem_587 = split_58[6] + getitem_588 = split_58[7]; split_58 = None + cat_50 = torch.ops.aten.cat.default([getitem_581, getitem_582, getitem_583, getitem_584, getitem_585, getitem_586, getitem_587, getitem_588]); getitem_581 = getitem_582 = getitem_583 = getitem_584 = getitem_585 = getitem_586 = getitem_587 = getitem_588 = None + reduce_scatter_tensor_25 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_50, 'sum', 8, '1'); cat_50 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25) + add_49 = torch.ops.aten.add.Tensor(add_47, wait_tensor_164); wait_tensor_164 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16) + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 32, '0'); convert_element_type_416 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32) + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = rsqrt_25 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_165); mul_100 = wait_tensor_165 = None + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_418, 8, '1'); convert_element_type_418 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + split_59 = torch.ops.aten.split.Tensor(wait_tensor_166, 2); wait_tensor_166 = None + getitem_589 = split_59[0] + getitem_590 = split_59[1] + getitem_591 = split_59[2] + getitem_592 = split_59[3] + getitem_593 = split_59[4] + getitem_594 = split_59[5] + getitem_595 = split_59[6] + getitem_596 = split_59[7]; split_59 = None + cat_51 = torch.ops.aten.cat.default([getitem_589, getitem_590, getitem_591, getitem_592, getitem_593, getitem_594, getitem_595, getitem_596], 1); getitem_589 = getitem_590 = getitem_591 = getitem_592 = getitem_593 = getitem_594 = getitem_595 = getitem_596 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16) + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 32, '0'); convert_element_type_419 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + view_924 = torch.ops.aten.view.default(cat_51, [16384, 4096]); cat_51 = None + mm_88 = torch.ops.aten.mm.default(view_924, permute_140); permute_140 = None + view_925 = torch.ops.aten.view.default(mm_88, [2, 8192, 1792]) + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_925, torch.float32); view_925 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); convert_element_type_422 = sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16) + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 32, '0'); convert_element_type_424 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_168, [1, 0]); wait_tensor_168 = None + mm_89 = torch.ops.aten.mm.default(view_924, permute_141); view_924 = permute_141 = None + view_932 = torch.ops.aten.view.default(mm_89, [2, 8192, 1792]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_932); convert_element_type_423 = view_932 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16) + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 32, '0'); convert_element_type_427 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + view_939 = torch.ops.aten.view.default(mul_103, [16384, 1792]); mul_103 = None + mm_90 = torch.ops.aten.mm.default(view_939, permute_142); view_939 = permute_142 = None + view_940 = torch.ops.aten.view.default(mm_90, [2, 8192, 4096]); mm_90 = None + split_60 = torch.ops.aten.split.Tensor(view_940, 1024, 1); view_940 = None + getitem_597 = split_60[0] + getitem_598 = split_60[1] + getitem_599 = split_60[2] + getitem_600 = split_60[3] + getitem_601 = split_60[4] + getitem_602 = split_60[5] + getitem_603 = split_60[6] + getitem_604 = split_60[7]; split_60 = None + cat_52 = torch.ops.aten.cat.default([getitem_597, getitem_598, getitem_599, getitem_600, getitem_601, getitem_602, getitem_603, getitem_604]); getitem_597 = getitem_598 = getitem_599 = getitem_600 = getitem_601 = getitem_602 = getitem_603 = getitem_604 = None + reduce_scatter_tensor_26 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_52, 'sum', 8, '1'); cat_52 = None + wait_tensor_170 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_26); reduce_scatter_tensor_26 = None + add_51 = torch.ops.aten.add.Tensor(add_49, wait_tensor_170); add_49 = wait_tensor_170 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16) + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 32, '0'); convert_element_type_430 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32) + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = rsqrt_26 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_171); mul_104 = wait_tensor_171 = None + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_432, 8, '1'); convert_element_type_432 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + split_61 = torch.ops.aten.split.Tensor(wait_tensor_172, 2); wait_tensor_172 = None + getitem_605 = split_61[0] + getitem_606 = split_61[1] + getitem_607 = split_61[2] + getitem_608 = split_61[3] + getitem_609 = split_61[4] + getitem_610 = split_61[5] + getitem_611 = split_61[6] + getitem_612 = split_61[7]; split_61 = None + cat_53 = torch.ops.aten.cat.default([getitem_605, getitem_606, getitem_607, getitem_608, getitem_609, getitem_610, getitem_611, getitem_612], 1); getitem_605 = getitem_606 = getitem_607 = getitem_608 = getitem_609 = getitem_610 = getitem_611 = getitem_612 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16) + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 32, '0'); convert_element_type_433 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + view_951 = torch.ops.aten.view.default(cat_53, [16384, 4096]); cat_53 = None + mm_91 = torch.ops.aten.mm.default(view_951, permute_143); permute_143 = None + view_952 = torch.ops.aten.view.default(mm_91, [2, 8192, 512]) + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16) + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 32, '0'); convert_element_type_436 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_92 = torch.ops.aten.mm.default(view_951, permute_144); permute_144 = None + view_959 = torch.ops.aten.view.default(mm_92, [2, 8192, 128]); mm_92 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16) + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 32, '0'); convert_element_type_439 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + mm_93 = torch.ops.aten.mm.default(view_951, permute_145); view_951 = permute_145 = None + view_966 = torch.ops.aten.view.default(mm_93, [2, 8192, 128]) + view_968 = torch.ops.aten.view.default(view_952, [2, 8192, -1, 128]); view_952 = None + view_969 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_970 = torch.ops.aten.view.default(view_966, [2, 8192, -1, 128]); view_966 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_968, torch.float32); view_968 = None + view_971 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 4, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_971); view_971 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_969, torch.float32); view_969 = None + view_972 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 1, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_972); view_972 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_37); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_974 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 4, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_37); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_975 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 1, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_974, torch.bfloat16); view_974 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_975, torch.bfloat16); view_975 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 1, 4, 128]); unsqueeze_26 = None + view_976 = torch.ops.aten.view.default(expand_26, [2, 8192, 4, 128]); expand_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_970, 3); view_970 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 1, 4, 128]); unsqueeze_27 = None + view_977 = torch.ops.aten.view.default(expand_27, [2, 8192, 4, 128]); expand_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_976, [0, 2, 1, 3]); view_976 = None + permute_148 = torch.ops.aten.permute.default(view_977, [0, 2, 1, 3]); view_977 = None + _scaled_dot_product_cudnn_attention_13 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_146, permute_147, permute_148, None, True, 0.0, True); permute_146 = permute_147 = permute_148 = None + getitem_613 = _scaled_dot_product_cudnn_attention_13[0] + getitem_614 = _scaled_dot_product_cudnn_attention_13[1] + getitem_619 = _scaled_dot_product_cudnn_attention_13[6] + getitem_620 = _scaled_dot_product_cudnn_attention_13[7]; _scaled_dot_product_cudnn_attention_13 = None + permute_149 = torch.ops.aten.permute.default(getitem_613, [0, 2, 1, 3]) + view_978 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16) + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 32, '0'); convert_element_type_446 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + view_984 = torch.ops.aten.view.default(view_978, [16384, 512]); view_978 = None + mm_94 = torch.ops.aten.mm.default(view_984, permute_150); view_984 = permute_150 = None + view_985 = torch.ops.aten.view.default(mm_94, [2, 8192, 4096]); mm_94 = None + split_62 = torch.ops.aten.split.Tensor(view_985, 1024, 1); view_985 = None + getitem_622 = split_62[0] + getitem_623 = split_62[1] + getitem_624 = split_62[2] + getitem_625 = split_62[3] + getitem_626 = split_62[4] + getitem_627 = split_62[5] + getitem_628 = split_62[6] + getitem_629 = split_62[7]; split_62 = None + cat_54 = torch.ops.aten.cat.default([getitem_622, getitem_623, getitem_624, getitem_625, getitem_626, getitem_627, getitem_628, getitem_629]); getitem_622 = getitem_623 = getitem_624 = getitem_625 = getitem_626 = getitem_627 = getitem_628 = getitem_629 = None + reduce_scatter_tensor_27 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_54, 'sum', 8, '1'); cat_54 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27) + add_53 = torch.ops.aten.add.Tensor(add_51, wait_tensor_177); wait_tensor_177 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16) + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 32, '0'); convert_element_type_449 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32) + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = rsqrt_27 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_178); mul_108 = wait_tensor_178 = None + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_451, 8, '1'); convert_element_type_451 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + split_63 = torch.ops.aten.split.Tensor(wait_tensor_179, 2); wait_tensor_179 = None + getitem_630 = split_63[0] + getitem_631 = split_63[1] + getitem_632 = split_63[2] + getitem_633 = split_63[3] + getitem_634 = split_63[4] + getitem_635 = split_63[5] + getitem_636 = split_63[6] + getitem_637 = split_63[7]; split_63 = None + cat_55 = torch.ops.aten.cat.default([getitem_630, getitem_631, getitem_632, getitem_633, getitem_634, getitem_635, getitem_636, getitem_637], 1); getitem_630 = getitem_631 = getitem_632 = getitem_633 = getitem_634 = getitem_635 = getitem_636 = getitem_637 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16) + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 32, '0'); convert_element_type_452 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + view_996 = torch.ops.aten.view.default(cat_55, [16384, 4096]); cat_55 = None + mm_95 = torch.ops.aten.mm.default(view_996, permute_151); permute_151 = None + view_997 = torch.ops.aten.view.default(mm_95, [2, 8192, 1792]) + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); convert_element_type_455 = sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16) + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 32, '0'); convert_element_type_457 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_181, [1, 0]); wait_tensor_181 = None + mm_96 = torch.ops.aten.mm.default(view_996, permute_152); view_996 = permute_152 = None + view_1004 = torch.ops.aten.view.default(mm_96, [2, 8192, 1792]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_1004); convert_element_type_456 = view_1004 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16) + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 32, '0'); convert_element_type_460 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + view_1011 = torch.ops.aten.view.default(mul_111, [16384, 1792]); mul_111 = None + mm_97 = torch.ops.aten.mm.default(view_1011, permute_153); view_1011 = permute_153 = None + view_1012 = torch.ops.aten.view.default(mm_97, [2, 8192, 4096]); mm_97 = None + split_64 = torch.ops.aten.split.Tensor(view_1012, 1024, 1); view_1012 = None + getitem_638 = split_64[0] + getitem_639 = split_64[1] + getitem_640 = split_64[2] + getitem_641 = split_64[3] + getitem_642 = split_64[4] + getitem_643 = split_64[5] + getitem_644 = split_64[6] + getitem_645 = split_64[7]; split_64 = None + cat_56 = torch.ops.aten.cat.default([getitem_638, getitem_639, getitem_640, getitem_641, getitem_642, getitem_643, getitem_644, getitem_645]); getitem_638 = getitem_639 = getitem_640 = getitem_641 = getitem_642 = getitem_643 = getitem_644 = getitem_645 = None + reduce_scatter_tensor_28 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_56, 'sum', 8, '1'); cat_56 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_28); reduce_scatter_tensor_28 = None + add_55 = torch.ops.aten.add.Tensor(add_53, wait_tensor_183); add_53 = wait_tensor_183 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16) + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 32, '0'); convert_element_type_463 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32) + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = rsqrt_28 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_184); mul_112 = wait_tensor_184 = None + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_465, 8, '1'); convert_element_type_465 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + split_65 = torch.ops.aten.split.Tensor(wait_tensor_185, 2); wait_tensor_185 = None + getitem_646 = split_65[0] + getitem_647 = split_65[1] + getitem_648 = split_65[2] + getitem_649 = split_65[3] + getitem_650 = split_65[4] + getitem_651 = split_65[5] + getitem_652 = split_65[6] + getitem_653 = split_65[7]; split_65 = None + cat_57 = torch.ops.aten.cat.default([getitem_646, getitem_647, getitem_648, getitem_649, getitem_650, getitem_651, getitem_652, getitem_653], 1); getitem_646 = getitem_647 = getitem_648 = getitem_649 = getitem_650 = getitem_651 = getitem_652 = getitem_653 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16) + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 32, '0'); convert_element_type_466 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_186, [1, 0]); wait_tensor_186 = None + view_1023 = torch.ops.aten.view.default(cat_57, [16384, 4096]); cat_57 = None + mm_98 = torch.ops.aten.mm.default(view_1023, permute_154); permute_154 = None + view_1024 = torch.ops.aten.view.default(mm_98, [2, 8192, 512]) + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16) + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 32, '0'); convert_element_type_469 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + mm_99 = torch.ops.aten.mm.default(view_1023, permute_155); permute_155 = None + view_1031 = torch.ops.aten.view.default(mm_99, [2, 8192, 128]); mm_99 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16) + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 32, '0'); convert_element_type_472 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + mm_100 = torch.ops.aten.mm.default(view_1023, permute_156); view_1023 = permute_156 = None + view_1038 = torch.ops.aten.view.default(mm_100, [2, 8192, 128]) + view_1040 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1041 = torch.ops.aten.view.default(view_1031, [2, 8192, -1, 128]); view_1031 = None + view_1042 = torch.ops.aten.view.default(view_1038, [2, 8192, -1, 128]); view_1038 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_1040, torch.float32); view_1040 = None + view_1043 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 4, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_1043); view_1043 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_1041, torch.float32); view_1041 = None + view_1044 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 1, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_1044); view_1044 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_37); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_1046 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 4, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_37); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_1047 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 1, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_1046, torch.bfloat16); view_1046 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_1047, torch.bfloat16); view_1047 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 1, 4, 128]); unsqueeze_28 = None + view_1048 = torch.ops.aten.view.default(expand_28, [2, 8192, 4, 128]); expand_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_1042, 3); view_1042 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 1, 4, 128]); unsqueeze_29 = None + view_1049 = torch.ops.aten.view.default(expand_29, [2, 8192, 4, 128]); expand_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_1048, [0, 2, 1, 3]); view_1048 = None + permute_159 = torch.ops.aten.permute.default(view_1049, [0, 2, 1, 3]); view_1049 = None + _scaled_dot_product_cudnn_attention_14 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_157, permute_158, permute_159, None, True, 0.0, True); permute_157 = permute_158 = permute_159 = None + getitem_654 = _scaled_dot_product_cudnn_attention_14[0] + getitem_655 = _scaled_dot_product_cudnn_attention_14[1] + getitem_660 = _scaled_dot_product_cudnn_attention_14[6] + getitem_661 = _scaled_dot_product_cudnn_attention_14[7]; _scaled_dot_product_cudnn_attention_14 = None + permute_160 = torch.ops.aten.permute.default(getitem_654, [0, 2, 1, 3]) + view_1050 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16) + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 32, '0'); convert_element_type_479 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + view_1056 = torch.ops.aten.view.default(view_1050, [16384, 512]); view_1050 = None + mm_101 = torch.ops.aten.mm.default(view_1056, permute_161); view_1056 = permute_161 = None + view_1057 = torch.ops.aten.view.default(mm_101, [2, 8192, 4096]); mm_101 = None + split_66 = torch.ops.aten.split.Tensor(view_1057, 1024, 1); view_1057 = None + getitem_663 = split_66[0] + getitem_664 = split_66[1] + getitem_665 = split_66[2] + getitem_666 = split_66[3] + getitem_667 = split_66[4] + getitem_668 = split_66[5] + getitem_669 = split_66[6] + getitem_670 = split_66[7]; split_66 = None + cat_58 = torch.ops.aten.cat.default([getitem_663, getitem_664, getitem_665, getitem_666, getitem_667, getitem_668, getitem_669, getitem_670]); getitem_663 = getitem_664 = getitem_665 = getitem_666 = getitem_667 = getitem_668 = getitem_669 = getitem_670 = None + reduce_scatter_tensor_29 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_58, 'sum', 8, '1'); cat_58 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29) + add_57 = torch.ops.aten.add.Tensor(add_55, wait_tensor_190); wait_tensor_190 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16) + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 32, '0'); convert_element_type_482 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32) + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = rsqrt_29 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_191); mul_116 = wait_tensor_191 = None + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_484, 8, '1'); convert_element_type_484 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + split_67 = torch.ops.aten.split.Tensor(wait_tensor_192, 2); wait_tensor_192 = None + getitem_671 = split_67[0] + getitem_672 = split_67[1] + getitem_673 = split_67[2] + getitem_674 = split_67[3] + getitem_675 = split_67[4] + getitem_676 = split_67[5] + getitem_677 = split_67[6] + getitem_678 = split_67[7]; split_67 = None + cat_59 = torch.ops.aten.cat.default([getitem_671, getitem_672, getitem_673, getitem_674, getitem_675, getitem_676, getitem_677, getitem_678], 1); getitem_671 = getitem_672 = getitem_673 = getitem_674 = getitem_675 = getitem_676 = getitem_677 = getitem_678 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16) + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 32, '0'); convert_element_type_485 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + view_1068 = torch.ops.aten.view.default(cat_59, [16384, 4096]); cat_59 = None + mm_102 = torch.ops.aten.mm.default(view_1068, permute_162); permute_162 = None + view_1069 = torch.ops.aten.view.default(mm_102, [2, 8192, 1792]) + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_1069, torch.float32); view_1069 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); convert_element_type_488 = sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16) + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 32, '0'); convert_element_type_490 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + mm_103 = torch.ops.aten.mm.default(view_1068, permute_163); view_1068 = permute_163 = None + view_1076 = torch.ops.aten.view.default(mm_103, [2, 8192, 1792]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_1076); convert_element_type_489 = view_1076 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16) + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 32, '0'); convert_element_type_493 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_195, [1, 0]); wait_tensor_195 = None + view_1083 = torch.ops.aten.view.default(mul_119, [16384, 1792]); mul_119 = None + mm_104 = torch.ops.aten.mm.default(view_1083, permute_164); view_1083 = permute_164 = None + view_1084 = torch.ops.aten.view.default(mm_104, [2, 8192, 4096]); mm_104 = None + split_68 = torch.ops.aten.split.Tensor(view_1084, 1024, 1); view_1084 = None + getitem_679 = split_68[0] + getitem_680 = split_68[1] + getitem_681 = split_68[2] + getitem_682 = split_68[3] + getitem_683 = split_68[4] + getitem_684 = split_68[5] + getitem_685 = split_68[6] + getitem_686 = split_68[7]; split_68 = None + cat_60 = torch.ops.aten.cat.default([getitem_679, getitem_680, getitem_681, getitem_682, getitem_683, getitem_684, getitem_685, getitem_686]); getitem_679 = getitem_680 = getitem_681 = getitem_682 = getitem_683 = getitem_684 = getitem_685 = getitem_686 = None + reduce_scatter_tensor_30 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_60, 'sum', 8, '1'); cat_60 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_30); reduce_scatter_tensor_30 = None + add_59 = torch.ops.aten.add.Tensor(add_57, wait_tensor_196); add_57 = wait_tensor_196 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16) + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 32, '0'); convert_element_type_496 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32) + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = rsqrt_30 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_197); mul_120 = wait_tensor_197 = None + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_498, 8, '1'); convert_element_type_498 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + split_69 = torch.ops.aten.split.Tensor(wait_tensor_198, 2); wait_tensor_198 = None + getitem_687 = split_69[0] + getitem_688 = split_69[1] + getitem_689 = split_69[2] + getitem_690 = split_69[3] + getitem_691 = split_69[4] + getitem_692 = split_69[5] + getitem_693 = split_69[6] + getitem_694 = split_69[7]; split_69 = None + cat_61 = torch.ops.aten.cat.default([getitem_687, getitem_688, getitem_689, getitem_690, getitem_691, getitem_692, getitem_693, getitem_694], 1); getitem_687 = getitem_688 = getitem_689 = getitem_690 = getitem_691 = getitem_692 = getitem_693 = getitem_694 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16) + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 32, '0'); convert_element_type_499 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_199, [1, 0]); wait_tensor_199 = None + view_1095 = torch.ops.aten.view.default(cat_61, [16384, 4096]); cat_61 = None + mm_105 = torch.ops.aten.mm.default(view_1095, permute_165); permute_165 = None + view_1096 = torch.ops.aten.view.default(mm_105, [2, 8192, 512]) + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16) + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 32, '0'); convert_element_type_502 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + mm_106 = torch.ops.aten.mm.default(view_1095, permute_166); permute_166 = None + view_1103 = torch.ops.aten.view.default(mm_106, [2, 8192, 128]); mm_106 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16) + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 32, '0'); convert_element_type_505 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + mm_107 = torch.ops.aten.mm.default(view_1095, permute_167); view_1095 = permute_167 = None + view_1110 = torch.ops.aten.view.default(mm_107, [2, 8192, 128]) + view_1112 = torch.ops.aten.view.default(view_1096, [2, 8192, -1, 128]); view_1096 = None + view_1113 = torch.ops.aten.view.default(view_1103, [2, 8192, -1, 128]); view_1103 = None + view_1114 = torch.ops.aten.view.default(view_1110, [2, 8192, -1, 128]); view_1110 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_1112, torch.float32); view_1112 = None + view_1115 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 4, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_1115); view_1115 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_1113, torch.float32); view_1113 = None + view_1116 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 1, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_1116); view_1116 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_37); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_1118 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 4, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_37); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_1119 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 1, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_1118, torch.bfloat16); view_1118 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_1119, torch.bfloat16); view_1119 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 1, 4, 128]); unsqueeze_30 = None + view_1120 = torch.ops.aten.view.default(expand_30, [2, 8192, 4, 128]); expand_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_1114, 3); view_1114 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 1, 4, 128]); unsqueeze_31 = None + view_1121 = torch.ops.aten.view.default(expand_31, [2, 8192, 4, 128]); expand_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_1120, [0, 2, 1, 3]); view_1120 = None + permute_170 = torch.ops.aten.permute.default(view_1121, [0, 2, 1, 3]); view_1121 = None + _scaled_dot_product_cudnn_attention_15 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_168, permute_169, permute_170, None, True, 0.0, True); permute_168 = permute_169 = permute_170 = None + getitem_695 = _scaled_dot_product_cudnn_attention_15[0] + getitem_696 = _scaled_dot_product_cudnn_attention_15[1] + getitem_701 = _scaled_dot_product_cudnn_attention_15[6] + getitem_702 = _scaled_dot_product_cudnn_attention_15[7]; _scaled_dot_product_cudnn_attention_15 = None + permute_171 = torch.ops.aten.permute.default(getitem_695, [0, 2, 1, 3]) + view_1122 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16) + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 32, '0'); convert_element_type_512 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + view_1128 = torch.ops.aten.view.default(view_1122, [16384, 512]); view_1122 = None + mm_108 = torch.ops.aten.mm.default(view_1128, permute_172); view_1128 = permute_172 = None + view_1129 = torch.ops.aten.view.default(mm_108, [2, 8192, 4096]); mm_108 = None + split_70 = torch.ops.aten.split.Tensor(view_1129, 1024, 1); view_1129 = None + getitem_704 = split_70[0] + getitem_705 = split_70[1] + getitem_706 = split_70[2] + getitem_707 = split_70[3] + getitem_708 = split_70[4] + getitem_709 = split_70[5] + getitem_710 = split_70[6] + getitem_711 = split_70[7]; split_70 = None + cat_62 = torch.ops.aten.cat.default([getitem_704, getitem_705, getitem_706, getitem_707, getitem_708, getitem_709, getitem_710, getitem_711]); getitem_704 = getitem_705 = getitem_706 = getitem_707 = getitem_708 = getitem_709 = getitem_710 = getitem_711 = None + reduce_scatter_tensor_31 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_62, 'sum', 8, '1'); cat_62 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31) + add_61 = torch.ops.aten.add.Tensor(add_59, wait_tensor_203); wait_tensor_203 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16) + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 32, '0'); convert_element_type_515 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32) + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = rsqrt_31 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_204); mul_124 = wait_tensor_204 = None + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_517, 8, '1'); convert_element_type_517 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + split_71 = torch.ops.aten.split.Tensor(wait_tensor_205, 2); wait_tensor_205 = None + getitem_712 = split_71[0] + getitem_713 = split_71[1] + getitem_714 = split_71[2] + getitem_715 = split_71[3] + getitem_716 = split_71[4] + getitem_717 = split_71[5] + getitem_718 = split_71[6] + getitem_719 = split_71[7]; split_71 = None + cat_63 = torch.ops.aten.cat.default([getitem_712, getitem_713, getitem_714, getitem_715, getitem_716, getitem_717, getitem_718, getitem_719], 1); getitem_712 = getitem_713 = getitem_714 = getitem_715 = getitem_716 = getitem_717 = getitem_718 = getitem_719 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16) + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 32, '0'); convert_element_type_518 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + view_1140 = torch.ops.aten.view.default(cat_63, [16384, 4096]); cat_63 = None + mm_109 = torch.ops.aten.mm.default(view_1140, permute_173); permute_173 = None + view_1141 = torch.ops.aten.view.default(mm_109, [2, 8192, 1792]) + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_1141, torch.float32); view_1141 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); convert_element_type_521 = sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16) + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 32, '0'); convert_element_type_523 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + mm_110 = torch.ops.aten.mm.default(view_1140, permute_174); view_1140 = permute_174 = None + view_1148 = torch.ops.aten.view.default(mm_110, [2, 8192, 1792]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_1148); convert_element_type_522 = view_1148 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16) + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 32, '0'); convert_element_type_526 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_208, [1, 0]); wait_tensor_208 = None + view_1155 = torch.ops.aten.view.default(mul_127, [16384, 1792]); mul_127 = None + mm_111 = torch.ops.aten.mm.default(view_1155, permute_175); view_1155 = permute_175 = None + view_1156 = torch.ops.aten.view.default(mm_111, [2, 8192, 4096]); mm_111 = None + split_72 = torch.ops.aten.split.Tensor(view_1156, 1024, 1); view_1156 = None + getitem_720 = split_72[0] + getitem_721 = split_72[1] + getitem_722 = split_72[2] + getitem_723 = split_72[3] + getitem_724 = split_72[4] + getitem_725 = split_72[5] + getitem_726 = split_72[6] + getitem_727 = split_72[7]; split_72 = None + cat_64 = torch.ops.aten.cat.default([getitem_720, getitem_721, getitem_722, getitem_723, getitem_724, getitem_725, getitem_726, getitem_727]); getitem_720 = getitem_721 = getitem_722 = getitem_723 = getitem_724 = getitem_725 = getitem_726 = getitem_727 = None + reduce_scatter_tensor_32 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_64, 'sum', 8, '1'); cat_64 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_32); reduce_scatter_tensor_32 = None + add_63 = torch.ops.aten.add.Tensor(add_61, wait_tensor_209); add_61 = wait_tensor_209 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16) + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 32, '0'); convert_element_type_529 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32) + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = rsqrt_32 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_210); mul_128 = wait_tensor_210 = None + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_531, 8, '1'); convert_element_type_531 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + split_73 = torch.ops.aten.split.Tensor(wait_tensor_211, 2); wait_tensor_211 = None + getitem_728 = split_73[0] + getitem_729 = split_73[1] + getitem_730 = split_73[2] + getitem_731 = split_73[3] + getitem_732 = split_73[4] + getitem_733 = split_73[5] + getitem_734 = split_73[6] + getitem_735 = split_73[7]; split_73 = None + cat_65 = torch.ops.aten.cat.default([getitem_728, getitem_729, getitem_730, getitem_731, getitem_732, getitem_733, getitem_734, getitem_735], 1); getitem_728 = getitem_729 = getitem_730 = getitem_731 = getitem_732 = getitem_733 = getitem_734 = getitem_735 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16) + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 32, '0'); convert_element_type_532 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + view_1167 = torch.ops.aten.view.default(cat_65, [16384, 4096]); cat_65 = None + mm_112 = torch.ops.aten.mm.default(view_1167, permute_176); permute_176 = None + view_1168 = torch.ops.aten.view.default(mm_112, [2, 8192, 512]) + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16) + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 32, '0'); convert_element_type_535 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_213, [1, 0]); wait_tensor_213 = None + mm_113 = torch.ops.aten.mm.default(view_1167, permute_177); permute_177 = None + view_1175 = torch.ops.aten.view.default(mm_113, [2, 8192, 128]); mm_113 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16) + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 32, '0'); convert_element_type_538 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + mm_114 = torch.ops.aten.mm.default(view_1167, permute_178); view_1167 = permute_178 = None + view_1182 = torch.ops.aten.view.default(mm_114, [2, 8192, 128]) + view_1184 = torch.ops.aten.view.default(view_1168, [2, 8192, -1, 128]); view_1168 = None + view_1185 = torch.ops.aten.view.default(view_1175, [2, 8192, -1, 128]); view_1175 = None + view_1186 = torch.ops.aten.view.default(view_1182, [2, 8192, -1, 128]); view_1182 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_1184, torch.float32); view_1184 = None + view_1187 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 4, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_1187); view_1187 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_1185, torch.float32); view_1185 = None + view_1188 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 1, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_1188); view_1188 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_37); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_1190 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 4, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_37); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_1191 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 1, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_1190, torch.bfloat16); view_1190 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_1191, torch.bfloat16); view_1191 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 1, 4, 128]); unsqueeze_32 = None + view_1192 = torch.ops.aten.view.default(expand_32, [2, 8192, 4, 128]); expand_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_1186, 3); view_1186 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 1, 4, 128]); unsqueeze_33 = None + view_1193 = torch.ops.aten.view.default(expand_33, [2, 8192, 4, 128]); expand_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_1192, [0, 2, 1, 3]); view_1192 = None + permute_181 = torch.ops.aten.permute.default(view_1193, [0, 2, 1, 3]); view_1193 = None + _scaled_dot_product_cudnn_attention_16 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_179, permute_180, permute_181, None, True, 0.0, True); permute_179 = permute_180 = permute_181 = None + getitem_736 = _scaled_dot_product_cudnn_attention_16[0] + getitem_737 = _scaled_dot_product_cudnn_attention_16[1] + getitem_742 = _scaled_dot_product_cudnn_attention_16[6] + getitem_743 = _scaled_dot_product_cudnn_attention_16[7]; _scaled_dot_product_cudnn_attention_16 = None + permute_182 = torch.ops.aten.permute.default(getitem_736, [0, 2, 1, 3]) + view_1194 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16) + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 32, '0'); convert_element_type_545 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + view_1200 = torch.ops.aten.view.default(view_1194, [16384, 512]); view_1194 = None + mm_115 = torch.ops.aten.mm.default(view_1200, permute_183); view_1200 = permute_183 = None + view_1201 = torch.ops.aten.view.default(mm_115, [2, 8192, 4096]); mm_115 = None + split_74 = torch.ops.aten.split.Tensor(view_1201, 1024, 1); view_1201 = None + getitem_745 = split_74[0] + getitem_746 = split_74[1] + getitem_747 = split_74[2] + getitem_748 = split_74[3] + getitem_749 = split_74[4] + getitem_750 = split_74[5] + getitem_751 = split_74[6] + getitem_752 = split_74[7]; split_74 = None + cat_66 = torch.ops.aten.cat.default([getitem_745, getitem_746, getitem_747, getitem_748, getitem_749, getitem_750, getitem_751, getitem_752]); getitem_745 = getitem_746 = getitem_747 = getitem_748 = getitem_749 = getitem_750 = getitem_751 = getitem_752 = None + reduce_scatter_tensor_33 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_66, 'sum', 8, '1'); cat_66 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33) + add_65 = torch.ops.aten.add.Tensor(add_63, wait_tensor_216); wait_tensor_216 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16) + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 32, '0'); convert_element_type_548 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32) + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = rsqrt_33 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_217); mul_132 = wait_tensor_217 = None + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_550, 8, '1'); convert_element_type_550 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + split_75 = torch.ops.aten.split.Tensor(wait_tensor_218, 2); wait_tensor_218 = None + getitem_753 = split_75[0] + getitem_754 = split_75[1] + getitem_755 = split_75[2] + getitem_756 = split_75[3] + getitem_757 = split_75[4] + getitem_758 = split_75[5] + getitem_759 = split_75[6] + getitem_760 = split_75[7]; split_75 = None + cat_67 = torch.ops.aten.cat.default([getitem_753, getitem_754, getitem_755, getitem_756, getitem_757, getitem_758, getitem_759, getitem_760], 1); getitem_753 = getitem_754 = getitem_755 = getitem_756 = getitem_757 = getitem_758 = getitem_759 = getitem_760 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16) + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 32, '0'); convert_element_type_551 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + view_1212 = torch.ops.aten.view.default(cat_67, [16384, 4096]); cat_67 = None + mm_116 = torch.ops.aten.mm.default(view_1212, permute_184); permute_184 = None + view_1213 = torch.ops.aten.view.default(mm_116, [2, 8192, 1792]) + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_1213, torch.float32); view_1213 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); convert_element_type_554 = sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16) + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 32, '0'); convert_element_type_556 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + mm_117 = torch.ops.aten.mm.default(view_1212, permute_185); view_1212 = permute_185 = None + view_1220 = torch.ops.aten.view.default(mm_117, [2, 8192, 1792]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_1220); convert_element_type_555 = view_1220 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16) + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 32, '0'); convert_element_type_559 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + view_1227 = torch.ops.aten.view.default(mul_135, [16384, 1792]); mul_135 = None + mm_118 = torch.ops.aten.mm.default(view_1227, permute_186); view_1227 = permute_186 = None + view_1228 = torch.ops.aten.view.default(mm_118, [2, 8192, 4096]); mm_118 = None + split_76 = torch.ops.aten.split.Tensor(view_1228, 1024, 1); view_1228 = None + getitem_761 = split_76[0] + getitem_762 = split_76[1] + getitem_763 = split_76[2] + getitem_764 = split_76[3] + getitem_765 = split_76[4] + getitem_766 = split_76[5] + getitem_767 = split_76[6] + getitem_768 = split_76[7]; split_76 = None + cat_68 = torch.ops.aten.cat.default([getitem_761, getitem_762, getitem_763, getitem_764, getitem_765, getitem_766, getitem_767, getitem_768]); getitem_761 = getitem_762 = getitem_763 = getitem_764 = getitem_765 = getitem_766 = getitem_767 = getitem_768 = None + reduce_scatter_tensor_34 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_68, 'sum', 8, '1'); cat_68 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_34); reduce_scatter_tensor_34 = None + add_67 = torch.ops.aten.add.Tensor(add_65, wait_tensor_222); add_65 = wait_tensor_222 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16) + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 32, '0'); convert_element_type_562 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32) + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = rsqrt_34 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_223); mul_136 = wait_tensor_223 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_564, 8, '1'); convert_element_type_564 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + split_77 = torch.ops.aten.split.Tensor(wait_tensor_224, 2); wait_tensor_224 = None + getitem_769 = split_77[0] + getitem_770 = split_77[1] + getitem_771 = split_77[2] + getitem_772 = split_77[3] + getitem_773 = split_77[4] + getitem_774 = split_77[5] + getitem_775 = split_77[6] + getitem_776 = split_77[7]; split_77 = None + cat_69 = torch.ops.aten.cat.default([getitem_769, getitem_770, getitem_771, getitem_772, getitem_773, getitem_774, getitem_775, getitem_776], 1); getitem_769 = getitem_770 = getitem_771 = getitem_772 = getitem_773 = getitem_774 = getitem_775 = getitem_776 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16) + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 32, '0'); convert_element_type_565 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + view_1239 = torch.ops.aten.view.default(cat_69, [16384, 4096]); cat_69 = None + mm_119 = torch.ops.aten.mm.default(view_1239, permute_187); permute_187 = None + view_1240 = torch.ops.aten.view.default(mm_119, [2, 8192, 512]) + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16) + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 32, '0'); convert_element_type_568 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_226, [1, 0]); wait_tensor_226 = None + mm_120 = torch.ops.aten.mm.default(view_1239, permute_188); permute_188 = None + view_1247 = torch.ops.aten.view.default(mm_120, [2, 8192, 128]); mm_120 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16) + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 32, '0'); convert_element_type_571 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + mm_121 = torch.ops.aten.mm.default(view_1239, permute_189); view_1239 = permute_189 = None + view_1254 = torch.ops.aten.view.default(mm_121, [2, 8192, 128]) + view_1256 = torch.ops.aten.view.default(view_1240, [2, 8192, -1, 128]); view_1240 = None + view_1257 = torch.ops.aten.view.default(view_1247, [2, 8192, -1, 128]); view_1247 = None + view_1258 = torch.ops.aten.view.default(view_1254, [2, 8192, -1, 128]); view_1254 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_1256, torch.float32); view_1256 = None + view_1259 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 4, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_1259); view_1259 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_1257, torch.float32); view_1257 = None + view_1260 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 1, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_1260); view_1260 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_37); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_1262 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 4, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_37); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_1263 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 1, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_1262, torch.bfloat16); view_1262 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_1263, torch.bfloat16); view_1263 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 1, 4, 128]); unsqueeze_34 = None + view_1264 = torch.ops.aten.view.default(expand_34, [2, 8192, 4, 128]); expand_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_1258, 3); view_1258 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 1, 4, 128]); unsqueeze_35 = None + view_1265 = torch.ops.aten.view.default(expand_35, [2, 8192, 4, 128]); expand_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_1264, [0, 2, 1, 3]); view_1264 = None + permute_192 = torch.ops.aten.permute.default(view_1265, [0, 2, 1, 3]); view_1265 = None + _scaled_dot_product_cudnn_attention_17 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_190, permute_191, permute_192, None, True, 0.0, True); permute_190 = permute_191 = permute_192 = None + getitem_777 = _scaled_dot_product_cudnn_attention_17[0] + getitem_778 = _scaled_dot_product_cudnn_attention_17[1] + getitem_783 = _scaled_dot_product_cudnn_attention_17[6] + getitem_784 = _scaled_dot_product_cudnn_attention_17[7]; _scaled_dot_product_cudnn_attention_17 = None + permute_193 = torch.ops.aten.permute.default(getitem_777, [0, 2, 1, 3]) + view_1266 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16) + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 32, '0'); convert_element_type_578 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + view_1272 = torch.ops.aten.view.default(view_1266, [16384, 512]); view_1266 = None + mm_122 = torch.ops.aten.mm.default(view_1272, permute_194); view_1272 = permute_194 = None + view_1273 = torch.ops.aten.view.default(mm_122, [2, 8192, 4096]); mm_122 = None + split_78 = torch.ops.aten.split.Tensor(view_1273, 1024, 1); view_1273 = None + getitem_786 = split_78[0] + getitem_787 = split_78[1] + getitem_788 = split_78[2] + getitem_789 = split_78[3] + getitem_790 = split_78[4] + getitem_791 = split_78[5] + getitem_792 = split_78[6] + getitem_793 = split_78[7]; split_78 = None + cat_70 = torch.ops.aten.cat.default([getitem_786, getitem_787, getitem_788, getitem_789, getitem_790, getitem_791, getitem_792, getitem_793]); getitem_786 = getitem_787 = getitem_788 = getitem_789 = getitem_790 = getitem_791 = getitem_792 = getitem_793 = None + reduce_scatter_tensor_35 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_70, 'sum', 8, '1'); cat_70 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35) + add_69 = torch.ops.aten.add.Tensor(add_67, wait_tensor_229); wait_tensor_229 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16) + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 32, '0'); convert_element_type_581 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32) + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = rsqrt_35 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_230); mul_140 = wait_tensor_230 = None + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_583, 8, '1'); convert_element_type_583 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + split_79 = torch.ops.aten.split.Tensor(wait_tensor_231, 2); wait_tensor_231 = None + getitem_794 = split_79[0] + getitem_795 = split_79[1] + getitem_796 = split_79[2] + getitem_797 = split_79[3] + getitem_798 = split_79[4] + getitem_799 = split_79[5] + getitem_800 = split_79[6] + getitem_801 = split_79[7]; split_79 = None + cat_71 = torch.ops.aten.cat.default([getitem_794, getitem_795, getitem_796, getitem_797, getitem_798, getitem_799, getitem_800, getitem_801], 1); getitem_794 = getitem_795 = getitem_796 = getitem_797 = getitem_798 = getitem_799 = getitem_800 = getitem_801 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16) + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 32, '0'); convert_element_type_584 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + view_1284 = torch.ops.aten.view.default(cat_71, [16384, 4096]); cat_71 = None + mm_123 = torch.ops.aten.mm.default(view_1284, permute_195); permute_195 = None + view_1285 = torch.ops.aten.view.default(mm_123, [2, 8192, 1792]) + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_1285, torch.float32); view_1285 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); convert_element_type_587 = sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16) + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 32, '0'); convert_element_type_589 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_124 = torch.ops.aten.mm.default(view_1284, permute_196); view_1284 = permute_196 = None + view_1292 = torch.ops.aten.view.default(mm_124, [2, 8192, 1792]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_1292); convert_element_type_588 = view_1292 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16) + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 32, '0'); convert_element_type_592 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + view_1299 = torch.ops.aten.view.default(mul_143, [16384, 1792]); mul_143 = None + mm_125 = torch.ops.aten.mm.default(view_1299, permute_197); view_1299 = permute_197 = None + view_1300 = torch.ops.aten.view.default(mm_125, [2, 8192, 4096]); mm_125 = None + split_80 = torch.ops.aten.split.Tensor(view_1300, 1024, 1); view_1300 = None + getitem_802 = split_80[0] + getitem_803 = split_80[1] + getitem_804 = split_80[2] + getitem_805 = split_80[3] + getitem_806 = split_80[4] + getitem_807 = split_80[5] + getitem_808 = split_80[6] + getitem_809 = split_80[7]; split_80 = None + cat_72 = torch.ops.aten.cat.default([getitem_802, getitem_803, getitem_804, getitem_805, getitem_806, getitem_807, getitem_808, getitem_809]); getitem_802 = getitem_803 = getitem_804 = getitem_805 = getitem_806 = getitem_807 = getitem_808 = getitem_809 = None + reduce_scatter_tensor_36 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_72, 'sum', 8, '1'); cat_72 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_36); reduce_scatter_tensor_36 = None + add_71 = torch.ops.aten.add.Tensor(add_69, wait_tensor_235); add_69 = wait_tensor_235 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16) + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 32, '0'); convert_element_type_595 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32) + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = rsqrt_36 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_236); mul_144 = wait_tensor_236 = None + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_597, 8, '1'); convert_element_type_597 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + split_81 = torch.ops.aten.split.Tensor(wait_tensor_237, 2); wait_tensor_237 = None + getitem_810 = split_81[0] + getitem_811 = split_81[1] + getitem_812 = split_81[2] + getitem_813 = split_81[3] + getitem_814 = split_81[4] + getitem_815 = split_81[5] + getitem_816 = split_81[6] + getitem_817 = split_81[7]; split_81 = None + cat_73 = torch.ops.aten.cat.default([getitem_810, getitem_811, getitem_812, getitem_813, getitem_814, getitem_815, getitem_816, getitem_817], 1); getitem_810 = getitem_811 = getitem_812 = getitem_813 = getitem_814 = getitem_815 = getitem_816 = getitem_817 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16) + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 32, '0'); convert_element_type_598 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + view_1311 = torch.ops.aten.view.default(cat_73, [16384, 4096]); cat_73 = None + mm_126 = torch.ops.aten.mm.default(view_1311, permute_198); permute_198 = None + view_1312 = torch.ops.aten.view.default(mm_126, [2, 8192, 512]) + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16) + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 32, '0'); convert_element_type_601 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + mm_127 = torch.ops.aten.mm.default(view_1311, permute_199); permute_199 = None + view_1319 = torch.ops.aten.view.default(mm_127, [2, 8192, 128]); mm_127 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16) + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 32, '0'); convert_element_type_604 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_240, [1, 0]); wait_tensor_240 = None + mm_128 = torch.ops.aten.mm.default(view_1311, permute_200); view_1311 = permute_200 = None + view_1326 = torch.ops.aten.view.default(mm_128, [2, 8192, 128]) + view_1328 = torch.ops.aten.view.default(view_1312, [2, 8192, -1, 128]); view_1312 = None + view_1329 = torch.ops.aten.view.default(view_1319, [2, 8192, -1, 128]); view_1319 = None + view_1330 = torch.ops.aten.view.default(view_1326, [2, 8192, -1, 128]); view_1326 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_1328, torch.float32); view_1328 = None + view_1331 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 4, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_1331); view_1331 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_1329, torch.float32); view_1329 = None + view_1332 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 1, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_1332); view_1332 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_37); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_1334 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 4, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_37); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_1335 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 1, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_1334, torch.bfloat16); view_1334 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_1335, torch.bfloat16); view_1335 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 1, 4, 128]); unsqueeze_36 = None + view_1336 = torch.ops.aten.view.default(expand_36, [2, 8192, 4, 128]); expand_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_1330, 3); view_1330 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 1, 4, 128]); unsqueeze_37 = None + view_1337 = torch.ops.aten.view.default(expand_37, [2, 8192, 4, 128]); expand_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_1336, [0, 2, 1, 3]); view_1336 = None + permute_203 = torch.ops.aten.permute.default(view_1337, [0, 2, 1, 3]); view_1337 = None + _scaled_dot_product_cudnn_attention_18 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_201, permute_202, permute_203, None, True, 0.0, True); permute_201 = permute_202 = permute_203 = None + getitem_818 = _scaled_dot_product_cudnn_attention_18[0] + getitem_819 = _scaled_dot_product_cudnn_attention_18[1] + getitem_824 = _scaled_dot_product_cudnn_attention_18[6] + getitem_825 = _scaled_dot_product_cudnn_attention_18[7]; _scaled_dot_product_cudnn_attention_18 = None + permute_204 = torch.ops.aten.permute.default(getitem_818, [0, 2, 1, 3]) + view_1338 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16) + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 32, '0'); convert_element_type_611 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + view_1344 = torch.ops.aten.view.default(view_1338, [16384, 512]); view_1338 = None + mm_129 = torch.ops.aten.mm.default(view_1344, permute_205); view_1344 = permute_205 = None + view_1345 = torch.ops.aten.view.default(mm_129, [2, 8192, 4096]); mm_129 = None + split_82 = torch.ops.aten.split.Tensor(view_1345, 1024, 1); view_1345 = None + getitem_827 = split_82[0] + getitem_828 = split_82[1] + getitem_829 = split_82[2] + getitem_830 = split_82[3] + getitem_831 = split_82[4] + getitem_832 = split_82[5] + getitem_833 = split_82[6] + getitem_834 = split_82[7]; split_82 = None + cat_74 = torch.ops.aten.cat.default([getitem_827, getitem_828, getitem_829, getitem_830, getitem_831, getitem_832, getitem_833, getitem_834]); getitem_827 = getitem_828 = getitem_829 = getitem_830 = getitem_831 = getitem_832 = getitem_833 = getitem_834 = None + reduce_scatter_tensor_37 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_74, 'sum', 8, '1'); cat_74 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37) + add_73 = torch.ops.aten.add.Tensor(add_71, wait_tensor_242); wait_tensor_242 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16) + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 32, '0'); convert_element_type_614 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32) + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = rsqrt_37 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_243); mul_148 = wait_tensor_243 = None + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_616, 8, '1'); convert_element_type_616 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + split_83 = torch.ops.aten.split.Tensor(wait_tensor_244, 2); wait_tensor_244 = None + getitem_835 = split_83[0] + getitem_836 = split_83[1] + getitem_837 = split_83[2] + getitem_838 = split_83[3] + getitem_839 = split_83[4] + getitem_840 = split_83[5] + getitem_841 = split_83[6] + getitem_842 = split_83[7]; split_83 = None + cat_75 = torch.ops.aten.cat.default([getitem_835, getitem_836, getitem_837, getitem_838, getitem_839, getitem_840, getitem_841, getitem_842], 1); getitem_835 = getitem_836 = getitem_837 = getitem_838 = getitem_839 = getitem_840 = getitem_841 = getitem_842 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16) + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 32, '0'); convert_element_type_617 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + view_1356 = torch.ops.aten.view.default(cat_75, [16384, 4096]); cat_75 = None + mm_130 = torch.ops.aten.mm.default(view_1356, permute_206); permute_206 = None + view_1357 = torch.ops.aten.view.default(mm_130, [2, 8192, 1792]) + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_1357, torch.float32); view_1357 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); convert_element_type_620 = sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16) + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 32, '0'); convert_element_type_622 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_131 = torch.ops.aten.mm.default(view_1356, permute_207); view_1356 = permute_207 = None + view_1364 = torch.ops.aten.view.default(mm_131, [2, 8192, 1792]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_1364); convert_element_type_621 = view_1364 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16) + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 32, '0'); convert_element_type_625 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + view_1371 = torch.ops.aten.view.default(mul_151, [16384, 1792]); mul_151 = None + mm_132 = torch.ops.aten.mm.default(view_1371, permute_208); view_1371 = permute_208 = None + view_1372 = torch.ops.aten.view.default(mm_132, [2, 8192, 4096]); mm_132 = None + split_84 = torch.ops.aten.split.Tensor(view_1372, 1024, 1); view_1372 = None + getitem_843 = split_84[0] + getitem_844 = split_84[1] + getitem_845 = split_84[2] + getitem_846 = split_84[3] + getitem_847 = split_84[4] + getitem_848 = split_84[5] + getitem_849 = split_84[6] + getitem_850 = split_84[7]; split_84 = None + cat_76 = torch.ops.aten.cat.default([getitem_843, getitem_844, getitem_845, getitem_846, getitem_847, getitem_848, getitem_849, getitem_850]); getitem_843 = getitem_844 = getitem_845 = getitem_846 = getitem_847 = getitem_848 = getitem_849 = getitem_850 = None + reduce_scatter_tensor_38 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_76, 'sum', 8, '1'); cat_76 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_38); reduce_scatter_tensor_38 = None + add_75 = torch.ops.aten.add.Tensor(add_73, wait_tensor_248); add_73 = wait_tensor_248 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16) + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 32, '0'); convert_element_type_628 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32) + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = rsqrt_38 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_249); mul_152 = wait_tensor_249 = None + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_630, 8, '1'); convert_element_type_630 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + split_85 = torch.ops.aten.split.Tensor(wait_tensor_250, 2); wait_tensor_250 = None + getitem_851 = split_85[0] + getitem_852 = split_85[1] + getitem_853 = split_85[2] + getitem_854 = split_85[3] + getitem_855 = split_85[4] + getitem_856 = split_85[5] + getitem_857 = split_85[6] + getitem_858 = split_85[7]; split_85 = None + cat_77 = torch.ops.aten.cat.default([getitem_851, getitem_852, getitem_853, getitem_854, getitem_855, getitem_856, getitem_857, getitem_858], 1); getitem_851 = getitem_852 = getitem_853 = getitem_854 = getitem_855 = getitem_856 = getitem_857 = getitem_858 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16) + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 32, '0'); convert_element_type_631 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + view_1383 = torch.ops.aten.view.default(cat_77, [16384, 4096]); cat_77 = None + mm_133 = torch.ops.aten.mm.default(view_1383, permute_209); permute_209 = None + view_1384 = torch.ops.aten.view.default(mm_133, [2, 8192, 512]) + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16) + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 32, '0'); convert_element_type_634 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + mm_134 = torch.ops.aten.mm.default(view_1383, permute_210); permute_210 = None + view_1391 = torch.ops.aten.view.default(mm_134, [2, 8192, 128]); mm_134 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16) + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 32, '0'); convert_element_type_637 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_253, [1, 0]); wait_tensor_253 = None + mm_135 = torch.ops.aten.mm.default(view_1383, permute_211); view_1383 = permute_211 = None + view_1398 = torch.ops.aten.view.default(mm_135, [2, 8192, 128]) + view_1400 = torch.ops.aten.view.default(view_1384, [2, 8192, -1, 128]); view_1384 = None + view_1401 = torch.ops.aten.view.default(view_1391, [2, 8192, -1, 128]); view_1391 = None + view_1402 = torch.ops.aten.view.default(view_1398, [2, 8192, -1, 128]); view_1398 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_1400, torch.float32); view_1400 = None + view_1403 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 4, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_1403); view_1403 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_1401, torch.float32); view_1401 = None + view_1404 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 1, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_1404); view_1404 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_37); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_1406 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 4, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_37); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_1407 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 1, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_1406, torch.bfloat16); view_1406 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_1407, torch.bfloat16); view_1407 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 1, 4, 128]); unsqueeze_38 = None + view_1408 = torch.ops.aten.view.default(expand_38, [2, 8192, 4, 128]); expand_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_1402, 3); view_1402 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 1, 4, 128]); unsqueeze_39 = None + view_1409 = torch.ops.aten.view.default(expand_39, [2, 8192, 4, 128]); expand_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_1408, [0, 2, 1, 3]); view_1408 = None + permute_214 = torch.ops.aten.permute.default(view_1409, [0, 2, 1, 3]); view_1409 = None + _scaled_dot_product_cudnn_attention_19 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_212, permute_213, permute_214, None, True, 0.0, True); permute_212 = permute_213 = permute_214 = None + getitem_859 = _scaled_dot_product_cudnn_attention_19[0] + getitem_860 = _scaled_dot_product_cudnn_attention_19[1] + getitem_865 = _scaled_dot_product_cudnn_attention_19[6] + getitem_866 = _scaled_dot_product_cudnn_attention_19[7]; _scaled_dot_product_cudnn_attention_19 = None + permute_215 = torch.ops.aten.permute.default(getitem_859, [0, 2, 1, 3]) + view_1410 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16) + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 32, '0'); convert_element_type_644 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + view_1416 = torch.ops.aten.view.default(view_1410, [16384, 512]); view_1410 = None + mm_136 = torch.ops.aten.mm.default(view_1416, permute_216); view_1416 = permute_216 = None + view_1417 = torch.ops.aten.view.default(mm_136, [2, 8192, 4096]); mm_136 = None + split_86 = torch.ops.aten.split.Tensor(view_1417, 1024, 1); view_1417 = None + getitem_868 = split_86[0] + getitem_869 = split_86[1] + getitem_870 = split_86[2] + getitem_871 = split_86[3] + getitem_872 = split_86[4] + getitem_873 = split_86[5] + getitem_874 = split_86[6] + getitem_875 = split_86[7]; split_86 = None + cat_78 = torch.ops.aten.cat.default([getitem_868, getitem_869, getitem_870, getitem_871, getitem_872, getitem_873, getitem_874, getitem_875]); getitem_868 = getitem_869 = getitem_870 = getitem_871 = getitem_872 = getitem_873 = getitem_874 = getitem_875 = None + reduce_scatter_tensor_39 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_78, 'sum', 8, '1'); cat_78 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39) + add_77 = torch.ops.aten.add.Tensor(add_75, wait_tensor_255); wait_tensor_255 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16) + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 32, '0'); convert_element_type_647 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32) + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = rsqrt_39 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_256); mul_156 = wait_tensor_256 = None + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_649, 8, '1'); convert_element_type_649 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + split_87 = torch.ops.aten.split.Tensor(wait_tensor_257, 2); wait_tensor_257 = None + getitem_876 = split_87[0] + getitem_877 = split_87[1] + getitem_878 = split_87[2] + getitem_879 = split_87[3] + getitem_880 = split_87[4] + getitem_881 = split_87[5] + getitem_882 = split_87[6] + getitem_883 = split_87[7]; split_87 = None + cat_79 = torch.ops.aten.cat.default([getitem_876, getitem_877, getitem_878, getitem_879, getitem_880, getitem_881, getitem_882, getitem_883], 1); getitem_876 = getitem_877 = getitem_878 = getitem_879 = getitem_880 = getitem_881 = getitem_882 = getitem_883 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16) + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 32, '0'); convert_element_type_650 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_258, [1, 0]); wait_tensor_258 = None + view_1428 = torch.ops.aten.view.default(cat_79, [16384, 4096]); cat_79 = None + mm_137 = torch.ops.aten.mm.default(view_1428, permute_217); permute_217 = None + view_1429 = torch.ops.aten.view.default(mm_137, [2, 8192, 1792]) + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_1429, torch.float32); view_1429 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); convert_element_type_653 = sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16) + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 32, '0'); convert_element_type_655 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + mm_138 = torch.ops.aten.mm.default(view_1428, permute_218); view_1428 = permute_218 = None + view_1436 = torch.ops.aten.view.default(mm_138, [2, 8192, 1792]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_1436); convert_element_type_654 = view_1436 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16) + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 32, '0'); convert_element_type_658 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + view_1443 = torch.ops.aten.view.default(mul_159, [16384, 1792]); mul_159 = None + mm_139 = torch.ops.aten.mm.default(view_1443, permute_219); view_1443 = permute_219 = None + view_1444 = torch.ops.aten.view.default(mm_139, [2, 8192, 4096]); mm_139 = None + split_88 = torch.ops.aten.split.Tensor(view_1444, 1024, 1); view_1444 = None + getitem_884 = split_88[0] + getitem_885 = split_88[1] + getitem_886 = split_88[2] + getitem_887 = split_88[3] + getitem_888 = split_88[4] + getitem_889 = split_88[5] + getitem_890 = split_88[6] + getitem_891 = split_88[7]; split_88 = None + cat_80 = torch.ops.aten.cat.default([getitem_884, getitem_885, getitem_886, getitem_887, getitem_888, getitem_889, getitem_890, getitem_891]); getitem_884 = getitem_885 = getitem_886 = getitem_887 = getitem_888 = getitem_889 = getitem_890 = getitem_891 = None + reduce_scatter_tensor_40 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_80, 'sum', 8, '1'); cat_80 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_40); reduce_scatter_tensor_40 = None + add_79 = torch.ops.aten.add.Tensor(add_77, wait_tensor_261); add_77 = wait_tensor_261 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16) + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 32, '0'); convert_element_type_661 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32) + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = rsqrt_40 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_262); mul_160 = wait_tensor_262 = None + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_663, 8, '1'); convert_element_type_663 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + split_89 = torch.ops.aten.split.Tensor(wait_tensor_263, 2); wait_tensor_263 = None + getitem_892 = split_89[0] + getitem_893 = split_89[1] + getitem_894 = split_89[2] + getitem_895 = split_89[3] + getitem_896 = split_89[4] + getitem_897 = split_89[5] + getitem_898 = split_89[6] + getitem_899 = split_89[7]; split_89 = None + cat_81 = torch.ops.aten.cat.default([getitem_892, getitem_893, getitem_894, getitem_895, getitem_896, getitem_897, getitem_898, getitem_899], 1); getitem_892 = getitem_893 = getitem_894 = getitem_895 = getitem_896 = getitem_897 = getitem_898 = getitem_899 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16) + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 32, '0'); convert_element_type_664 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + view_1455 = torch.ops.aten.view.default(cat_81, [16384, 4096]); cat_81 = None + mm_140 = torch.ops.aten.mm.default(view_1455, permute_220); permute_220 = None + view_1456 = torch.ops.aten.view.default(mm_140, [2, 8192, 512]) + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16) + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 32, '0'); convert_element_type_667 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + mm_141 = torch.ops.aten.mm.default(view_1455, permute_221); permute_221 = None + view_1463 = torch.ops.aten.view.default(mm_141, [2, 8192, 128]); mm_141 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16) + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 32, '0'); convert_element_type_670 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + mm_142 = torch.ops.aten.mm.default(view_1455, permute_222); view_1455 = permute_222 = None + view_1470 = torch.ops.aten.view.default(mm_142, [2, 8192, 128]) + view_1472 = torch.ops.aten.view.default(view_1456, [2, 8192, -1, 128]); view_1456 = None + view_1473 = torch.ops.aten.view.default(view_1463, [2, 8192, -1, 128]); view_1463 = None + view_1474 = torch.ops.aten.view.default(view_1470, [2, 8192, -1, 128]); view_1470 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_1472, torch.float32); view_1472 = None + view_1475 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 4, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_1475); view_1475 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_1473, torch.float32); view_1473 = None + view_1476 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 1, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_1476); view_1476 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_37); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_1478 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 4, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_37); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_1479 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 1, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_1478, torch.bfloat16); view_1478 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_1479, torch.bfloat16); view_1479 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 1, 4, 128]); unsqueeze_40 = None + view_1480 = torch.ops.aten.view.default(expand_40, [2, 8192, 4, 128]); expand_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_1474, 3); view_1474 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 1, 4, 128]); unsqueeze_41 = None + view_1481 = torch.ops.aten.view.default(expand_41, [2, 8192, 4, 128]); expand_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_1480, [0, 2, 1, 3]); view_1480 = None + permute_225 = torch.ops.aten.permute.default(view_1481, [0, 2, 1, 3]); view_1481 = None + _scaled_dot_product_cudnn_attention_20 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_223, permute_224, permute_225, None, True, 0.0, True); permute_223 = permute_224 = permute_225 = None + getitem_900 = _scaled_dot_product_cudnn_attention_20[0] + getitem_901 = _scaled_dot_product_cudnn_attention_20[1] + getitem_906 = _scaled_dot_product_cudnn_attention_20[6] + getitem_907 = _scaled_dot_product_cudnn_attention_20[7]; _scaled_dot_product_cudnn_attention_20 = None + permute_226 = torch.ops.aten.permute.default(getitem_900, [0, 2, 1, 3]) + view_1482 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16) + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 32, '0'); convert_element_type_677 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_267, [1, 0]); wait_tensor_267 = None + view_1488 = torch.ops.aten.view.default(view_1482, [16384, 512]); view_1482 = None + mm_143 = torch.ops.aten.mm.default(view_1488, permute_227); view_1488 = permute_227 = None + view_1489 = torch.ops.aten.view.default(mm_143, [2, 8192, 4096]); mm_143 = None + split_90 = torch.ops.aten.split.Tensor(view_1489, 1024, 1); view_1489 = None + getitem_909 = split_90[0] + getitem_910 = split_90[1] + getitem_911 = split_90[2] + getitem_912 = split_90[3] + getitem_913 = split_90[4] + getitem_914 = split_90[5] + getitem_915 = split_90[6] + getitem_916 = split_90[7]; split_90 = None + cat_82 = torch.ops.aten.cat.default([getitem_909, getitem_910, getitem_911, getitem_912, getitem_913, getitem_914, getitem_915, getitem_916]); getitem_909 = getitem_910 = getitem_911 = getitem_912 = getitem_913 = getitem_914 = getitem_915 = getitem_916 = None + reduce_scatter_tensor_41 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_82, 'sum', 8, '1'); cat_82 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41) + add_81 = torch.ops.aten.add.Tensor(add_79, wait_tensor_268); wait_tensor_268 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16) + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 32, '0'); convert_element_type_680 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32) + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = rsqrt_41 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_269); mul_164 = wait_tensor_269 = None + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_682, 8, '1'); convert_element_type_682 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + split_91 = torch.ops.aten.split.Tensor(wait_tensor_270, 2); wait_tensor_270 = None + getitem_917 = split_91[0] + getitem_918 = split_91[1] + getitem_919 = split_91[2] + getitem_920 = split_91[3] + getitem_921 = split_91[4] + getitem_922 = split_91[5] + getitem_923 = split_91[6] + getitem_924 = split_91[7]; split_91 = None + cat_83 = torch.ops.aten.cat.default([getitem_917, getitem_918, getitem_919, getitem_920, getitem_921, getitem_922, getitem_923, getitem_924], 1); getitem_917 = getitem_918 = getitem_919 = getitem_920 = getitem_921 = getitem_922 = getitem_923 = getitem_924 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16) + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 32, '0'); convert_element_type_683 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_271, [1, 0]); wait_tensor_271 = None + view_1500 = torch.ops.aten.view.default(cat_83, [16384, 4096]); cat_83 = None + mm_144 = torch.ops.aten.mm.default(view_1500, permute_228); permute_228 = None + view_1501 = torch.ops.aten.view.default(mm_144, [2, 8192, 1792]) + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_1501, torch.float32); view_1501 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); convert_element_type_686 = sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16) + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 32, '0'); convert_element_type_688 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + mm_145 = torch.ops.aten.mm.default(view_1500, permute_229); view_1500 = permute_229 = None + view_1508 = torch.ops.aten.view.default(mm_145, [2, 8192, 1792]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_1508); convert_element_type_687 = view_1508 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16) + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 32, '0'); convert_element_type_691 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + view_1515 = torch.ops.aten.view.default(mul_167, [16384, 1792]); mul_167 = None + mm_146 = torch.ops.aten.mm.default(view_1515, permute_230); view_1515 = permute_230 = None + view_1516 = torch.ops.aten.view.default(mm_146, [2, 8192, 4096]); mm_146 = None + split_92 = torch.ops.aten.split.Tensor(view_1516, 1024, 1); view_1516 = None + getitem_925 = split_92[0] + getitem_926 = split_92[1] + getitem_927 = split_92[2] + getitem_928 = split_92[3] + getitem_929 = split_92[4] + getitem_930 = split_92[5] + getitem_931 = split_92[6] + getitem_932 = split_92[7]; split_92 = None + cat_84 = torch.ops.aten.cat.default([getitem_925, getitem_926, getitem_927, getitem_928, getitem_929, getitem_930, getitem_931, getitem_932]); getitem_925 = getitem_926 = getitem_927 = getitem_928 = getitem_929 = getitem_930 = getitem_931 = getitem_932 = None + reduce_scatter_tensor_42 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_84, 'sum', 8, '1'); cat_84 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_42); reduce_scatter_tensor_42 = None + add_83 = torch.ops.aten.add.Tensor(add_81, wait_tensor_274); add_81 = wait_tensor_274 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16) + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 32, '0'); convert_element_type_694 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32) + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = rsqrt_42 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_275); mul_168 = wait_tensor_275 = None + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_696, 8, '1'); convert_element_type_696 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + split_93 = torch.ops.aten.split.Tensor(wait_tensor_276, 2); wait_tensor_276 = None + getitem_933 = split_93[0] + getitem_934 = split_93[1] + getitem_935 = split_93[2] + getitem_936 = split_93[3] + getitem_937 = split_93[4] + getitem_938 = split_93[5] + getitem_939 = split_93[6] + getitem_940 = split_93[7]; split_93 = None + cat_85 = torch.ops.aten.cat.default([getitem_933, getitem_934, getitem_935, getitem_936, getitem_937, getitem_938, getitem_939, getitem_940], 1); getitem_933 = getitem_934 = getitem_935 = getitem_936 = getitem_937 = getitem_938 = getitem_939 = getitem_940 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16) + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 32, '0'); convert_element_type_697 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + view_1527 = torch.ops.aten.view.default(cat_85, [16384, 4096]); cat_85 = None + mm_147 = torch.ops.aten.mm.default(view_1527, permute_231); permute_231 = None + view_1528 = torch.ops.aten.view.default(mm_147, [2, 8192, 512]) + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16) + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 32, '0'); convert_element_type_700 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_148 = torch.ops.aten.mm.default(view_1527, permute_232); permute_232 = None + view_1535 = torch.ops.aten.view.default(mm_148, [2, 8192, 128]); mm_148 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16) + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 32, '0'); convert_element_type_703 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + mm_149 = torch.ops.aten.mm.default(view_1527, permute_233); view_1527 = permute_233 = None + view_1542 = torch.ops.aten.view.default(mm_149, [2, 8192, 128]) + view_1544 = torch.ops.aten.view.default(view_1528, [2, 8192, -1, 128]); view_1528 = None + view_1545 = torch.ops.aten.view.default(view_1535, [2, 8192, -1, 128]); view_1535 = None + view_1546 = torch.ops.aten.view.default(view_1542, [2, 8192, -1, 128]); view_1542 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_1544, torch.float32); view_1544 = None + view_1547 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 4, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_1547); view_1547 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_1545, torch.float32); view_1545 = None + view_1548 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 1, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_1548); view_1548 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_37); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_1550 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 4, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_37); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_1551 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 1, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_1550, torch.bfloat16); view_1550 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_1551, torch.bfloat16); view_1551 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 1, 4, 128]); unsqueeze_42 = None + view_1552 = torch.ops.aten.view.default(expand_42, [2, 8192, 4, 128]); expand_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_1546, 3); view_1546 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 1, 4, 128]); unsqueeze_43 = None + view_1553 = torch.ops.aten.view.default(expand_43, [2, 8192, 4, 128]); expand_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_1552, [0, 2, 1, 3]); view_1552 = None + permute_236 = torch.ops.aten.permute.default(view_1553, [0, 2, 1, 3]); view_1553 = None + _scaled_dot_product_cudnn_attention_21 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_234, permute_235, permute_236, None, True, 0.0, True); permute_234 = permute_235 = permute_236 = None + getitem_941 = _scaled_dot_product_cudnn_attention_21[0] + getitem_942 = _scaled_dot_product_cudnn_attention_21[1] + getitem_947 = _scaled_dot_product_cudnn_attention_21[6] + getitem_948 = _scaled_dot_product_cudnn_attention_21[7]; _scaled_dot_product_cudnn_attention_21 = None + permute_237 = torch.ops.aten.permute.default(getitem_941, [0, 2, 1, 3]) + view_1554 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16) + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 32, '0'); convert_element_type_710 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_280, [1, 0]); wait_tensor_280 = None + view_1560 = torch.ops.aten.view.default(view_1554, [16384, 512]); view_1554 = None + mm_150 = torch.ops.aten.mm.default(view_1560, permute_238); view_1560 = permute_238 = None + view_1561 = torch.ops.aten.view.default(mm_150, [2, 8192, 4096]); mm_150 = None + split_94 = torch.ops.aten.split.Tensor(view_1561, 1024, 1); view_1561 = None + getitem_950 = split_94[0] + getitem_951 = split_94[1] + getitem_952 = split_94[2] + getitem_953 = split_94[3] + getitem_954 = split_94[4] + getitem_955 = split_94[5] + getitem_956 = split_94[6] + getitem_957 = split_94[7]; split_94 = None + cat_86 = torch.ops.aten.cat.default([getitem_950, getitem_951, getitem_952, getitem_953, getitem_954, getitem_955, getitem_956, getitem_957]); getitem_950 = getitem_951 = getitem_952 = getitem_953 = getitem_954 = getitem_955 = getitem_956 = getitem_957 = None + reduce_scatter_tensor_43 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_86, 'sum', 8, '1'); cat_86 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43) + add_85 = torch.ops.aten.add.Tensor(add_83, wait_tensor_281); wait_tensor_281 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16) + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 32, '0'); convert_element_type_713 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32) + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = rsqrt_43 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_282); mul_172 = wait_tensor_282 = None + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_715, 8, '1'); convert_element_type_715 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + split_95 = torch.ops.aten.split.Tensor(wait_tensor_283, 2); wait_tensor_283 = None + getitem_958 = split_95[0] + getitem_959 = split_95[1] + getitem_960 = split_95[2] + getitem_961 = split_95[3] + getitem_962 = split_95[4] + getitem_963 = split_95[5] + getitem_964 = split_95[6] + getitem_965 = split_95[7]; split_95 = None + cat_87 = torch.ops.aten.cat.default([getitem_958, getitem_959, getitem_960, getitem_961, getitem_962, getitem_963, getitem_964, getitem_965], 1); getitem_958 = getitem_959 = getitem_960 = getitem_961 = getitem_962 = getitem_963 = getitem_964 = getitem_965 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16) + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 32, '0'); convert_element_type_716 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + view_1572 = torch.ops.aten.view.default(cat_87, [16384, 4096]); cat_87 = None + mm_151 = torch.ops.aten.mm.default(view_1572, permute_239); permute_239 = None + view_1573 = torch.ops.aten.view.default(mm_151, [2, 8192, 1792]) + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_1573, torch.float32); view_1573 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); convert_element_type_719 = sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16) + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 32, '0'); convert_element_type_721 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_285, [1, 0]); wait_tensor_285 = None + mm_152 = torch.ops.aten.mm.default(view_1572, permute_240); view_1572 = permute_240 = None + view_1580 = torch.ops.aten.view.default(mm_152, [2, 8192, 1792]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_1580); convert_element_type_720 = view_1580 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16) + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 32, '0'); convert_element_type_724 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + view_1587 = torch.ops.aten.view.default(mul_175, [16384, 1792]); mul_175 = None + mm_153 = torch.ops.aten.mm.default(view_1587, permute_241); view_1587 = permute_241 = None + view_1588 = torch.ops.aten.view.default(mm_153, [2, 8192, 4096]); mm_153 = None + split_96 = torch.ops.aten.split.Tensor(view_1588, 1024, 1); view_1588 = None + getitem_966 = split_96[0] + getitem_967 = split_96[1] + getitem_968 = split_96[2] + getitem_969 = split_96[3] + getitem_970 = split_96[4] + getitem_971 = split_96[5] + getitem_972 = split_96[6] + getitem_973 = split_96[7]; split_96 = None + cat_88 = torch.ops.aten.cat.default([getitem_966, getitem_967, getitem_968, getitem_969, getitem_970, getitem_971, getitem_972, getitem_973]); getitem_966 = getitem_967 = getitem_968 = getitem_969 = getitem_970 = getitem_971 = getitem_972 = getitem_973 = None + reduce_scatter_tensor_44 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_88, 'sum', 8, '1'); cat_88 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_44); reduce_scatter_tensor_44 = None + add_87 = torch.ops.aten.add.Tensor(add_85, wait_tensor_287); add_85 = wait_tensor_287 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16) + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 32, '0'); convert_element_type_727 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32) + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = rsqrt_44 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_288); mul_176 = wait_tensor_288 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_729, 8, '1'); convert_element_type_729 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + split_97 = torch.ops.aten.split.Tensor(wait_tensor_289, 2); wait_tensor_289 = None + getitem_974 = split_97[0] + getitem_975 = split_97[1] + getitem_976 = split_97[2] + getitem_977 = split_97[3] + getitem_978 = split_97[4] + getitem_979 = split_97[5] + getitem_980 = split_97[6] + getitem_981 = split_97[7]; split_97 = None + cat_89 = torch.ops.aten.cat.default([getitem_974, getitem_975, getitem_976, getitem_977, getitem_978, getitem_979, getitem_980, getitem_981], 1); getitem_974 = getitem_975 = getitem_976 = getitem_977 = getitem_978 = getitem_979 = getitem_980 = getitem_981 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16) + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 32, '0'); convert_element_type_730 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + view_1599 = torch.ops.aten.view.default(cat_89, [16384, 4096]); cat_89 = None + mm_154 = torch.ops.aten.mm.default(view_1599, permute_242); permute_242 = None + view_1600 = torch.ops.aten.view.default(mm_154, [2, 8192, 512]) + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16) + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 32, '0'); convert_element_type_733 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_291, [1, 0]); wait_tensor_291 = None + mm_155 = torch.ops.aten.mm.default(view_1599, permute_243); permute_243 = None + view_1607 = torch.ops.aten.view.default(mm_155, [2, 8192, 128]); mm_155 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16) + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 32, '0'); convert_element_type_736 = None + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_292, [1, 0]); wait_tensor_292 = None + mm_156 = torch.ops.aten.mm.default(view_1599, permute_244); view_1599 = permute_244 = None + view_1614 = torch.ops.aten.view.default(mm_156, [2, 8192, 128]) + view_1616 = torch.ops.aten.view.default(view_1600, [2, 8192, -1, 128]); view_1600 = None + view_1617 = torch.ops.aten.view.default(view_1607, [2, 8192, -1, 128]); view_1607 = None + view_1618 = torch.ops.aten.view.default(view_1614, [2, 8192, -1, 128]); view_1614 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_1616, torch.float32); view_1616 = None + view_1619 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 4, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_1619); view_1619 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_1617, torch.float32); view_1617 = None + view_1620 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 1, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_1620); view_1620 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_37); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_1622 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 4, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_37); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_1623 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 1, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_1622, torch.bfloat16); view_1622 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_1623, torch.bfloat16); view_1623 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 1, 4, 128]); unsqueeze_44 = None + view_1624 = torch.ops.aten.view.default(expand_44, [2, 8192, 4, 128]); expand_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_1618, 3); view_1618 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 1, 4, 128]); unsqueeze_45 = None + view_1625 = torch.ops.aten.view.default(expand_45, [2, 8192, 4, 128]); expand_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_1624, [0, 2, 1, 3]); view_1624 = None + permute_247 = torch.ops.aten.permute.default(view_1625, [0, 2, 1, 3]); view_1625 = None + _scaled_dot_product_cudnn_attention_22 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_245, permute_246, permute_247, None, True, 0.0, True); permute_245 = permute_246 = permute_247 = None + getitem_982 = _scaled_dot_product_cudnn_attention_22[0] + getitem_983 = _scaled_dot_product_cudnn_attention_22[1] + getitem_988 = _scaled_dot_product_cudnn_attention_22[6] + getitem_989 = _scaled_dot_product_cudnn_attention_22[7]; _scaled_dot_product_cudnn_attention_22 = None + permute_248 = torch.ops.aten.permute.default(getitem_982, [0, 2, 1, 3]) + view_1626 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16) + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 32, '0'); convert_element_type_743 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_293, [1, 0]); wait_tensor_293 = None + view_1632 = torch.ops.aten.view.default(view_1626, [16384, 512]); view_1626 = None + mm_157 = torch.ops.aten.mm.default(view_1632, permute_249); view_1632 = permute_249 = None + view_1633 = torch.ops.aten.view.default(mm_157, [2, 8192, 4096]); mm_157 = None + split_98 = torch.ops.aten.split.Tensor(view_1633, 1024, 1); view_1633 = None + getitem_991 = split_98[0] + getitem_992 = split_98[1] + getitem_993 = split_98[2] + getitem_994 = split_98[3] + getitem_995 = split_98[4] + getitem_996 = split_98[5] + getitem_997 = split_98[6] + getitem_998 = split_98[7]; split_98 = None + cat_90 = torch.ops.aten.cat.default([getitem_991, getitem_992, getitem_993, getitem_994, getitem_995, getitem_996, getitem_997, getitem_998]); getitem_991 = getitem_992 = getitem_993 = getitem_994 = getitem_995 = getitem_996 = getitem_997 = getitem_998 = None + reduce_scatter_tensor_45 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_90, 'sum', 8, '1'); cat_90 = None + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45) + add_89 = torch.ops.aten.add.Tensor(add_87, wait_tensor_294); wait_tensor_294 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16) + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 32, '0'); convert_element_type_746 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32) + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = rsqrt_45 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_295); mul_180 = wait_tensor_295 = None + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_748, 8, '1'); convert_element_type_748 = None + wait_tensor_296 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + split_99 = torch.ops.aten.split.Tensor(wait_tensor_296, 2); wait_tensor_296 = None + getitem_999 = split_99[0] + getitem_1000 = split_99[1] + getitem_1001 = split_99[2] + getitem_1002 = split_99[3] + getitem_1003 = split_99[4] + getitem_1004 = split_99[5] + getitem_1005 = split_99[6] + getitem_1006 = split_99[7]; split_99 = None + cat_91 = torch.ops.aten.cat.default([getitem_999, getitem_1000, getitem_1001, getitem_1002, getitem_1003, getitem_1004, getitem_1005, getitem_1006], 1); getitem_999 = getitem_1000 = getitem_1001 = getitem_1002 = getitem_1003 = getitem_1004 = getitem_1005 = getitem_1006 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16) + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 32, '0'); convert_element_type_749 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_297, [1, 0]); wait_tensor_297 = None + view_1644 = torch.ops.aten.view.default(cat_91, [16384, 4096]); cat_91 = None + mm_158 = torch.ops.aten.mm.default(view_1644, permute_250); permute_250 = None + view_1645 = torch.ops.aten.view.default(mm_158, [2, 8192, 1792]) + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_1645, torch.float32); view_1645 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); convert_element_type_752 = sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16) + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 32, '0'); convert_element_type_754 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_298, [1, 0]); wait_tensor_298 = None + mm_159 = torch.ops.aten.mm.default(view_1644, permute_251); view_1644 = permute_251 = None + view_1652 = torch.ops.aten.view.default(mm_159, [2, 8192, 1792]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_1652); convert_element_type_753 = view_1652 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16) + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 32, '0'); convert_element_type_757 = None + wait_tensor_299 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_299, [1, 0]); wait_tensor_299 = None + view_1659 = torch.ops.aten.view.default(mul_183, [16384, 1792]); mul_183 = None + mm_160 = torch.ops.aten.mm.default(view_1659, permute_252); view_1659 = permute_252 = None + view_1660 = torch.ops.aten.view.default(mm_160, [2, 8192, 4096]); mm_160 = None + split_100 = torch.ops.aten.split.Tensor(view_1660, 1024, 1); view_1660 = None + getitem_1007 = split_100[0] + getitem_1008 = split_100[1] + getitem_1009 = split_100[2] + getitem_1010 = split_100[3] + getitem_1011 = split_100[4] + getitem_1012 = split_100[5] + getitem_1013 = split_100[6] + getitem_1014 = split_100[7]; split_100 = None + cat_92 = torch.ops.aten.cat.default([getitem_1007, getitem_1008, getitem_1009, getitem_1010, getitem_1011, getitem_1012, getitem_1013, getitem_1014]); getitem_1007 = getitem_1008 = getitem_1009 = getitem_1010 = getitem_1011 = getitem_1012 = getitem_1013 = getitem_1014 = None + reduce_scatter_tensor_46 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_92, 'sum', 8, '1'); cat_92 = None + wait_tensor_300 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_46); reduce_scatter_tensor_46 = None + add_91 = torch.ops.aten.add.Tensor(add_89, wait_tensor_300); add_89 = wait_tensor_300 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16) + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 32, '0'); convert_element_type_760 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32) + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = rsqrt_46 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_301); mul_184 = wait_tensor_301 = None + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_762, 8, '1'); convert_element_type_762 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + split_101 = torch.ops.aten.split.Tensor(wait_tensor_302, 2); wait_tensor_302 = None + getitem_1015 = split_101[0] + getitem_1016 = split_101[1] + getitem_1017 = split_101[2] + getitem_1018 = split_101[3] + getitem_1019 = split_101[4] + getitem_1020 = split_101[5] + getitem_1021 = split_101[6] + getitem_1022 = split_101[7]; split_101 = None + cat_93 = torch.ops.aten.cat.default([getitem_1015, getitem_1016, getitem_1017, getitem_1018, getitem_1019, getitem_1020, getitem_1021, getitem_1022], 1); getitem_1015 = getitem_1016 = getitem_1017 = getitem_1018 = getitem_1019 = getitem_1020 = getitem_1021 = getitem_1022 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16) + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 32, '0'); convert_element_type_763 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_303, [1, 0]); wait_tensor_303 = None + view_1671 = torch.ops.aten.view.default(cat_93, [16384, 4096]); cat_93 = None + mm_161 = torch.ops.aten.mm.default(view_1671, permute_253); permute_253 = None + view_1672 = torch.ops.aten.view.default(mm_161, [2, 8192, 512]) + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16) + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 32, '0'); convert_element_type_766 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_304, [1, 0]); wait_tensor_304 = None + mm_162 = torch.ops.aten.mm.default(view_1671, permute_254); permute_254 = None + view_1679 = torch.ops.aten.view.default(mm_162, [2, 8192, 128]); mm_162 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16) + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 32, '0'); convert_element_type_769 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_305, [1, 0]); wait_tensor_305 = None + mm_163 = torch.ops.aten.mm.default(view_1671, permute_255); view_1671 = permute_255 = None + view_1686 = torch.ops.aten.view.default(mm_163, [2, 8192, 128]) + view_1688 = torch.ops.aten.view.default(view_1672, [2, 8192, -1, 128]); view_1672 = None + view_1689 = torch.ops.aten.view.default(view_1679, [2, 8192, -1, 128]); view_1679 = None + view_1690 = torch.ops.aten.view.default(view_1686, [2, 8192, -1, 128]); view_1686 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_1688, torch.float32); view_1688 = None + view_1691 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 4, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_1691); view_1691 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_1689, torch.float32); view_1689 = None + view_1692 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 1, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_1692); view_1692 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_37); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_1694 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 4, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_37); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_1695 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 1, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_1694, torch.bfloat16); view_1694 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_1695, torch.bfloat16); view_1695 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 1, 4, 128]); unsqueeze_46 = None + view_1696 = torch.ops.aten.view.default(expand_46, [2, 8192, 4, 128]); expand_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_1690, 3); view_1690 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 1, 4, 128]); unsqueeze_47 = None + view_1697 = torch.ops.aten.view.default(expand_47, [2, 8192, 4, 128]); expand_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_1696, [0, 2, 1, 3]); view_1696 = None + permute_258 = torch.ops.aten.permute.default(view_1697, [0, 2, 1, 3]); view_1697 = None + _scaled_dot_product_cudnn_attention_23 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_256, permute_257, permute_258, None, True, 0.0, True); permute_256 = permute_257 = permute_258 = None + getitem_1023 = _scaled_dot_product_cudnn_attention_23[0] + getitem_1024 = _scaled_dot_product_cudnn_attention_23[1] + getitem_1029 = _scaled_dot_product_cudnn_attention_23[6] + getitem_1030 = _scaled_dot_product_cudnn_attention_23[7]; _scaled_dot_product_cudnn_attention_23 = None + permute_259 = torch.ops.aten.permute.default(getitem_1023, [0, 2, 1, 3]) + view_1698 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16) + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 32, '0'); convert_element_type_776 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_306, [1, 0]); wait_tensor_306 = None + view_1704 = torch.ops.aten.view.default(view_1698, [16384, 512]); view_1698 = None + mm_164 = torch.ops.aten.mm.default(view_1704, permute_260); view_1704 = permute_260 = None + view_1705 = torch.ops.aten.view.default(mm_164, [2, 8192, 4096]); mm_164 = None + split_102 = torch.ops.aten.split.Tensor(view_1705, 1024, 1); view_1705 = None + getitem_1032 = split_102[0] + getitem_1033 = split_102[1] + getitem_1034 = split_102[2] + getitem_1035 = split_102[3] + getitem_1036 = split_102[4] + getitem_1037 = split_102[5] + getitem_1038 = split_102[6] + getitem_1039 = split_102[7]; split_102 = None + cat_94 = torch.ops.aten.cat.default([getitem_1032, getitem_1033, getitem_1034, getitem_1035, getitem_1036, getitem_1037, getitem_1038, getitem_1039]); getitem_1032 = getitem_1033 = getitem_1034 = getitem_1035 = getitem_1036 = getitem_1037 = getitem_1038 = getitem_1039 = None + reduce_scatter_tensor_47 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_94, 'sum', 8, '1'); cat_94 = None + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47) + add_93 = torch.ops.aten.add.Tensor(add_91, wait_tensor_307); wait_tensor_307 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16) + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 32, '0'); convert_element_type_779 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32) + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = rsqrt_47 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_308); mul_188 = wait_tensor_308 = None + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_781, 8, '1'); convert_element_type_781 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + split_103 = torch.ops.aten.split.Tensor(wait_tensor_309, 2); wait_tensor_309 = None + getitem_1040 = split_103[0] + getitem_1041 = split_103[1] + getitem_1042 = split_103[2] + getitem_1043 = split_103[3] + getitem_1044 = split_103[4] + getitem_1045 = split_103[5] + getitem_1046 = split_103[6] + getitem_1047 = split_103[7]; split_103 = None + cat_95 = torch.ops.aten.cat.default([getitem_1040, getitem_1041, getitem_1042, getitem_1043, getitem_1044, getitem_1045, getitem_1046, getitem_1047], 1); getitem_1040 = getitem_1041 = getitem_1042 = getitem_1043 = getitem_1044 = getitem_1045 = getitem_1046 = getitem_1047 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16) + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 32, '0'); convert_element_type_782 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_310, [1, 0]); wait_tensor_310 = None + view_1716 = torch.ops.aten.view.default(cat_95, [16384, 4096]); cat_95 = None + mm_165 = torch.ops.aten.mm.default(view_1716, permute_261); permute_261 = None + view_1717 = torch.ops.aten.view.default(mm_165, [2, 8192, 1792]) + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_1717, torch.float32); view_1717 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); convert_element_type_785 = sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16) + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 32, '0'); convert_element_type_787 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_311, [1, 0]); wait_tensor_311 = None + mm_166 = torch.ops.aten.mm.default(view_1716, permute_262); view_1716 = permute_262 = None + view_1724 = torch.ops.aten.view.default(mm_166, [2, 8192, 1792]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_1724); convert_element_type_786 = view_1724 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16) + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 32, '0'); convert_element_type_790 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_312, [1, 0]); wait_tensor_312 = None + view_1731 = torch.ops.aten.view.default(mul_191, [16384, 1792]); mul_191 = None + mm_167 = torch.ops.aten.mm.default(view_1731, permute_263); view_1731 = permute_263 = None + view_1732 = torch.ops.aten.view.default(mm_167, [2, 8192, 4096]); mm_167 = None + split_104 = torch.ops.aten.split.Tensor(view_1732, 1024, 1); view_1732 = None + getitem_1048 = split_104[0] + getitem_1049 = split_104[1] + getitem_1050 = split_104[2] + getitem_1051 = split_104[3] + getitem_1052 = split_104[4] + getitem_1053 = split_104[5] + getitem_1054 = split_104[6] + getitem_1055 = split_104[7]; split_104 = None + cat_96 = torch.ops.aten.cat.default([getitem_1048, getitem_1049, getitem_1050, getitem_1051, getitem_1052, getitem_1053, getitem_1054, getitem_1055]); getitem_1048 = getitem_1049 = getitem_1050 = getitem_1051 = getitem_1052 = getitem_1053 = getitem_1054 = getitem_1055 = None + reduce_scatter_tensor_48 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_96, 'sum', 8, '1'); cat_96 = None + wait_tensor_313 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_48); reduce_scatter_tensor_48 = None + add_95 = torch.ops.aten.add.Tensor(add_93, wait_tensor_313); add_93 = wait_tensor_313 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16) + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 32, '0'); convert_element_type_793 = None + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32) + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = rsqrt_48 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_314); mul_192 = wait_tensor_314 = None + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_795, 8, '1'); convert_element_type_795 = None + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + split_105 = torch.ops.aten.split.Tensor(wait_tensor_315, 2); wait_tensor_315 = None + getitem_1056 = split_105[0] + getitem_1057 = split_105[1] + getitem_1058 = split_105[2] + getitem_1059 = split_105[3] + getitem_1060 = split_105[4] + getitem_1061 = split_105[5] + getitem_1062 = split_105[6] + getitem_1063 = split_105[7]; split_105 = None + cat_97 = torch.ops.aten.cat.default([getitem_1056, getitem_1057, getitem_1058, getitem_1059, getitem_1060, getitem_1061, getitem_1062, getitem_1063], 1); getitem_1056 = getitem_1057 = getitem_1058 = getitem_1059 = getitem_1060 = getitem_1061 = getitem_1062 = getitem_1063 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16) + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 32, '0'); convert_element_type_796 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_316, [1, 0]); wait_tensor_316 = None + view_1743 = torch.ops.aten.view.default(cat_97, [16384, 4096]); cat_97 = None + mm_168 = torch.ops.aten.mm.default(view_1743, permute_264); permute_264 = None + view_1744 = torch.ops.aten.view.default(mm_168, [2, 8192, 512]) + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16) + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 32, '0'); convert_element_type_799 = None + wait_tensor_317 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_317, [1, 0]); wait_tensor_317 = None + mm_169 = torch.ops.aten.mm.default(view_1743, permute_265); permute_265 = None + view_1751 = torch.ops.aten.view.default(mm_169, [2, 8192, 128]); mm_169 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16) + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 32, '0'); convert_element_type_802 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_318, [1, 0]); wait_tensor_318 = None + mm_170 = torch.ops.aten.mm.default(view_1743, permute_266); view_1743 = permute_266 = None + view_1758 = torch.ops.aten.view.default(mm_170, [2, 8192, 128]) + view_1760 = torch.ops.aten.view.default(view_1744, [2, 8192, -1, 128]); view_1744 = None + view_1761 = torch.ops.aten.view.default(view_1751, [2, 8192, -1, 128]); view_1751 = None + view_1762 = torch.ops.aten.view.default(view_1758, [2, 8192, -1, 128]); view_1758 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_1760, torch.float32); view_1760 = None + view_1763 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 4, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_1763); view_1763 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_1761, torch.float32); view_1761 = None + view_1764 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 1, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_1764); view_1764 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_37); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_1766 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 4, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_37); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_1767 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 1, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_1766, torch.bfloat16); view_1766 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_1767, torch.bfloat16); view_1767 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 1, 4, 128]); unsqueeze_48 = None + view_1768 = torch.ops.aten.view.default(expand_48, [2, 8192, 4, 128]); expand_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_1762, 3); view_1762 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 1, 4, 128]); unsqueeze_49 = None + view_1769 = torch.ops.aten.view.default(expand_49, [2, 8192, 4, 128]); expand_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_1768, [0, 2, 1, 3]); view_1768 = None + permute_269 = torch.ops.aten.permute.default(view_1769, [0, 2, 1, 3]); view_1769 = None + _scaled_dot_product_cudnn_attention_24 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_267, permute_268, permute_269, None, True, 0.0, True); permute_267 = permute_268 = permute_269 = None + getitem_1064 = _scaled_dot_product_cudnn_attention_24[0] + getitem_1065 = _scaled_dot_product_cudnn_attention_24[1] + getitem_1070 = _scaled_dot_product_cudnn_attention_24[6] + getitem_1071 = _scaled_dot_product_cudnn_attention_24[7]; _scaled_dot_product_cudnn_attention_24 = None + permute_270 = torch.ops.aten.permute.default(getitem_1064, [0, 2, 1, 3]) + view_1770 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16) + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 32, '0'); convert_element_type_809 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_319, [1, 0]); wait_tensor_319 = None + view_1776 = torch.ops.aten.view.default(view_1770, [16384, 512]); view_1770 = None + mm_171 = torch.ops.aten.mm.default(view_1776, permute_271); view_1776 = permute_271 = None + view_1777 = torch.ops.aten.view.default(mm_171, [2, 8192, 4096]); mm_171 = None + split_106 = torch.ops.aten.split.Tensor(view_1777, 1024, 1); view_1777 = None + getitem_1073 = split_106[0] + getitem_1074 = split_106[1] + getitem_1075 = split_106[2] + getitem_1076 = split_106[3] + getitem_1077 = split_106[4] + getitem_1078 = split_106[5] + getitem_1079 = split_106[6] + getitem_1080 = split_106[7]; split_106 = None + cat_98 = torch.ops.aten.cat.default([getitem_1073, getitem_1074, getitem_1075, getitem_1076, getitem_1077, getitem_1078, getitem_1079, getitem_1080]); getitem_1073 = getitem_1074 = getitem_1075 = getitem_1076 = getitem_1077 = getitem_1078 = getitem_1079 = getitem_1080 = None + reduce_scatter_tensor_49 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_98, 'sum', 8, '1'); cat_98 = None + wait_tensor_320 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49) + add_97 = torch.ops.aten.add.Tensor(add_95, wait_tensor_320); wait_tensor_320 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16) + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 32, '0'); convert_element_type_812 = None + wait_tensor_321 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32) + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = rsqrt_49 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_321); mul_196 = wait_tensor_321 = None + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_814, 8, '1'); convert_element_type_814 = None + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + split_107 = torch.ops.aten.split.Tensor(wait_tensor_322, 2); wait_tensor_322 = None + getitem_1081 = split_107[0] + getitem_1082 = split_107[1] + getitem_1083 = split_107[2] + getitem_1084 = split_107[3] + getitem_1085 = split_107[4] + getitem_1086 = split_107[5] + getitem_1087 = split_107[6] + getitem_1088 = split_107[7]; split_107 = None + cat_99 = torch.ops.aten.cat.default([getitem_1081, getitem_1082, getitem_1083, getitem_1084, getitem_1085, getitem_1086, getitem_1087, getitem_1088], 1); getitem_1081 = getitem_1082 = getitem_1083 = getitem_1084 = getitem_1085 = getitem_1086 = getitem_1087 = getitem_1088 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16) + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 32, '0'); convert_element_type_815 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_323, [1, 0]); wait_tensor_323 = None + view_1788 = torch.ops.aten.view.default(cat_99, [16384, 4096]); cat_99 = None + mm_172 = torch.ops.aten.mm.default(view_1788, permute_272); permute_272 = None + view_1789 = torch.ops.aten.view.default(mm_172, [2, 8192, 1792]) + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_1789, torch.float32); view_1789 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); convert_element_type_818 = sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16) + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 32, '0'); convert_element_type_820 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_324, [1, 0]); wait_tensor_324 = None + mm_173 = torch.ops.aten.mm.default(view_1788, permute_273); view_1788 = permute_273 = None + view_1796 = torch.ops.aten.view.default(mm_173, [2, 8192, 1792]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_1796); convert_element_type_819 = view_1796 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16) + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 32, '0'); convert_element_type_823 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_325, [1, 0]); wait_tensor_325 = None + view_1803 = torch.ops.aten.view.default(mul_199, [16384, 1792]); mul_199 = None + mm_174 = torch.ops.aten.mm.default(view_1803, permute_274); view_1803 = permute_274 = None + view_1804 = torch.ops.aten.view.default(mm_174, [2, 8192, 4096]); mm_174 = None + split_108 = torch.ops.aten.split.Tensor(view_1804, 1024, 1); view_1804 = None + getitem_1089 = split_108[0] + getitem_1090 = split_108[1] + getitem_1091 = split_108[2] + getitem_1092 = split_108[3] + getitem_1093 = split_108[4] + getitem_1094 = split_108[5] + getitem_1095 = split_108[6] + getitem_1096 = split_108[7]; split_108 = None + cat_100 = torch.ops.aten.cat.default([getitem_1089, getitem_1090, getitem_1091, getitem_1092, getitem_1093, getitem_1094, getitem_1095, getitem_1096]); getitem_1089 = getitem_1090 = getitem_1091 = getitem_1092 = getitem_1093 = getitem_1094 = getitem_1095 = getitem_1096 = None + reduce_scatter_tensor_50 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_100, 'sum', 8, '1'); cat_100 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_50); reduce_scatter_tensor_50 = None + add_99 = torch.ops.aten.add.Tensor(add_97, wait_tensor_326); add_97 = wait_tensor_326 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16) + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 32, '0'); convert_element_type_826 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32) + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = rsqrt_50 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_327); mul_200 = wait_tensor_327 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_828, 8, '1'); convert_element_type_828 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + split_109 = torch.ops.aten.split.Tensor(wait_tensor_328, 2); wait_tensor_328 = None + getitem_1097 = split_109[0] + getitem_1098 = split_109[1] + getitem_1099 = split_109[2] + getitem_1100 = split_109[3] + getitem_1101 = split_109[4] + getitem_1102 = split_109[5] + getitem_1103 = split_109[6] + getitem_1104 = split_109[7]; split_109 = None + cat_101 = torch.ops.aten.cat.default([getitem_1097, getitem_1098, getitem_1099, getitem_1100, getitem_1101, getitem_1102, getitem_1103, getitem_1104], 1); getitem_1097 = getitem_1098 = getitem_1099 = getitem_1100 = getitem_1101 = getitem_1102 = getitem_1103 = getitem_1104 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16) + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 32, '0'); convert_element_type_829 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_329, [1, 0]); wait_tensor_329 = None + view_1815 = torch.ops.aten.view.default(cat_101, [16384, 4096]); cat_101 = None + mm_175 = torch.ops.aten.mm.default(view_1815, permute_275); permute_275 = None + view_1816 = torch.ops.aten.view.default(mm_175, [2, 8192, 512]) + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16) + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 32, '0'); convert_element_type_832 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_330, [1, 0]); wait_tensor_330 = None + mm_176 = torch.ops.aten.mm.default(view_1815, permute_276); permute_276 = None + view_1823 = torch.ops.aten.view.default(mm_176, [2, 8192, 128]); mm_176 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16) + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 32, '0'); convert_element_type_835 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_331, [1, 0]); wait_tensor_331 = None + mm_177 = torch.ops.aten.mm.default(view_1815, permute_277); view_1815 = permute_277 = None + view_1830 = torch.ops.aten.view.default(mm_177, [2, 8192, 128]) + view_1832 = torch.ops.aten.view.default(view_1816, [2, 8192, -1, 128]); view_1816 = None + view_1833 = torch.ops.aten.view.default(view_1823, [2, 8192, -1, 128]); view_1823 = None + view_1834 = torch.ops.aten.view.default(view_1830, [2, 8192, -1, 128]); view_1830 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_1832, torch.float32); view_1832 = None + view_1835 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 4, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_1835); view_1835 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_1833, torch.float32); view_1833 = None + view_1836 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 1, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_1836); view_1836 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_37); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_1838 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 4, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_37); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_1839 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 1, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_1838, torch.bfloat16); view_1838 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_1839, torch.bfloat16); view_1839 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 1, 4, 128]); unsqueeze_50 = None + view_1840 = torch.ops.aten.view.default(expand_50, [2, 8192, 4, 128]); expand_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_1834, 3); view_1834 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 1, 4, 128]); unsqueeze_51 = None + view_1841 = torch.ops.aten.view.default(expand_51, [2, 8192, 4, 128]); expand_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_1840, [0, 2, 1, 3]); view_1840 = None + permute_280 = torch.ops.aten.permute.default(view_1841, [0, 2, 1, 3]); view_1841 = None + _scaled_dot_product_cudnn_attention_25 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_278, permute_279, permute_280, None, True, 0.0, True); permute_278 = permute_279 = permute_280 = None + getitem_1105 = _scaled_dot_product_cudnn_attention_25[0] + getitem_1106 = _scaled_dot_product_cudnn_attention_25[1] + getitem_1111 = _scaled_dot_product_cudnn_attention_25[6] + getitem_1112 = _scaled_dot_product_cudnn_attention_25[7]; _scaled_dot_product_cudnn_attention_25 = None + permute_281 = torch.ops.aten.permute.default(getitem_1105, [0, 2, 1, 3]) + view_1842 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16) + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 32, '0'); convert_element_type_842 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_332, [1, 0]); wait_tensor_332 = None + view_1848 = torch.ops.aten.view.default(view_1842, [16384, 512]); view_1842 = None + mm_178 = torch.ops.aten.mm.default(view_1848, permute_282); view_1848 = permute_282 = None + view_1849 = torch.ops.aten.view.default(mm_178, [2, 8192, 4096]); mm_178 = None + split_110 = torch.ops.aten.split.Tensor(view_1849, 1024, 1); view_1849 = None + getitem_1114 = split_110[0] + getitem_1115 = split_110[1] + getitem_1116 = split_110[2] + getitem_1117 = split_110[3] + getitem_1118 = split_110[4] + getitem_1119 = split_110[5] + getitem_1120 = split_110[6] + getitem_1121 = split_110[7]; split_110 = None + cat_102 = torch.ops.aten.cat.default([getitem_1114, getitem_1115, getitem_1116, getitem_1117, getitem_1118, getitem_1119, getitem_1120, getitem_1121]); getitem_1114 = getitem_1115 = getitem_1116 = getitem_1117 = getitem_1118 = getitem_1119 = getitem_1120 = getitem_1121 = None + reduce_scatter_tensor_51 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_102, 'sum', 8, '1'); cat_102 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51) + add_101 = torch.ops.aten.add.Tensor(add_99, wait_tensor_333); wait_tensor_333 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16) + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 32, '0'); convert_element_type_845 = None + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32) + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = rsqrt_51 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_334); mul_204 = wait_tensor_334 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_847, 8, '1'); convert_element_type_847 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + split_111 = torch.ops.aten.split.Tensor(wait_tensor_335, 2); wait_tensor_335 = None + getitem_1122 = split_111[0] + getitem_1123 = split_111[1] + getitem_1124 = split_111[2] + getitem_1125 = split_111[3] + getitem_1126 = split_111[4] + getitem_1127 = split_111[5] + getitem_1128 = split_111[6] + getitem_1129 = split_111[7]; split_111 = None + cat_103 = torch.ops.aten.cat.default([getitem_1122, getitem_1123, getitem_1124, getitem_1125, getitem_1126, getitem_1127, getitem_1128, getitem_1129], 1); getitem_1122 = getitem_1123 = getitem_1124 = getitem_1125 = getitem_1126 = getitem_1127 = getitem_1128 = getitem_1129 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16) + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 32, '0'); convert_element_type_848 = None + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_336, [1, 0]); wait_tensor_336 = None + view_1860 = torch.ops.aten.view.default(cat_103, [16384, 4096]); cat_103 = None + mm_179 = torch.ops.aten.mm.default(view_1860, permute_283); permute_283 = None + view_1861 = torch.ops.aten.view.default(mm_179, [2, 8192, 1792]) + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_1861, torch.float32); view_1861 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); convert_element_type_851 = sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16) + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 32, '0'); convert_element_type_853 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_337, [1, 0]); wait_tensor_337 = None + mm_180 = torch.ops.aten.mm.default(view_1860, permute_284); view_1860 = permute_284 = None + view_1868 = torch.ops.aten.view.default(mm_180, [2, 8192, 1792]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_1868); convert_element_type_852 = view_1868 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16) + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 32, '0'); convert_element_type_856 = None + wait_tensor_338 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_338, [1, 0]); wait_tensor_338 = None + view_1875 = torch.ops.aten.view.default(mul_207, [16384, 1792]); mul_207 = None + mm_181 = torch.ops.aten.mm.default(view_1875, permute_285); view_1875 = permute_285 = None + view_1876 = torch.ops.aten.view.default(mm_181, [2, 8192, 4096]); mm_181 = None + split_112 = torch.ops.aten.split.Tensor(view_1876, 1024, 1); view_1876 = None + getitem_1130 = split_112[0] + getitem_1131 = split_112[1] + getitem_1132 = split_112[2] + getitem_1133 = split_112[3] + getitem_1134 = split_112[4] + getitem_1135 = split_112[5] + getitem_1136 = split_112[6] + getitem_1137 = split_112[7]; split_112 = None + cat_104 = torch.ops.aten.cat.default([getitem_1130, getitem_1131, getitem_1132, getitem_1133, getitem_1134, getitem_1135, getitem_1136, getitem_1137]); getitem_1130 = getitem_1131 = getitem_1132 = getitem_1133 = getitem_1134 = getitem_1135 = getitem_1136 = getitem_1137 = None + reduce_scatter_tensor_52 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_104, 'sum', 8, '1'); cat_104 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_52); reduce_scatter_tensor_52 = None + add_103 = torch.ops.aten.add.Tensor(add_101, wait_tensor_339); add_101 = wait_tensor_339 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16) + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 32, '0'); convert_element_type_859 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32) + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = rsqrt_52 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_340); mul_208 = wait_tensor_340 = None + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_861, 8, '1'); convert_element_type_861 = None + wait_tensor_341 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + split_113 = torch.ops.aten.split.Tensor(wait_tensor_341, 2); wait_tensor_341 = None + getitem_1138 = split_113[0] + getitem_1139 = split_113[1] + getitem_1140 = split_113[2] + getitem_1141 = split_113[3] + getitem_1142 = split_113[4] + getitem_1143 = split_113[5] + getitem_1144 = split_113[6] + getitem_1145 = split_113[7]; split_113 = None + cat_105 = torch.ops.aten.cat.default([getitem_1138, getitem_1139, getitem_1140, getitem_1141, getitem_1142, getitem_1143, getitem_1144, getitem_1145], 1); getitem_1138 = getitem_1139 = getitem_1140 = getitem_1141 = getitem_1142 = getitem_1143 = getitem_1144 = getitem_1145 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16) + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 32, '0'); convert_element_type_862 = None + wait_tensor_342 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_342, [1, 0]); wait_tensor_342 = None + view_1887 = torch.ops.aten.view.default(cat_105, [16384, 4096]); cat_105 = None + mm_182 = torch.ops.aten.mm.default(view_1887, permute_286); permute_286 = None + view_1888 = torch.ops.aten.view.default(mm_182, [2, 8192, 512]) + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16) + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 32, '0'); convert_element_type_865 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_343, [1, 0]); wait_tensor_343 = None + mm_183 = torch.ops.aten.mm.default(view_1887, permute_287); permute_287 = None + view_1895 = torch.ops.aten.view.default(mm_183, [2, 8192, 128]); mm_183 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16) + all_gather_into_tensor_291 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 32, '0'); convert_element_type_868 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_291); all_gather_into_tensor_291 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_344, [1, 0]); wait_tensor_344 = None + mm_184 = torch.ops.aten.mm.default(view_1887, permute_288); view_1887 = permute_288 = None + view_1902 = torch.ops.aten.view.default(mm_184, [2, 8192, 128]) + view_1904 = torch.ops.aten.view.default(view_1888, [2, 8192, -1, 128]); view_1888 = None + view_1905 = torch.ops.aten.view.default(view_1895, [2, 8192, -1, 128]); view_1895 = None + view_1906 = torch.ops.aten.view.default(view_1902, [2, 8192, -1, 128]); view_1902 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_1904, torch.float32); view_1904 = None + view_1907 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 4, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_1907); view_1907 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_1905, torch.float32); view_1905 = None + view_1908 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 1, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_1908); view_1908 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_37); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_1910 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 4, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_37); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_1911 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 1, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_1910, torch.bfloat16); view_1910 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_1911, torch.bfloat16); view_1911 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 1, 4, 128]); unsqueeze_52 = None + view_1912 = torch.ops.aten.view.default(expand_52, [2, 8192, 4, 128]); expand_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_1906, 3); view_1906 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 1, 4, 128]); unsqueeze_53 = None + view_1913 = torch.ops.aten.view.default(expand_53, [2, 8192, 4, 128]); expand_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_1912, [0, 2, 1, 3]); view_1912 = None + permute_291 = torch.ops.aten.permute.default(view_1913, [0, 2, 1, 3]); view_1913 = None + _scaled_dot_product_cudnn_attention_26 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_289, permute_290, permute_291, None, True, 0.0, True); permute_289 = permute_290 = permute_291 = None + getitem_1146 = _scaled_dot_product_cudnn_attention_26[0] + getitem_1147 = _scaled_dot_product_cudnn_attention_26[1] + getitem_1152 = _scaled_dot_product_cudnn_attention_26[6] + getitem_1153 = _scaled_dot_product_cudnn_attention_26[7]; _scaled_dot_product_cudnn_attention_26 = None + permute_292 = torch.ops.aten.permute.default(getitem_1146, [0, 2, 1, 3]) + view_1914 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16) + all_gather_into_tensor_292 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 32, '0'); convert_element_type_875 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_292); all_gather_into_tensor_292 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_345, [1, 0]); wait_tensor_345 = None + view_1920 = torch.ops.aten.view.default(view_1914, [16384, 512]); view_1914 = None + mm_185 = torch.ops.aten.mm.default(view_1920, permute_293); view_1920 = permute_293 = None + view_1921 = torch.ops.aten.view.default(mm_185, [2, 8192, 4096]); mm_185 = None + split_114 = torch.ops.aten.split.Tensor(view_1921, 1024, 1); view_1921 = None + getitem_1155 = split_114[0] + getitem_1156 = split_114[1] + getitem_1157 = split_114[2] + getitem_1158 = split_114[3] + getitem_1159 = split_114[4] + getitem_1160 = split_114[5] + getitem_1161 = split_114[6] + getitem_1162 = split_114[7]; split_114 = None + cat_106 = torch.ops.aten.cat.default([getitem_1155, getitem_1156, getitem_1157, getitem_1158, getitem_1159, getitem_1160, getitem_1161, getitem_1162]); getitem_1155 = getitem_1156 = getitem_1157 = getitem_1158 = getitem_1159 = getitem_1160 = getitem_1161 = getitem_1162 = None + reduce_scatter_tensor_53 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_106, 'sum', 8, '1'); cat_106 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53) + add_105 = torch.ops.aten.add.Tensor(add_103, wait_tensor_346); wait_tensor_346 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16) + all_gather_into_tensor_293 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 32, '0'); convert_element_type_878 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_293); all_gather_into_tensor_293 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32) + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = rsqrt_53 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_347); mul_212 = wait_tensor_347 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + all_gather_into_tensor_294 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_880, 8, '1'); convert_element_type_880 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_294); all_gather_into_tensor_294 = None + split_115 = torch.ops.aten.split.Tensor(wait_tensor_348, 2); wait_tensor_348 = None + getitem_1163 = split_115[0] + getitem_1164 = split_115[1] + getitem_1165 = split_115[2] + getitem_1166 = split_115[3] + getitem_1167 = split_115[4] + getitem_1168 = split_115[5] + getitem_1169 = split_115[6] + getitem_1170 = split_115[7]; split_115 = None + cat_107 = torch.ops.aten.cat.default([getitem_1163, getitem_1164, getitem_1165, getitem_1166, getitem_1167, getitem_1168, getitem_1169, getitem_1170], 1); getitem_1163 = getitem_1164 = getitem_1165 = getitem_1166 = getitem_1167 = getitem_1168 = getitem_1169 = getitem_1170 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16) + all_gather_into_tensor_295 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 32, '0'); convert_element_type_881 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_295); all_gather_into_tensor_295 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_349, [1, 0]); wait_tensor_349 = None + view_1932 = torch.ops.aten.view.default(cat_107, [16384, 4096]); cat_107 = None + mm_186 = torch.ops.aten.mm.default(view_1932, permute_294); permute_294 = None + view_1933 = torch.ops.aten.view.default(mm_186, [2, 8192, 1792]) + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_1933, torch.float32); view_1933 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); convert_element_type_884 = sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16) + all_gather_into_tensor_296 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 32, '0'); convert_element_type_886 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_296); all_gather_into_tensor_296 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_350, [1, 0]); wait_tensor_350 = None + mm_187 = torch.ops.aten.mm.default(view_1932, permute_295); view_1932 = permute_295 = None + view_1940 = torch.ops.aten.view.default(mm_187, [2, 8192, 1792]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_1940); convert_element_type_885 = view_1940 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16) + all_gather_into_tensor_297 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 32, '0'); convert_element_type_889 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_297); all_gather_into_tensor_297 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_351, [1, 0]); wait_tensor_351 = None + view_1947 = torch.ops.aten.view.default(mul_215, [16384, 1792]); mul_215 = None + mm_188 = torch.ops.aten.mm.default(view_1947, permute_296); view_1947 = permute_296 = None + view_1948 = torch.ops.aten.view.default(mm_188, [2, 8192, 4096]); mm_188 = None + split_116 = torch.ops.aten.split.Tensor(view_1948, 1024, 1); view_1948 = None + getitem_1171 = split_116[0] + getitem_1172 = split_116[1] + getitem_1173 = split_116[2] + getitem_1174 = split_116[3] + getitem_1175 = split_116[4] + getitem_1176 = split_116[5] + getitem_1177 = split_116[6] + getitem_1178 = split_116[7]; split_116 = None + cat_108 = torch.ops.aten.cat.default([getitem_1171, getitem_1172, getitem_1173, getitem_1174, getitem_1175, getitem_1176, getitem_1177, getitem_1178]); getitem_1171 = getitem_1172 = getitem_1173 = getitem_1174 = getitem_1175 = getitem_1176 = getitem_1177 = getitem_1178 = None + reduce_scatter_tensor_54 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_108, 'sum', 8, '1'); cat_108 = None + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_54); reduce_scatter_tensor_54 = None + add_107 = torch.ops.aten.add.Tensor(add_105, wait_tensor_352); add_105 = wait_tensor_352 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16) + all_gather_into_tensor_298 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 32, '0'); convert_element_type_892 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_298); all_gather_into_tensor_298 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32) + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = rsqrt_54 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_353); mul_216 = wait_tensor_353 = None + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + all_gather_into_tensor_299 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_894, 8, '1'); convert_element_type_894 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_299); all_gather_into_tensor_299 = None + split_117 = torch.ops.aten.split.Tensor(wait_tensor_354, 2); wait_tensor_354 = None + getitem_1179 = split_117[0] + getitem_1180 = split_117[1] + getitem_1181 = split_117[2] + getitem_1182 = split_117[3] + getitem_1183 = split_117[4] + getitem_1184 = split_117[5] + getitem_1185 = split_117[6] + getitem_1186 = split_117[7]; split_117 = None + cat_109 = torch.ops.aten.cat.default([getitem_1179, getitem_1180, getitem_1181, getitem_1182, getitem_1183, getitem_1184, getitem_1185, getitem_1186], 1); getitem_1179 = getitem_1180 = getitem_1181 = getitem_1182 = getitem_1183 = getitem_1184 = getitem_1185 = getitem_1186 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16) + all_gather_into_tensor_300 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 32, '0'); convert_element_type_895 = None + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_300); all_gather_into_tensor_300 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_355, [1, 0]); wait_tensor_355 = None + view_1959 = torch.ops.aten.view.default(cat_109, [16384, 4096]); cat_109 = None + mm_189 = torch.ops.aten.mm.default(view_1959, permute_297); permute_297 = None + view_1960 = torch.ops.aten.view.default(mm_189, [2, 8192, 512]) + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16) + all_gather_into_tensor_301 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 32, '0'); convert_element_type_898 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_301); all_gather_into_tensor_301 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_356, [1, 0]); wait_tensor_356 = None + mm_190 = torch.ops.aten.mm.default(view_1959, permute_298); permute_298 = None + view_1967 = torch.ops.aten.view.default(mm_190, [2, 8192, 128]); mm_190 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16) + all_gather_into_tensor_302 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 32, '0'); convert_element_type_901 = None + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_302); all_gather_into_tensor_302 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_357, [1, 0]); wait_tensor_357 = None + mm_191 = torch.ops.aten.mm.default(view_1959, permute_299); view_1959 = permute_299 = None + view_1974 = torch.ops.aten.view.default(mm_191, [2, 8192, 128]) + view_1976 = torch.ops.aten.view.default(view_1960, [2, 8192, -1, 128]); view_1960 = None + view_1977 = torch.ops.aten.view.default(view_1967, [2, 8192, -1, 128]); view_1967 = None + view_1978 = torch.ops.aten.view.default(view_1974, [2, 8192, -1, 128]); view_1974 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_1976, torch.float32); view_1976 = None + view_1979 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 4, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_1979); view_1979 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_1977, torch.float32); view_1977 = None + view_1980 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 1, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_1980); view_1980 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_37); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_1982 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 4, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_37); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_1983 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 1, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_1982, torch.bfloat16); view_1982 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_1983, torch.bfloat16); view_1983 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 1, 4, 128]); unsqueeze_54 = None + view_1984 = torch.ops.aten.view.default(expand_54, [2, 8192, 4, 128]); expand_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_1978, 3); view_1978 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 1, 4, 128]); unsqueeze_55 = None + view_1985 = torch.ops.aten.view.default(expand_55, [2, 8192, 4, 128]); expand_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_1984, [0, 2, 1, 3]); view_1984 = None + permute_302 = torch.ops.aten.permute.default(view_1985, [0, 2, 1, 3]); view_1985 = None + _scaled_dot_product_cudnn_attention_27 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_300, permute_301, permute_302, None, True, 0.0, True); permute_300 = permute_301 = permute_302 = None + getitem_1187 = _scaled_dot_product_cudnn_attention_27[0] + getitem_1188 = _scaled_dot_product_cudnn_attention_27[1] + getitem_1193 = _scaled_dot_product_cudnn_attention_27[6] + getitem_1194 = _scaled_dot_product_cudnn_attention_27[7]; _scaled_dot_product_cudnn_attention_27 = None + permute_303 = torch.ops.aten.permute.default(getitem_1187, [0, 2, 1, 3]) + view_1986 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16) + all_gather_into_tensor_303 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 32, '0'); convert_element_type_908 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_303); all_gather_into_tensor_303 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_358, [1, 0]); wait_tensor_358 = None + view_1992 = torch.ops.aten.view.default(view_1986, [16384, 512]); view_1986 = None + mm_192 = torch.ops.aten.mm.default(view_1992, permute_304); view_1992 = permute_304 = None + view_1993 = torch.ops.aten.view.default(mm_192, [2, 8192, 4096]); mm_192 = None + split_118 = torch.ops.aten.split.Tensor(view_1993, 1024, 1); view_1993 = None + getitem_1196 = split_118[0] + getitem_1197 = split_118[1] + getitem_1198 = split_118[2] + getitem_1199 = split_118[3] + getitem_1200 = split_118[4] + getitem_1201 = split_118[5] + getitem_1202 = split_118[6] + getitem_1203 = split_118[7]; split_118 = None + cat_110 = torch.ops.aten.cat.default([getitem_1196, getitem_1197, getitem_1198, getitem_1199, getitem_1200, getitem_1201, getitem_1202, getitem_1203]); getitem_1196 = getitem_1197 = getitem_1198 = getitem_1199 = getitem_1200 = getitem_1201 = getitem_1202 = getitem_1203 = None + reduce_scatter_tensor_55 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_110, 'sum', 8, '1'); cat_110 = None + wait_tensor_359 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55) + add_109 = torch.ops.aten.add.Tensor(add_107, wait_tensor_359); wait_tensor_359 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16) + all_gather_into_tensor_304 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 32, '0'); convert_element_type_911 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_304); all_gather_into_tensor_304 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32) + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = rsqrt_55 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_360); mul_220 = wait_tensor_360 = None + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + all_gather_into_tensor_305 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_913, 8, '1'); convert_element_type_913 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_305); all_gather_into_tensor_305 = None + split_119 = torch.ops.aten.split.Tensor(wait_tensor_361, 2); wait_tensor_361 = None + getitem_1204 = split_119[0] + getitem_1205 = split_119[1] + getitem_1206 = split_119[2] + getitem_1207 = split_119[3] + getitem_1208 = split_119[4] + getitem_1209 = split_119[5] + getitem_1210 = split_119[6] + getitem_1211 = split_119[7]; split_119 = None + cat_111 = torch.ops.aten.cat.default([getitem_1204, getitem_1205, getitem_1206, getitem_1207, getitem_1208, getitem_1209, getitem_1210, getitem_1211], 1); getitem_1204 = getitem_1205 = getitem_1206 = getitem_1207 = getitem_1208 = getitem_1209 = getitem_1210 = getitem_1211 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16) + all_gather_into_tensor_306 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 32, '0'); convert_element_type_914 = None + wait_tensor_362 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_306); all_gather_into_tensor_306 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_362, [1, 0]); wait_tensor_362 = None + view_2004 = torch.ops.aten.view.default(cat_111, [16384, 4096]); cat_111 = None + mm_193 = torch.ops.aten.mm.default(view_2004, permute_305); permute_305 = None + view_2005 = torch.ops.aten.view.default(mm_193, [2, 8192, 1792]) + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_2005, torch.float32); view_2005 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); convert_element_type_917 = sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16) + all_gather_into_tensor_307 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 32, '0'); convert_element_type_919 = None + wait_tensor_363 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_307); all_gather_into_tensor_307 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_363, [1, 0]); wait_tensor_363 = None + mm_194 = torch.ops.aten.mm.default(view_2004, permute_306); view_2004 = permute_306 = None + view_2012 = torch.ops.aten.view.default(mm_194, [2, 8192, 1792]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_2012); convert_element_type_918 = view_2012 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16) + all_gather_into_tensor_308 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 32, '0'); convert_element_type_922 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_308); all_gather_into_tensor_308 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_364, [1, 0]); wait_tensor_364 = None + view_2019 = torch.ops.aten.view.default(mul_223, [16384, 1792]); mul_223 = None + mm_195 = torch.ops.aten.mm.default(view_2019, permute_307); view_2019 = permute_307 = None + view_2020 = torch.ops.aten.view.default(mm_195, [2, 8192, 4096]); mm_195 = None + split_120 = torch.ops.aten.split.Tensor(view_2020, 1024, 1); view_2020 = None + getitem_1212 = split_120[0] + getitem_1213 = split_120[1] + getitem_1214 = split_120[2] + getitem_1215 = split_120[3] + getitem_1216 = split_120[4] + getitem_1217 = split_120[5] + getitem_1218 = split_120[6] + getitem_1219 = split_120[7]; split_120 = None + cat_112 = torch.ops.aten.cat.default([getitem_1212, getitem_1213, getitem_1214, getitem_1215, getitem_1216, getitem_1217, getitem_1218, getitem_1219]); getitem_1212 = getitem_1213 = getitem_1214 = getitem_1215 = getitem_1216 = getitem_1217 = getitem_1218 = getitem_1219 = None + reduce_scatter_tensor_56 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_112, 'sum', 8, '1'); cat_112 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_56); reduce_scatter_tensor_56 = None + add_111 = torch.ops.aten.add.Tensor(add_109, wait_tensor_365); add_109 = wait_tensor_365 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16) + all_gather_into_tensor_309 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 32, '0'); convert_element_type_925 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_309); all_gather_into_tensor_309 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32) + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = rsqrt_56 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_366); mul_224 = wait_tensor_366 = None + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + all_gather_into_tensor_310 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_927, 8, '1'); convert_element_type_927 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_310); all_gather_into_tensor_310 = None + split_121 = torch.ops.aten.split.Tensor(wait_tensor_367, 2); wait_tensor_367 = None + getitem_1220 = split_121[0] + getitem_1221 = split_121[1] + getitem_1222 = split_121[2] + getitem_1223 = split_121[3] + getitem_1224 = split_121[4] + getitem_1225 = split_121[5] + getitem_1226 = split_121[6] + getitem_1227 = split_121[7]; split_121 = None + cat_113 = torch.ops.aten.cat.default([getitem_1220, getitem_1221, getitem_1222, getitem_1223, getitem_1224, getitem_1225, getitem_1226, getitem_1227], 1); getitem_1220 = getitem_1221 = getitem_1222 = getitem_1223 = getitem_1224 = getitem_1225 = getitem_1226 = getitem_1227 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16) + all_gather_into_tensor_311 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 32, '0'); convert_element_type_928 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_311); all_gather_into_tensor_311 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_368, [1, 0]); wait_tensor_368 = None + view_2031 = torch.ops.aten.view.default(cat_113, [16384, 4096]); cat_113 = None + mm_196 = torch.ops.aten.mm.default(view_2031, permute_308); permute_308 = None + view_2032 = torch.ops.aten.view.default(mm_196, [2, 8192, 512]) + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16) + all_gather_into_tensor_312 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 32, '0'); convert_element_type_931 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_312); all_gather_into_tensor_312 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_369, [1, 0]); wait_tensor_369 = None + mm_197 = torch.ops.aten.mm.default(view_2031, permute_309); permute_309 = None + view_2039 = torch.ops.aten.view.default(mm_197, [2, 8192, 128]); mm_197 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16) + all_gather_into_tensor_313 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 32, '0'); convert_element_type_934 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_313); all_gather_into_tensor_313 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_370, [1, 0]); wait_tensor_370 = None + mm_198 = torch.ops.aten.mm.default(view_2031, permute_310); view_2031 = permute_310 = None + view_2046 = torch.ops.aten.view.default(mm_198, [2, 8192, 128]) + view_2048 = torch.ops.aten.view.default(view_2032, [2, 8192, -1, 128]); view_2032 = None + view_2049 = torch.ops.aten.view.default(view_2039, [2, 8192, -1, 128]); view_2039 = None + view_2050 = torch.ops.aten.view.default(view_2046, [2, 8192, -1, 128]); view_2046 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_2048, torch.float32); view_2048 = None + view_2051 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 4, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_2051); view_2051 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_2049, torch.float32); view_2049 = None + view_2052 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 1, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_2052); view_2052 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_37); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_2054 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 4, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_37); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_2055 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 1, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_2054, torch.bfloat16); view_2054 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_2055, torch.bfloat16); view_2055 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 1, 4, 128]); unsqueeze_56 = None + view_2056 = torch.ops.aten.view.default(expand_56, [2, 8192, 4, 128]); expand_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_2050, 3); view_2050 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 1, 4, 128]); unsqueeze_57 = None + view_2057 = torch.ops.aten.view.default(expand_57, [2, 8192, 4, 128]); expand_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_2056, [0, 2, 1, 3]); view_2056 = None + permute_313 = torch.ops.aten.permute.default(view_2057, [0, 2, 1, 3]); view_2057 = None + _scaled_dot_product_cudnn_attention_28 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_311, permute_312, permute_313, None, True, 0.0, True); permute_311 = permute_312 = permute_313 = None + getitem_1228 = _scaled_dot_product_cudnn_attention_28[0] + getitem_1229 = _scaled_dot_product_cudnn_attention_28[1] + getitem_1234 = _scaled_dot_product_cudnn_attention_28[6] + getitem_1235 = _scaled_dot_product_cudnn_attention_28[7]; _scaled_dot_product_cudnn_attention_28 = None + permute_314 = torch.ops.aten.permute.default(getitem_1228, [0, 2, 1, 3]) + view_2058 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16) + all_gather_into_tensor_314 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 32, '0'); convert_element_type_941 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_314); all_gather_into_tensor_314 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_371, [1, 0]); wait_tensor_371 = None + view_2064 = torch.ops.aten.view.default(view_2058, [16384, 512]); view_2058 = None + mm_199 = torch.ops.aten.mm.default(view_2064, permute_315); view_2064 = permute_315 = None + view_2065 = torch.ops.aten.view.default(mm_199, [2, 8192, 4096]); mm_199 = None + split_122 = torch.ops.aten.split.Tensor(view_2065, 1024, 1); view_2065 = None + getitem_1237 = split_122[0] + getitem_1238 = split_122[1] + getitem_1239 = split_122[2] + getitem_1240 = split_122[3] + getitem_1241 = split_122[4] + getitem_1242 = split_122[5] + getitem_1243 = split_122[6] + getitem_1244 = split_122[7]; split_122 = None + cat_114 = torch.ops.aten.cat.default([getitem_1237, getitem_1238, getitem_1239, getitem_1240, getitem_1241, getitem_1242, getitem_1243, getitem_1244]); getitem_1237 = getitem_1238 = getitem_1239 = getitem_1240 = getitem_1241 = getitem_1242 = getitem_1243 = getitem_1244 = None + reduce_scatter_tensor_57 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_114, 'sum', 8, '1'); cat_114 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57) + add_113 = torch.ops.aten.add.Tensor(add_111, wait_tensor_372); wait_tensor_372 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16) + all_gather_into_tensor_315 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 32, '0'); convert_element_type_944 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_315); all_gather_into_tensor_315 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32) + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = rsqrt_57 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_373); mul_228 = wait_tensor_373 = None + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + all_gather_into_tensor_316 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_946, 8, '1'); convert_element_type_946 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_316); all_gather_into_tensor_316 = None + split_123 = torch.ops.aten.split.Tensor(wait_tensor_374, 2); wait_tensor_374 = None + getitem_1245 = split_123[0] + getitem_1246 = split_123[1] + getitem_1247 = split_123[2] + getitem_1248 = split_123[3] + getitem_1249 = split_123[4] + getitem_1250 = split_123[5] + getitem_1251 = split_123[6] + getitem_1252 = split_123[7]; split_123 = None + cat_115 = torch.ops.aten.cat.default([getitem_1245, getitem_1246, getitem_1247, getitem_1248, getitem_1249, getitem_1250, getitem_1251, getitem_1252], 1); getitem_1245 = getitem_1246 = getitem_1247 = getitem_1248 = getitem_1249 = getitem_1250 = getitem_1251 = getitem_1252 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16) + all_gather_into_tensor_317 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 32, '0'); convert_element_type_947 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_317); all_gather_into_tensor_317 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_375, [1, 0]); wait_tensor_375 = None + view_2076 = torch.ops.aten.view.default(cat_115, [16384, 4096]); cat_115 = None + mm_200 = torch.ops.aten.mm.default(view_2076, permute_316); permute_316 = None + view_2077 = torch.ops.aten.view.default(mm_200, [2, 8192, 1792]) + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_2077, torch.float32); view_2077 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); convert_element_type_950 = sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16) + all_gather_into_tensor_318 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 32, '0'); convert_element_type_952 = None + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_318); all_gather_into_tensor_318 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_376, [1, 0]); wait_tensor_376 = None + mm_201 = torch.ops.aten.mm.default(view_2076, permute_317); view_2076 = permute_317 = None + view_2084 = torch.ops.aten.view.default(mm_201, [2, 8192, 1792]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_2084); convert_element_type_951 = view_2084 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16) + all_gather_into_tensor_319 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 32, '0'); convert_element_type_955 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_319); all_gather_into_tensor_319 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_377, [1, 0]); wait_tensor_377 = None + view_2091 = torch.ops.aten.view.default(mul_231, [16384, 1792]); mul_231 = None + mm_202 = torch.ops.aten.mm.default(view_2091, permute_318); view_2091 = permute_318 = None + view_2092 = torch.ops.aten.view.default(mm_202, [2, 8192, 4096]); mm_202 = None + split_124 = torch.ops.aten.split.Tensor(view_2092, 1024, 1); view_2092 = None + getitem_1253 = split_124[0] + getitem_1254 = split_124[1] + getitem_1255 = split_124[2] + getitem_1256 = split_124[3] + getitem_1257 = split_124[4] + getitem_1258 = split_124[5] + getitem_1259 = split_124[6] + getitem_1260 = split_124[7]; split_124 = None + cat_116 = torch.ops.aten.cat.default([getitem_1253, getitem_1254, getitem_1255, getitem_1256, getitem_1257, getitem_1258, getitem_1259, getitem_1260]); getitem_1253 = getitem_1254 = getitem_1255 = getitem_1256 = getitem_1257 = getitem_1258 = getitem_1259 = getitem_1260 = None + reduce_scatter_tensor_58 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_116, 'sum', 8, '1'); cat_116 = None + wait_tensor_378 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_58); reduce_scatter_tensor_58 = None + add_115 = torch.ops.aten.add.Tensor(add_113, wait_tensor_378); add_113 = wait_tensor_378 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16) + all_gather_into_tensor_320 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 32, '0'); convert_element_type_958 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_320); all_gather_into_tensor_320 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32) + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = rsqrt_58 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_379); mul_232 = wait_tensor_379 = None + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + all_gather_into_tensor_321 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_960, 8, '1'); convert_element_type_960 = None + wait_tensor_380 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_321); all_gather_into_tensor_321 = None + split_125 = torch.ops.aten.split.Tensor(wait_tensor_380, 2); wait_tensor_380 = None + getitem_1261 = split_125[0] + getitem_1262 = split_125[1] + getitem_1263 = split_125[2] + getitem_1264 = split_125[3] + getitem_1265 = split_125[4] + getitem_1266 = split_125[5] + getitem_1267 = split_125[6] + getitem_1268 = split_125[7]; split_125 = None + cat_117 = torch.ops.aten.cat.default([getitem_1261, getitem_1262, getitem_1263, getitem_1264, getitem_1265, getitem_1266, getitem_1267, getitem_1268], 1); getitem_1261 = getitem_1262 = getitem_1263 = getitem_1264 = getitem_1265 = getitem_1266 = getitem_1267 = getitem_1268 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16) + all_gather_into_tensor_322 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 32, '0'); convert_element_type_961 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_322); all_gather_into_tensor_322 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_381, [1, 0]); wait_tensor_381 = None + view_2103 = torch.ops.aten.view.default(cat_117, [16384, 4096]); cat_117 = None + mm_203 = torch.ops.aten.mm.default(view_2103, permute_319); permute_319 = None + view_2104 = torch.ops.aten.view.default(mm_203, [2, 8192, 512]) + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16) + all_gather_into_tensor_323 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 32, '0'); convert_element_type_964 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_323); all_gather_into_tensor_323 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_382, [1, 0]); wait_tensor_382 = None + mm_204 = torch.ops.aten.mm.default(view_2103, permute_320); permute_320 = None + view_2111 = torch.ops.aten.view.default(mm_204, [2, 8192, 128]); mm_204 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16) + all_gather_into_tensor_324 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 32, '0'); convert_element_type_967 = None + wait_tensor_383 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_324); all_gather_into_tensor_324 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_383, [1, 0]); wait_tensor_383 = None + mm_205 = torch.ops.aten.mm.default(view_2103, permute_321); view_2103 = permute_321 = None + view_2118 = torch.ops.aten.view.default(mm_205, [2, 8192, 128]) + view_2120 = torch.ops.aten.view.default(view_2104, [2, 8192, -1, 128]); view_2104 = None + view_2121 = torch.ops.aten.view.default(view_2111, [2, 8192, -1, 128]); view_2111 = None + view_2122 = torch.ops.aten.view.default(view_2118, [2, 8192, -1, 128]); view_2118 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_2120, torch.float32); view_2120 = None + view_2123 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 4, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_2123); view_2123 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_2121, torch.float32); view_2121 = None + view_2124 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 1, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_2124); view_2124 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_37); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_2126 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 4, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_37); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_2127 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 1, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_2126, torch.bfloat16); view_2126 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_2127, torch.bfloat16); view_2127 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 1, 4, 128]); unsqueeze_58 = None + view_2128 = torch.ops.aten.view.default(expand_58, [2, 8192, 4, 128]); expand_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_2122, 3); view_2122 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 1, 4, 128]); unsqueeze_59 = None + view_2129 = torch.ops.aten.view.default(expand_59, [2, 8192, 4, 128]); expand_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_2128, [0, 2, 1, 3]); view_2128 = None + permute_324 = torch.ops.aten.permute.default(view_2129, [0, 2, 1, 3]); view_2129 = None + _scaled_dot_product_cudnn_attention_29 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_322, permute_323, permute_324, None, True, 0.0, True); permute_322 = permute_323 = permute_324 = None + getitem_1269 = _scaled_dot_product_cudnn_attention_29[0] + getitem_1270 = _scaled_dot_product_cudnn_attention_29[1] + getitem_1275 = _scaled_dot_product_cudnn_attention_29[6] + getitem_1276 = _scaled_dot_product_cudnn_attention_29[7]; _scaled_dot_product_cudnn_attention_29 = None + permute_325 = torch.ops.aten.permute.default(getitem_1269, [0, 2, 1, 3]) + view_2130 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16) + all_gather_into_tensor_325 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 32, '0'); convert_element_type_974 = None + wait_tensor_384 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_325); all_gather_into_tensor_325 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_384, [1, 0]); wait_tensor_384 = None + view_2136 = torch.ops.aten.view.default(view_2130, [16384, 512]); view_2130 = None + mm_206 = torch.ops.aten.mm.default(view_2136, permute_326); view_2136 = permute_326 = None + view_2137 = torch.ops.aten.view.default(mm_206, [2, 8192, 4096]); mm_206 = None + split_126 = torch.ops.aten.split.Tensor(view_2137, 1024, 1); view_2137 = None + getitem_1278 = split_126[0] + getitem_1279 = split_126[1] + getitem_1280 = split_126[2] + getitem_1281 = split_126[3] + getitem_1282 = split_126[4] + getitem_1283 = split_126[5] + getitem_1284 = split_126[6] + getitem_1285 = split_126[7]; split_126 = None + cat_118 = torch.ops.aten.cat.default([getitem_1278, getitem_1279, getitem_1280, getitem_1281, getitem_1282, getitem_1283, getitem_1284, getitem_1285]); getitem_1278 = getitem_1279 = getitem_1280 = getitem_1281 = getitem_1282 = getitem_1283 = getitem_1284 = getitem_1285 = None + reduce_scatter_tensor_59 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_118, 'sum', 8, '1'); cat_118 = None + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59) + add_117 = torch.ops.aten.add.Tensor(add_115, wait_tensor_385); wait_tensor_385 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16) + all_gather_into_tensor_326 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 32, '0'); convert_element_type_977 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_326); all_gather_into_tensor_326 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32) + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = rsqrt_59 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_386); mul_236 = wait_tensor_386 = None + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + all_gather_into_tensor_327 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_979, 8, '1'); convert_element_type_979 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_327); all_gather_into_tensor_327 = None + split_127 = torch.ops.aten.split.Tensor(wait_tensor_387, 2); wait_tensor_387 = None + getitem_1286 = split_127[0] + getitem_1287 = split_127[1] + getitem_1288 = split_127[2] + getitem_1289 = split_127[3] + getitem_1290 = split_127[4] + getitem_1291 = split_127[5] + getitem_1292 = split_127[6] + getitem_1293 = split_127[7]; split_127 = None + cat_119 = torch.ops.aten.cat.default([getitem_1286, getitem_1287, getitem_1288, getitem_1289, getitem_1290, getitem_1291, getitem_1292, getitem_1293], 1); getitem_1286 = getitem_1287 = getitem_1288 = getitem_1289 = getitem_1290 = getitem_1291 = getitem_1292 = getitem_1293 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16) + all_gather_into_tensor_328 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 32, '0'); convert_element_type_980 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_328); all_gather_into_tensor_328 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_388, [1, 0]); wait_tensor_388 = None + view_2148 = torch.ops.aten.view.default(cat_119, [16384, 4096]); cat_119 = None + mm_207 = torch.ops.aten.mm.default(view_2148, permute_327); permute_327 = None + view_2149 = torch.ops.aten.view.default(mm_207, [2, 8192, 1792]) + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_2149, torch.float32); view_2149 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); convert_element_type_983 = sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16) + all_gather_into_tensor_329 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 32, '0'); convert_element_type_985 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_329); all_gather_into_tensor_329 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_389, [1, 0]); wait_tensor_389 = None + mm_208 = torch.ops.aten.mm.default(view_2148, permute_328); view_2148 = permute_328 = None + view_2156 = torch.ops.aten.view.default(mm_208, [2, 8192, 1792]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_2156); convert_element_type_984 = view_2156 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16) + all_gather_into_tensor_330 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 32, '0'); convert_element_type_988 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_330); all_gather_into_tensor_330 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_390, [1, 0]); wait_tensor_390 = None + view_2163 = torch.ops.aten.view.default(mul_239, [16384, 1792]); mul_239 = None + mm_209 = torch.ops.aten.mm.default(view_2163, permute_329); view_2163 = permute_329 = None + view_2164 = torch.ops.aten.view.default(mm_209, [2, 8192, 4096]); mm_209 = None + split_128 = torch.ops.aten.split.Tensor(view_2164, 1024, 1); view_2164 = None + getitem_1294 = split_128[0] + getitem_1295 = split_128[1] + getitem_1296 = split_128[2] + getitem_1297 = split_128[3] + getitem_1298 = split_128[4] + getitem_1299 = split_128[5] + getitem_1300 = split_128[6] + getitem_1301 = split_128[7]; split_128 = None + cat_120 = torch.ops.aten.cat.default([getitem_1294, getitem_1295, getitem_1296, getitem_1297, getitem_1298, getitem_1299, getitem_1300, getitem_1301]); getitem_1294 = getitem_1295 = getitem_1296 = getitem_1297 = getitem_1298 = getitem_1299 = getitem_1300 = getitem_1301 = None + reduce_scatter_tensor_60 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_120, 'sum', 8, '1'); cat_120 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_60); reduce_scatter_tensor_60 = None + add_119 = torch.ops.aten.add.Tensor(add_117, wait_tensor_391); add_117 = wait_tensor_391 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16) + all_gather_into_tensor_331 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 32, '0'); convert_element_type_991 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_331); all_gather_into_tensor_331 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32) + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = rsqrt_60 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_392); mul_240 = wait_tensor_392 = None + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + all_gather_into_tensor_332 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_993, 8, '1'); convert_element_type_993 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_332); all_gather_into_tensor_332 = None + split_129 = torch.ops.aten.split.Tensor(wait_tensor_393, 2); wait_tensor_393 = None + getitem_1302 = split_129[0] + getitem_1303 = split_129[1] + getitem_1304 = split_129[2] + getitem_1305 = split_129[3] + getitem_1306 = split_129[4] + getitem_1307 = split_129[5] + getitem_1308 = split_129[6] + getitem_1309 = split_129[7]; split_129 = None + cat_121 = torch.ops.aten.cat.default([getitem_1302, getitem_1303, getitem_1304, getitem_1305, getitem_1306, getitem_1307, getitem_1308, getitem_1309], 1); getitem_1302 = getitem_1303 = getitem_1304 = getitem_1305 = getitem_1306 = getitem_1307 = getitem_1308 = getitem_1309 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16) + all_gather_into_tensor_333 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 32, '0'); convert_element_type_994 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_333); all_gather_into_tensor_333 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_394, [1, 0]); wait_tensor_394 = None + view_2175 = torch.ops.aten.view.default(cat_121, [16384, 4096]); cat_121 = None + mm_210 = torch.ops.aten.mm.default(view_2175, permute_330); permute_330 = None + view_2176 = torch.ops.aten.view.default(mm_210, [2, 8192, 512]) + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16) + all_gather_into_tensor_334 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 32, '0'); convert_element_type_997 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_334); all_gather_into_tensor_334 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_395, [1, 0]); wait_tensor_395 = None + mm_211 = torch.ops.aten.mm.default(view_2175, permute_331); permute_331 = None + view_2183 = torch.ops.aten.view.default(mm_211, [2, 8192, 128]); mm_211 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16) + all_gather_into_tensor_335 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 32, '0'); convert_element_type_1000 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_335); all_gather_into_tensor_335 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_396, [1, 0]); wait_tensor_396 = None + mm_212 = torch.ops.aten.mm.default(view_2175, permute_332); view_2175 = permute_332 = None + view_2190 = torch.ops.aten.view.default(mm_212, [2, 8192, 128]) + view_2192 = torch.ops.aten.view.default(view_2176, [2, 8192, -1, 128]); view_2176 = None + view_2193 = torch.ops.aten.view.default(view_2183, [2, 8192, -1, 128]); view_2183 = None + view_2194 = torch.ops.aten.view.default(view_2190, [2, 8192, -1, 128]); view_2190 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_2192, torch.float32); view_2192 = None + view_2195 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 4, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_2195); view_2195 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_2193, torch.float32); view_2193 = None + view_2196 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 1, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_2196); view_2196 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_37); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_2198 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 4, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_37); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_2199 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 1, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_2198, torch.bfloat16); view_2198 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_2199, torch.bfloat16); view_2199 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 1, 4, 128]); unsqueeze_60 = None + view_2200 = torch.ops.aten.view.default(expand_60, [2, 8192, 4, 128]); expand_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_2194, 3); view_2194 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 1, 4, 128]); unsqueeze_61 = None + view_2201 = torch.ops.aten.view.default(expand_61, [2, 8192, 4, 128]); expand_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_2200, [0, 2, 1, 3]); view_2200 = None + permute_335 = torch.ops.aten.permute.default(view_2201, [0, 2, 1, 3]); view_2201 = None + _scaled_dot_product_cudnn_attention_30 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_333, permute_334, permute_335, None, True, 0.0, True); permute_333 = permute_334 = permute_335 = None + getitem_1310 = _scaled_dot_product_cudnn_attention_30[0] + getitem_1311 = _scaled_dot_product_cudnn_attention_30[1] + getitem_1316 = _scaled_dot_product_cudnn_attention_30[6] + getitem_1317 = _scaled_dot_product_cudnn_attention_30[7]; _scaled_dot_product_cudnn_attention_30 = None + permute_336 = torch.ops.aten.permute.default(getitem_1310, [0, 2, 1, 3]) + view_2202 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16) + all_gather_into_tensor_336 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 32, '0'); convert_element_type_1007 = None + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_336); all_gather_into_tensor_336 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_397, [1, 0]); wait_tensor_397 = None + view_2208 = torch.ops.aten.view.default(view_2202, [16384, 512]); view_2202 = None + mm_213 = torch.ops.aten.mm.default(view_2208, permute_337); view_2208 = permute_337 = None + view_2209 = torch.ops.aten.view.default(mm_213, [2, 8192, 4096]); mm_213 = None + split_130 = torch.ops.aten.split.Tensor(view_2209, 1024, 1); view_2209 = None + getitem_1319 = split_130[0] + getitem_1320 = split_130[1] + getitem_1321 = split_130[2] + getitem_1322 = split_130[3] + getitem_1323 = split_130[4] + getitem_1324 = split_130[5] + getitem_1325 = split_130[6] + getitem_1326 = split_130[7]; split_130 = None + cat_122 = torch.ops.aten.cat.default([getitem_1319, getitem_1320, getitem_1321, getitem_1322, getitem_1323, getitem_1324, getitem_1325, getitem_1326]); getitem_1319 = getitem_1320 = getitem_1321 = getitem_1322 = getitem_1323 = getitem_1324 = getitem_1325 = getitem_1326 = None + reduce_scatter_tensor_61 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_122, 'sum', 8, '1'); cat_122 = None + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61) + add_121 = torch.ops.aten.add.Tensor(add_119, wait_tensor_398); wait_tensor_398 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16) + all_gather_into_tensor_337 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 32, '0'); convert_element_type_1010 = None + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_337); all_gather_into_tensor_337 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32) + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = rsqrt_61 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_399); mul_244 = wait_tensor_399 = None + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + all_gather_into_tensor_338 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1012, 8, '1'); convert_element_type_1012 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_338); all_gather_into_tensor_338 = None + split_131 = torch.ops.aten.split.Tensor(wait_tensor_400, 2); wait_tensor_400 = None + getitem_1327 = split_131[0] + getitem_1328 = split_131[1] + getitem_1329 = split_131[2] + getitem_1330 = split_131[3] + getitem_1331 = split_131[4] + getitem_1332 = split_131[5] + getitem_1333 = split_131[6] + getitem_1334 = split_131[7]; split_131 = None + cat_123 = torch.ops.aten.cat.default([getitem_1327, getitem_1328, getitem_1329, getitem_1330, getitem_1331, getitem_1332, getitem_1333, getitem_1334], 1); getitem_1327 = getitem_1328 = getitem_1329 = getitem_1330 = getitem_1331 = getitem_1332 = getitem_1333 = getitem_1334 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16) + all_gather_into_tensor_339 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 32, '0'); convert_element_type_1013 = None + wait_tensor_401 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_339); all_gather_into_tensor_339 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_401, [1, 0]); wait_tensor_401 = None + view_2220 = torch.ops.aten.view.default(cat_123, [16384, 4096]); cat_123 = None + mm_214 = torch.ops.aten.mm.default(view_2220, permute_338); permute_338 = None + view_2221 = torch.ops.aten.view.default(mm_214, [2, 8192, 1792]) + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_2221, torch.float32); view_2221 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); convert_element_type_1016 = sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16) + all_gather_into_tensor_340 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 32, '0'); convert_element_type_1018 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_340); all_gather_into_tensor_340 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_402, [1, 0]); wait_tensor_402 = None + mm_215 = torch.ops.aten.mm.default(view_2220, permute_339); view_2220 = permute_339 = None + view_2228 = torch.ops.aten.view.default(mm_215, [2, 8192, 1792]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_2228); convert_element_type_1017 = view_2228 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16) + all_gather_into_tensor_341 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 32, '0'); convert_element_type_1021 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_341); all_gather_into_tensor_341 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_403, [1, 0]); wait_tensor_403 = None + view_2235 = torch.ops.aten.view.default(mul_247, [16384, 1792]); mul_247 = None + mm_216 = torch.ops.aten.mm.default(view_2235, permute_340); view_2235 = permute_340 = None + view_2236 = torch.ops.aten.view.default(mm_216, [2, 8192, 4096]); mm_216 = None + split_132 = torch.ops.aten.split.Tensor(view_2236, 1024, 1); view_2236 = None + getitem_1335 = split_132[0] + getitem_1336 = split_132[1] + getitem_1337 = split_132[2] + getitem_1338 = split_132[3] + getitem_1339 = split_132[4] + getitem_1340 = split_132[5] + getitem_1341 = split_132[6] + getitem_1342 = split_132[7]; split_132 = None + cat_124 = torch.ops.aten.cat.default([getitem_1335, getitem_1336, getitem_1337, getitem_1338, getitem_1339, getitem_1340, getitem_1341, getitem_1342]); getitem_1335 = getitem_1336 = getitem_1337 = getitem_1338 = getitem_1339 = getitem_1340 = getitem_1341 = getitem_1342 = None + reduce_scatter_tensor_62 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_124, 'sum', 8, '1'); cat_124 = None + wait_tensor_404 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_62); reduce_scatter_tensor_62 = None + add_123 = torch.ops.aten.add.Tensor(add_121, wait_tensor_404); add_121 = wait_tensor_404 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16) + all_gather_into_tensor_342 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 32, '0'); convert_element_type_1024 = None + wait_tensor_405 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_342); all_gather_into_tensor_342 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32) + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = rsqrt_62 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_405); mul_248 = wait_tensor_405 = None + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + all_gather_into_tensor_343 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1026, 8, '1'); convert_element_type_1026 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_343); all_gather_into_tensor_343 = None + split_133 = torch.ops.aten.split.Tensor(wait_tensor_406, 2); wait_tensor_406 = None + getitem_1343 = split_133[0] + getitem_1344 = split_133[1] + getitem_1345 = split_133[2] + getitem_1346 = split_133[3] + getitem_1347 = split_133[4] + getitem_1348 = split_133[5] + getitem_1349 = split_133[6] + getitem_1350 = split_133[7]; split_133 = None + cat_125 = torch.ops.aten.cat.default([getitem_1343, getitem_1344, getitem_1345, getitem_1346, getitem_1347, getitem_1348, getitem_1349, getitem_1350], 1); getitem_1343 = getitem_1344 = getitem_1345 = getitem_1346 = getitem_1347 = getitem_1348 = getitem_1349 = getitem_1350 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16) + all_gather_into_tensor_344 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 32, '0'); convert_element_type_1027 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_344); all_gather_into_tensor_344 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_407, [1, 0]); wait_tensor_407 = None + view_2247 = torch.ops.aten.view.default(cat_125, [16384, 4096]); cat_125 = None + mm_217 = torch.ops.aten.mm.default(view_2247, permute_341); permute_341 = None + view_2248 = torch.ops.aten.view.default(mm_217, [2, 8192, 512]) + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16) + all_gather_into_tensor_345 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 32, '0'); convert_element_type_1030 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_345); all_gather_into_tensor_345 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_408, [1, 0]); wait_tensor_408 = None + mm_218 = torch.ops.aten.mm.default(view_2247, permute_342); permute_342 = None + view_2255 = torch.ops.aten.view.default(mm_218, [2, 8192, 128]); mm_218 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16) + all_gather_into_tensor_346 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 32, '0'); convert_element_type_1033 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_346); all_gather_into_tensor_346 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_409, [1, 0]); wait_tensor_409 = None + mm_219 = torch.ops.aten.mm.default(view_2247, permute_343); view_2247 = permute_343 = None + view_2262 = torch.ops.aten.view.default(mm_219, [2, 8192, 128]) + view_2264 = torch.ops.aten.view.default(view_2248, [2, 8192, -1, 128]); view_2248 = None + view_2265 = torch.ops.aten.view.default(view_2255, [2, 8192, -1, 128]); view_2255 = None + view_2266 = torch.ops.aten.view.default(view_2262, [2, 8192, -1, 128]); view_2262 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_2264, torch.float32); view_2264 = None + view_2267 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 4, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_2267); view_2267 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_2265, torch.float32); view_2265 = None + view_2268 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 1, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_2268); view_2268 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_37); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_2270 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 4, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_37); view_as_complex_63 = view_37 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_2271 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 1, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_2270, torch.bfloat16); view_2270 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_2271, torch.bfloat16); view_2271 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 1, 4, 128]); unsqueeze_62 = None + view_2272 = torch.ops.aten.view.default(expand_62, [2, 8192, 4, 128]); expand_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_2266, 3); view_2266 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 1, 4, 128]); unsqueeze_63 = None + view_2273 = torch.ops.aten.view.default(expand_63, [2, 8192, 4, 128]); expand_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_2272, [0, 2, 1, 3]); view_2272 = None + permute_346 = torch.ops.aten.permute.default(view_2273, [0, 2, 1, 3]); view_2273 = None + _scaled_dot_product_cudnn_attention_31 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_344, permute_345, permute_346, None, True, 0.0, True); permute_344 = permute_345 = permute_346 = None + getitem_1351 = _scaled_dot_product_cudnn_attention_31[0] + getitem_1352 = _scaled_dot_product_cudnn_attention_31[1] + getitem_1357 = _scaled_dot_product_cudnn_attention_31[6] + getitem_1358 = _scaled_dot_product_cudnn_attention_31[7]; _scaled_dot_product_cudnn_attention_31 = None + permute_347 = torch.ops.aten.permute.default(getitem_1351, [0, 2, 1, 3]) + view_2274 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16) + all_gather_into_tensor_347 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 32, '0'); convert_element_type_1040 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_347); all_gather_into_tensor_347 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_410, [1, 0]); wait_tensor_410 = None + view_2280 = torch.ops.aten.view.default(view_2274, [16384, 512]); view_2274 = None + mm_220 = torch.ops.aten.mm.default(view_2280, permute_348); view_2280 = permute_348 = None + view_2281 = torch.ops.aten.view.default(mm_220, [2, 8192, 4096]); mm_220 = None + split_134 = torch.ops.aten.split.Tensor(view_2281, 1024, 1); view_2281 = None + getitem_1360 = split_134[0] + getitem_1361 = split_134[1] + getitem_1362 = split_134[2] + getitem_1363 = split_134[3] + getitem_1364 = split_134[4] + getitem_1365 = split_134[5] + getitem_1366 = split_134[6] + getitem_1367 = split_134[7]; split_134 = None + cat_126 = torch.ops.aten.cat.default([getitem_1360, getitem_1361, getitem_1362, getitem_1363, getitem_1364, getitem_1365, getitem_1366, getitem_1367]); getitem_1360 = getitem_1361 = getitem_1362 = getitem_1363 = getitem_1364 = getitem_1365 = getitem_1366 = getitem_1367 = None + reduce_scatter_tensor_63 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_126, 'sum', 8, '1'); cat_126 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63) + add_125 = torch.ops.aten.add.Tensor(add_123, wait_tensor_411); wait_tensor_411 = None + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16) + all_gather_into_tensor_348 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 32, '0'); convert_element_type_1043 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_348); all_gather_into_tensor_348 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32) + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = rsqrt_63 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_412); mul_252 = wait_tensor_412 = None + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + all_gather_into_tensor_349 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1045, 8, '1'); convert_element_type_1045 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_349); all_gather_into_tensor_349 = None + split_135 = torch.ops.aten.split.Tensor(wait_tensor_413, 2); wait_tensor_413 = None + getitem_1368 = split_135[0] + getitem_1369 = split_135[1] + getitem_1370 = split_135[2] + getitem_1371 = split_135[3] + getitem_1372 = split_135[4] + getitem_1373 = split_135[5] + getitem_1374 = split_135[6] + getitem_1375 = split_135[7]; split_135 = None + cat_127 = torch.ops.aten.cat.default([getitem_1368, getitem_1369, getitem_1370, getitem_1371, getitem_1372, getitem_1373, getitem_1374, getitem_1375], 1); getitem_1368 = getitem_1369 = getitem_1370 = getitem_1371 = getitem_1372 = getitem_1373 = getitem_1374 = getitem_1375 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16) + all_gather_into_tensor_350 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 32, '0'); convert_element_type_1046 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_350); all_gather_into_tensor_350 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_414, [1, 0]); wait_tensor_414 = None + view_2292 = torch.ops.aten.view.default(cat_127, [16384, 4096]); cat_127 = None + mm_221 = torch.ops.aten.mm.default(view_2292, permute_349); permute_349 = None + view_2293 = torch.ops.aten.view.default(mm_221, [2, 8192, 1792]) + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_2293, torch.float32); view_2293 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); convert_element_type_1049 = sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16) + all_gather_into_tensor_351 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 32, '0'); convert_element_type_1051 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_351); all_gather_into_tensor_351 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_415, [1, 0]); wait_tensor_415 = None + mm_222 = torch.ops.aten.mm.default(view_2292, permute_350); view_2292 = permute_350 = None + view_2300 = torch.ops.aten.view.default(mm_222, [2, 8192, 1792]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_2300); convert_element_type_1050 = view_2300 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16) + all_gather_into_tensor_352 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 32, '0'); convert_element_type_1054 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_352); all_gather_into_tensor_352 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_416, [1, 0]); wait_tensor_416 = None + view_2307 = torch.ops.aten.view.default(mul_255, [16384, 1792]); mul_255 = None + mm_223 = torch.ops.aten.mm.default(view_2307, permute_351); view_2307 = permute_351 = None + view_2308 = torch.ops.aten.view.default(mm_223, [2, 8192, 4096]); mm_223 = None + split_136 = torch.ops.aten.split.Tensor(view_2308, 1024, 1); view_2308 = None + getitem_1376 = split_136[0] + getitem_1377 = split_136[1] + getitem_1378 = split_136[2] + getitem_1379 = split_136[3] + getitem_1380 = split_136[4] + getitem_1381 = split_136[5] + getitem_1382 = split_136[6] + getitem_1383 = split_136[7]; split_136 = None + cat_128 = torch.ops.aten.cat.default([getitem_1376, getitem_1377, getitem_1378, getitem_1379, getitem_1380, getitem_1381, getitem_1382, getitem_1383]); getitem_1376 = getitem_1377 = getitem_1378 = getitem_1379 = getitem_1380 = getitem_1381 = getitem_1382 = getitem_1383 = None + reduce_scatter_tensor_64 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_128, 'sum', 8, '1'); cat_128 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64) + add_127 = torch.ops.aten.add.Tensor(add_125, wait_tensor_417); add_125 = wait_tensor_417 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16) + all_gather_into_tensor_353 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 32, '0'); convert_element_type_1057 = None + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_353); all_gather_into_tensor_353 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + pow_65 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1058, 2) + mean_64 = torch.ops.aten.mean.dim(pow_65, [2], True); pow_65 = None + add_128 = torch.ops.aten.add.Scalar(mean_64, 1e-05); mean_64 = None + rsqrt_64 = torch.ops.aten.rsqrt.default(add_128); add_128 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_257 = torch.ops.aten.mul.Tensor(mul_256, wait_tensor_418); mul_256 = wait_tensor_418 = None + convert_element_type_1059 = torch.ops.prims.convert_element_type.default(mul_257, torch.bfloat16); mul_257 = None + all_gather_into_tensor_354 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1059, 8, '1'); convert_element_type_1059 = None + wait_tensor_419 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_354); all_gather_into_tensor_354 = None + split_137 = torch.ops.aten.split.Tensor(wait_tensor_419, 2); wait_tensor_419 = None + getitem_1384 = split_137[0] + getitem_1385 = split_137[1] + getitem_1386 = split_137[2] + getitem_1387 = split_137[3] + getitem_1388 = split_137[4] + getitem_1389 = split_137[5] + getitem_1390 = split_137[6] + getitem_1391 = split_137[7]; split_137 = None + cat_129 = torch.ops.aten.cat.default([getitem_1384, getitem_1385, getitem_1386, getitem_1387, getitem_1388, getitem_1389, getitem_1390, getitem_1391], 1); getitem_1384 = getitem_1385 = getitem_1386 = getitem_1387 = getitem_1388 = getitem_1389 = getitem_1390 = getitem_1391 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16) + all_gather_into_tensor_355 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 32, '0'); convert_element_type_1060 = None + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_355); all_gather_into_tensor_355 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_420, [1, 0]); wait_tensor_420 = None + view_2319 = torch.ops.aten.view.default(cat_129, [16384, 4096]); cat_129 = None + mm_224 = torch.ops.aten.mm.default(view_2319, permute_352); permute_352 = None + view_2320 = torch.ops.aten.view.default(mm_224, [2, 8192, 16032]); mm_224 = None + return (view_2320, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, wait_tensor_1, mm, mm_2, getitem_80, getitem_81, getitem_86, getitem_87, reduce_scatter_tensor_1, mm_4, add_3, mm_7, mm_9, getitem_121, getitem_122, getitem_127, getitem_128, reduce_scatter_tensor_3, mm_11, add_7, mm_14, mm_16, getitem_162, getitem_163, getitem_168, getitem_169, reduce_scatter_tensor_5, mm_18, add_11, mm_21, mm_23, getitem_203, getitem_204, getitem_209, getitem_210, reduce_scatter_tensor_7, mm_25, add_15, mm_28, mm_30, getitem_244, getitem_245, getitem_250, getitem_251, reduce_scatter_tensor_9, mm_32, add_19, mm_35, mm_37, getitem_285, getitem_286, getitem_291, getitem_292, reduce_scatter_tensor_11, mm_39, add_23, mm_42, mm_44, getitem_326, getitem_327, getitem_332, getitem_333, reduce_scatter_tensor_13, mm_46, add_27, mm_49, mm_51, getitem_367, getitem_368, getitem_373, getitem_374, reduce_scatter_tensor_15, mm_53, add_31, mm_56, mm_58, getitem_408, getitem_409, getitem_414, getitem_415, reduce_scatter_tensor_17, mm_60, add_35, mm_63, mm_65, getitem_449, getitem_450, getitem_455, getitem_456, reduce_scatter_tensor_19, mm_67, add_39, mm_70, mm_72, getitem_490, getitem_491, getitem_496, getitem_497, reduce_scatter_tensor_21, mm_74, add_43, mm_77, mm_79, getitem_531, getitem_532, getitem_537, getitem_538, reduce_scatter_tensor_23, mm_81, add_47, mm_84, mm_86, getitem_572, getitem_573, getitem_578, getitem_579, reduce_scatter_tensor_25, mm_88, add_51, mm_91, mm_93, getitem_613, getitem_614, getitem_619, getitem_620, reduce_scatter_tensor_27, mm_95, add_55, mm_98, mm_100, getitem_654, getitem_655, getitem_660, getitem_661, reduce_scatter_tensor_29, mm_102, add_59, mm_105, mm_107, getitem_695, getitem_696, getitem_701, getitem_702, reduce_scatter_tensor_31, mm_109, add_63, mm_112, mm_114, getitem_736, getitem_737, getitem_742, getitem_743, reduce_scatter_tensor_33, mm_116, add_67, mm_119, mm_121, getitem_777, getitem_778, getitem_783, getitem_784, reduce_scatter_tensor_35, mm_123, add_71, mm_126, mm_128, getitem_818, getitem_819, getitem_824, getitem_825, reduce_scatter_tensor_37, mm_130, add_75, mm_133, mm_135, getitem_859, getitem_860, getitem_865, getitem_866, reduce_scatter_tensor_39, mm_137, add_79, mm_140, mm_142, getitem_900, getitem_901, getitem_906, getitem_907, reduce_scatter_tensor_41, mm_144, add_83, mm_147, mm_149, getitem_941, getitem_942, getitem_947, getitem_948, reduce_scatter_tensor_43, mm_151, add_87, mm_154, mm_156, getitem_982, getitem_983, getitem_988, getitem_989, reduce_scatter_tensor_45, mm_158, add_91, mm_161, mm_163, getitem_1023, getitem_1024, getitem_1029, getitem_1030, reduce_scatter_tensor_47, mm_165, add_95, mm_168, mm_170, getitem_1064, getitem_1065, getitem_1070, getitem_1071, reduce_scatter_tensor_49, mm_172, add_99, mm_175, mm_177, getitem_1105, getitem_1106, getitem_1111, getitem_1112, reduce_scatter_tensor_51, mm_179, add_103, mm_182, mm_184, getitem_1146, getitem_1147, getitem_1152, getitem_1153, reduce_scatter_tensor_53, mm_186, add_107, mm_189, mm_191, getitem_1187, getitem_1188, getitem_1193, getitem_1194, reduce_scatter_tensor_55, mm_193, add_111, mm_196, mm_198, getitem_1228, getitem_1229, getitem_1234, getitem_1235, reduce_scatter_tensor_57, mm_200, add_115, mm_203, mm_205, getitem_1269, getitem_1270, getitem_1275, getitem_1276, reduce_scatter_tensor_59, mm_207, add_119, mm_210, mm_212, getitem_1310, getitem_1311, getitem_1316, getitem_1317, reduce_scatter_tensor_61, mm_214, add_123, mm_217, mm_219, getitem_1351, getitem_1352, getitem_1357, getitem_1358, reduce_scatter_tensor_63, mm_221, reduce_scatter_tensor_64, rsqrt_64, view_2319) + +def load_args(reader): + buf0 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf0, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_1 + buf1 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf1, (501, 4096), is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf3, (128,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf4, (16, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf5, (4, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf6, (4, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf7, (128, 512), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf8, (128,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf9, (56, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf10, (56, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf11, (128, 1792), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf12, (128,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf13, (16, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf14, (4, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf15, (4, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf16, (128, 512), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf17, (128,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf18, (56, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf19, (56, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf20, (128, 1792), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf21, (128,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf22, (16, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf23, (4, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf24, (4, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf25, (128, 512), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf26, (128,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf27, (56, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf28, (56, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf29, (128, 1792), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf30, (128,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf31, (16, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf32, (4, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf33, (4, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf34, (128, 512), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf35, (128,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf36, (56, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf37, (56, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf38, (128, 1792), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf39, (128,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf40, (16, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf41, (4, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf42, (4, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf43, (128, 512), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf44, (128,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf45, (56, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf46, (56, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf47, (128, 1792), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf48, (128,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf49, (16, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf50, (4, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf51, (4, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf52, (128, 512), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf53, (128,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf54, (56, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf55, (56, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf56, (128, 1792), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf57, (128,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf58, (16, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf59, (4, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf60, (4, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf61, (128, 512), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf62, (128,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf63, (56, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf64, (56, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf65, (128, 1792), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf66, (128,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf67, (16, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf68, (4, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf69, (4, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf70, (128, 512), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf71, (128,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf72, (56, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf73, (56, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf74, (128, 1792), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf75, (128,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf76, (16, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf77, (4, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf78, (4, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf79, (128, 512), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf80, (128,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf81, (56, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf82, (56, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf83, (128, 1792), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf84, (128,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf85, (16, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf86, (4, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf87, (4, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf88, (128, 512), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf89, (128,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf90, (56, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf91, (56, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf92, (128, 1792), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf93, (128,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf94, (16, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf95, (4, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf96, (4, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf97, (128, 512), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf98, (128,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf99, (56, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf100, (56, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf101, (128, 1792), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf102, (128,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf103, (16, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf104, (4, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf105, (4, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf106, (128, 512), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf107, (128,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf108, (56, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf109, (56, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf110, (128, 1792), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf111, (128,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf112, (16, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf113, (4, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf114, (4, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf115, (128, 512), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf116, (128,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf117, (56, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf118, (56, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf119, (128, 1792), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf120, (128,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf121, (16, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf122, (4, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf123, (4, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf124, (128, 512), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf125, (128,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf126, (56, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf127, (56, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf128, (128, 1792), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf129, (128,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf130, (16, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf131, (4, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf132, (4, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf133, (128, 512), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf134, (128,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf135, (56, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf136, (56, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf137, (128, 1792), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf138, (128,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf139, (16, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf140, (4, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf141, (4, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf142, (128, 512), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf143, (128,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf144, (56, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf145, (56, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf146, (128, 1792), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf147, (128,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf148, (16, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf149, (4, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf150, (4, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf151, (128, 512), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf152, (128,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf153, (56, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf154, (56, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf155, (128, 1792), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf156, (128,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf157, (16, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf158, (4, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf159, (4, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf160, (128, 512), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf161, (128,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf162, (56, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf163, (56, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf164, (128, 1792), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf165, (128,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf166, (16, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf167, (4, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf168, (4, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf169, (128, 512), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf170, (128,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf171, (56, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf172, (56, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf173, (128, 1792), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf174, (128,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf175, (16, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf176, (4, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf177, (4, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf178, (128, 512), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf179, (128,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf180, (56, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf181, (56, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf182, (128, 1792), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf183, (128,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf184, (16, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf185, (4, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf186, (4, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf187, (128, 512), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf188, (128,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf189, (56, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf190, (56, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf191, (128, 1792), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf192, (128,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf193, (16, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf194, (4, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf195, (4, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf196, (128, 512), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf197, (128,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf198, (56, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf199, (56, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf200, (128, 1792), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf201, (128,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf202, (16, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf203, (4, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf204, (4, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf205, (128, 512), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf206, (128,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf207, (56, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf208, (56, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf209, (128, 1792), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf210, (128,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf211, (16, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf212, (4, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf213, (4, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf214, (128, 512), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf215, (128,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf216, (56, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf217, (56, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf218, (128, 1792), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf219, (128,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf220, (16, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf221, (4, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf222, (4, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf223, (128, 512), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf224, (128,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf225, (56, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf226, (56, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf227, (128, 1792), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf228, (128,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf229, (16, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf230, (4, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf231, (4, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf232, (128, 512), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf233, (128,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf234, (56, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf235, (56, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf236, (128, 1792), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf237, (128,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf238, (16, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf239, (4, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf240, (4, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf241, (128, 512), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf242, (128,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf243, (56, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf244, (56, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf245, (128, 1792), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf246, (128,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf247, (16, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf248, (4, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf249, (4, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf250, (128, 512), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf251, (128,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf252, (56, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf253, (56, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf254, (128, 1792), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf255, (128,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf256, (16, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf257, (4, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf258, (4, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf259, (128, 512), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf260, (128,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf261, (56, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf262, (56, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf263, (128, 1792), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf264, (128,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf265, (16, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf266, (4, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf267, (4, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf268, (128, 512), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf269, (128,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf270, (56, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf271, (56, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf272, (128, 1792), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf273, (128,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf274, (16, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf275, (4, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf276, (4, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf277, (128, 512), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf278, (128,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf279, (56, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf280, (56, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf281, (128, 1792), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf282, (128,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf283, (16, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf284, (4, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf285, (4, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf286, (128, 512), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf287, (128,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf288, (56, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf289, (56, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf290, (128, 1792), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf291, (128,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf292, (501, 4096), is_leaf=True) # primals_293 + +load_args._version = 0 + +def get_mesh_sizes(): + return 32, 8 + +def get_colls_estimations_file(): + return "colls32_8.table" + +def get_pg_names(): + return "0", "1" + diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_1d_32layers.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_1d_32layers.py new file mode 100644 index 00000000..562c25ca --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_1d_32layers.py @@ -0,0 +1,4153 @@ +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293): + convert_element_type = torch.ops.prims.convert_element_type.default(primals_1, torch.bfloat16) + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type, 64, '0'); convert_element_type = None + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None + embedding = torch.ops.aten.embedding.default(wait_tensor, primals_2); wait_tensor = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16) + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 64, '0'); convert_element_type_1 = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(embedding, torch.float32) + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = rsqrt = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_1); mul = wait_tensor_1 = None + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16) + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 64, '0'); convert_element_type_4 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + permute = torch.ops.aten.permute.default(wait_tensor_2, [1, 0]); wait_tensor_2 = None + view_3 = torch.ops.aten.view.default(convert_element_type_3, [16384, 4096]); convert_element_type_3 = None + mm = torch.ops.aten.mm.default(view_3, permute); permute = None + view_4 = torch.ops.aten.view.default(mm, [2, 8192, 4096]) + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16) + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 64, '0'); convert_element_type_7 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_3, [1, 0]); wait_tensor_3 = None + mm_1 = torch.ops.aten.mm.default(view_3, permute_1); permute_1 = None + view_7 = torch.ops.aten.view.default(mm_1, [2, 8192, 1024]); mm_1 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16) + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 64, '0'); convert_element_type_10 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + mm_2 = torch.ops.aten.mm.default(view_3, permute_2); view_3 = permute_2 = None + view_10 = torch.ops.aten.view.default(mm_2, [2, 8192, 1024]) + view_11 = torch.ops.aten.view.default(view_4, [2, 8192, -1, 128]); view_4 = None + view_12 = torch.ops.aten.view.default(view_7, [2, 8192, -1, 128]); view_7 = None + view_13 = torch.ops.aten.view.default(view_10, [2, 8192, -1, 128]); view_10 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_11, torch.float32); view_11 = None + view_14 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 32, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_14); view_14 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_12, torch.float32); view_12 = None + view_15 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 8, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_15); view_15 = None + view_16 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]) + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_16); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_17 = torch.ops.aten.view.default(view_as_real, [2, 8192, 32, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_16); view_as_complex_1 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_18 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 8, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_17, torch.bfloat16); view_17 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_18, torch.bfloat16); view_18 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 8, 4, 128]); unsqueeze = None + clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None + view_19 = torch.ops.aten.view.default(clone, [2, 8192, 32, 128]); clone = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_13, 3); view_13 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 8, 4, 128]); unsqueeze_1 = None + clone_1 = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None + view_20 = torch.ops.aten.view.default(clone_1, [2, 8192, 32, 128]); clone_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3]); view_19 = None + permute_5 = torch.ops.aten.permute.default(view_20, [0, 2, 1, 3]); view_20 = None + _scaled_dot_product_cudnn_attention = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_3, permute_4, permute_5, None, True, 0.0, True); permute_3 = permute_4 = permute_5 = None + getitem = _scaled_dot_product_cudnn_attention[0] + getitem_1 = _scaled_dot_product_cudnn_attention[1] + getitem_6 = _scaled_dot_product_cudnn_attention[6] + getitem_7 = _scaled_dot_product_cudnn_attention[7]; _scaled_dot_product_cudnn_attention = None + permute_6 = torch.ops.aten.permute.default(getitem, [0, 2, 1, 3]) + view_21 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16) + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 64, '0'); convert_element_type_17 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + view_23 = torch.ops.aten.view.default(view_21, [16384, 4096]); view_21 = None + mm_3 = torch.ops.aten.mm.default(view_23, permute_7); view_23 = permute_7 = None + view_24 = torch.ops.aten.view.default(mm_3, [2, 8192, 4096]); mm_3 = None + add_1 = torch.ops.aten.add.Tensor(embedding, view_24); view_24 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16) + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 64, '0'); convert_element_type_20 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32) + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = rsqrt_1 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_6); mul_4 = wait_tensor_6 = None + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16) + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 64, '0'); convert_element_type_23 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + view_27 = torch.ops.aten.view.default(convert_element_type_22, [16384, 4096]); convert_element_type_22 = None + mm_4 = torch.ops.aten.mm.default(view_27, permute_8); permute_8 = None + view_28 = torch.ops.aten.view.default(mm_4, [2, 8192, 14336]) + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_28, torch.float32); view_28 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); convert_element_type_26 = sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16) + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 64, '0'); convert_element_type_28 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_8, [1, 0]); wait_tensor_8 = None + mm_5 = torch.ops.aten.mm.default(view_27, permute_9); view_27 = permute_9 = None + view_31 = torch.ops.aten.view.default(mm_5, [2, 8192, 14336]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_31); convert_element_type_27 = view_31 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16) + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 64, '0'); convert_element_type_31 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_9, [1, 0]); wait_tensor_9 = None + view_33 = torch.ops.aten.view.default(mul_7, [16384, 14336]); mul_7 = None + mm_6 = torch.ops.aten.mm.default(view_33, permute_10); view_33 = permute_10 = None + view_34 = torch.ops.aten.view.default(mm_6, [2, 8192, 4096]); mm_6 = None + add_3 = torch.ops.aten.add.Tensor(add_1, view_34); add_1 = view_34 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16) + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 64, '0'); convert_element_type_34 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32) + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = rsqrt_2 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_10); mul_8 = wait_tensor_10 = None + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16) + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 64, '0'); convert_element_type_37 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + view_37 = torch.ops.aten.view.default(convert_element_type_36, [16384, 4096]); convert_element_type_36 = None + mm_7 = torch.ops.aten.mm.default(view_37, permute_11); permute_11 = None + view_38 = torch.ops.aten.view.default(mm_7, [2, 8192, 4096]) + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16) + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 64, '0'); convert_element_type_40 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_8 = torch.ops.aten.mm.default(view_37, permute_12); permute_12 = None + view_41 = torch.ops.aten.view.default(mm_8, [2, 8192, 1024]); mm_8 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16) + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 64, '0'); convert_element_type_43 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + mm_9 = torch.ops.aten.mm.default(view_37, permute_13); view_37 = permute_13 = None + view_44 = torch.ops.aten.view.default(mm_9, [2, 8192, 1024]) + view_45 = torch.ops.aten.view.default(view_38, [2, 8192, -1, 128]); view_38 = None + view_46 = torch.ops.aten.view.default(view_41, [2, 8192, -1, 128]); view_41 = None + view_47 = torch.ops.aten.view.default(view_44, [2, 8192, -1, 128]); view_44 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_45, torch.float32); view_45 = None + view_48 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 32, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_48); view_48 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_46, torch.float32); view_46 = None + view_49 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 8, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_49); view_49 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_16); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_51 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 32, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_16); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_52 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 8, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_51, torch.bfloat16); view_51 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_52, torch.bfloat16); view_52 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 8, 4, 128]); unsqueeze_2 = None + clone_2 = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format); expand_2 = None + view_53 = torch.ops.aten.view.default(clone_2, [2, 8192, 32, 128]); clone_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_47, 3); view_47 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 8, 4, 128]); unsqueeze_3 = None + clone_3 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format); expand_3 = None + view_54 = torch.ops.aten.view.default(clone_3, [2, 8192, 32, 128]); clone_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_53, [0, 2, 1, 3]); view_53 = None + permute_16 = torch.ops.aten.permute.default(view_54, [0, 2, 1, 3]); view_54 = None + _scaled_dot_product_cudnn_attention_1 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_14, permute_15, permute_16, None, True, 0.0, True); permute_14 = permute_15 = permute_16 = None + getitem_9 = _scaled_dot_product_cudnn_attention_1[0] + getitem_10 = _scaled_dot_product_cudnn_attention_1[1] + getitem_15 = _scaled_dot_product_cudnn_attention_1[6] + getitem_16 = _scaled_dot_product_cudnn_attention_1[7]; _scaled_dot_product_cudnn_attention_1 = None + permute_17 = torch.ops.aten.permute.default(getitem_9, [0, 2, 1, 3]) + view_55 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16) + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 64, '0'); convert_element_type_50 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_14, [1, 0]); wait_tensor_14 = None + view_57 = torch.ops.aten.view.default(view_55, [16384, 4096]); view_55 = None + mm_10 = torch.ops.aten.mm.default(view_57, permute_18); view_57 = permute_18 = None + view_58 = torch.ops.aten.view.default(mm_10, [2, 8192, 4096]); mm_10 = None + add_5 = torch.ops.aten.add.Tensor(add_3, view_58); view_58 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16) + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 64, '0'); convert_element_type_53 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32) + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = rsqrt_3 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_15); mul_12 = wait_tensor_15 = None + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16) + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 64, '0'); convert_element_type_56 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_16, [1, 0]); wait_tensor_16 = None + view_61 = torch.ops.aten.view.default(convert_element_type_55, [16384, 4096]); convert_element_type_55 = None + mm_11 = torch.ops.aten.mm.default(view_61, permute_19); permute_19 = None + view_62 = torch.ops.aten.view.default(mm_11, [2, 8192, 14336]) + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_62, torch.float32); view_62 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); convert_element_type_59 = sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16) + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 64, '0'); convert_element_type_61 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + mm_12 = torch.ops.aten.mm.default(view_61, permute_20); view_61 = permute_20 = None + view_65 = torch.ops.aten.view.default(mm_12, [2, 8192, 14336]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_65); convert_element_type_60 = view_65 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16) + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 64, '0'); convert_element_type_64 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + view_67 = torch.ops.aten.view.default(mul_15, [16384, 14336]); mul_15 = None + mm_13 = torch.ops.aten.mm.default(view_67, permute_21); view_67 = permute_21 = None + view_68 = torch.ops.aten.view.default(mm_13, [2, 8192, 4096]); mm_13 = None + add_7 = torch.ops.aten.add.Tensor(add_5, view_68); add_5 = view_68 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16) + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 64, '0'); convert_element_type_67 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32) + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = rsqrt_4 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_19); mul_16 = wait_tensor_19 = None + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16) + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 64, '0'); convert_element_type_70 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + view_71 = torch.ops.aten.view.default(convert_element_type_69, [16384, 4096]); convert_element_type_69 = None + mm_14 = torch.ops.aten.mm.default(view_71, permute_22); permute_22 = None + view_72 = torch.ops.aten.view.default(mm_14, [2, 8192, 4096]) + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16) + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 64, '0'); convert_element_type_73 = None + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_21, [1, 0]); wait_tensor_21 = None + mm_15 = torch.ops.aten.mm.default(view_71, permute_23); permute_23 = None + view_75 = torch.ops.aten.view.default(mm_15, [2, 8192, 1024]); mm_15 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16) + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 64, '0'); convert_element_type_76 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_22, [1, 0]); wait_tensor_22 = None + mm_16 = torch.ops.aten.mm.default(view_71, permute_24); view_71 = permute_24 = None + view_78 = torch.ops.aten.view.default(mm_16, [2, 8192, 1024]) + view_79 = torch.ops.aten.view.default(view_72, [2, 8192, -1, 128]); view_72 = None + view_80 = torch.ops.aten.view.default(view_75, [2, 8192, -1, 128]); view_75 = None + view_81 = torch.ops.aten.view.default(view_78, [2, 8192, -1, 128]); view_78 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_79, torch.float32); view_79 = None + view_82 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 32, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_82); view_82 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_80, torch.float32); view_80 = None + view_83 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 8, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_83); view_83 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_16); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_85 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 32, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_16); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_86 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 8, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_85, torch.bfloat16); view_85 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_86, torch.bfloat16); view_86 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 8, 4, 128]); unsqueeze_4 = None + clone_4 = torch.ops.aten.clone.default(expand_4, memory_format = torch.contiguous_format); expand_4 = None + view_87 = torch.ops.aten.view.default(clone_4, [2, 8192, 32, 128]); clone_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_81, 3); view_81 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 8, 4, 128]); unsqueeze_5 = None + clone_5 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format); expand_5 = None + view_88 = torch.ops.aten.view.default(clone_5, [2, 8192, 32, 128]); clone_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_87, [0, 2, 1, 3]); view_87 = None + permute_27 = torch.ops.aten.permute.default(view_88, [0, 2, 1, 3]); view_88 = None + _scaled_dot_product_cudnn_attention_2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_25, permute_26, permute_27, None, True, 0.0, True); permute_25 = permute_26 = permute_27 = None + getitem_18 = _scaled_dot_product_cudnn_attention_2[0] + getitem_19 = _scaled_dot_product_cudnn_attention_2[1] + getitem_24 = _scaled_dot_product_cudnn_attention_2[6] + getitem_25 = _scaled_dot_product_cudnn_attention_2[7]; _scaled_dot_product_cudnn_attention_2 = None + permute_28 = torch.ops.aten.permute.default(getitem_18, [0, 2, 1, 3]) + view_89 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16) + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 64, '0'); convert_element_type_83 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_23, [1, 0]); wait_tensor_23 = None + view_91 = torch.ops.aten.view.default(view_89, [16384, 4096]); view_89 = None + mm_17 = torch.ops.aten.mm.default(view_91, permute_29); view_91 = permute_29 = None + view_92 = torch.ops.aten.view.default(mm_17, [2, 8192, 4096]); mm_17 = None + add_9 = torch.ops.aten.add.Tensor(add_7, view_92); view_92 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16) + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 64, '0'); convert_element_type_86 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32) + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = rsqrt_5 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_24); mul_20 = wait_tensor_24 = None + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16) + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 64, '0'); convert_element_type_89 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + view_95 = torch.ops.aten.view.default(convert_element_type_88, [16384, 4096]); convert_element_type_88 = None + mm_18 = torch.ops.aten.mm.default(view_95, permute_30); permute_30 = None + view_96 = torch.ops.aten.view.default(mm_18, [2, 8192, 14336]) + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_96, torch.float32); view_96 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); convert_element_type_92 = sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16) + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 64, '0'); convert_element_type_94 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + mm_19 = torch.ops.aten.mm.default(view_95, permute_31); view_95 = permute_31 = None + view_99 = torch.ops.aten.view.default(mm_19, [2, 8192, 14336]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_99); convert_element_type_93 = view_99 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16) + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 64, '0'); convert_element_type_97 = None + wait_tensor_27 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_27, [1, 0]); wait_tensor_27 = None + view_101 = torch.ops.aten.view.default(mul_23, [16384, 14336]); mul_23 = None + mm_20 = torch.ops.aten.mm.default(view_101, permute_32); view_101 = permute_32 = None + view_102 = torch.ops.aten.view.default(mm_20, [2, 8192, 4096]); mm_20 = None + add_11 = torch.ops.aten.add.Tensor(add_9, view_102); add_9 = view_102 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16) + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 64, '0'); convert_element_type_100 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32) + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = rsqrt_6 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_28); mul_24 = wait_tensor_28 = None + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16) + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 64, '0'); convert_element_type_103 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_29, [1, 0]); wait_tensor_29 = None + view_105 = torch.ops.aten.view.default(convert_element_type_102, [16384, 4096]); convert_element_type_102 = None + mm_21 = torch.ops.aten.mm.default(view_105, permute_33); permute_33 = None + view_106 = torch.ops.aten.view.default(mm_21, [2, 8192, 4096]) + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16) + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 64, '0'); convert_element_type_106 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + mm_22 = torch.ops.aten.mm.default(view_105, permute_34); permute_34 = None + view_109 = torch.ops.aten.view.default(mm_22, [2, 8192, 1024]); mm_22 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16) + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 64, '0'); convert_element_type_109 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_23 = torch.ops.aten.mm.default(view_105, permute_35); view_105 = permute_35 = None + view_112 = torch.ops.aten.view.default(mm_23, [2, 8192, 1024]) + view_113 = torch.ops.aten.view.default(view_106, [2, 8192, -1, 128]); view_106 = None + view_114 = torch.ops.aten.view.default(view_109, [2, 8192, -1, 128]); view_109 = None + view_115 = torch.ops.aten.view.default(view_112, [2, 8192, -1, 128]); view_112 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_113, torch.float32); view_113 = None + view_116 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 32, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_116); view_116 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_114, torch.float32); view_114 = None + view_117 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 8, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_117); view_117 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_16); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_119 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 32, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_16); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_120 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 8, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_119, torch.bfloat16); view_119 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_120, torch.bfloat16); view_120 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 8, 4, 128]); unsqueeze_6 = None + clone_6 = torch.ops.aten.clone.default(expand_6, memory_format = torch.contiguous_format); expand_6 = None + view_121 = torch.ops.aten.view.default(clone_6, [2, 8192, 32, 128]); clone_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_115, 3); view_115 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 8, 4, 128]); unsqueeze_7 = None + clone_7 = torch.ops.aten.clone.default(expand_7, memory_format = torch.contiguous_format); expand_7 = None + view_122 = torch.ops.aten.view.default(clone_7, [2, 8192, 32, 128]); clone_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_121, [0, 2, 1, 3]); view_121 = None + permute_38 = torch.ops.aten.permute.default(view_122, [0, 2, 1, 3]); view_122 = None + _scaled_dot_product_cudnn_attention_3 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_36, permute_37, permute_38, None, True, 0.0, True); permute_36 = permute_37 = permute_38 = None + getitem_27 = _scaled_dot_product_cudnn_attention_3[0] + getitem_28 = _scaled_dot_product_cudnn_attention_3[1] + getitem_33 = _scaled_dot_product_cudnn_attention_3[6] + getitem_34 = _scaled_dot_product_cudnn_attention_3[7]; _scaled_dot_product_cudnn_attention_3 = None + permute_39 = torch.ops.aten.permute.default(getitem_27, [0, 2, 1, 3]) + view_123 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16) + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 64, '0'); convert_element_type_116 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + view_125 = torch.ops.aten.view.default(view_123, [16384, 4096]); view_123 = None + mm_24 = torch.ops.aten.mm.default(view_125, permute_40); view_125 = permute_40 = None + view_126 = torch.ops.aten.view.default(mm_24, [2, 8192, 4096]); mm_24 = None + add_13 = torch.ops.aten.add.Tensor(add_11, view_126); view_126 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16) + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 64, '0'); convert_element_type_119 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32) + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = rsqrt_7 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_33); mul_28 = wait_tensor_33 = None + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16) + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 64, '0'); convert_element_type_122 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_34, [1, 0]); wait_tensor_34 = None + view_129 = torch.ops.aten.view.default(convert_element_type_121, [16384, 4096]); convert_element_type_121 = None + mm_25 = torch.ops.aten.mm.default(view_129, permute_41); permute_41 = None + view_130 = torch.ops.aten.view.default(mm_25, [2, 8192, 14336]) + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_130, torch.float32); view_130 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); convert_element_type_125 = sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16) + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 64, '0'); convert_element_type_127 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_35, [1, 0]); wait_tensor_35 = None + mm_26 = torch.ops.aten.mm.default(view_129, permute_42); view_129 = permute_42 = None + view_133 = torch.ops.aten.view.default(mm_26, [2, 8192, 14336]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_133); convert_element_type_126 = view_133 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16) + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 64, '0'); convert_element_type_130 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_36, [1, 0]); wait_tensor_36 = None + view_135 = torch.ops.aten.view.default(mul_31, [16384, 14336]); mul_31 = None + mm_27 = torch.ops.aten.mm.default(view_135, permute_43); view_135 = permute_43 = None + view_136 = torch.ops.aten.view.default(mm_27, [2, 8192, 4096]); mm_27 = None + add_15 = torch.ops.aten.add.Tensor(add_13, view_136); add_13 = view_136 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16) + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 64, '0'); convert_element_type_133 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32) + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = rsqrt_8 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_37); mul_32 = wait_tensor_37 = None + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16) + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 64, '0'); convert_element_type_136 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + view_139 = torch.ops.aten.view.default(convert_element_type_135, [16384, 4096]); convert_element_type_135 = None + mm_28 = torch.ops.aten.mm.default(view_139, permute_44); permute_44 = None + view_140 = torch.ops.aten.view.default(mm_28, [2, 8192, 4096]) + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16) + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 64, '0'); convert_element_type_139 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + mm_29 = torch.ops.aten.mm.default(view_139, permute_45); permute_45 = None + view_143 = torch.ops.aten.view.default(mm_29, [2, 8192, 1024]); mm_29 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16) + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 64, '0'); convert_element_type_142 = None + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_40, [1, 0]); wait_tensor_40 = None + mm_30 = torch.ops.aten.mm.default(view_139, permute_46); view_139 = permute_46 = None + view_146 = torch.ops.aten.view.default(mm_30, [2, 8192, 1024]) + view_147 = torch.ops.aten.view.default(view_140, [2, 8192, -1, 128]); view_140 = None + view_148 = torch.ops.aten.view.default(view_143, [2, 8192, -1, 128]); view_143 = None + view_149 = torch.ops.aten.view.default(view_146, [2, 8192, -1, 128]); view_146 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_147, torch.float32); view_147 = None + view_150 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 32, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_150); view_150 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_148, torch.float32); view_148 = None + view_151 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 8, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_151); view_151 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_16); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_153 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 32, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_16); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_154 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 8, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_153, torch.bfloat16); view_153 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_154, torch.bfloat16); view_154 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 8, 4, 128]); unsqueeze_8 = None + clone_8 = torch.ops.aten.clone.default(expand_8, memory_format = torch.contiguous_format); expand_8 = None + view_155 = torch.ops.aten.view.default(clone_8, [2, 8192, 32, 128]); clone_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_149, 3); view_149 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 8, 4, 128]); unsqueeze_9 = None + clone_9 = torch.ops.aten.clone.default(expand_9, memory_format = torch.contiguous_format); expand_9 = None + view_156 = torch.ops.aten.view.default(clone_9, [2, 8192, 32, 128]); clone_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_155, [0, 2, 1, 3]); view_155 = None + permute_49 = torch.ops.aten.permute.default(view_156, [0, 2, 1, 3]); view_156 = None + _scaled_dot_product_cudnn_attention_4 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_47, permute_48, permute_49, None, True, 0.0, True); permute_47 = permute_48 = permute_49 = None + getitem_36 = _scaled_dot_product_cudnn_attention_4[0] + getitem_37 = _scaled_dot_product_cudnn_attention_4[1] + getitem_42 = _scaled_dot_product_cudnn_attention_4[6] + getitem_43 = _scaled_dot_product_cudnn_attention_4[7]; _scaled_dot_product_cudnn_attention_4 = None + permute_50 = torch.ops.aten.permute.default(getitem_36, [0, 2, 1, 3]) + view_157 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16) + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 64, '0'); convert_element_type_149 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_41, [1, 0]); wait_tensor_41 = None + view_159 = torch.ops.aten.view.default(view_157, [16384, 4096]); view_157 = None + mm_31 = torch.ops.aten.mm.default(view_159, permute_51); view_159 = permute_51 = None + view_160 = torch.ops.aten.view.default(mm_31, [2, 8192, 4096]); mm_31 = None + add_17 = torch.ops.aten.add.Tensor(add_15, view_160); view_160 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16) + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 64, '0'); convert_element_type_152 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32) + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = rsqrt_9 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_42); mul_36 = wait_tensor_42 = None + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16) + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 64, '0'); convert_element_type_155 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + view_163 = torch.ops.aten.view.default(convert_element_type_154, [16384, 4096]); convert_element_type_154 = None + mm_32 = torch.ops.aten.mm.default(view_163, permute_52); permute_52 = None + view_164 = torch.ops.aten.view.default(mm_32, [2, 8192, 14336]) + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_164, torch.float32); view_164 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); convert_element_type_158 = sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16) + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 64, '0'); convert_element_type_160 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_33 = torch.ops.aten.mm.default(view_163, permute_53); view_163 = permute_53 = None + view_167 = torch.ops.aten.view.default(mm_33, [2, 8192, 14336]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_167); convert_element_type_159 = view_167 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16) + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 64, '0'); convert_element_type_163 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + view_169 = torch.ops.aten.view.default(mul_39, [16384, 14336]); mul_39 = None + mm_34 = torch.ops.aten.mm.default(view_169, permute_54); view_169 = permute_54 = None + view_170 = torch.ops.aten.view.default(mm_34, [2, 8192, 4096]); mm_34 = None + add_19 = torch.ops.aten.add.Tensor(add_17, view_170); add_17 = view_170 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16) + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 64, '0'); convert_element_type_166 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32) + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = rsqrt_10 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_46); mul_40 = wait_tensor_46 = None + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16) + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 64, '0'); convert_element_type_169 = None + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_47, [1, 0]); wait_tensor_47 = None + view_173 = torch.ops.aten.view.default(convert_element_type_168, [16384, 4096]); convert_element_type_168 = None + mm_35 = torch.ops.aten.mm.default(view_173, permute_55); permute_55 = None + view_174 = torch.ops.aten.view.default(mm_35, [2, 8192, 4096]) + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16) + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 64, '0'); convert_element_type_172 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_48, [1, 0]); wait_tensor_48 = None + mm_36 = torch.ops.aten.mm.default(view_173, permute_56); permute_56 = None + view_177 = torch.ops.aten.view.default(mm_36, [2, 8192, 1024]); mm_36 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16) + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 64, '0'); convert_element_type_175 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_49, [1, 0]); wait_tensor_49 = None + mm_37 = torch.ops.aten.mm.default(view_173, permute_57); view_173 = permute_57 = None + view_180 = torch.ops.aten.view.default(mm_37, [2, 8192, 1024]) + view_181 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + view_182 = torch.ops.aten.view.default(view_177, [2, 8192, -1, 128]); view_177 = None + view_183 = torch.ops.aten.view.default(view_180, [2, 8192, -1, 128]); view_180 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_181, torch.float32); view_181 = None + view_184 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 32, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_184); view_184 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_182, torch.float32); view_182 = None + view_185 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 8, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_185); view_185 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_16); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_187 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 32, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_16); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_188 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 8, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_187, torch.bfloat16); view_187 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_188, torch.bfloat16); view_188 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 8, 4, 128]); unsqueeze_10 = None + clone_10 = torch.ops.aten.clone.default(expand_10, memory_format = torch.contiguous_format); expand_10 = None + view_189 = torch.ops.aten.view.default(clone_10, [2, 8192, 32, 128]); clone_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_183, 3); view_183 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 8, 4, 128]); unsqueeze_11 = None + clone_11 = torch.ops.aten.clone.default(expand_11, memory_format = torch.contiguous_format); expand_11 = None + view_190 = torch.ops.aten.view.default(clone_11, [2, 8192, 32, 128]); clone_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_189, [0, 2, 1, 3]); view_189 = None + permute_60 = torch.ops.aten.permute.default(view_190, [0, 2, 1, 3]); view_190 = None + _scaled_dot_product_cudnn_attention_5 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_58, permute_59, permute_60, None, True, 0.0, True); permute_58 = permute_59 = permute_60 = None + getitem_45 = _scaled_dot_product_cudnn_attention_5[0] + getitem_46 = _scaled_dot_product_cudnn_attention_5[1] + getitem_51 = _scaled_dot_product_cudnn_attention_5[6] + getitem_52 = _scaled_dot_product_cudnn_attention_5[7]; _scaled_dot_product_cudnn_attention_5 = None + permute_61 = torch.ops.aten.permute.default(getitem_45, [0, 2, 1, 3]) + view_191 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16) + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 64, '0'); convert_element_type_182 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + view_193 = torch.ops.aten.view.default(view_191, [16384, 4096]); view_191 = None + mm_38 = torch.ops.aten.mm.default(view_193, permute_62); view_193 = permute_62 = None + view_194 = torch.ops.aten.view.default(mm_38, [2, 8192, 4096]); mm_38 = None + add_21 = torch.ops.aten.add.Tensor(add_19, view_194); view_194 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16) + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 64, '0'); convert_element_type_185 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32) + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = rsqrt_11 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_51); mul_44 = wait_tensor_51 = None + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16) + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 64, '0'); convert_element_type_188 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + view_197 = torch.ops.aten.view.default(convert_element_type_187, [16384, 4096]); convert_element_type_187 = None + mm_39 = torch.ops.aten.mm.default(view_197, permute_63); permute_63 = None + view_198 = torch.ops.aten.view.default(mm_39, [2, 8192, 14336]) + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_198, torch.float32); view_198 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); convert_element_type_191 = sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16) + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 64, '0'); convert_element_type_193 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_53, [1, 0]); wait_tensor_53 = None + mm_40 = torch.ops.aten.mm.default(view_197, permute_64); view_197 = permute_64 = None + view_201 = torch.ops.aten.view.default(mm_40, [2, 8192, 14336]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_201); convert_element_type_192 = view_201 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16) + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 64, '0'); convert_element_type_196 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_54, [1, 0]); wait_tensor_54 = None + view_203 = torch.ops.aten.view.default(mul_47, [16384, 14336]); mul_47 = None + mm_41 = torch.ops.aten.mm.default(view_203, permute_65); view_203 = permute_65 = None + view_204 = torch.ops.aten.view.default(mm_41, [2, 8192, 4096]); mm_41 = None + add_23 = torch.ops.aten.add.Tensor(add_21, view_204); add_21 = view_204 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16) + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 64, '0'); convert_element_type_199 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32) + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = rsqrt_12 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_55); mul_48 = wait_tensor_55 = None + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16) + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 64, '0'); convert_element_type_202 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + view_207 = torch.ops.aten.view.default(convert_element_type_201, [16384, 4096]); convert_element_type_201 = None + mm_42 = torch.ops.aten.mm.default(view_207, permute_66); permute_66 = None + view_208 = torch.ops.aten.view.default(mm_42, [2, 8192, 4096]) + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16) + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 64, '0'); convert_element_type_205 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_43 = torch.ops.aten.mm.default(view_207, permute_67); permute_67 = None + view_211 = torch.ops.aten.view.default(mm_43, [2, 8192, 1024]); mm_43 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16) + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 64, '0'); convert_element_type_208 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + mm_44 = torch.ops.aten.mm.default(view_207, permute_68); view_207 = permute_68 = None + view_214 = torch.ops.aten.view.default(mm_44, [2, 8192, 1024]) + view_215 = torch.ops.aten.view.default(view_208, [2, 8192, -1, 128]); view_208 = None + view_216 = torch.ops.aten.view.default(view_211, [2, 8192, -1, 128]); view_211 = None + view_217 = torch.ops.aten.view.default(view_214, [2, 8192, -1, 128]); view_214 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_215, torch.float32); view_215 = None + view_218 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 32, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_218); view_218 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_216, torch.float32); view_216 = None + view_219 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 8, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_219); view_219 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_16); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_221 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 32, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_16); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_222 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 8, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_221, torch.bfloat16); view_221 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_222, torch.bfloat16); view_222 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 8, 4, 128]); unsqueeze_12 = None + clone_12 = torch.ops.aten.clone.default(expand_12, memory_format = torch.contiguous_format); expand_12 = None + view_223 = torch.ops.aten.view.default(clone_12, [2, 8192, 32, 128]); clone_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_217, 3); view_217 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 8, 4, 128]); unsqueeze_13 = None + clone_13 = torch.ops.aten.clone.default(expand_13, memory_format = torch.contiguous_format); expand_13 = None + view_224 = torch.ops.aten.view.default(clone_13, [2, 8192, 32, 128]); clone_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_223, [0, 2, 1, 3]); view_223 = None + permute_71 = torch.ops.aten.permute.default(view_224, [0, 2, 1, 3]); view_224 = None + _scaled_dot_product_cudnn_attention_6 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_69, permute_70, permute_71, None, True, 0.0, True); permute_69 = permute_70 = permute_71 = None + getitem_54 = _scaled_dot_product_cudnn_attention_6[0] + getitem_55 = _scaled_dot_product_cudnn_attention_6[1] + getitem_60 = _scaled_dot_product_cudnn_attention_6[6] + getitem_61 = _scaled_dot_product_cudnn_attention_6[7]; _scaled_dot_product_cudnn_attention_6 = None + permute_72 = torch.ops.aten.permute.default(getitem_54, [0, 2, 1, 3]) + view_225 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16) + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 64, '0'); convert_element_type_215 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + view_227 = torch.ops.aten.view.default(view_225, [16384, 4096]); view_225 = None + mm_45 = torch.ops.aten.mm.default(view_227, permute_73); view_227 = permute_73 = None + view_228 = torch.ops.aten.view.default(mm_45, [2, 8192, 4096]); mm_45 = None + add_25 = torch.ops.aten.add.Tensor(add_23, view_228); view_228 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16) + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 64, '0'); convert_element_type_218 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32) + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = rsqrt_13 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_60); mul_52 = wait_tensor_60 = None + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16) + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 64, '0'); convert_element_type_221 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_61, [1, 0]); wait_tensor_61 = None + view_231 = torch.ops.aten.view.default(convert_element_type_220, [16384, 4096]); convert_element_type_220 = None + mm_46 = torch.ops.aten.mm.default(view_231, permute_74); permute_74 = None + view_232 = torch.ops.aten.view.default(mm_46, [2, 8192, 14336]) + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_232, torch.float32); view_232 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); convert_element_type_224 = sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16) + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 64, '0'); convert_element_type_226 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_62, [1, 0]); wait_tensor_62 = None + mm_47 = torch.ops.aten.mm.default(view_231, permute_75); view_231 = permute_75 = None + view_235 = torch.ops.aten.view.default(mm_47, [2, 8192, 14336]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_235); convert_element_type_225 = view_235 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16) + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 64, '0'); convert_element_type_229 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + view_237 = torch.ops.aten.view.default(mul_55, [16384, 14336]); mul_55 = None + mm_48 = torch.ops.aten.mm.default(view_237, permute_76); view_237 = permute_76 = None + view_238 = torch.ops.aten.view.default(mm_48, [2, 8192, 4096]); mm_48 = None + add_27 = torch.ops.aten.add.Tensor(add_25, view_238); add_25 = view_238 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16) + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 64, '0'); convert_element_type_232 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32) + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = rsqrt_14 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_64); mul_56 = wait_tensor_64 = None + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16) + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 64, '0'); convert_element_type_235 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + view_241 = torch.ops.aten.view.default(convert_element_type_234, [16384, 4096]); convert_element_type_234 = None + mm_49 = torch.ops.aten.mm.default(view_241, permute_77); permute_77 = None + view_242 = torch.ops.aten.view.default(mm_49, [2, 8192, 4096]) + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16) + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 64, '0'); convert_element_type_238 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_66, [1, 0]); wait_tensor_66 = None + mm_50 = torch.ops.aten.mm.default(view_241, permute_78); permute_78 = None + view_245 = torch.ops.aten.view.default(mm_50, [2, 8192, 1024]); mm_50 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16) + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 64, '0'); convert_element_type_241 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_67, [1, 0]); wait_tensor_67 = None + mm_51 = torch.ops.aten.mm.default(view_241, permute_79); view_241 = permute_79 = None + view_248 = torch.ops.aten.view.default(mm_51, [2, 8192, 1024]) + view_249 = torch.ops.aten.view.default(view_242, [2, 8192, -1, 128]); view_242 = None + view_250 = torch.ops.aten.view.default(view_245, [2, 8192, -1, 128]); view_245 = None + view_251 = torch.ops.aten.view.default(view_248, [2, 8192, -1, 128]); view_248 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 32, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_250, torch.float32); view_250 = None + view_253 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 8, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_253); view_253 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_16); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_255 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 32, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_16); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_256 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 8, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_256, torch.bfloat16); view_256 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 8, 4, 128]); unsqueeze_14 = None + clone_14 = torch.ops.aten.clone.default(expand_14, memory_format = torch.contiguous_format); expand_14 = None + view_257 = torch.ops.aten.view.default(clone_14, [2, 8192, 32, 128]); clone_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_251, 3); view_251 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 8, 4, 128]); unsqueeze_15 = None + clone_15 = torch.ops.aten.clone.default(expand_15, memory_format = torch.contiguous_format); expand_15 = None + view_258 = torch.ops.aten.view.default(clone_15, [2, 8192, 32, 128]); clone_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + permute_82 = torch.ops.aten.permute.default(view_258, [0, 2, 1, 3]); view_258 = None + _scaled_dot_product_cudnn_attention_7 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_80, permute_81, permute_82, None, True, 0.0, True); permute_80 = permute_81 = permute_82 = None + getitem_63 = _scaled_dot_product_cudnn_attention_7[0] + getitem_64 = _scaled_dot_product_cudnn_attention_7[1] + getitem_69 = _scaled_dot_product_cudnn_attention_7[6] + getitem_70 = _scaled_dot_product_cudnn_attention_7[7]; _scaled_dot_product_cudnn_attention_7 = None + permute_83 = torch.ops.aten.permute.default(getitem_63, [0, 2, 1, 3]) + view_259 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16) + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 64, '0'); convert_element_type_248 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_68, [1, 0]); wait_tensor_68 = None + view_261 = torch.ops.aten.view.default(view_259, [16384, 4096]); view_259 = None + mm_52 = torch.ops.aten.mm.default(view_261, permute_84); view_261 = permute_84 = None + view_262 = torch.ops.aten.view.default(mm_52, [2, 8192, 4096]); mm_52 = None + add_29 = torch.ops.aten.add.Tensor(add_27, view_262); view_262 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16) + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 64, '0'); convert_element_type_251 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32) + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = rsqrt_15 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_69); mul_60 = wait_tensor_69 = None + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16) + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 64, '0'); convert_element_type_254 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + view_265 = torch.ops.aten.view.default(convert_element_type_253, [16384, 4096]); convert_element_type_253 = None + mm_53 = torch.ops.aten.mm.default(view_265, permute_85); permute_85 = None + view_266 = torch.ops.aten.view.default(mm_53, [2, 8192, 14336]) + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_266, torch.float32); view_266 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); convert_element_type_257 = sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16) + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 64, '0'); convert_element_type_259 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_54 = torch.ops.aten.mm.default(view_265, permute_86); view_265 = permute_86 = None + view_269 = torch.ops.aten.view.default(mm_54, [2, 8192, 14336]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_269); convert_element_type_258 = view_269 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16) + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 64, '0'); convert_element_type_262 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + view_271 = torch.ops.aten.view.default(mul_63, [16384, 14336]); mul_63 = None + mm_55 = torch.ops.aten.mm.default(view_271, permute_87); view_271 = permute_87 = None + view_272 = torch.ops.aten.view.default(mm_55, [2, 8192, 4096]); mm_55 = None + add_31 = torch.ops.aten.add.Tensor(add_29, view_272); add_29 = view_272 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16) + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 64, '0'); convert_element_type_265 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32) + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = rsqrt_16 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_73); mul_64 = wait_tensor_73 = None + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16) + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 64, '0'); convert_element_type_268 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_74, [1, 0]); wait_tensor_74 = None + view_275 = torch.ops.aten.view.default(convert_element_type_267, [16384, 4096]); convert_element_type_267 = None + mm_56 = torch.ops.aten.mm.default(view_275, permute_88); permute_88 = None + view_276 = torch.ops.aten.view.default(mm_56, [2, 8192, 4096]) + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16) + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 64, '0'); convert_element_type_271 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_75, [1, 0]); wait_tensor_75 = None + mm_57 = torch.ops.aten.mm.default(view_275, permute_89); permute_89 = None + view_279 = torch.ops.aten.view.default(mm_57, [2, 8192, 1024]); mm_57 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16) + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 64, '0'); convert_element_type_274 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + mm_58 = torch.ops.aten.mm.default(view_275, permute_90); view_275 = permute_90 = None + view_282 = torch.ops.aten.view.default(mm_58, [2, 8192, 1024]) + view_283 = torch.ops.aten.view.default(view_276, [2, 8192, -1, 128]); view_276 = None + view_284 = torch.ops.aten.view.default(view_279, [2, 8192, -1, 128]); view_279 = None + view_285 = torch.ops.aten.view.default(view_282, [2, 8192, -1, 128]); view_282 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_283, torch.float32); view_283 = None + view_286 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 32, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_286); view_286 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_284, torch.float32); view_284 = None + view_287 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 8, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_287); view_287 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_16); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_289 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 32, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_16); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_290 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 8, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_289, torch.bfloat16); view_289 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_290, torch.bfloat16); view_290 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 8, 4, 128]); unsqueeze_16 = None + clone_16 = torch.ops.aten.clone.default(expand_16, memory_format = torch.contiguous_format); expand_16 = None + view_291 = torch.ops.aten.view.default(clone_16, [2, 8192, 32, 128]); clone_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_285, 3); view_285 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 8, 4, 128]); unsqueeze_17 = None + clone_17 = torch.ops.aten.clone.default(expand_17, memory_format = torch.contiguous_format); expand_17 = None + view_292 = torch.ops.aten.view.default(clone_17, [2, 8192, 32, 128]); clone_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_291, [0, 2, 1, 3]); view_291 = None + permute_93 = torch.ops.aten.permute.default(view_292, [0, 2, 1, 3]); view_292 = None + _scaled_dot_product_cudnn_attention_8 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_91, permute_92, permute_93, None, True, 0.0, True); permute_91 = permute_92 = permute_93 = None + getitem_72 = _scaled_dot_product_cudnn_attention_8[0] + getitem_73 = _scaled_dot_product_cudnn_attention_8[1] + getitem_78 = _scaled_dot_product_cudnn_attention_8[6] + getitem_79 = _scaled_dot_product_cudnn_attention_8[7]; _scaled_dot_product_cudnn_attention_8 = None + permute_94 = torch.ops.aten.permute.default(getitem_72, [0, 2, 1, 3]) + view_293 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16) + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 64, '0'); convert_element_type_281 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + view_295 = torch.ops.aten.view.default(view_293, [16384, 4096]); view_293 = None + mm_59 = torch.ops.aten.mm.default(view_295, permute_95); view_295 = permute_95 = None + view_296 = torch.ops.aten.view.default(mm_59, [2, 8192, 4096]); mm_59 = None + add_33 = torch.ops.aten.add.Tensor(add_31, view_296); view_296 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16) + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 64, '0'); convert_element_type_284 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32) + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = rsqrt_17 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_78); mul_68 = wait_tensor_78 = None + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16) + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 64, '0'); convert_element_type_287 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_79, [1, 0]); wait_tensor_79 = None + view_299 = torch.ops.aten.view.default(convert_element_type_286, [16384, 4096]); convert_element_type_286 = None + mm_60 = torch.ops.aten.mm.default(view_299, permute_96); permute_96 = None + view_300 = torch.ops.aten.view.default(mm_60, [2, 8192, 14336]) + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); convert_element_type_290 = sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16) + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 64, '0'); convert_element_type_292 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_80, [1, 0]); wait_tensor_80 = None + mm_61 = torch.ops.aten.mm.default(view_299, permute_97); view_299 = permute_97 = None + view_303 = torch.ops.aten.view.default(mm_61, [2, 8192, 14336]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_303); convert_element_type_291 = view_303 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16) + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 64, '0'); convert_element_type_295 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_81, [1, 0]); wait_tensor_81 = None + view_305 = torch.ops.aten.view.default(mul_71, [16384, 14336]); mul_71 = None + mm_62 = torch.ops.aten.mm.default(view_305, permute_98); view_305 = permute_98 = None + view_306 = torch.ops.aten.view.default(mm_62, [2, 8192, 4096]); mm_62 = None + add_35 = torch.ops.aten.add.Tensor(add_33, view_306); add_33 = view_306 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16) + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 64, '0'); convert_element_type_298 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32) + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = rsqrt_18 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_82); mul_72 = wait_tensor_82 = None + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16) + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 64, '0'); convert_element_type_301 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + view_309 = torch.ops.aten.view.default(convert_element_type_300, [16384, 4096]); convert_element_type_300 = None + mm_63 = torch.ops.aten.mm.default(view_309, permute_99); permute_99 = None + view_310 = torch.ops.aten.view.default(mm_63, [2, 8192, 4096]) + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16) + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 64, '0'); convert_element_type_304 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + mm_64 = torch.ops.aten.mm.default(view_309, permute_100); permute_100 = None + view_313 = torch.ops.aten.view.default(mm_64, [2, 8192, 1024]); mm_64 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16) + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 64, '0'); convert_element_type_307 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + mm_65 = torch.ops.aten.mm.default(view_309, permute_101); view_309 = permute_101 = None + view_316 = torch.ops.aten.view.default(mm_65, [2, 8192, 1024]) + view_317 = torch.ops.aten.view.default(view_310, [2, 8192, -1, 128]); view_310 = None + view_318 = torch.ops.aten.view.default(view_313, [2, 8192, -1, 128]); view_313 = None + view_319 = torch.ops.aten.view.default(view_316, [2, 8192, -1, 128]); view_316 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_317, torch.float32); view_317 = None + view_320 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 32, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_320); view_320 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_318, torch.float32); view_318 = None + view_321 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 8, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_321); view_321 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_16); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_323 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 32, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_16); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_324 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 8, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_323, torch.bfloat16); view_323 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_324, torch.bfloat16); view_324 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 8, 4, 128]); unsqueeze_18 = None + clone_18 = torch.ops.aten.clone.default(expand_18, memory_format = torch.contiguous_format); expand_18 = None + view_325 = torch.ops.aten.view.default(clone_18, [2, 8192, 32, 128]); clone_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_319, 3); view_319 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 8, 4, 128]); unsqueeze_19 = None + clone_19 = torch.ops.aten.clone.default(expand_19, memory_format = torch.contiguous_format); expand_19 = None + view_326 = torch.ops.aten.view.default(clone_19, [2, 8192, 32, 128]); clone_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_325, [0, 2, 1, 3]); view_325 = None + permute_104 = torch.ops.aten.permute.default(view_326, [0, 2, 1, 3]); view_326 = None + _scaled_dot_product_cudnn_attention_9 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_102, permute_103, permute_104, None, True, 0.0, True); permute_102 = permute_103 = permute_104 = None + getitem_81 = _scaled_dot_product_cudnn_attention_9[0] + getitem_82 = _scaled_dot_product_cudnn_attention_9[1] + getitem_87 = _scaled_dot_product_cudnn_attention_9[6] + getitem_88 = _scaled_dot_product_cudnn_attention_9[7]; _scaled_dot_product_cudnn_attention_9 = None + permute_105 = torch.ops.aten.permute.default(getitem_81, [0, 2, 1, 3]) + view_327 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16) + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 64, '0'); convert_element_type_314 = None + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_86, [1, 0]); wait_tensor_86 = None + view_329 = torch.ops.aten.view.default(view_327, [16384, 4096]); view_327 = None + mm_66 = torch.ops.aten.mm.default(view_329, permute_106); view_329 = permute_106 = None + view_330 = torch.ops.aten.view.default(mm_66, [2, 8192, 4096]); mm_66 = None + add_37 = torch.ops.aten.add.Tensor(add_35, view_330); view_330 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16) + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 64, '0'); convert_element_type_317 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32) + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = rsqrt_19 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_87); mul_76 = wait_tensor_87 = None + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16) + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 64, '0'); convert_element_type_320 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_88, [1, 0]); wait_tensor_88 = None + view_333 = torch.ops.aten.view.default(convert_element_type_319, [16384, 4096]); convert_element_type_319 = None + mm_67 = torch.ops.aten.mm.default(view_333, permute_107); permute_107 = None + view_334 = torch.ops.aten.view.default(mm_67, [2, 8192, 14336]) + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_334, torch.float32); view_334 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); convert_element_type_323 = sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16) + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 64, '0'); convert_element_type_325 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + mm_68 = torch.ops.aten.mm.default(view_333, permute_108); view_333 = permute_108 = None + view_337 = torch.ops.aten.view.default(mm_68, [2, 8192, 14336]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_337); convert_element_type_324 = view_337 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16) + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 64, '0'); convert_element_type_328 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + view_339 = torch.ops.aten.view.default(mul_79, [16384, 14336]); mul_79 = None + mm_69 = torch.ops.aten.mm.default(view_339, permute_109); view_339 = permute_109 = None + view_340 = torch.ops.aten.view.default(mm_69, [2, 8192, 4096]); mm_69 = None + add_39 = torch.ops.aten.add.Tensor(add_37, view_340); add_37 = view_340 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16) + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 64, '0'); convert_element_type_331 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32) + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = rsqrt_20 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_91); mul_80 = wait_tensor_91 = None + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16) + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 64, '0'); convert_element_type_334 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_92, [1, 0]); wait_tensor_92 = None + view_343 = torch.ops.aten.view.default(convert_element_type_333, [16384, 4096]); convert_element_type_333 = None + mm_70 = torch.ops.aten.mm.default(view_343, permute_110); permute_110 = None + view_344 = torch.ops.aten.view.default(mm_70, [2, 8192, 4096]) + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16) + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 64, '0'); convert_element_type_337 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_93, [1, 0]); wait_tensor_93 = None + mm_71 = torch.ops.aten.mm.default(view_343, permute_111); permute_111 = None + view_347 = torch.ops.aten.view.default(mm_71, [2, 8192, 1024]); mm_71 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16) + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 64, '0'); convert_element_type_340 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_94, [1, 0]); wait_tensor_94 = None + mm_72 = torch.ops.aten.mm.default(view_343, permute_112); view_343 = permute_112 = None + view_350 = torch.ops.aten.view.default(mm_72, [2, 8192, 1024]) + view_351 = torch.ops.aten.view.default(view_344, [2, 8192, -1, 128]); view_344 = None + view_352 = torch.ops.aten.view.default(view_347, [2, 8192, -1, 128]); view_347 = None + view_353 = torch.ops.aten.view.default(view_350, [2, 8192, -1, 128]); view_350 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_351, torch.float32); view_351 = None + view_354 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 32, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_354); view_354 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_352, torch.float32); view_352 = None + view_355 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 8, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_355); view_355 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_16); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_357 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 32, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_16); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_358 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 8, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_357, torch.bfloat16); view_357 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_358, torch.bfloat16); view_358 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 8, 4, 128]); unsqueeze_20 = None + clone_20 = torch.ops.aten.clone.default(expand_20, memory_format = torch.contiguous_format); expand_20 = None + view_359 = torch.ops.aten.view.default(clone_20, [2, 8192, 32, 128]); clone_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_353, 3); view_353 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 8, 4, 128]); unsqueeze_21 = None + clone_21 = torch.ops.aten.clone.default(expand_21, memory_format = torch.contiguous_format); expand_21 = None + view_360 = torch.ops.aten.view.default(clone_21, [2, 8192, 32, 128]); clone_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_359, [0, 2, 1, 3]); view_359 = None + permute_115 = torch.ops.aten.permute.default(view_360, [0, 2, 1, 3]); view_360 = None + _scaled_dot_product_cudnn_attention_10 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_113, permute_114, permute_115, None, True, 0.0, True); permute_113 = permute_114 = permute_115 = None + getitem_90 = _scaled_dot_product_cudnn_attention_10[0] + getitem_91 = _scaled_dot_product_cudnn_attention_10[1] + getitem_96 = _scaled_dot_product_cudnn_attention_10[6] + getitem_97 = _scaled_dot_product_cudnn_attention_10[7]; _scaled_dot_product_cudnn_attention_10 = None + permute_116 = torch.ops.aten.permute.default(getitem_90, [0, 2, 1, 3]) + view_361 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16) + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 64, '0'); convert_element_type_347 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + view_363 = torch.ops.aten.view.default(view_361, [16384, 4096]); view_361 = None + mm_73 = torch.ops.aten.mm.default(view_363, permute_117); view_363 = permute_117 = None + view_364 = torch.ops.aten.view.default(mm_73, [2, 8192, 4096]); mm_73 = None + add_41 = torch.ops.aten.add.Tensor(add_39, view_364); view_364 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16) + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 64, '0'); convert_element_type_350 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32) + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = rsqrt_21 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_96); mul_84 = wait_tensor_96 = None + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16) + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 64, '0'); convert_element_type_353 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + view_367 = torch.ops.aten.view.default(convert_element_type_352, [16384, 4096]); convert_element_type_352 = None + mm_74 = torch.ops.aten.mm.default(view_367, permute_118); permute_118 = None + view_368 = torch.ops.aten.view.default(mm_74, [2, 8192, 14336]) + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_368, torch.float32); view_368 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); convert_element_type_356 = sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16) + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 64, '0'); convert_element_type_358 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + mm_75 = torch.ops.aten.mm.default(view_367, permute_119); view_367 = permute_119 = None + view_371 = torch.ops.aten.view.default(mm_75, [2, 8192, 14336]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_371); convert_element_type_357 = view_371 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16) + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 64, '0'); convert_element_type_361 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_99, [1, 0]); wait_tensor_99 = None + view_373 = torch.ops.aten.view.default(mul_87, [16384, 14336]); mul_87 = None + mm_76 = torch.ops.aten.mm.default(view_373, permute_120); view_373 = permute_120 = None + view_374 = torch.ops.aten.view.default(mm_76, [2, 8192, 4096]); mm_76 = None + add_43 = torch.ops.aten.add.Tensor(add_41, view_374); add_41 = view_374 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16) + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 64, '0'); convert_element_type_364 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32) + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = rsqrt_22 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_100); mul_88 = wait_tensor_100 = None + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16) + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 64, '0'); convert_element_type_367 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_101, [1, 0]); wait_tensor_101 = None + view_377 = torch.ops.aten.view.default(convert_element_type_366, [16384, 4096]); convert_element_type_366 = None + mm_77 = torch.ops.aten.mm.default(view_377, permute_121); permute_121 = None + view_378 = torch.ops.aten.view.default(mm_77, [2, 8192, 4096]) + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16) + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 64, '0'); convert_element_type_370 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + mm_78 = torch.ops.aten.mm.default(view_377, permute_122); permute_122 = None + view_381 = torch.ops.aten.view.default(mm_78, [2, 8192, 1024]); mm_78 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16) + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 64, '0'); convert_element_type_373 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + mm_79 = torch.ops.aten.mm.default(view_377, permute_123); view_377 = permute_123 = None + view_384 = torch.ops.aten.view.default(mm_79, [2, 8192, 1024]) + view_385 = torch.ops.aten.view.default(view_378, [2, 8192, -1, 128]); view_378 = None + view_386 = torch.ops.aten.view.default(view_381, [2, 8192, -1, 128]); view_381 = None + view_387 = torch.ops.aten.view.default(view_384, [2, 8192, -1, 128]); view_384 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_385, torch.float32); view_385 = None + view_388 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 32, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_388); view_388 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_386, torch.float32); view_386 = None + view_389 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 8, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_389); view_389 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_16); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_391 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 32, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_16); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_392 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 8, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_391, torch.bfloat16); view_391 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_392, torch.bfloat16); view_392 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 8, 4, 128]); unsqueeze_22 = None + clone_22 = torch.ops.aten.clone.default(expand_22, memory_format = torch.contiguous_format); expand_22 = None + view_393 = torch.ops.aten.view.default(clone_22, [2, 8192, 32, 128]); clone_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_387, 3); view_387 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 8, 4, 128]); unsqueeze_23 = None + clone_23 = torch.ops.aten.clone.default(expand_23, memory_format = torch.contiguous_format); expand_23 = None + view_394 = torch.ops.aten.view.default(clone_23, [2, 8192, 32, 128]); clone_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_393, [0, 2, 1, 3]); view_393 = None + permute_126 = torch.ops.aten.permute.default(view_394, [0, 2, 1, 3]); view_394 = None + _scaled_dot_product_cudnn_attention_11 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_124, permute_125, permute_126, None, True, 0.0, True); permute_124 = permute_125 = permute_126 = None + getitem_99 = _scaled_dot_product_cudnn_attention_11[0] + getitem_100 = _scaled_dot_product_cudnn_attention_11[1] + getitem_105 = _scaled_dot_product_cudnn_attention_11[6] + getitem_106 = _scaled_dot_product_cudnn_attention_11[7]; _scaled_dot_product_cudnn_attention_11 = None + permute_127 = torch.ops.aten.permute.default(getitem_99, [0, 2, 1, 3]) + view_395 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16) + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 64, '0'); convert_element_type_380 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + view_397 = torch.ops.aten.view.default(view_395, [16384, 4096]); view_395 = None + mm_80 = torch.ops.aten.mm.default(view_397, permute_128); view_397 = permute_128 = None + view_398 = torch.ops.aten.view.default(mm_80, [2, 8192, 4096]); mm_80 = None + add_45 = torch.ops.aten.add.Tensor(add_43, view_398); view_398 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16) + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 64, '0'); convert_element_type_383 = None + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32) + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = rsqrt_23 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_105); mul_92 = wait_tensor_105 = None + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16) + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 64, '0'); convert_element_type_386 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_106, [1, 0]); wait_tensor_106 = None + view_401 = torch.ops.aten.view.default(convert_element_type_385, [16384, 4096]); convert_element_type_385 = None + mm_81 = torch.ops.aten.mm.default(view_401, permute_129); permute_129 = None + view_402 = torch.ops.aten.view.default(mm_81, [2, 8192, 14336]) + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_402, torch.float32); view_402 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); convert_element_type_389 = sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16) + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 64, '0'); convert_element_type_391 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_107, [1, 0]); wait_tensor_107 = None + mm_82 = torch.ops.aten.mm.default(view_401, permute_130); view_401 = permute_130 = None + view_405 = torch.ops.aten.view.default(mm_82, [2, 8192, 14336]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_405); convert_element_type_390 = view_405 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16) + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 64, '0'); convert_element_type_394 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + view_407 = torch.ops.aten.view.default(mul_95, [16384, 14336]); mul_95 = None + mm_83 = torch.ops.aten.mm.default(view_407, permute_131); view_407 = permute_131 = None + view_408 = torch.ops.aten.view.default(mm_83, [2, 8192, 4096]); mm_83 = None + add_47 = torch.ops.aten.add.Tensor(add_45, view_408); add_45 = view_408 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16) + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 64, '0'); convert_element_type_397 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32) + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = rsqrt_24 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_109); mul_96 = wait_tensor_109 = None + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16) + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 64, '0'); convert_element_type_400 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + view_411 = torch.ops.aten.view.default(convert_element_type_399, [16384, 4096]); convert_element_type_399 = None + mm_84 = torch.ops.aten.mm.default(view_411, permute_132); permute_132 = None + view_412 = torch.ops.aten.view.default(mm_84, [2, 8192, 4096]) + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16) + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 64, '0'); convert_element_type_403 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + mm_85 = torch.ops.aten.mm.default(view_411, permute_133); permute_133 = None + view_415 = torch.ops.aten.view.default(mm_85, [2, 8192, 1024]); mm_85 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16) + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 64, '0'); convert_element_type_406 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_112, [1, 0]); wait_tensor_112 = None + mm_86 = torch.ops.aten.mm.default(view_411, permute_134); view_411 = permute_134 = None + view_418 = torch.ops.aten.view.default(mm_86, [2, 8192, 1024]) + view_419 = torch.ops.aten.view.default(view_412, [2, 8192, -1, 128]); view_412 = None + view_420 = torch.ops.aten.view.default(view_415, [2, 8192, -1, 128]); view_415 = None + view_421 = torch.ops.aten.view.default(view_418, [2, 8192, -1, 128]); view_418 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_419, torch.float32); view_419 = None + view_422 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 32, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_422); view_422 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_420, torch.float32); view_420 = None + view_423 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 8, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_423); view_423 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_16); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_425 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 32, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_16); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_426 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 8, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_425, torch.bfloat16); view_425 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_426, torch.bfloat16); view_426 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 8, 4, 128]); unsqueeze_24 = None + clone_24 = torch.ops.aten.clone.default(expand_24, memory_format = torch.contiguous_format); expand_24 = None + view_427 = torch.ops.aten.view.default(clone_24, [2, 8192, 32, 128]); clone_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_421, 3); view_421 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 8, 4, 128]); unsqueeze_25 = None + clone_25 = torch.ops.aten.clone.default(expand_25, memory_format = torch.contiguous_format); expand_25 = None + view_428 = torch.ops.aten.view.default(clone_25, [2, 8192, 32, 128]); clone_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_427, [0, 2, 1, 3]); view_427 = None + permute_137 = torch.ops.aten.permute.default(view_428, [0, 2, 1, 3]); view_428 = None + _scaled_dot_product_cudnn_attention_12 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_135, permute_136, permute_137, None, True, 0.0, True); permute_135 = permute_136 = permute_137 = None + getitem_108 = _scaled_dot_product_cudnn_attention_12[0] + getitem_109 = _scaled_dot_product_cudnn_attention_12[1] + getitem_114 = _scaled_dot_product_cudnn_attention_12[6] + getitem_115 = _scaled_dot_product_cudnn_attention_12[7]; _scaled_dot_product_cudnn_attention_12 = None + permute_138 = torch.ops.aten.permute.default(getitem_108, [0, 2, 1, 3]) + view_429 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16) + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 64, '0'); convert_element_type_413 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_113, [1, 0]); wait_tensor_113 = None + view_431 = torch.ops.aten.view.default(view_429, [16384, 4096]); view_429 = None + mm_87 = torch.ops.aten.mm.default(view_431, permute_139); view_431 = permute_139 = None + view_432 = torch.ops.aten.view.default(mm_87, [2, 8192, 4096]); mm_87 = None + add_49 = torch.ops.aten.add.Tensor(add_47, view_432); view_432 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16) + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 64, '0'); convert_element_type_416 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32) + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = rsqrt_25 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_114); mul_100 = wait_tensor_114 = None + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16) + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 64, '0'); convert_element_type_419 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + view_435 = torch.ops.aten.view.default(convert_element_type_418, [16384, 4096]); convert_element_type_418 = None + mm_88 = torch.ops.aten.mm.default(view_435, permute_140); permute_140 = None + view_436 = torch.ops.aten.view.default(mm_88, [2, 8192, 14336]) + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_436, torch.float32); view_436 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); convert_element_type_422 = sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16) + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 64, '0'); convert_element_type_424 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_89 = torch.ops.aten.mm.default(view_435, permute_141); view_435 = permute_141 = None + view_439 = torch.ops.aten.view.default(mm_89, [2, 8192, 14336]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_439); convert_element_type_423 = view_439 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16) + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 64, '0'); convert_element_type_427 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + view_441 = torch.ops.aten.view.default(mul_103, [16384, 14336]); mul_103 = None + mm_90 = torch.ops.aten.mm.default(view_441, permute_142); view_441 = permute_142 = None + view_442 = torch.ops.aten.view.default(mm_90, [2, 8192, 4096]); mm_90 = None + add_51 = torch.ops.aten.add.Tensor(add_49, view_442); add_49 = view_442 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16) + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 64, '0'); convert_element_type_430 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32) + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = rsqrt_26 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_118); mul_104 = wait_tensor_118 = None + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16) + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 64, '0'); convert_element_type_433 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_119, [1, 0]); wait_tensor_119 = None + view_445 = torch.ops.aten.view.default(convert_element_type_432, [16384, 4096]); convert_element_type_432 = None + mm_91 = torch.ops.aten.mm.default(view_445, permute_143); permute_143 = None + view_446 = torch.ops.aten.view.default(mm_91, [2, 8192, 4096]) + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16) + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 64, '0'); convert_element_type_436 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_120, [1, 0]); wait_tensor_120 = None + mm_92 = torch.ops.aten.mm.default(view_445, permute_144); permute_144 = None + view_449 = torch.ops.aten.view.default(mm_92, [2, 8192, 1024]); mm_92 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16) + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 64, '0'); convert_element_type_439 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + mm_93 = torch.ops.aten.mm.default(view_445, permute_145); view_445 = permute_145 = None + view_452 = torch.ops.aten.view.default(mm_93, [2, 8192, 1024]) + view_453 = torch.ops.aten.view.default(view_446, [2, 8192, -1, 128]); view_446 = None + view_454 = torch.ops.aten.view.default(view_449, [2, 8192, -1, 128]); view_449 = None + view_455 = torch.ops.aten.view.default(view_452, [2, 8192, -1, 128]); view_452 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_453, torch.float32); view_453 = None + view_456 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 32, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_456); view_456 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_454, torch.float32); view_454 = None + view_457 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 8, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_457); view_457 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_16); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_459 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 32, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_16); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_460 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 8, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_459, torch.bfloat16); view_459 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_460, torch.bfloat16); view_460 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 8, 4, 128]); unsqueeze_26 = None + clone_26 = torch.ops.aten.clone.default(expand_26, memory_format = torch.contiguous_format); expand_26 = None + view_461 = torch.ops.aten.view.default(clone_26, [2, 8192, 32, 128]); clone_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_455, 3); view_455 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 8, 4, 128]); unsqueeze_27 = None + clone_27 = torch.ops.aten.clone.default(expand_27, memory_format = torch.contiguous_format); expand_27 = None + view_462 = torch.ops.aten.view.default(clone_27, [2, 8192, 32, 128]); clone_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_461, [0, 2, 1, 3]); view_461 = None + permute_148 = torch.ops.aten.permute.default(view_462, [0, 2, 1, 3]); view_462 = None + _scaled_dot_product_cudnn_attention_13 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_146, permute_147, permute_148, None, True, 0.0, True); permute_146 = permute_147 = permute_148 = None + getitem_117 = _scaled_dot_product_cudnn_attention_13[0] + getitem_118 = _scaled_dot_product_cudnn_attention_13[1] + getitem_123 = _scaled_dot_product_cudnn_attention_13[6] + getitem_124 = _scaled_dot_product_cudnn_attention_13[7]; _scaled_dot_product_cudnn_attention_13 = None + permute_149 = torch.ops.aten.permute.default(getitem_117, [0, 2, 1, 3]) + view_463 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16) + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 64, '0'); convert_element_type_446 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + view_465 = torch.ops.aten.view.default(view_463, [16384, 4096]); view_463 = None + mm_94 = torch.ops.aten.mm.default(view_465, permute_150); view_465 = permute_150 = None + view_466 = torch.ops.aten.view.default(mm_94, [2, 8192, 4096]); mm_94 = None + add_53 = torch.ops.aten.add.Tensor(add_51, view_466); view_466 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16) + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 64, '0'); convert_element_type_449 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32) + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = rsqrt_27 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_123); mul_108 = wait_tensor_123 = None + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16) + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 64, '0'); convert_element_type_452 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + view_469 = torch.ops.aten.view.default(convert_element_type_451, [16384, 4096]); convert_element_type_451 = None + mm_95 = torch.ops.aten.mm.default(view_469, permute_151); permute_151 = None + view_470 = torch.ops.aten.view.default(mm_95, [2, 8192, 14336]) + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_470, torch.float32); view_470 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); convert_element_type_455 = sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16) + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 64, '0'); convert_element_type_457 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_125, [1, 0]); wait_tensor_125 = None + mm_96 = torch.ops.aten.mm.default(view_469, permute_152); view_469 = permute_152 = None + view_473 = torch.ops.aten.view.default(mm_96, [2, 8192, 14336]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_473); convert_element_type_456 = view_473 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16) + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 64, '0'); convert_element_type_460 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_126, [1, 0]); wait_tensor_126 = None + view_475 = torch.ops.aten.view.default(mul_111, [16384, 14336]); mul_111 = None + mm_97 = torch.ops.aten.mm.default(view_475, permute_153); view_475 = permute_153 = None + view_476 = torch.ops.aten.view.default(mm_97, [2, 8192, 4096]); mm_97 = None + add_55 = torch.ops.aten.add.Tensor(add_53, view_476); add_53 = view_476 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16) + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 64, '0'); convert_element_type_463 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32) + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = rsqrt_28 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_127); mul_112 = wait_tensor_127 = None + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16) + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 64, '0'); convert_element_type_466 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + view_479 = torch.ops.aten.view.default(convert_element_type_465, [16384, 4096]); convert_element_type_465 = None + mm_98 = torch.ops.aten.mm.default(view_479, permute_154); permute_154 = None + view_480 = torch.ops.aten.view.default(mm_98, [2, 8192, 4096]) + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16) + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 64, '0'); convert_element_type_469 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_99 = torch.ops.aten.mm.default(view_479, permute_155); permute_155 = None + view_483 = torch.ops.aten.view.default(mm_99, [2, 8192, 1024]); mm_99 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16) + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 64, '0'); convert_element_type_472 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + mm_100 = torch.ops.aten.mm.default(view_479, permute_156); view_479 = permute_156 = None + view_486 = torch.ops.aten.view.default(mm_100, [2, 8192, 1024]) + view_487 = torch.ops.aten.view.default(view_480, [2, 8192, -1, 128]); view_480 = None + view_488 = torch.ops.aten.view.default(view_483, [2, 8192, -1, 128]); view_483 = None + view_489 = torch.ops.aten.view.default(view_486, [2, 8192, -1, 128]); view_486 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_487, torch.float32); view_487 = None + view_490 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 32, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_490); view_490 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_488, torch.float32); view_488 = None + view_491 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 8, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_491); view_491 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_16); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_493 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 32, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_16); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_494 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 8, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_493, torch.bfloat16); view_493 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_494, torch.bfloat16); view_494 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 8, 4, 128]); unsqueeze_28 = None + clone_28 = torch.ops.aten.clone.default(expand_28, memory_format = torch.contiguous_format); expand_28 = None + view_495 = torch.ops.aten.view.default(clone_28, [2, 8192, 32, 128]); clone_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_489, 3); view_489 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 8, 4, 128]); unsqueeze_29 = None + clone_29 = torch.ops.aten.clone.default(expand_29, memory_format = torch.contiguous_format); expand_29 = None + view_496 = torch.ops.aten.view.default(clone_29, [2, 8192, 32, 128]); clone_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_495, [0, 2, 1, 3]); view_495 = None + permute_159 = torch.ops.aten.permute.default(view_496, [0, 2, 1, 3]); view_496 = None + _scaled_dot_product_cudnn_attention_14 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_157, permute_158, permute_159, None, True, 0.0, True); permute_157 = permute_158 = permute_159 = None + getitem_126 = _scaled_dot_product_cudnn_attention_14[0] + getitem_127 = _scaled_dot_product_cudnn_attention_14[1] + getitem_132 = _scaled_dot_product_cudnn_attention_14[6] + getitem_133 = _scaled_dot_product_cudnn_attention_14[7]; _scaled_dot_product_cudnn_attention_14 = None + permute_160 = torch.ops.aten.permute.default(getitem_126, [0, 2, 1, 3]) + view_497 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16) + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 64, '0'); convert_element_type_479 = None + wait_tensor_131 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_131, [1, 0]); wait_tensor_131 = None + view_499 = torch.ops.aten.view.default(view_497, [16384, 4096]); view_497 = None + mm_101 = torch.ops.aten.mm.default(view_499, permute_161); view_499 = permute_161 = None + view_500 = torch.ops.aten.view.default(mm_101, [2, 8192, 4096]); mm_101 = None + add_57 = torch.ops.aten.add.Tensor(add_55, view_500); view_500 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16) + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 64, '0'); convert_element_type_482 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32) + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = rsqrt_29 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_132); mul_116 = wait_tensor_132 = None + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16) + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 64, '0'); convert_element_type_485 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_133, [1, 0]); wait_tensor_133 = None + view_503 = torch.ops.aten.view.default(convert_element_type_484, [16384, 4096]); convert_element_type_484 = None + mm_102 = torch.ops.aten.mm.default(view_503, permute_162); permute_162 = None + view_504 = torch.ops.aten.view.default(mm_102, [2, 8192, 14336]) + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_504, torch.float32); view_504 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); convert_element_type_488 = sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16) + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 64, '0'); convert_element_type_490 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + mm_103 = torch.ops.aten.mm.default(view_503, permute_163); view_503 = permute_163 = None + view_507 = torch.ops.aten.view.default(mm_103, [2, 8192, 14336]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_507); convert_element_type_489 = view_507 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16) + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 64, '0'); convert_element_type_493 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + view_509 = torch.ops.aten.view.default(mul_119, [16384, 14336]); mul_119 = None + mm_104 = torch.ops.aten.mm.default(view_509, permute_164); view_509 = permute_164 = None + view_510 = torch.ops.aten.view.default(mm_104, [2, 8192, 4096]); mm_104 = None + add_59 = torch.ops.aten.add.Tensor(add_57, view_510); add_57 = view_510 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16) + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 64, '0'); convert_element_type_496 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32) + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = rsqrt_30 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_136); mul_120 = wait_tensor_136 = None + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16) + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 64, '0'); convert_element_type_499 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + view_513 = torch.ops.aten.view.default(convert_element_type_498, [16384, 4096]); convert_element_type_498 = None + mm_105 = torch.ops.aten.mm.default(view_513, permute_165); permute_165 = None + view_514 = torch.ops.aten.view.default(mm_105, [2, 8192, 4096]) + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16) + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 64, '0'); convert_element_type_502 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_138, [1, 0]); wait_tensor_138 = None + mm_106 = torch.ops.aten.mm.default(view_513, permute_166); permute_166 = None + view_517 = torch.ops.aten.view.default(mm_106, [2, 8192, 1024]); mm_106 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16) + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 64, '0'); convert_element_type_505 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_139, [1, 0]); wait_tensor_139 = None + mm_107 = torch.ops.aten.mm.default(view_513, permute_167); view_513 = permute_167 = None + view_520 = torch.ops.aten.view.default(mm_107, [2, 8192, 1024]) + view_521 = torch.ops.aten.view.default(view_514, [2, 8192, -1, 128]); view_514 = None + view_522 = torch.ops.aten.view.default(view_517, [2, 8192, -1, 128]); view_517 = None + view_523 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_521, torch.float32); view_521 = None + view_524 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 32, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_524); view_524 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_522, torch.float32); view_522 = None + view_525 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 8, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_525); view_525 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_16); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_527 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 32, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_16); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_528 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 8, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_527, torch.bfloat16); view_527 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_528, torch.bfloat16); view_528 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 8, 4, 128]); unsqueeze_30 = None + clone_30 = torch.ops.aten.clone.default(expand_30, memory_format = torch.contiguous_format); expand_30 = None + view_529 = torch.ops.aten.view.default(clone_30, [2, 8192, 32, 128]); clone_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_523, 3); view_523 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 8, 4, 128]); unsqueeze_31 = None + clone_31 = torch.ops.aten.clone.default(expand_31, memory_format = torch.contiguous_format); expand_31 = None + view_530 = torch.ops.aten.view.default(clone_31, [2, 8192, 32, 128]); clone_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_529, [0, 2, 1, 3]); view_529 = None + permute_170 = torch.ops.aten.permute.default(view_530, [0, 2, 1, 3]); view_530 = None + _scaled_dot_product_cudnn_attention_15 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_168, permute_169, permute_170, None, True, 0.0, True); permute_168 = permute_169 = permute_170 = None + getitem_135 = _scaled_dot_product_cudnn_attention_15[0] + getitem_136 = _scaled_dot_product_cudnn_attention_15[1] + getitem_141 = _scaled_dot_product_cudnn_attention_15[6] + getitem_142 = _scaled_dot_product_cudnn_attention_15[7]; _scaled_dot_product_cudnn_attention_15 = None + permute_171 = torch.ops.aten.permute.default(getitem_135, [0, 2, 1, 3]) + view_531 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16) + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 64, '0'); convert_element_type_512 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_140, [1, 0]); wait_tensor_140 = None + view_533 = torch.ops.aten.view.default(view_531, [16384, 4096]); view_531 = None + mm_108 = torch.ops.aten.mm.default(view_533, permute_172); view_533 = permute_172 = None + view_534 = torch.ops.aten.view.default(mm_108, [2, 8192, 4096]); mm_108 = None + add_61 = torch.ops.aten.add.Tensor(add_59, view_534); view_534 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16) + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 64, '0'); convert_element_type_515 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32) + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = rsqrt_31 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_141); mul_124 = wait_tensor_141 = None + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16) + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 64, '0'); convert_element_type_518 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + view_537 = torch.ops.aten.view.default(convert_element_type_517, [16384, 4096]); convert_element_type_517 = None + mm_109 = torch.ops.aten.mm.default(view_537, permute_173); permute_173 = None + view_538 = torch.ops.aten.view.default(mm_109, [2, 8192, 14336]) + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_538, torch.float32); view_538 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); convert_element_type_521 = sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16) + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 64, '0'); convert_element_type_523 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + mm_110 = torch.ops.aten.mm.default(view_537, permute_174); view_537 = permute_174 = None + view_541 = torch.ops.aten.view.default(mm_110, [2, 8192, 14336]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_541); convert_element_type_522 = view_541 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16) + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 64, '0'); convert_element_type_526 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_144, [1, 0]); wait_tensor_144 = None + view_543 = torch.ops.aten.view.default(mul_127, [16384, 14336]); mul_127 = None + mm_111 = torch.ops.aten.mm.default(view_543, permute_175); view_543 = permute_175 = None + view_544 = torch.ops.aten.view.default(mm_111, [2, 8192, 4096]); mm_111 = None + add_63 = torch.ops.aten.add.Tensor(add_61, view_544); add_61 = view_544 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16) + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 64, '0'); convert_element_type_529 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32) + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = rsqrt_32 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_145); mul_128 = wait_tensor_145 = None + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16) + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 64, '0'); convert_element_type_532 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_146, [1, 0]); wait_tensor_146 = None + view_547 = torch.ops.aten.view.default(convert_element_type_531, [16384, 4096]); convert_element_type_531 = None + mm_112 = torch.ops.aten.mm.default(view_547, permute_176); permute_176 = None + view_548 = torch.ops.aten.view.default(mm_112, [2, 8192, 4096]) + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16) + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 64, '0'); convert_element_type_535 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + mm_113 = torch.ops.aten.mm.default(view_547, permute_177); permute_177 = None + view_551 = torch.ops.aten.view.default(mm_113, [2, 8192, 1024]); mm_113 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16) + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 64, '0'); convert_element_type_538 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + mm_114 = torch.ops.aten.mm.default(view_547, permute_178); view_547 = permute_178 = None + view_554 = torch.ops.aten.view.default(mm_114, [2, 8192, 1024]) + view_555 = torch.ops.aten.view.default(view_548, [2, 8192, -1, 128]); view_548 = None + view_556 = torch.ops.aten.view.default(view_551, [2, 8192, -1, 128]); view_551 = None + view_557 = torch.ops.aten.view.default(view_554, [2, 8192, -1, 128]); view_554 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_555, torch.float32); view_555 = None + view_558 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 32, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_558); view_558 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_556, torch.float32); view_556 = None + view_559 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 8, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_559); view_559 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_16); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_561 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 32, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_16); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_562 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 8, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_561, torch.bfloat16); view_561 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_562, torch.bfloat16); view_562 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 8, 4, 128]); unsqueeze_32 = None + clone_32 = torch.ops.aten.clone.default(expand_32, memory_format = torch.contiguous_format); expand_32 = None + view_563 = torch.ops.aten.view.default(clone_32, [2, 8192, 32, 128]); clone_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_557, 3); view_557 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 8, 4, 128]); unsqueeze_33 = None + clone_33 = torch.ops.aten.clone.default(expand_33, memory_format = torch.contiguous_format); expand_33 = None + view_564 = torch.ops.aten.view.default(clone_33, [2, 8192, 32, 128]); clone_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_563, [0, 2, 1, 3]); view_563 = None + permute_181 = torch.ops.aten.permute.default(view_564, [0, 2, 1, 3]); view_564 = None + _scaled_dot_product_cudnn_attention_16 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_179, permute_180, permute_181, None, True, 0.0, True); permute_179 = permute_180 = permute_181 = None + getitem_144 = _scaled_dot_product_cudnn_attention_16[0] + getitem_145 = _scaled_dot_product_cudnn_attention_16[1] + getitem_150 = _scaled_dot_product_cudnn_attention_16[6] + getitem_151 = _scaled_dot_product_cudnn_attention_16[7]; _scaled_dot_product_cudnn_attention_16 = None + permute_182 = torch.ops.aten.permute.default(getitem_144, [0, 2, 1, 3]) + view_565 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16) + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 64, '0'); convert_element_type_545 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + view_567 = torch.ops.aten.view.default(view_565, [16384, 4096]); view_565 = None + mm_115 = torch.ops.aten.mm.default(view_567, permute_183); view_567 = permute_183 = None + view_568 = torch.ops.aten.view.default(mm_115, [2, 8192, 4096]); mm_115 = None + add_65 = torch.ops.aten.add.Tensor(add_63, view_568); view_568 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16) + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 64, '0'); convert_element_type_548 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32) + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = rsqrt_33 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_150); mul_132 = wait_tensor_150 = None + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16) + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 64, '0'); convert_element_type_551 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_151, [1, 0]); wait_tensor_151 = None + view_571 = torch.ops.aten.view.default(convert_element_type_550, [16384, 4096]); convert_element_type_550 = None + mm_116 = torch.ops.aten.mm.default(view_571, permute_184); permute_184 = None + view_572 = torch.ops.aten.view.default(mm_116, [2, 8192, 14336]) + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_572, torch.float32); view_572 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); convert_element_type_554 = sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16) + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 64, '0'); convert_element_type_556 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_152, [1, 0]); wait_tensor_152 = None + mm_117 = torch.ops.aten.mm.default(view_571, permute_185); view_571 = permute_185 = None + view_575 = torch.ops.aten.view.default(mm_117, [2, 8192, 14336]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_575); convert_element_type_555 = view_575 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16) + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 64, '0'); convert_element_type_559 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_153, [1, 0]); wait_tensor_153 = None + view_577 = torch.ops.aten.view.default(mul_135, [16384, 14336]); mul_135 = None + mm_118 = torch.ops.aten.mm.default(view_577, permute_186); view_577 = permute_186 = None + view_578 = torch.ops.aten.view.default(mm_118, [2, 8192, 4096]); mm_118 = None + add_67 = torch.ops.aten.add.Tensor(add_65, view_578); add_65 = view_578 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16) + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 64, '0'); convert_element_type_562 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32) + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = rsqrt_34 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_154); mul_136 = wait_tensor_154 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16) + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 64, '0'); convert_element_type_565 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + view_581 = torch.ops.aten.view.default(convert_element_type_564, [16384, 4096]); convert_element_type_564 = None + mm_119 = torch.ops.aten.mm.default(view_581, permute_187); permute_187 = None + view_582 = torch.ops.aten.view.default(mm_119, [2, 8192, 4096]) + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16) + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 64, '0'); convert_element_type_568 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + mm_120 = torch.ops.aten.mm.default(view_581, permute_188); permute_188 = None + view_585 = torch.ops.aten.view.default(mm_120, [2, 8192, 1024]); mm_120 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16) + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 64, '0'); convert_element_type_571 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_157, [1, 0]); wait_tensor_157 = None + mm_121 = torch.ops.aten.mm.default(view_581, permute_189); view_581 = permute_189 = None + view_588 = torch.ops.aten.view.default(mm_121, [2, 8192, 1024]) + view_589 = torch.ops.aten.view.default(view_582, [2, 8192, -1, 128]); view_582 = None + view_590 = torch.ops.aten.view.default(view_585, [2, 8192, -1, 128]); view_585 = None + view_591 = torch.ops.aten.view.default(view_588, [2, 8192, -1, 128]); view_588 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_589, torch.float32); view_589 = None + view_592 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 32, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_592); view_592 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_590, torch.float32); view_590 = None + view_593 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 8, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_593); view_593 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_16); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_595 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 32, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_16); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_596 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 8, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_595, torch.bfloat16); view_595 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_596, torch.bfloat16); view_596 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 8, 4, 128]); unsqueeze_34 = None + clone_34 = torch.ops.aten.clone.default(expand_34, memory_format = torch.contiguous_format); expand_34 = None + view_597 = torch.ops.aten.view.default(clone_34, [2, 8192, 32, 128]); clone_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_591, 3); view_591 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 8, 4, 128]); unsqueeze_35 = None + clone_35 = torch.ops.aten.clone.default(expand_35, memory_format = torch.contiguous_format); expand_35 = None + view_598 = torch.ops.aten.view.default(clone_35, [2, 8192, 32, 128]); clone_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_597, [0, 2, 1, 3]); view_597 = None + permute_192 = torch.ops.aten.permute.default(view_598, [0, 2, 1, 3]); view_598 = None + _scaled_dot_product_cudnn_attention_17 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_190, permute_191, permute_192, None, True, 0.0, True); permute_190 = permute_191 = permute_192 = None + getitem_153 = _scaled_dot_product_cudnn_attention_17[0] + getitem_154 = _scaled_dot_product_cudnn_attention_17[1] + getitem_159 = _scaled_dot_product_cudnn_attention_17[6] + getitem_160 = _scaled_dot_product_cudnn_attention_17[7]; _scaled_dot_product_cudnn_attention_17 = None + permute_193 = torch.ops.aten.permute.default(getitem_153, [0, 2, 1, 3]) + view_599 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16) + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 64, '0'); convert_element_type_578 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_158, [1, 0]); wait_tensor_158 = None + view_601 = torch.ops.aten.view.default(view_599, [16384, 4096]); view_599 = None + mm_122 = torch.ops.aten.mm.default(view_601, permute_194); view_601 = permute_194 = None + view_602 = torch.ops.aten.view.default(mm_122, [2, 8192, 4096]); mm_122 = None + add_69 = torch.ops.aten.add.Tensor(add_67, view_602); view_602 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16) + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 64, '0'); convert_element_type_581 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32) + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = rsqrt_35 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_159); mul_140 = wait_tensor_159 = None + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16) + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 64, '0'); convert_element_type_584 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + view_605 = torch.ops.aten.view.default(convert_element_type_583, [16384, 4096]); convert_element_type_583 = None + mm_123 = torch.ops.aten.mm.default(view_605, permute_195); permute_195 = None + view_606 = torch.ops.aten.view.default(mm_123, [2, 8192, 14336]) + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_606, torch.float32); view_606 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); convert_element_type_587 = sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16) + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 64, '0'); convert_element_type_589 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_124 = torch.ops.aten.mm.default(view_605, permute_196); view_605 = permute_196 = None + view_609 = torch.ops.aten.view.default(mm_124, [2, 8192, 14336]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_609); convert_element_type_588 = view_609 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16) + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 64, '0'); convert_element_type_592 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + view_611 = torch.ops.aten.view.default(mul_143, [16384, 14336]); mul_143 = None + mm_125 = torch.ops.aten.mm.default(view_611, permute_197); view_611 = permute_197 = None + view_612 = torch.ops.aten.view.default(mm_125, [2, 8192, 4096]); mm_125 = None + add_71 = torch.ops.aten.add.Tensor(add_69, view_612); add_69 = view_612 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16) + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 64, '0'); convert_element_type_595 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32) + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = rsqrt_36 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_163); mul_144 = wait_tensor_163 = None + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16) + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 64, '0'); convert_element_type_598 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_164, [1, 0]); wait_tensor_164 = None + view_615 = torch.ops.aten.view.default(convert_element_type_597, [16384, 4096]); convert_element_type_597 = None + mm_126 = torch.ops.aten.mm.default(view_615, permute_198); permute_198 = None + view_616 = torch.ops.aten.view.default(mm_126, [2, 8192, 4096]) + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16) + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 64, '0'); convert_element_type_601 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_165, [1, 0]); wait_tensor_165 = None + mm_127 = torch.ops.aten.mm.default(view_615, permute_199); permute_199 = None + view_619 = torch.ops.aten.view.default(mm_127, [2, 8192, 1024]); mm_127 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16) + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 64, '0'); convert_element_type_604 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_166, [1, 0]); wait_tensor_166 = None + mm_128 = torch.ops.aten.mm.default(view_615, permute_200); view_615 = permute_200 = None + view_622 = torch.ops.aten.view.default(mm_128, [2, 8192, 1024]) + view_623 = torch.ops.aten.view.default(view_616, [2, 8192, -1, 128]); view_616 = None + view_624 = torch.ops.aten.view.default(view_619, [2, 8192, -1, 128]); view_619 = None + view_625 = torch.ops.aten.view.default(view_622, [2, 8192, -1, 128]); view_622 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_623, torch.float32); view_623 = None + view_626 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 32, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_626); view_626 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_624, torch.float32); view_624 = None + view_627 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 8, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_627); view_627 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_16); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_629 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 32, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_16); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_630 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 8, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_629, torch.bfloat16); view_629 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_630, torch.bfloat16); view_630 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 8, 4, 128]); unsqueeze_36 = None + clone_36 = torch.ops.aten.clone.default(expand_36, memory_format = torch.contiguous_format); expand_36 = None + view_631 = torch.ops.aten.view.default(clone_36, [2, 8192, 32, 128]); clone_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_625, 3); view_625 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 8, 4, 128]); unsqueeze_37 = None + clone_37 = torch.ops.aten.clone.default(expand_37, memory_format = torch.contiguous_format); expand_37 = None + view_632 = torch.ops.aten.view.default(clone_37, [2, 8192, 32, 128]); clone_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_631, [0, 2, 1, 3]); view_631 = None + permute_203 = torch.ops.aten.permute.default(view_632, [0, 2, 1, 3]); view_632 = None + _scaled_dot_product_cudnn_attention_18 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_201, permute_202, permute_203, None, True, 0.0, True); permute_201 = permute_202 = permute_203 = None + getitem_162 = _scaled_dot_product_cudnn_attention_18[0] + getitem_163 = _scaled_dot_product_cudnn_attention_18[1] + getitem_168 = _scaled_dot_product_cudnn_attention_18[6] + getitem_169 = _scaled_dot_product_cudnn_attention_18[7]; _scaled_dot_product_cudnn_attention_18 = None + permute_204 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_633 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16) + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 64, '0'); convert_element_type_611 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + view_635 = torch.ops.aten.view.default(view_633, [16384, 4096]); view_633 = None + mm_129 = torch.ops.aten.mm.default(view_635, permute_205); view_635 = permute_205 = None + view_636 = torch.ops.aten.view.default(mm_129, [2, 8192, 4096]); mm_129 = None + add_73 = torch.ops.aten.add.Tensor(add_71, view_636); view_636 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16) + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 64, '0'); convert_element_type_614 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32) + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = rsqrt_37 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_168); mul_148 = wait_tensor_168 = None + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16) + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 64, '0'); convert_element_type_617 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + view_639 = torch.ops.aten.view.default(convert_element_type_616, [16384, 4096]); convert_element_type_616 = None + mm_130 = torch.ops.aten.mm.default(view_639, permute_206); permute_206 = None + view_640 = torch.ops.aten.view.default(mm_130, [2, 8192, 14336]) + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_640, torch.float32); view_640 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); convert_element_type_620 = sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16) + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 64, '0'); convert_element_type_622 = None + wait_tensor_170 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_170, [1, 0]); wait_tensor_170 = None + mm_131 = torch.ops.aten.mm.default(view_639, permute_207); view_639 = permute_207 = None + view_643 = torch.ops.aten.view.default(mm_131, [2, 8192, 14336]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_643); convert_element_type_621 = view_643 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16) + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 64, '0'); convert_element_type_625 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_171, [1, 0]); wait_tensor_171 = None + view_645 = torch.ops.aten.view.default(mul_151, [16384, 14336]); mul_151 = None + mm_132 = torch.ops.aten.mm.default(view_645, permute_208); view_645 = permute_208 = None + view_646 = torch.ops.aten.view.default(mm_132, [2, 8192, 4096]); mm_132 = None + add_75 = torch.ops.aten.add.Tensor(add_73, view_646); add_73 = view_646 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16) + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 64, '0'); convert_element_type_628 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32) + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = rsqrt_38 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_172); mul_152 = wait_tensor_172 = None + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16) + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 64, '0'); convert_element_type_631 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + view_649 = torch.ops.aten.view.default(convert_element_type_630, [16384, 4096]); convert_element_type_630 = None + mm_133 = torch.ops.aten.mm.default(view_649, permute_209); permute_209 = None + view_650 = torch.ops.aten.view.default(mm_133, [2, 8192, 4096]) + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16) + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 64, '0'); convert_element_type_634 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_134 = torch.ops.aten.mm.default(view_649, permute_210); permute_210 = None + view_653 = torch.ops.aten.view.default(mm_134, [2, 8192, 1024]); mm_134 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16) + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 64, '0'); convert_element_type_637 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + mm_135 = torch.ops.aten.mm.default(view_649, permute_211); view_649 = permute_211 = None + view_656 = torch.ops.aten.view.default(mm_135, [2, 8192, 1024]) + view_657 = torch.ops.aten.view.default(view_650, [2, 8192, -1, 128]); view_650 = None + view_658 = torch.ops.aten.view.default(view_653, [2, 8192, -1, 128]); view_653 = None + view_659 = torch.ops.aten.view.default(view_656, [2, 8192, -1, 128]); view_656 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_657, torch.float32); view_657 = None + view_660 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 32, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_660); view_660 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_658, torch.float32); view_658 = None + view_661 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 8, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_661); view_661 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_16); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_663 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 32, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_16); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_664 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 8, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_663, torch.bfloat16); view_663 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_664, torch.bfloat16); view_664 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 8, 4, 128]); unsqueeze_38 = None + clone_38 = torch.ops.aten.clone.default(expand_38, memory_format = torch.contiguous_format); expand_38 = None + view_665 = torch.ops.aten.view.default(clone_38, [2, 8192, 32, 128]); clone_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_659, 3); view_659 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 8, 4, 128]); unsqueeze_39 = None + clone_39 = torch.ops.aten.clone.default(expand_39, memory_format = torch.contiguous_format); expand_39 = None + view_666 = torch.ops.aten.view.default(clone_39, [2, 8192, 32, 128]); clone_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_665, [0, 2, 1, 3]); view_665 = None + permute_214 = torch.ops.aten.permute.default(view_666, [0, 2, 1, 3]); view_666 = None + _scaled_dot_product_cudnn_attention_19 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_212, permute_213, permute_214, None, True, 0.0, True); permute_212 = permute_213 = permute_214 = None + getitem_171 = _scaled_dot_product_cudnn_attention_19[0] + getitem_172 = _scaled_dot_product_cudnn_attention_19[1] + getitem_177 = _scaled_dot_product_cudnn_attention_19[6] + getitem_178 = _scaled_dot_product_cudnn_attention_19[7]; _scaled_dot_product_cudnn_attention_19 = None + permute_215 = torch.ops.aten.permute.default(getitem_171, [0, 2, 1, 3]) + view_667 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16) + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 64, '0'); convert_element_type_644 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + view_669 = torch.ops.aten.view.default(view_667, [16384, 4096]); view_667 = None + mm_136 = torch.ops.aten.mm.default(view_669, permute_216); view_669 = permute_216 = None + view_670 = torch.ops.aten.view.default(mm_136, [2, 8192, 4096]); mm_136 = None + add_77 = torch.ops.aten.add.Tensor(add_75, view_670); view_670 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16) + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 64, '0'); convert_element_type_647 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32) + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = rsqrt_39 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_177); mul_156 = wait_tensor_177 = None + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16) + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 64, '0'); convert_element_type_650 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_178, [1, 0]); wait_tensor_178 = None + view_673 = torch.ops.aten.view.default(convert_element_type_649, [16384, 4096]); convert_element_type_649 = None + mm_137 = torch.ops.aten.mm.default(view_673, permute_217); permute_217 = None + view_674 = torch.ops.aten.view.default(mm_137, [2, 8192, 14336]) + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_674, torch.float32); view_674 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); convert_element_type_653 = sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16) + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 64, '0'); convert_element_type_655 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_179, [1, 0]); wait_tensor_179 = None + mm_138 = torch.ops.aten.mm.default(view_673, permute_218); view_673 = permute_218 = None + view_677 = torch.ops.aten.view.default(mm_138, [2, 8192, 14336]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_677); convert_element_type_654 = view_677 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16) + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 64, '0'); convert_element_type_658 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + view_679 = torch.ops.aten.view.default(mul_159, [16384, 14336]); mul_159 = None + mm_139 = torch.ops.aten.mm.default(view_679, permute_219); view_679 = permute_219 = None + view_680 = torch.ops.aten.view.default(mm_139, [2, 8192, 4096]); mm_139 = None + add_79 = torch.ops.aten.add.Tensor(add_77, view_680); add_77 = view_680 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16) + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 64, '0'); convert_element_type_661 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32) + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = rsqrt_40 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_181); mul_160 = wait_tensor_181 = None + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16) + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 64, '0'); convert_element_type_664 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + view_683 = torch.ops.aten.view.default(convert_element_type_663, [16384, 4096]); convert_element_type_663 = None + mm_140 = torch.ops.aten.mm.default(view_683, permute_220); permute_220 = None + view_684 = torch.ops.aten.view.default(mm_140, [2, 8192, 4096]) + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16) + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 64, '0'); convert_element_type_667 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_183, [1, 0]); wait_tensor_183 = None + mm_141 = torch.ops.aten.mm.default(view_683, permute_221); permute_221 = None + view_687 = torch.ops.aten.view.default(mm_141, [2, 8192, 1024]); mm_141 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16) + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 64, '0'); convert_element_type_670 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_184, [1, 0]); wait_tensor_184 = None + mm_142 = torch.ops.aten.mm.default(view_683, permute_222); view_683 = permute_222 = None + view_690 = torch.ops.aten.view.default(mm_142, [2, 8192, 1024]) + view_691 = torch.ops.aten.view.default(view_684, [2, 8192, -1, 128]); view_684 = None + view_692 = torch.ops.aten.view.default(view_687, [2, 8192, -1, 128]); view_687 = None + view_693 = torch.ops.aten.view.default(view_690, [2, 8192, -1, 128]); view_690 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_691, torch.float32); view_691 = None + view_694 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 32, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_694); view_694 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_692, torch.float32); view_692 = None + view_695 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 8, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_695); view_695 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_16); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_697 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 32, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_16); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_698 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 8, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_697, torch.bfloat16); view_697 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_698, torch.bfloat16); view_698 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 8, 4, 128]); unsqueeze_40 = None + clone_40 = torch.ops.aten.clone.default(expand_40, memory_format = torch.contiguous_format); expand_40 = None + view_699 = torch.ops.aten.view.default(clone_40, [2, 8192, 32, 128]); clone_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_693, 3); view_693 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 8, 4, 128]); unsqueeze_41 = None + clone_41 = torch.ops.aten.clone.default(expand_41, memory_format = torch.contiguous_format); expand_41 = None + view_700 = torch.ops.aten.view.default(clone_41, [2, 8192, 32, 128]); clone_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_699, [0, 2, 1, 3]); view_699 = None + permute_225 = torch.ops.aten.permute.default(view_700, [0, 2, 1, 3]); view_700 = None + _scaled_dot_product_cudnn_attention_20 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_223, permute_224, permute_225, None, True, 0.0, True); permute_223 = permute_224 = permute_225 = None + getitem_180 = _scaled_dot_product_cudnn_attention_20[0] + getitem_181 = _scaled_dot_product_cudnn_attention_20[1] + getitem_186 = _scaled_dot_product_cudnn_attention_20[6] + getitem_187 = _scaled_dot_product_cudnn_attention_20[7]; _scaled_dot_product_cudnn_attention_20 = None + permute_226 = torch.ops.aten.permute.default(getitem_180, [0, 2, 1, 3]) + view_701 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16) + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 64, '0'); convert_element_type_677 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_185, [1, 0]); wait_tensor_185 = None + view_703 = torch.ops.aten.view.default(view_701, [16384, 4096]); view_701 = None + mm_143 = torch.ops.aten.mm.default(view_703, permute_227); view_703 = permute_227 = None + view_704 = torch.ops.aten.view.default(mm_143, [2, 8192, 4096]); mm_143 = None + add_81 = torch.ops.aten.add.Tensor(add_79, view_704); view_704 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16) + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 64, '0'); convert_element_type_680 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32) + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = rsqrt_41 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_186); mul_164 = wait_tensor_186 = None + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16) + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 64, '0'); convert_element_type_683 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + view_707 = torch.ops.aten.view.default(convert_element_type_682, [16384, 4096]); convert_element_type_682 = None + mm_144 = torch.ops.aten.mm.default(view_707, permute_228); permute_228 = None + view_708 = torch.ops.aten.view.default(mm_144, [2, 8192, 14336]) + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_708, torch.float32); view_708 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); convert_element_type_686 = sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16) + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 64, '0'); convert_element_type_688 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + mm_145 = torch.ops.aten.mm.default(view_707, permute_229); view_707 = permute_229 = None + view_711 = torch.ops.aten.view.default(mm_145, [2, 8192, 14336]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_711); convert_element_type_687 = view_711 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16) + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 64, '0'); convert_element_type_691 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + view_713 = torch.ops.aten.view.default(mul_167, [16384, 14336]); mul_167 = None + mm_146 = torch.ops.aten.mm.default(view_713, permute_230); view_713 = permute_230 = None + view_714 = torch.ops.aten.view.default(mm_146, [2, 8192, 4096]); mm_146 = None + add_83 = torch.ops.aten.add.Tensor(add_81, view_714); add_81 = view_714 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16) + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 64, '0'); convert_element_type_694 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32) + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = rsqrt_42 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_190); mul_168 = wait_tensor_190 = None + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16) + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 64, '0'); convert_element_type_697 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_191, [1, 0]); wait_tensor_191 = None + view_717 = torch.ops.aten.view.default(convert_element_type_696, [16384, 4096]); convert_element_type_696 = None + mm_147 = torch.ops.aten.mm.default(view_717, permute_231); permute_231 = None + view_718 = torch.ops.aten.view.default(mm_147, [2, 8192, 4096]) + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16) + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 64, '0'); convert_element_type_700 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_192, [1, 0]); wait_tensor_192 = None + mm_148 = torch.ops.aten.mm.default(view_717, permute_232); permute_232 = None + view_721 = torch.ops.aten.view.default(mm_148, [2, 8192, 1024]); mm_148 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16) + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 64, '0'); convert_element_type_703 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + mm_149 = torch.ops.aten.mm.default(view_717, permute_233); view_717 = permute_233 = None + view_724 = torch.ops.aten.view.default(mm_149, [2, 8192, 1024]) + view_725 = torch.ops.aten.view.default(view_718, [2, 8192, -1, 128]); view_718 = None + view_726 = torch.ops.aten.view.default(view_721, [2, 8192, -1, 128]); view_721 = None + view_727 = torch.ops.aten.view.default(view_724, [2, 8192, -1, 128]); view_724 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_725, torch.float32); view_725 = None + view_728 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 32, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_728); view_728 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_726, torch.float32); view_726 = None + view_729 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 8, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_729); view_729 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_16); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_731 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 32, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_16); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_732 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 8, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_731, torch.bfloat16); view_731 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_732, torch.bfloat16); view_732 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 8, 4, 128]); unsqueeze_42 = None + clone_42 = torch.ops.aten.clone.default(expand_42, memory_format = torch.contiguous_format); expand_42 = None + view_733 = torch.ops.aten.view.default(clone_42, [2, 8192, 32, 128]); clone_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_727, 3); view_727 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 8, 4, 128]); unsqueeze_43 = None + clone_43 = torch.ops.aten.clone.default(expand_43, memory_format = torch.contiguous_format); expand_43 = None + view_734 = torch.ops.aten.view.default(clone_43, [2, 8192, 32, 128]); clone_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_733, [0, 2, 1, 3]); view_733 = None + permute_236 = torch.ops.aten.permute.default(view_734, [0, 2, 1, 3]); view_734 = None + _scaled_dot_product_cudnn_attention_21 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_234, permute_235, permute_236, None, True, 0.0, True); permute_234 = permute_235 = permute_236 = None + getitem_189 = _scaled_dot_product_cudnn_attention_21[0] + getitem_190 = _scaled_dot_product_cudnn_attention_21[1] + getitem_195 = _scaled_dot_product_cudnn_attention_21[6] + getitem_196 = _scaled_dot_product_cudnn_attention_21[7]; _scaled_dot_product_cudnn_attention_21 = None + permute_237 = torch.ops.aten.permute.default(getitem_189, [0, 2, 1, 3]) + view_735 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16) + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 64, '0'); convert_element_type_710 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + view_737 = torch.ops.aten.view.default(view_735, [16384, 4096]); view_735 = None + mm_150 = torch.ops.aten.mm.default(view_737, permute_238); view_737 = permute_238 = None + view_738 = torch.ops.aten.view.default(mm_150, [2, 8192, 4096]); mm_150 = None + add_85 = torch.ops.aten.add.Tensor(add_83, view_738); view_738 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16) + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 64, '0'); convert_element_type_713 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32) + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = rsqrt_43 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_195); mul_172 = wait_tensor_195 = None + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16) + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 64, '0'); convert_element_type_716 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_196, [1, 0]); wait_tensor_196 = None + view_741 = torch.ops.aten.view.default(convert_element_type_715, [16384, 4096]); convert_element_type_715 = None + mm_151 = torch.ops.aten.mm.default(view_741, permute_239); permute_239 = None + view_742 = torch.ops.aten.view.default(mm_151, [2, 8192, 14336]) + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_742, torch.float32); view_742 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); convert_element_type_719 = sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16) + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 64, '0'); convert_element_type_721 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_197, [1, 0]); wait_tensor_197 = None + mm_152 = torch.ops.aten.mm.default(view_741, permute_240); view_741 = permute_240 = None + view_745 = torch.ops.aten.view.default(mm_152, [2, 8192, 14336]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_745); convert_element_type_720 = view_745 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16) + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 64, '0'); convert_element_type_724 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_198, [1, 0]); wait_tensor_198 = None + view_747 = torch.ops.aten.view.default(mul_175, [16384, 14336]); mul_175 = None + mm_153 = torch.ops.aten.mm.default(view_747, permute_241); view_747 = permute_241 = None + view_748 = torch.ops.aten.view.default(mm_153, [2, 8192, 4096]); mm_153 = None + add_87 = torch.ops.aten.add.Tensor(add_85, view_748); add_85 = view_748 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16) + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 64, '0'); convert_element_type_727 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32) + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = rsqrt_44 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_199); mul_176 = wait_tensor_199 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16) + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 64, '0'); convert_element_type_730 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + view_751 = torch.ops.aten.view.default(convert_element_type_729, [16384, 4096]); convert_element_type_729 = None + mm_154 = torch.ops.aten.mm.default(view_751, permute_242); permute_242 = None + view_752 = torch.ops.aten.view.default(mm_154, [2, 8192, 4096]) + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16) + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 64, '0'); convert_element_type_733 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + mm_155 = torch.ops.aten.mm.default(view_751, permute_243); permute_243 = None + view_755 = torch.ops.aten.view.default(mm_155, [2, 8192, 1024]); mm_155 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16) + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 64, '0'); convert_element_type_736 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + mm_156 = torch.ops.aten.mm.default(view_751, permute_244); view_751 = permute_244 = None + view_758 = torch.ops.aten.view.default(mm_156, [2, 8192, 1024]) + view_759 = torch.ops.aten.view.default(view_752, [2, 8192, -1, 128]); view_752 = None + view_760 = torch.ops.aten.view.default(view_755, [2, 8192, -1, 128]); view_755 = None + view_761 = torch.ops.aten.view.default(view_758, [2, 8192, -1, 128]); view_758 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_759, torch.float32); view_759 = None + view_762 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 32, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_762); view_762 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_760, torch.float32); view_760 = None + view_763 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 8, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_763); view_763 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_16); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_765 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 32, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_16); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_766 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 8, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_765, torch.bfloat16); view_765 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_766, torch.bfloat16); view_766 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 8, 4, 128]); unsqueeze_44 = None + clone_44 = torch.ops.aten.clone.default(expand_44, memory_format = torch.contiguous_format); expand_44 = None + view_767 = torch.ops.aten.view.default(clone_44, [2, 8192, 32, 128]); clone_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_761, 3); view_761 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 8, 4, 128]); unsqueeze_45 = None + clone_45 = torch.ops.aten.clone.default(expand_45, memory_format = torch.contiguous_format); expand_45 = None + view_768 = torch.ops.aten.view.default(clone_45, [2, 8192, 32, 128]); clone_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_767, [0, 2, 1, 3]); view_767 = None + permute_247 = torch.ops.aten.permute.default(view_768, [0, 2, 1, 3]); view_768 = None + _scaled_dot_product_cudnn_attention_22 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_245, permute_246, permute_247, None, True, 0.0, True); permute_245 = permute_246 = permute_247 = None + getitem_198 = _scaled_dot_product_cudnn_attention_22[0] + getitem_199 = _scaled_dot_product_cudnn_attention_22[1] + getitem_204 = _scaled_dot_product_cudnn_attention_22[6] + getitem_205 = _scaled_dot_product_cudnn_attention_22[7]; _scaled_dot_product_cudnn_attention_22 = None + permute_248 = torch.ops.aten.permute.default(getitem_198, [0, 2, 1, 3]) + view_769 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16) + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 64, '0'); convert_element_type_743 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_203, [1, 0]); wait_tensor_203 = None + view_771 = torch.ops.aten.view.default(view_769, [16384, 4096]); view_769 = None + mm_157 = torch.ops.aten.mm.default(view_771, permute_249); view_771 = permute_249 = None + view_772 = torch.ops.aten.view.default(mm_157, [2, 8192, 4096]); mm_157 = None + add_89 = torch.ops.aten.add.Tensor(add_87, view_772); view_772 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16) + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 64, '0'); convert_element_type_746 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32) + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = rsqrt_45 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_204); mul_180 = wait_tensor_204 = None + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16) + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 64, '0'); convert_element_type_749 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_205, [1, 0]); wait_tensor_205 = None + view_775 = torch.ops.aten.view.default(convert_element_type_748, [16384, 4096]); convert_element_type_748 = None + mm_158 = torch.ops.aten.mm.default(view_775, permute_250); permute_250 = None + view_776 = torch.ops.aten.view.default(mm_158, [2, 8192, 14336]) + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_776, torch.float32); view_776 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); convert_element_type_752 = sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16) + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 64, '0'); convert_element_type_754 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + mm_159 = torch.ops.aten.mm.default(view_775, permute_251); view_775 = permute_251 = None + view_779 = torch.ops.aten.view.default(mm_159, [2, 8192, 14336]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_779); convert_element_type_753 = view_779 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16) + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 64, '0'); convert_element_type_757 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + view_781 = torch.ops.aten.view.default(mul_183, [16384, 14336]); mul_183 = None + mm_160 = torch.ops.aten.mm.default(view_781, permute_252); view_781 = permute_252 = None + view_782 = torch.ops.aten.view.default(mm_160, [2, 8192, 4096]); mm_160 = None + add_91 = torch.ops.aten.add.Tensor(add_89, view_782); add_89 = view_782 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16) + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 64, '0'); convert_element_type_760 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32) + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = rsqrt_46 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_208); mul_184 = wait_tensor_208 = None + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16) + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 64, '0'); convert_element_type_763 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_209, [1, 0]); wait_tensor_209 = None + view_785 = torch.ops.aten.view.default(convert_element_type_762, [16384, 4096]); convert_element_type_762 = None + mm_161 = torch.ops.aten.mm.default(view_785, permute_253); permute_253 = None + view_786 = torch.ops.aten.view.default(mm_161, [2, 8192, 4096]) + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16) + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 64, '0'); convert_element_type_766 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_210, [1, 0]); wait_tensor_210 = None + mm_162 = torch.ops.aten.mm.default(view_785, permute_254); permute_254 = None + view_789 = torch.ops.aten.view.default(mm_162, [2, 8192, 1024]); mm_162 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16) + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 64, '0'); convert_element_type_769 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_211, [1, 0]); wait_tensor_211 = None + mm_163 = torch.ops.aten.mm.default(view_785, permute_255); view_785 = permute_255 = None + view_792 = torch.ops.aten.view.default(mm_163, [2, 8192, 1024]) + view_793 = torch.ops.aten.view.default(view_786, [2, 8192, -1, 128]); view_786 = None + view_794 = torch.ops.aten.view.default(view_789, [2, 8192, -1, 128]); view_789 = None + view_795 = torch.ops.aten.view.default(view_792, [2, 8192, -1, 128]); view_792 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_793, torch.float32); view_793 = None + view_796 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 32, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_796); view_796 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_794, torch.float32); view_794 = None + view_797 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 8, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_797); view_797 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_16); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_799 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 32, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_16); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_800 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 8, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_799, torch.bfloat16); view_799 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_800, torch.bfloat16); view_800 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 8, 4, 128]); unsqueeze_46 = None + clone_46 = torch.ops.aten.clone.default(expand_46, memory_format = torch.contiguous_format); expand_46 = None + view_801 = torch.ops.aten.view.default(clone_46, [2, 8192, 32, 128]); clone_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_795, 3); view_795 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 8, 4, 128]); unsqueeze_47 = None + clone_47 = torch.ops.aten.clone.default(expand_47, memory_format = torch.contiguous_format); expand_47 = None + view_802 = torch.ops.aten.view.default(clone_47, [2, 8192, 32, 128]); clone_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_801, [0, 2, 1, 3]); view_801 = None + permute_258 = torch.ops.aten.permute.default(view_802, [0, 2, 1, 3]); view_802 = None + _scaled_dot_product_cudnn_attention_23 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_256, permute_257, permute_258, None, True, 0.0, True); permute_256 = permute_257 = permute_258 = None + getitem_207 = _scaled_dot_product_cudnn_attention_23[0] + getitem_208 = _scaled_dot_product_cudnn_attention_23[1] + getitem_213 = _scaled_dot_product_cudnn_attention_23[6] + getitem_214 = _scaled_dot_product_cudnn_attention_23[7]; _scaled_dot_product_cudnn_attention_23 = None + permute_259 = torch.ops.aten.permute.default(getitem_207, [0, 2, 1, 3]) + view_803 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16) + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 64, '0'); convert_element_type_776 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + view_805 = torch.ops.aten.view.default(view_803, [16384, 4096]); view_803 = None + mm_164 = torch.ops.aten.mm.default(view_805, permute_260); view_805 = permute_260 = None + view_806 = torch.ops.aten.view.default(mm_164, [2, 8192, 4096]); mm_164 = None + add_93 = torch.ops.aten.add.Tensor(add_91, view_806); view_806 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16) + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 64, '0'); convert_element_type_779 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32) + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = rsqrt_47 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_213); mul_188 = wait_tensor_213 = None + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16) + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 64, '0'); convert_element_type_782 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + view_809 = torch.ops.aten.view.default(convert_element_type_781, [16384, 4096]); convert_element_type_781 = None + mm_165 = torch.ops.aten.mm.default(view_809, permute_261); permute_261 = None + view_810 = torch.ops.aten.view.default(mm_165, [2, 8192, 14336]) + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_810, torch.float32); view_810 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); convert_element_type_785 = sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16) + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 64, '0'); convert_element_type_787 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + mm_166 = torch.ops.aten.mm.default(view_809, permute_262); view_809 = permute_262 = None + view_813 = torch.ops.aten.view.default(mm_166, [2, 8192, 14336]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_813); convert_element_type_786 = view_813 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16) + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 64, '0'); convert_element_type_790 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_216, [1, 0]); wait_tensor_216 = None + view_815 = torch.ops.aten.view.default(mul_191, [16384, 14336]); mul_191 = None + mm_167 = torch.ops.aten.mm.default(view_815, permute_263); view_815 = permute_263 = None + view_816 = torch.ops.aten.view.default(mm_167, [2, 8192, 4096]); mm_167 = None + add_95 = torch.ops.aten.add.Tensor(add_93, view_816); add_93 = view_816 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16) + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 64, '0'); convert_element_type_793 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32) + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = rsqrt_48 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_217); mul_192 = wait_tensor_217 = None + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16) + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 64, '0'); convert_element_type_796 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_218, [1, 0]); wait_tensor_218 = None + view_819 = torch.ops.aten.view.default(convert_element_type_795, [16384, 4096]); convert_element_type_795 = None + mm_168 = torch.ops.aten.mm.default(view_819, permute_264); permute_264 = None + view_820 = torch.ops.aten.view.default(mm_168, [2, 8192, 4096]) + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16) + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 64, '0'); convert_element_type_799 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + mm_169 = torch.ops.aten.mm.default(view_819, permute_265); permute_265 = None + view_823 = torch.ops.aten.view.default(mm_169, [2, 8192, 1024]); mm_169 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16) + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 64, '0'); convert_element_type_802 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + mm_170 = torch.ops.aten.mm.default(view_819, permute_266); view_819 = permute_266 = None + view_826 = torch.ops.aten.view.default(mm_170, [2, 8192, 1024]) + view_827 = torch.ops.aten.view.default(view_820, [2, 8192, -1, 128]); view_820 = None + view_828 = torch.ops.aten.view.default(view_823, [2, 8192, -1, 128]); view_823 = None + view_829 = torch.ops.aten.view.default(view_826, [2, 8192, -1, 128]); view_826 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_827, torch.float32); view_827 = None + view_830 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 32, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_830); view_830 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_828, torch.float32); view_828 = None + view_831 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 8, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_831); view_831 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_16); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_833 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 32, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_16); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_834 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 8, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_833, torch.bfloat16); view_833 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_834, torch.bfloat16); view_834 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 8, 4, 128]); unsqueeze_48 = None + clone_48 = torch.ops.aten.clone.default(expand_48, memory_format = torch.contiguous_format); expand_48 = None + view_835 = torch.ops.aten.view.default(clone_48, [2, 8192, 32, 128]); clone_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_829, 3); view_829 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 8, 4, 128]); unsqueeze_49 = None + clone_49 = torch.ops.aten.clone.default(expand_49, memory_format = torch.contiguous_format); expand_49 = None + view_836 = torch.ops.aten.view.default(clone_49, [2, 8192, 32, 128]); clone_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_835, [0, 2, 1, 3]); view_835 = None + permute_269 = torch.ops.aten.permute.default(view_836, [0, 2, 1, 3]); view_836 = None + _scaled_dot_product_cudnn_attention_24 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_267, permute_268, permute_269, None, True, 0.0, True); permute_267 = permute_268 = permute_269 = None + getitem_216 = _scaled_dot_product_cudnn_attention_24[0] + getitem_217 = _scaled_dot_product_cudnn_attention_24[1] + getitem_222 = _scaled_dot_product_cudnn_attention_24[6] + getitem_223 = _scaled_dot_product_cudnn_attention_24[7]; _scaled_dot_product_cudnn_attention_24 = None + permute_270 = torch.ops.aten.permute.default(getitem_216, [0, 2, 1, 3]) + view_837 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16) + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 64, '0'); convert_element_type_809 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + view_839 = torch.ops.aten.view.default(view_837, [16384, 4096]); view_837 = None + mm_171 = torch.ops.aten.mm.default(view_839, permute_271); view_839 = permute_271 = None + view_840 = torch.ops.aten.view.default(mm_171, [2, 8192, 4096]); mm_171 = None + add_97 = torch.ops.aten.add.Tensor(add_95, view_840); view_840 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16) + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 64, '0'); convert_element_type_812 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32) + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = rsqrt_49 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_222); mul_196 = wait_tensor_222 = None + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16) + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 64, '0'); convert_element_type_815 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_223, [1, 0]); wait_tensor_223 = None + view_843 = torch.ops.aten.view.default(convert_element_type_814, [16384, 4096]); convert_element_type_814 = None + mm_172 = torch.ops.aten.mm.default(view_843, permute_272); permute_272 = None + view_844 = torch.ops.aten.view.default(mm_172, [2, 8192, 14336]) + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_844, torch.float32); view_844 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); convert_element_type_818 = sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16) + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 64, '0'); convert_element_type_820 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_224, [1, 0]); wait_tensor_224 = None + mm_173 = torch.ops.aten.mm.default(view_843, permute_273); view_843 = permute_273 = None + view_847 = torch.ops.aten.view.default(mm_173, [2, 8192, 14336]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_847); convert_element_type_819 = view_847 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16) + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 64, '0'); convert_element_type_823 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + view_849 = torch.ops.aten.view.default(mul_199, [16384, 14336]); mul_199 = None + mm_174 = torch.ops.aten.mm.default(view_849, permute_274); view_849 = permute_274 = None + view_850 = torch.ops.aten.view.default(mm_174, [2, 8192, 4096]); mm_174 = None + add_99 = torch.ops.aten.add.Tensor(add_97, view_850); add_97 = view_850 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16) + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 64, '0'); convert_element_type_826 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32) + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = rsqrt_50 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_226); mul_200 = wait_tensor_226 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16) + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 64, '0'); convert_element_type_829 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + view_853 = torch.ops.aten.view.default(convert_element_type_828, [16384, 4096]); convert_element_type_828 = None + mm_175 = torch.ops.aten.mm.default(view_853, permute_275); permute_275 = None + view_854 = torch.ops.aten.view.default(mm_175, [2, 8192, 4096]) + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16) + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 64, '0'); convert_element_type_832 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + mm_176 = torch.ops.aten.mm.default(view_853, permute_276); permute_276 = None + view_857 = torch.ops.aten.view.default(mm_176, [2, 8192, 1024]); mm_176 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16) + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 64, '0'); convert_element_type_835 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_229, [1, 0]); wait_tensor_229 = None + mm_177 = torch.ops.aten.mm.default(view_853, permute_277); view_853 = permute_277 = None + view_860 = torch.ops.aten.view.default(mm_177, [2, 8192, 1024]) + view_861 = torch.ops.aten.view.default(view_854, [2, 8192, -1, 128]); view_854 = None + view_862 = torch.ops.aten.view.default(view_857, [2, 8192, -1, 128]); view_857 = None + view_863 = torch.ops.aten.view.default(view_860, [2, 8192, -1, 128]); view_860 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_861, torch.float32); view_861 = None + view_864 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 32, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_864); view_864 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_862, torch.float32); view_862 = None + view_865 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 8, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_865); view_865 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_16); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_867 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 32, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_16); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_868 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 8, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_867, torch.bfloat16); view_867 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_868, torch.bfloat16); view_868 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 8, 4, 128]); unsqueeze_50 = None + clone_50 = torch.ops.aten.clone.default(expand_50, memory_format = torch.contiguous_format); expand_50 = None + view_869 = torch.ops.aten.view.default(clone_50, [2, 8192, 32, 128]); clone_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_863, 3); view_863 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 8, 4, 128]); unsqueeze_51 = None + clone_51 = torch.ops.aten.clone.default(expand_51, memory_format = torch.contiguous_format); expand_51 = None + view_870 = torch.ops.aten.view.default(clone_51, [2, 8192, 32, 128]); clone_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_869, [0, 2, 1, 3]); view_869 = None + permute_280 = torch.ops.aten.permute.default(view_870, [0, 2, 1, 3]); view_870 = None + _scaled_dot_product_cudnn_attention_25 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_278, permute_279, permute_280, None, True, 0.0, True); permute_278 = permute_279 = permute_280 = None + getitem_225 = _scaled_dot_product_cudnn_attention_25[0] + getitem_226 = _scaled_dot_product_cudnn_attention_25[1] + getitem_231 = _scaled_dot_product_cudnn_attention_25[6] + getitem_232 = _scaled_dot_product_cudnn_attention_25[7]; _scaled_dot_product_cudnn_attention_25 = None + permute_281 = torch.ops.aten.permute.default(getitem_225, [0, 2, 1, 3]) + view_871 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16) + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 64, '0'); convert_element_type_842 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_230, [1, 0]); wait_tensor_230 = None + view_873 = torch.ops.aten.view.default(view_871, [16384, 4096]); view_871 = None + mm_178 = torch.ops.aten.mm.default(view_873, permute_282); view_873 = permute_282 = None + view_874 = torch.ops.aten.view.default(mm_178, [2, 8192, 4096]); mm_178 = None + add_101 = torch.ops.aten.add.Tensor(add_99, view_874); view_874 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16) + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 64, '0'); convert_element_type_845 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32) + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = rsqrt_51 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_231); mul_204 = wait_tensor_231 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16) + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 64, '0'); convert_element_type_848 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + view_877 = torch.ops.aten.view.default(convert_element_type_847, [16384, 4096]); convert_element_type_847 = None + mm_179 = torch.ops.aten.mm.default(view_877, permute_283); permute_283 = None + view_878 = torch.ops.aten.view.default(mm_179, [2, 8192, 14336]) + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_878, torch.float32); view_878 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); convert_element_type_851 = sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16) + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 64, '0'); convert_element_type_853 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_180 = torch.ops.aten.mm.default(view_877, permute_284); view_877 = permute_284 = None + view_881 = torch.ops.aten.view.default(mm_180, [2, 8192, 14336]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_881); convert_element_type_852 = view_881 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16) + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 64, '0'); convert_element_type_856 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + view_883 = torch.ops.aten.view.default(mul_207, [16384, 14336]); mul_207 = None + mm_181 = torch.ops.aten.mm.default(view_883, permute_285); view_883 = permute_285 = None + view_884 = torch.ops.aten.view.default(mm_181, [2, 8192, 4096]); mm_181 = None + add_103 = torch.ops.aten.add.Tensor(add_101, view_884); add_101 = view_884 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16) + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 64, '0'); convert_element_type_859 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32) + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = rsqrt_52 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_235); mul_208 = wait_tensor_235 = None + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16) + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 64, '0'); convert_element_type_862 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_236, [1, 0]); wait_tensor_236 = None + view_887 = torch.ops.aten.view.default(convert_element_type_861, [16384, 4096]); convert_element_type_861 = None + mm_182 = torch.ops.aten.mm.default(view_887, permute_286); permute_286 = None + view_888 = torch.ops.aten.view.default(mm_182, [2, 8192, 4096]) + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16) + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 64, '0'); convert_element_type_865 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_237, [1, 0]); wait_tensor_237 = None + mm_183 = torch.ops.aten.mm.default(view_887, permute_287); permute_287 = None + view_891 = torch.ops.aten.view.default(mm_183, [2, 8192, 1024]); mm_183 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16) + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 64, '0'); convert_element_type_868 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + mm_184 = torch.ops.aten.mm.default(view_887, permute_288); view_887 = permute_288 = None + view_894 = torch.ops.aten.view.default(mm_184, [2, 8192, 1024]) + view_895 = torch.ops.aten.view.default(view_888, [2, 8192, -1, 128]); view_888 = None + view_896 = torch.ops.aten.view.default(view_891, [2, 8192, -1, 128]); view_891 = None + view_897 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_895, torch.float32); view_895 = None + view_898 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 32, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_898); view_898 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 8, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_16); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_901 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 32, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_16); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_902 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 8, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_901, torch.bfloat16); view_901 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 8, 4, 128]); unsqueeze_52 = None + clone_52 = torch.ops.aten.clone.default(expand_52, memory_format = torch.contiguous_format); expand_52 = None + view_903 = torch.ops.aten.view.default(clone_52, [2, 8192, 32, 128]); clone_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_897, 3); view_897 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 8, 4, 128]); unsqueeze_53 = None + clone_53 = torch.ops.aten.clone.default(expand_53, memory_format = torch.contiguous_format); expand_53 = None + view_904 = torch.ops.aten.view.default(clone_53, [2, 8192, 32, 128]); clone_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_903, [0, 2, 1, 3]); view_903 = None + permute_291 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + _scaled_dot_product_cudnn_attention_26 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_289, permute_290, permute_291, None, True, 0.0, True); permute_289 = permute_290 = permute_291 = None + getitem_234 = _scaled_dot_product_cudnn_attention_26[0] + getitem_235 = _scaled_dot_product_cudnn_attention_26[1] + getitem_240 = _scaled_dot_product_cudnn_attention_26[6] + getitem_241 = _scaled_dot_product_cudnn_attention_26[7]; _scaled_dot_product_cudnn_attention_26 = None + permute_292 = torch.ops.aten.permute.default(getitem_234, [0, 2, 1, 3]) + view_905 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16) + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 64, '0'); convert_element_type_875 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + view_907 = torch.ops.aten.view.default(view_905, [16384, 4096]); view_905 = None + mm_185 = torch.ops.aten.mm.default(view_907, permute_293); view_907 = permute_293 = None + view_908 = torch.ops.aten.view.default(mm_185, [2, 8192, 4096]); mm_185 = None + add_105 = torch.ops.aten.add.Tensor(add_103, view_908); view_908 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16) + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 64, '0'); convert_element_type_878 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32) + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = rsqrt_53 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_240); mul_212 = wait_tensor_240 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16) + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 64, '0'); convert_element_type_881 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + view_911 = torch.ops.aten.view.default(convert_element_type_880, [16384, 4096]); convert_element_type_880 = None + mm_186 = torch.ops.aten.mm.default(view_911, permute_294); permute_294 = None + view_912 = torch.ops.aten.view.default(mm_186, [2, 8192, 14336]) + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_912, torch.float32); view_912 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); convert_element_type_884 = sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16) + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 64, '0'); convert_element_type_886 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_242, [1, 0]); wait_tensor_242 = None + mm_187 = torch.ops.aten.mm.default(view_911, permute_295); view_911 = permute_295 = None + view_915 = torch.ops.aten.view.default(mm_187, [2, 8192, 14336]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_915); convert_element_type_885 = view_915 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16) + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 64, '0'); convert_element_type_889 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_243, [1, 0]); wait_tensor_243 = None + view_917 = torch.ops.aten.view.default(mul_215, [16384, 14336]); mul_215 = None + mm_188 = torch.ops.aten.mm.default(view_917, permute_296); view_917 = permute_296 = None + view_918 = torch.ops.aten.view.default(mm_188, [2, 8192, 4096]); mm_188 = None + add_107 = torch.ops.aten.add.Tensor(add_105, view_918); add_105 = view_918 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16) + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 64, '0'); convert_element_type_892 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32) + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = rsqrt_54 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_244); mul_216 = wait_tensor_244 = None + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16) + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 64, '0'); convert_element_type_895 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + view_921 = torch.ops.aten.view.default(convert_element_type_894, [16384, 4096]); convert_element_type_894 = None + mm_189 = torch.ops.aten.mm.default(view_921, permute_297); permute_297 = None + view_922 = torch.ops.aten.view.default(mm_189, [2, 8192, 4096]) + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16) + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 64, '0'); convert_element_type_898 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_190 = torch.ops.aten.mm.default(view_921, permute_298); permute_298 = None + view_925 = torch.ops.aten.view.default(mm_190, [2, 8192, 1024]); mm_190 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16) + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 64, '0'); convert_element_type_901 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + mm_191 = torch.ops.aten.mm.default(view_921, permute_299); view_921 = permute_299 = None + view_928 = torch.ops.aten.view.default(mm_191, [2, 8192, 1024]) + view_929 = torch.ops.aten.view.default(view_922, [2, 8192, -1, 128]); view_922 = None + view_930 = torch.ops.aten.view.default(view_925, [2, 8192, -1, 128]); view_925 = None + view_931 = torch.ops.aten.view.default(view_928, [2, 8192, -1, 128]); view_928 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_929, torch.float32); view_929 = None + view_932 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 32, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_932); view_932 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_930, torch.float32); view_930 = None + view_933 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 8, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_933); view_933 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_16); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_935 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 32, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_16); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_936 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 8, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_935, torch.bfloat16); view_935 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_936, torch.bfloat16); view_936 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 8, 4, 128]); unsqueeze_54 = None + clone_54 = torch.ops.aten.clone.default(expand_54, memory_format = torch.contiguous_format); expand_54 = None + view_937 = torch.ops.aten.view.default(clone_54, [2, 8192, 32, 128]); clone_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_931, 3); view_931 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 8, 4, 128]); unsqueeze_55 = None + clone_55 = torch.ops.aten.clone.default(expand_55, memory_format = torch.contiguous_format); expand_55 = None + view_938 = torch.ops.aten.view.default(clone_55, [2, 8192, 32, 128]); clone_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_937, [0, 2, 1, 3]); view_937 = None + permute_302 = torch.ops.aten.permute.default(view_938, [0, 2, 1, 3]); view_938 = None + _scaled_dot_product_cudnn_attention_27 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_300, permute_301, permute_302, None, True, 0.0, True); permute_300 = permute_301 = permute_302 = None + getitem_243 = _scaled_dot_product_cudnn_attention_27[0] + getitem_244 = _scaled_dot_product_cudnn_attention_27[1] + getitem_249 = _scaled_dot_product_cudnn_attention_27[6] + getitem_250 = _scaled_dot_product_cudnn_attention_27[7]; _scaled_dot_product_cudnn_attention_27 = None + permute_303 = torch.ops.aten.permute.default(getitem_243, [0, 2, 1, 3]) + view_939 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16) + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 64, '0'); convert_element_type_908 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_248, [1, 0]); wait_tensor_248 = None + view_941 = torch.ops.aten.view.default(view_939, [16384, 4096]); view_939 = None + mm_192 = torch.ops.aten.mm.default(view_941, permute_304); view_941 = permute_304 = None + view_942 = torch.ops.aten.view.default(mm_192, [2, 8192, 4096]); mm_192 = None + add_109 = torch.ops.aten.add.Tensor(add_107, view_942); view_942 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16) + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 64, '0'); convert_element_type_911 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32) + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = rsqrt_55 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_249); mul_220 = wait_tensor_249 = None + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16) + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 64, '0'); convert_element_type_914 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_250, [1, 0]); wait_tensor_250 = None + view_945 = torch.ops.aten.view.default(convert_element_type_913, [16384, 4096]); convert_element_type_913 = None + mm_193 = torch.ops.aten.mm.default(view_945, permute_305); permute_305 = None + view_946 = torch.ops.aten.view.default(mm_193, [2, 8192, 14336]) + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_946, torch.float32); view_946 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); convert_element_type_917 = sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16) + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 64, '0'); convert_element_type_919 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + mm_194 = torch.ops.aten.mm.default(view_945, permute_306); view_945 = permute_306 = None + view_949 = torch.ops.aten.view.default(mm_194, [2, 8192, 14336]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_949); convert_element_type_918 = view_949 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16) + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 64, '0'); convert_element_type_922 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + view_951 = torch.ops.aten.view.default(mul_223, [16384, 14336]); mul_223 = None + mm_195 = torch.ops.aten.mm.default(view_951, permute_307); view_951 = permute_307 = None + view_952 = torch.ops.aten.view.default(mm_195, [2, 8192, 4096]); mm_195 = None + add_111 = torch.ops.aten.add.Tensor(add_109, view_952); add_109 = view_952 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16) + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 64, '0'); convert_element_type_925 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32) + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = rsqrt_56 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_253); mul_224 = wait_tensor_253 = None + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16) + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 64, '0'); convert_element_type_928 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + view_955 = torch.ops.aten.view.default(convert_element_type_927, [16384, 4096]); convert_element_type_927 = None + mm_196 = torch.ops.aten.mm.default(view_955, permute_308); permute_308 = None + view_956 = torch.ops.aten.view.default(mm_196, [2, 8192, 4096]) + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16) + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 64, '0'); convert_element_type_931 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_255, [1, 0]); wait_tensor_255 = None + mm_197 = torch.ops.aten.mm.default(view_955, permute_309); permute_309 = None + view_959 = torch.ops.aten.view.default(mm_197, [2, 8192, 1024]); mm_197 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16) + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 64, '0'); convert_element_type_934 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_256, [1, 0]); wait_tensor_256 = None + mm_198 = torch.ops.aten.mm.default(view_955, permute_310); view_955 = permute_310 = None + view_962 = torch.ops.aten.view.default(mm_198, [2, 8192, 1024]) + view_963 = torch.ops.aten.view.default(view_956, [2, 8192, -1, 128]); view_956 = None + view_964 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_965 = torch.ops.aten.view.default(view_962, [2, 8192, -1, 128]); view_962 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_963, torch.float32); view_963 = None + view_966 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 32, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_966); view_966 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_964, torch.float32); view_964 = None + view_967 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 8, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_967); view_967 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_16); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_969 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 32, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_16); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_970 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 8, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_969, torch.bfloat16); view_969 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_970, torch.bfloat16); view_970 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 8, 4, 128]); unsqueeze_56 = None + clone_56 = torch.ops.aten.clone.default(expand_56, memory_format = torch.contiguous_format); expand_56 = None + view_971 = torch.ops.aten.view.default(clone_56, [2, 8192, 32, 128]); clone_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_965, 3); view_965 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 8, 4, 128]); unsqueeze_57 = None + clone_57 = torch.ops.aten.clone.default(expand_57, memory_format = torch.contiguous_format); expand_57 = None + view_972 = torch.ops.aten.view.default(clone_57, [2, 8192, 32, 128]); clone_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_971, [0, 2, 1, 3]); view_971 = None + permute_313 = torch.ops.aten.permute.default(view_972, [0, 2, 1, 3]); view_972 = None + _scaled_dot_product_cudnn_attention_28 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_311, permute_312, permute_313, None, True, 0.0, True); permute_311 = permute_312 = permute_313 = None + getitem_252 = _scaled_dot_product_cudnn_attention_28[0] + getitem_253 = _scaled_dot_product_cudnn_attention_28[1] + getitem_258 = _scaled_dot_product_cudnn_attention_28[6] + getitem_259 = _scaled_dot_product_cudnn_attention_28[7]; _scaled_dot_product_cudnn_attention_28 = None + permute_314 = torch.ops.aten.permute.default(getitem_252, [0, 2, 1, 3]) + view_973 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16) + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 64, '0'); convert_element_type_941 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_257, [1, 0]); wait_tensor_257 = None + view_975 = torch.ops.aten.view.default(view_973, [16384, 4096]); view_973 = None + mm_199 = torch.ops.aten.mm.default(view_975, permute_315); view_975 = permute_315 = None + view_976 = torch.ops.aten.view.default(mm_199, [2, 8192, 4096]); mm_199 = None + add_113 = torch.ops.aten.add.Tensor(add_111, view_976); view_976 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16) + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 64, '0'); convert_element_type_944 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32) + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = rsqrt_57 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_258); mul_228 = wait_tensor_258 = None + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16) + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 64, '0'); convert_element_type_947 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + view_979 = torch.ops.aten.view.default(convert_element_type_946, [16384, 4096]); convert_element_type_946 = None + mm_200 = torch.ops.aten.mm.default(view_979, permute_316); permute_316 = None + view_980 = torch.ops.aten.view.default(mm_200, [2, 8192, 14336]) + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_980, torch.float32); view_980 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); convert_element_type_950 = sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16) + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 64, '0'); convert_element_type_952 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + mm_201 = torch.ops.aten.mm.default(view_979, permute_317); view_979 = permute_317 = None + view_983 = torch.ops.aten.view.default(mm_201, [2, 8192, 14336]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_983); convert_element_type_951 = view_983 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16) + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 64, '0'); convert_element_type_955 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_261, [1, 0]); wait_tensor_261 = None + view_985 = torch.ops.aten.view.default(mul_231, [16384, 14336]); mul_231 = None + mm_202 = torch.ops.aten.mm.default(view_985, permute_318); view_985 = permute_318 = None + view_986 = torch.ops.aten.view.default(mm_202, [2, 8192, 4096]); mm_202 = None + add_115 = torch.ops.aten.add.Tensor(add_113, view_986); add_113 = view_986 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16) + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 64, '0'); convert_element_type_958 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32) + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = rsqrt_58 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_262); mul_232 = wait_tensor_262 = None + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16) + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 64, '0'); convert_element_type_961 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_263, [1, 0]); wait_tensor_263 = None + view_989 = torch.ops.aten.view.default(convert_element_type_960, [16384, 4096]); convert_element_type_960 = None + mm_203 = torch.ops.aten.mm.default(view_989, permute_319); permute_319 = None + view_990 = torch.ops.aten.view.default(mm_203, [2, 8192, 4096]) + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16) + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 64, '0'); convert_element_type_964 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + mm_204 = torch.ops.aten.mm.default(view_989, permute_320); permute_320 = None + view_993 = torch.ops.aten.view.default(mm_204, [2, 8192, 1024]); mm_204 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16) + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 64, '0'); convert_element_type_967 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + mm_205 = torch.ops.aten.mm.default(view_989, permute_321); view_989 = permute_321 = None + view_996 = torch.ops.aten.view.default(mm_205, [2, 8192, 1024]) + view_997 = torch.ops.aten.view.default(view_990, [2, 8192, -1, 128]); view_990 = None + view_998 = torch.ops.aten.view.default(view_993, [2, 8192, -1, 128]); view_993 = None + view_999 = torch.ops.aten.view.default(view_996, [2, 8192, -1, 128]); view_996 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + view_1000 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 32, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_1000); view_1000 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_998, torch.float32); view_998 = None + view_1001 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 8, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_1001); view_1001 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_16); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_1003 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 32, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_16); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_1004 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 8, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_1003, torch.bfloat16); view_1003 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_1004, torch.bfloat16); view_1004 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 8, 4, 128]); unsqueeze_58 = None + clone_58 = torch.ops.aten.clone.default(expand_58, memory_format = torch.contiguous_format); expand_58 = None + view_1005 = torch.ops.aten.view.default(clone_58, [2, 8192, 32, 128]); clone_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_999, 3); view_999 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 8, 4, 128]); unsqueeze_59 = None + clone_59 = torch.ops.aten.clone.default(expand_59, memory_format = torch.contiguous_format); expand_59 = None + view_1006 = torch.ops.aten.view.default(clone_59, [2, 8192, 32, 128]); clone_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_1005, [0, 2, 1, 3]); view_1005 = None + permute_324 = torch.ops.aten.permute.default(view_1006, [0, 2, 1, 3]); view_1006 = None + _scaled_dot_product_cudnn_attention_29 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_322, permute_323, permute_324, None, True, 0.0, True); permute_322 = permute_323 = permute_324 = None + getitem_261 = _scaled_dot_product_cudnn_attention_29[0] + getitem_262 = _scaled_dot_product_cudnn_attention_29[1] + getitem_267 = _scaled_dot_product_cudnn_attention_29[6] + getitem_268 = _scaled_dot_product_cudnn_attention_29[7]; _scaled_dot_product_cudnn_attention_29 = None + permute_325 = torch.ops.aten.permute.default(getitem_261, [0, 2, 1, 3]) + view_1007 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16) + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 64, '0'); convert_element_type_974 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + view_1009 = torch.ops.aten.view.default(view_1007, [16384, 4096]); view_1007 = None + mm_206 = torch.ops.aten.mm.default(view_1009, permute_326); view_1009 = permute_326 = None + view_1010 = torch.ops.aten.view.default(mm_206, [2, 8192, 4096]); mm_206 = None + add_117 = torch.ops.aten.add.Tensor(add_115, view_1010); view_1010 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16) + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 64, '0'); convert_element_type_977 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32) + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = rsqrt_59 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_267); mul_236 = wait_tensor_267 = None + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16) + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 64, '0'); convert_element_type_980 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_268, [1, 0]); wait_tensor_268 = None + view_1013 = torch.ops.aten.view.default(convert_element_type_979, [16384, 4096]); convert_element_type_979 = None + mm_207 = torch.ops.aten.mm.default(view_1013, permute_327); permute_327 = None + view_1014 = torch.ops.aten.view.default(mm_207, [2, 8192, 14336]) + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_1014, torch.float32); view_1014 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); convert_element_type_983 = sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16) + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 64, '0'); convert_element_type_985 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_269, [1, 0]); wait_tensor_269 = None + mm_208 = torch.ops.aten.mm.default(view_1013, permute_328); view_1013 = permute_328 = None + view_1017 = torch.ops.aten.view.default(mm_208, [2, 8192, 14336]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_1017); convert_element_type_984 = view_1017 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16) + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 64, '0'); convert_element_type_988 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_270, [1, 0]); wait_tensor_270 = None + view_1019 = torch.ops.aten.view.default(mul_239, [16384, 14336]); mul_239 = None + mm_209 = torch.ops.aten.mm.default(view_1019, permute_329); view_1019 = permute_329 = None + view_1020 = torch.ops.aten.view.default(mm_209, [2, 8192, 4096]); mm_209 = None + add_119 = torch.ops.aten.add.Tensor(add_117, view_1020); add_117 = view_1020 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16) + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 64, '0'); convert_element_type_991 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32) + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = rsqrt_60 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_271); mul_240 = wait_tensor_271 = None + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16) + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 64, '0'); convert_element_type_994 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + view_1023 = torch.ops.aten.view.default(convert_element_type_993, [16384, 4096]); convert_element_type_993 = None + mm_210 = torch.ops.aten.mm.default(view_1023, permute_330); permute_330 = None + view_1024 = torch.ops.aten.view.default(mm_210, [2, 8192, 4096]) + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16) + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 64, '0'); convert_element_type_997 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + mm_211 = torch.ops.aten.mm.default(view_1023, permute_331); permute_331 = None + view_1027 = torch.ops.aten.view.default(mm_211, [2, 8192, 1024]); mm_211 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16) + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 64, '0'); convert_element_type_1000 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_274, [1, 0]); wait_tensor_274 = None + mm_212 = torch.ops.aten.mm.default(view_1023, permute_332); view_1023 = permute_332 = None + view_1030 = torch.ops.aten.view.default(mm_212, [2, 8192, 1024]) + view_1031 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1032 = torch.ops.aten.view.default(view_1027, [2, 8192, -1, 128]); view_1027 = None + view_1033 = torch.ops.aten.view.default(view_1030, [2, 8192, -1, 128]); view_1030 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_1031, torch.float32); view_1031 = None + view_1034 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 32, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_1034); view_1034 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_1032, torch.float32); view_1032 = None + view_1035 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 8, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_1035); view_1035 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_16); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_1037 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 32, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_16); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_1038 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 8, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_1037, torch.bfloat16); view_1037 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_1038, torch.bfloat16); view_1038 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 8, 4, 128]); unsqueeze_60 = None + clone_60 = torch.ops.aten.clone.default(expand_60, memory_format = torch.contiguous_format); expand_60 = None + view_1039 = torch.ops.aten.view.default(clone_60, [2, 8192, 32, 128]); clone_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_1033, 3); view_1033 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 8, 4, 128]); unsqueeze_61 = None + clone_61 = torch.ops.aten.clone.default(expand_61, memory_format = torch.contiguous_format); expand_61 = None + view_1040 = torch.ops.aten.view.default(clone_61, [2, 8192, 32, 128]); clone_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_1039, [0, 2, 1, 3]); view_1039 = None + permute_335 = torch.ops.aten.permute.default(view_1040, [0, 2, 1, 3]); view_1040 = None + _scaled_dot_product_cudnn_attention_30 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_333, permute_334, permute_335, None, True, 0.0, True); permute_333 = permute_334 = permute_335 = None + getitem_270 = _scaled_dot_product_cudnn_attention_30[0] + getitem_271 = _scaled_dot_product_cudnn_attention_30[1] + getitem_276 = _scaled_dot_product_cudnn_attention_30[6] + getitem_277 = _scaled_dot_product_cudnn_attention_30[7]; _scaled_dot_product_cudnn_attention_30 = None + permute_336 = torch.ops.aten.permute.default(getitem_270, [0, 2, 1, 3]) + view_1041 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16) + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 64, '0'); convert_element_type_1007 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_275, [1, 0]); wait_tensor_275 = None + view_1043 = torch.ops.aten.view.default(view_1041, [16384, 4096]); view_1041 = None + mm_213 = torch.ops.aten.mm.default(view_1043, permute_337); view_1043 = permute_337 = None + view_1044 = torch.ops.aten.view.default(mm_213, [2, 8192, 4096]); mm_213 = None + add_121 = torch.ops.aten.add.Tensor(add_119, view_1044); view_1044 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16) + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 64, '0'); convert_element_type_1010 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32) + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = rsqrt_61 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_276); mul_244 = wait_tensor_276 = None + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16) + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 64, '0'); convert_element_type_1013 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + view_1047 = torch.ops.aten.view.default(convert_element_type_1012, [16384, 4096]); convert_element_type_1012 = None + mm_214 = torch.ops.aten.mm.default(view_1047, permute_338); permute_338 = None + view_1048 = torch.ops.aten.view.default(mm_214, [2, 8192, 14336]) + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_1048, torch.float32); view_1048 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); convert_element_type_1016 = sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16) + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 64, '0'); convert_element_type_1018 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_215 = torch.ops.aten.mm.default(view_1047, permute_339); view_1047 = permute_339 = None + view_1051 = torch.ops.aten.view.default(mm_215, [2, 8192, 14336]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_1051); convert_element_type_1017 = view_1051 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16) + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 64, '0'); convert_element_type_1021 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + view_1053 = torch.ops.aten.view.default(mul_247, [16384, 14336]); mul_247 = None + mm_216 = torch.ops.aten.mm.default(view_1053, permute_340); view_1053 = permute_340 = None + view_1054 = torch.ops.aten.view.default(mm_216, [2, 8192, 4096]); mm_216 = None + add_123 = torch.ops.aten.add.Tensor(add_121, view_1054); add_121 = view_1054 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16) + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 64, '0'); convert_element_type_1024 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32) + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = rsqrt_62 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_280); mul_248 = wait_tensor_280 = None + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16) + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 64, '0'); convert_element_type_1027 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_281, [1, 0]); wait_tensor_281 = None + view_1057 = torch.ops.aten.view.default(convert_element_type_1026, [16384, 4096]); convert_element_type_1026 = None + mm_217 = torch.ops.aten.mm.default(view_1057, permute_341); permute_341 = None + view_1058 = torch.ops.aten.view.default(mm_217, [2, 8192, 4096]) + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16) + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 64, '0'); convert_element_type_1030 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_282, [1, 0]); wait_tensor_282 = None + mm_218 = torch.ops.aten.mm.default(view_1057, permute_342); permute_342 = None + view_1061 = torch.ops.aten.view.default(mm_218, [2, 8192, 1024]); mm_218 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16) + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 64, '0'); convert_element_type_1033 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_283, [1, 0]); wait_tensor_283 = None + mm_219 = torch.ops.aten.mm.default(view_1057, permute_343); view_1057 = permute_343 = None + view_1064 = torch.ops.aten.view.default(mm_219, [2, 8192, 1024]) + view_1065 = torch.ops.aten.view.default(view_1058, [2, 8192, -1, 128]); view_1058 = None + view_1066 = torch.ops.aten.view.default(view_1061, [2, 8192, -1, 128]); view_1061 = None + view_1067 = torch.ops.aten.view.default(view_1064, [2, 8192, -1, 128]); view_1064 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_1065, torch.float32); view_1065 = None + view_1068 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 32, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_1068); view_1068 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_1066, torch.float32); view_1066 = None + view_1069 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 8, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_1069); view_1069 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_16); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_1071 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 32, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_16); view_as_complex_63 = view_16 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_1072 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 8, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_1071, torch.bfloat16); view_1071 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_1072, torch.bfloat16); view_1072 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 8, 4, 128]); unsqueeze_62 = None + clone_62 = torch.ops.aten.clone.default(expand_62, memory_format = torch.contiguous_format); expand_62 = None + view_1073 = torch.ops.aten.view.default(clone_62, [2, 8192, 32, 128]); clone_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_1067, 3); view_1067 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 8, 4, 128]); unsqueeze_63 = None + clone_63 = torch.ops.aten.clone.default(expand_63, memory_format = torch.contiguous_format); expand_63 = None + view_1074 = torch.ops.aten.view.default(clone_63, [2, 8192, 32, 128]); clone_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_1073, [0, 2, 1, 3]); view_1073 = None + permute_346 = torch.ops.aten.permute.default(view_1074, [0, 2, 1, 3]); view_1074 = None + _scaled_dot_product_cudnn_attention_31 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_344, permute_345, permute_346, None, True, 0.0, True); permute_344 = permute_345 = permute_346 = None + getitem_279 = _scaled_dot_product_cudnn_attention_31[0] + getitem_280 = _scaled_dot_product_cudnn_attention_31[1] + getitem_285 = _scaled_dot_product_cudnn_attention_31[6] + getitem_286 = _scaled_dot_product_cudnn_attention_31[7]; _scaled_dot_product_cudnn_attention_31 = None + permute_347 = torch.ops.aten.permute.default(getitem_279, [0, 2, 1, 3]) + view_1075 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16) + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 64, '0'); convert_element_type_1040 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + view_1077 = torch.ops.aten.view.default(view_1075, [16384, 4096]); view_1075 = None + mm_220 = torch.ops.aten.mm.default(view_1077, permute_348); view_1077 = permute_348 = None + view_1078 = torch.ops.aten.view.default(mm_220, [2, 8192, 4096]); mm_220 = None + add_125 = torch.ops.aten.add.Tensor(add_123, view_1078); view_1078 = None + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16) + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 64, '0'); convert_element_type_1043 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32) + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = rsqrt_63 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_285); mul_252 = wait_tensor_285 = None + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16) + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 64, '0'); convert_element_type_1046 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + view_1081 = torch.ops.aten.view.default(convert_element_type_1045, [16384, 4096]); convert_element_type_1045 = None + mm_221 = torch.ops.aten.mm.default(view_1081, permute_349); permute_349 = None + view_1082 = torch.ops.aten.view.default(mm_221, [2, 8192, 14336]) + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_1082, torch.float32); view_1082 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); convert_element_type_1049 = sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16) + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 64, '0'); convert_element_type_1051 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_287, [1, 0]); wait_tensor_287 = None + mm_222 = torch.ops.aten.mm.default(view_1081, permute_350); view_1081 = permute_350 = None + view_1085 = torch.ops.aten.view.default(mm_222, [2, 8192, 14336]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_1085); convert_element_type_1050 = view_1085 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16) + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 64, '0'); convert_element_type_1054 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_288, [1, 0]); wait_tensor_288 = None + view_1087 = torch.ops.aten.view.default(mul_255, [16384, 14336]); mul_255 = None + mm_223 = torch.ops.aten.mm.default(view_1087, permute_351); view_1087 = permute_351 = None + view_1088 = torch.ops.aten.view.default(mm_223, [2, 8192, 4096]) + add_127 = torch.ops.aten.add.Tensor(add_125, view_1088); add_125 = view_1088 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16) + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 64, '0'); convert_element_type_1057 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + pow_65 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1058, 2) + mean_64 = torch.ops.aten.mean.dim(pow_65, [2], True); pow_65 = None + add_128 = torch.ops.aten.add.Scalar(mean_64, 1e-05); mean_64 = None + rsqrt_64 = torch.ops.aten.rsqrt.default(add_128); add_128 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_257 = torch.ops.aten.mul.Tensor(mul_256, wait_tensor_289); mul_256 = wait_tensor_289 = None + convert_element_type_1059 = torch.ops.prims.convert_element_type.default(mul_257, torch.bfloat16); mul_257 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16) + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 64, '0'); convert_element_type_1060 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + view_1091 = torch.ops.aten.view.default(convert_element_type_1059, [16384, 4096]); convert_element_type_1059 = None + mm_224 = torch.ops.aten.mm.default(view_1091, permute_352); permute_352 = None + view_1092 = torch.ops.aten.view.default(mm_224, [2, 8192, 128256]); mm_224 = None + return (view_1092, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, embedding, mm, mm_2, getitem, getitem_1, getitem_6, getitem_7, mm_4, add_3, mm_7, mm_9, getitem_9, getitem_10, getitem_15, getitem_16, mm_11, add_7, mm_14, mm_16, getitem_18, getitem_19, getitem_24, getitem_25, mm_18, add_11, mm_21, mm_23, getitem_27, getitem_28, getitem_33, getitem_34, mm_25, add_15, mm_28, mm_30, getitem_36, getitem_37, getitem_42, getitem_43, mm_32, add_19, mm_35, mm_37, getitem_45, getitem_46, getitem_51, getitem_52, mm_39, add_23, mm_42, mm_44, getitem_54, getitem_55, getitem_60, getitem_61, mm_46, add_27, mm_49, mm_51, getitem_63, getitem_64, getitem_69, getitem_70, mm_53, add_31, mm_56, mm_58, getitem_72, getitem_73, getitem_78, getitem_79, mm_60, add_35, mm_63, mm_65, getitem_81, getitem_82, getitem_87, getitem_88, mm_67, add_39, mm_70, mm_72, getitem_90, getitem_91, getitem_96, getitem_97, mm_74, add_43, mm_77, mm_79, getitem_99, getitem_100, getitem_105, getitem_106, mm_81, add_47, mm_84, mm_86, getitem_108, getitem_109, getitem_114, getitem_115, mm_88, add_51, mm_91, mm_93, getitem_117, getitem_118, getitem_123, getitem_124, mm_95, add_55, mm_98, mm_100, getitem_126, getitem_127, getitem_132, getitem_133, mm_102, add_59, mm_105, mm_107, getitem_135, getitem_136, getitem_141, getitem_142, mm_109, add_63, mm_112, mm_114, getitem_144, getitem_145, getitem_150, getitem_151, mm_116, add_67, mm_119, mm_121, getitem_153, getitem_154, getitem_159, getitem_160, mm_123, add_71, mm_126, mm_128, getitem_162, getitem_163, getitem_168, getitem_169, mm_130, add_75, mm_133, mm_135, getitem_171, getitem_172, getitem_177, getitem_178, mm_137, add_79, mm_140, mm_142, getitem_180, getitem_181, getitem_186, getitem_187, mm_144, add_83, mm_147, mm_149, getitem_189, getitem_190, getitem_195, getitem_196, mm_151, add_87, mm_154, mm_156, getitem_198, getitem_199, getitem_204, getitem_205, mm_158, add_91, mm_161, mm_163, getitem_207, getitem_208, getitem_213, getitem_214, mm_165, add_95, mm_168, mm_170, getitem_216, getitem_217, getitem_222, getitem_223, mm_172, add_99, mm_175, mm_177, getitem_225, getitem_226, getitem_231, getitem_232, mm_179, add_103, mm_182, mm_184, getitem_234, getitem_235, getitem_240, getitem_241, mm_186, add_107, mm_189, mm_191, getitem_243, getitem_244, getitem_249, getitem_250, mm_193, add_111, mm_196, mm_198, getitem_252, getitem_253, getitem_258, getitem_259, mm_200, add_115, mm_203, mm_205, getitem_261, getitem_262, getitem_267, getitem_268, mm_207, add_119, mm_210, mm_212, getitem_270, getitem_271, getitem_276, getitem_277, mm_214, add_123, mm_217, mm_219, getitem_279, getitem_280, getitem_285, getitem_286, mm_221, mm_223, rsqrt_64, view_1091) + +def load_args(reader): + buf0 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf0, (2004, 4096), is_leaf=True) # primals_1 + buf1 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf3, (64,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf4, (64, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf5, (16, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf6, (16, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf7, (64, 4096), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf8, (64,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf9, (224, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf10, (224, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf11, (64, 14336), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf12, (64,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf13, (64, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf14, (16, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf15, (16, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf16, (64, 4096), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf17, (64,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf18, (224, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf19, (224, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf20, (64, 14336), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf21, (64,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf22, (64, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf23, (16, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf24, (16, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf25, (64, 4096), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf26, (64,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf27, (224, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf28, (224, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf29, (64, 14336), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf30, (64,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf31, (64, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf32, (16, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf33, (16, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf34, (64, 4096), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf35, (64,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf36, (224, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf37, (224, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf38, (64, 14336), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf39, (64,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf40, (64, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf41, (16, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf42, (16, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf43, (64, 4096), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf44, (64,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf45, (224, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf46, (224, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf47, (64, 14336), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf48, (64,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf49, (64, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf50, (16, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf51, (16, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf52, (64, 4096), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf53, (64,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf54, (224, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf55, (224, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf56, (64, 14336), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf57, (64,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf58, (64, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf59, (16, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf60, (16, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf61, (64, 4096), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf62, (64,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf63, (224, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf64, (224, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf65, (64, 14336), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf66, (64,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf67, (64, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf68, (16, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf69, (16, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf70, (64, 4096), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf71, (64,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf72, (224, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf73, (224, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf74, (64, 14336), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf75, (64,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf76, (64, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf77, (16, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf78, (16, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf79, (64, 4096), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf80, (64,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf81, (224, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf82, (224, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf83, (64, 14336), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf84, (64,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf85, (64, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf86, (16, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf87, (16, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf88, (64, 4096), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf89, (64,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf90, (224, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf91, (224, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf92, (64, 14336), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf93, (64,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf94, (64, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf95, (16, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf96, (16, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf97, (64, 4096), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf98, (64,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf99, (224, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf100, (224, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf101, (64, 14336), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf102, (64,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf103, (64, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf104, (16, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf105, (16, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf106, (64, 4096), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf107, (64,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf108, (224, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf109, (224, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf110, (64, 14336), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf111, (64,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf112, (64, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf113, (16, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf114, (16, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf115, (64, 4096), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf116, (64,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf117, (224, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf118, (224, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf119, (64, 14336), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf120, (64,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf121, (64, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf122, (16, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf123, (16, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf124, (64, 4096), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf125, (64,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf126, (224, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf127, (224, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf128, (64, 14336), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf129, (64,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf130, (64, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf131, (16, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf132, (16, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf133, (64, 4096), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf134, (64,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf135, (224, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf136, (224, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf137, (64, 14336), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf138, (64,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf139, (64, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf140, (16, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf141, (16, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf142, (64, 4096), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf143, (64,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf144, (224, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf145, (224, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf146, (64, 14336), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf147, (64,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf148, (64, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf149, (16, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf150, (16, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf151, (64, 4096), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf152, (64,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf153, (224, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf154, (224, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf155, (64, 14336), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf156, (64,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf157, (64, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf158, (16, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf159, (16, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf160, (64, 4096), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf161, (64,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf162, (224, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf163, (224, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf164, (64, 14336), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf165, (64,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf166, (64, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf167, (16, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf168, (16, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf169, (64, 4096), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf170, (64,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf171, (224, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf172, (224, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf173, (64, 14336), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf174, (64,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf175, (64, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf176, (16, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf177, (16, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf178, (64, 4096), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf179, (64,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf180, (224, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf181, (224, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf182, (64, 14336), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf183, (64,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf184, (64, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf185, (16, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf186, (16, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf187, (64, 4096), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf188, (64,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf189, (224, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf190, (224, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf191, (64, 14336), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf192, (64,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf193, (64, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf194, (16, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf195, (16, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf196, (64, 4096), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf197, (64,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf198, (224, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf199, (224, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf200, (64, 14336), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf201, (64,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf202, (64, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf203, (16, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf204, (16, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf205, (64, 4096), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf206, (64,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf207, (224, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf208, (224, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf209, (64, 14336), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf210, (64,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf211, (64, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf212, (16, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf213, (16, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf214, (64, 4096), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf215, (64,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf216, (224, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf217, (224, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf218, (64, 14336), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf219, (64,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf220, (64, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf221, (16, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf222, (16, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf223, (64, 4096), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf224, (64,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf225, (224, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf226, (224, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf227, (64, 14336), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf228, (64,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf229, (64, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf230, (16, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf231, (16, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf232, (64, 4096), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf233, (64,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf234, (224, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf235, (224, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf236, (64, 14336), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf237, (64,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf238, (64, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf239, (16, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf240, (16, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf241, (64, 4096), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf242, (64,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf243, (224, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf244, (224, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf245, (64, 14336), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf246, (64,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf247, (64, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf248, (16, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf249, (16, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf250, (64, 4096), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf251, (64,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf252, (224, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf253, (224, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf254, (64, 14336), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf255, (64,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf256, (64, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf257, (16, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf258, (16, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf259, (64, 4096), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf260, (64,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf261, (224, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf262, (224, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf263, (64, 14336), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf264, (64,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf265, (64, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf266, (16, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf267, (16, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf268, (64, 4096), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf269, (64,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf270, (224, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf271, (224, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf272, (64, 14336), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf273, (64,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf274, (64, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf275, (16, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf276, (16, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf277, (64, 4096), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf278, (64,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf279, (224, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf280, (224, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf281, (64, 14336), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf282, (64,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf283, (64, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf284, (16, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf285, (16, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf286, (64, 4096), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf287, (64,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf288, (224, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf289, (224, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf290, (64, 14336), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf291, (64,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf292, (2004, 4096), is_leaf=True) # primals_293 + +load_args._version = 0 + +def get_mesh_sizes(): + return 64, + +def get_colls_estimations_file(): + return "colls8_8.table" + +def get_pg_names(): + return "0", diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_2d_32layers.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_2d_32layers.py new file mode 100644 index 00000000..9b307613 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_2d_32layers.py @@ -0,0 +1,5657 @@ +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293): + convert_element_type = torch.ops.prims.convert_element_type.default(primals_2, torch.bfloat16) + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type, 8, '0'); convert_element_type = None + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None + lt = torch.ops.aten.lt.Scalar(primals_1, 0) + ge = torch.ops.aten.ge.Scalar(primals_1, 16032) + bitwise_or = torch.ops.aten.bitwise_or.Tensor(lt, ge); lt = ge = None + sub = torch.ops.aten.sub.Tensor(primals_1, 0) + full_default = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + index_put = torch.ops.aten.index_put.default(sub, [bitwise_or], full_default); sub = full_default = None + embedding = torch.ops.aten.embedding.default(wait_tensor, index_put); wait_tensor = index_put = None + full_default_1 = torch.ops.aten.full.default([], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + index_put_1 = torch.ops.aten.index_put.default(embedding, [bitwise_or], full_default_1); embedding = bitwise_or = full_default_1 = None + split_1 = torch.ops.aten.split.Tensor(index_put_1, 1024, 1); index_put_1 = None + getitem_8 = split_1[0] + getitem_17 = split_1[1] + getitem_26 = split_1[2] + getitem_35 = split_1[3] + getitem_44 = split_1[4] + getitem_53 = split_1[5] + getitem_62 = split_1[6] + getitem_71 = split_1[7]; split_1 = None + cat = torch.ops.aten.cat.default([getitem_8, getitem_17, getitem_26, getitem_35, getitem_44, getitem_53, getitem_62, getitem_71]); getitem_8 = getitem_17 = getitem_26 = getitem_35 = getitem_44 = getitem_53 = getitem_62 = getitem_71 = None + reduce_scatter_tensor = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat, 'sum', 8, '1'); cat = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16) + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 8, '0'); convert_element_type_1 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32) + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = rsqrt = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_2); mul = wait_tensor_2 = None + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_3, 8, '1'); convert_element_type_3 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + split_9 = torch.ops.aten.split.Tensor(wait_tensor_3, 2); wait_tensor_3 = None + getitem_72 = split_9[0] + getitem_73 = split_9[1] + getitem_74 = split_9[2] + getitem_75 = split_9[3] + getitem_76 = split_9[4] + getitem_77 = split_9[5] + getitem_78 = split_9[6] + getitem_79 = split_9[7]; split_9 = None + cat_1 = torch.ops.aten.cat.default([getitem_72, getitem_73, getitem_74, getitem_75, getitem_76, getitem_77, getitem_78, getitem_79], 1); getitem_72 = getitem_73 = getitem_74 = getitem_75 = getitem_76 = getitem_77 = getitem_78 = getitem_79 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16) + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 8, '0'); convert_element_type_4 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + view_15 = torch.ops.aten.view.default(cat_1, [16384, 4096]); cat_1 = None + mm = torch.ops.aten.mm.default(view_15, permute); permute = None + view_16 = torch.ops.aten.view.default(mm, [2, 8192, 512]) + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16) + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 8, '0'); convert_element_type_7 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + mm_1 = torch.ops.aten.mm.default(view_15, permute_1); permute_1 = None + view_23 = torch.ops.aten.view.default(mm_1, [2, 8192, 128]); mm_1 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16) + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 8, '0'); convert_element_type_10 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_6, [1, 0]); wait_tensor_6 = None + mm_2 = torch.ops.aten.mm.default(view_15, permute_2); view_15 = permute_2 = None + view_30 = torch.ops.aten.view.default(mm_2, [2, 8192, 128]) + view_32 = torch.ops.aten.view.default(view_16, [2, 8192, -1, 128]); view_16 = None + view_33 = torch.ops.aten.view.default(view_23, [2, 8192, -1, 128]); view_23 = None + view_34 = torch.ops.aten.view.default(view_30, [2, 8192, -1, 128]); view_30 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_32, torch.float32); view_32 = None + view_35 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 4, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_35); view_35 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_33, torch.float32); view_33 = None + view_36 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 1, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_36); view_36 = None + view_37 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]) + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_37); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_38 = torch.ops.aten.view.default(view_as_real, [2, 8192, 4, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_37); view_as_complex_1 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_39 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 1, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_38, torch.bfloat16); view_38 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_39, torch.bfloat16); view_39 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 1, 4, 128]); unsqueeze = None + view_40 = torch.ops.aten.view.default(expand, [2, 8192, 4, 128]); expand = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_34, 3); view_34 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 1, 4, 128]); unsqueeze_1 = None + view_41 = torch.ops.aten.view.default(expand_1, [2, 8192, 4, 128]); expand_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_40, [0, 2, 1, 3]); view_40 = None + permute_5 = torch.ops.aten.permute.default(view_41, [0, 2, 1, 3]); view_41 = None + _scaled_dot_product_cudnn_attention = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_3, permute_4, permute_5, None, True, 0.0, True); permute_3 = permute_4 = permute_5 = None + getitem_80 = _scaled_dot_product_cudnn_attention[0] + getitem_81 = _scaled_dot_product_cudnn_attention[1] + getitem_86 = _scaled_dot_product_cudnn_attention[6] + getitem_87 = _scaled_dot_product_cudnn_attention[7]; _scaled_dot_product_cudnn_attention = None + permute_6 = torch.ops.aten.permute.default(getitem_80, [0, 2, 1, 3]) + view_42 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16) + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 8, '0'); convert_element_type_17 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + view_48 = torch.ops.aten.view.default(view_42, [16384, 512]); view_42 = None + mm_3 = torch.ops.aten.mm.default(view_48, permute_7); view_48 = permute_7 = None + view_49 = torch.ops.aten.view.default(mm_3, [2, 8192, 4096]); mm_3 = None + split_10 = torch.ops.aten.split.Tensor(view_49, 1024, 1); view_49 = None + getitem_89 = split_10[0] + getitem_90 = split_10[1] + getitem_91 = split_10[2] + getitem_92 = split_10[3] + getitem_93 = split_10[4] + getitem_94 = split_10[5] + getitem_95 = split_10[6] + getitem_96 = split_10[7]; split_10 = None + cat_2 = torch.ops.aten.cat.default([getitem_89, getitem_90, getitem_91, getitem_92, getitem_93, getitem_94, getitem_95, getitem_96]); getitem_89 = getitem_90 = getitem_91 = getitem_92 = getitem_93 = getitem_94 = getitem_95 = getitem_96 = None + reduce_scatter_tensor_1 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_2, 'sum', 8, '1'); cat_2 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1) + add_1 = torch.ops.aten.add.Tensor(wait_tensor_1, wait_tensor_8); wait_tensor_8 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16) + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 8, '0'); convert_element_type_20 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32) + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = rsqrt_1 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_9); mul_4 = wait_tensor_9 = None + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_22, 8, '1'); convert_element_type_22 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + split_11 = torch.ops.aten.split.Tensor(wait_tensor_10, 2); wait_tensor_10 = None + getitem_97 = split_11[0] + getitem_98 = split_11[1] + getitem_99 = split_11[2] + getitem_100 = split_11[3] + getitem_101 = split_11[4] + getitem_102 = split_11[5] + getitem_103 = split_11[6] + getitem_104 = split_11[7]; split_11 = None + cat_3 = torch.ops.aten.cat.default([getitem_97, getitem_98, getitem_99, getitem_100, getitem_101, getitem_102, getitem_103, getitem_104], 1); getitem_97 = getitem_98 = getitem_99 = getitem_100 = getitem_101 = getitem_102 = getitem_103 = getitem_104 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16) + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 8, '0'); convert_element_type_23 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + view_60 = torch.ops.aten.view.default(cat_3, [16384, 4096]); cat_3 = None + mm_4 = torch.ops.aten.mm.default(view_60, permute_8); permute_8 = None + view_61 = torch.ops.aten.view.default(mm_4, [2, 8192, 1792]) + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_61, torch.float32); view_61 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); convert_element_type_26 = sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16) + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 8, '0'); convert_element_type_28 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_5 = torch.ops.aten.mm.default(view_60, permute_9); view_60 = permute_9 = None + view_68 = torch.ops.aten.view.default(mm_5, [2, 8192, 1792]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_68); convert_element_type_27 = view_68 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16) + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 8, '0'); convert_element_type_31 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + view_75 = torch.ops.aten.view.default(mul_7, [16384, 1792]); mul_7 = None + mm_6 = torch.ops.aten.mm.default(view_75, permute_10); view_75 = permute_10 = None + view_76 = torch.ops.aten.view.default(mm_6, [2, 8192, 4096]); mm_6 = None + split_12 = torch.ops.aten.split.Tensor(view_76, 1024, 1); view_76 = None + getitem_105 = split_12[0] + getitem_106 = split_12[1] + getitem_107 = split_12[2] + getitem_108 = split_12[3] + getitem_109 = split_12[4] + getitem_110 = split_12[5] + getitem_111 = split_12[6] + getitem_112 = split_12[7]; split_12 = None + cat_4 = torch.ops.aten.cat.default([getitem_105, getitem_106, getitem_107, getitem_108, getitem_109, getitem_110, getitem_111, getitem_112]); getitem_105 = getitem_106 = getitem_107 = getitem_108 = getitem_109 = getitem_110 = getitem_111 = getitem_112 = None + reduce_scatter_tensor_2 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_4, 'sum', 8, '1'); cat_4 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None + add_3 = torch.ops.aten.add.Tensor(add_1, wait_tensor_14); add_1 = wait_tensor_14 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16) + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 8, '0'); convert_element_type_34 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32) + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = rsqrt_2 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_15); mul_8 = wait_tensor_15 = None + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_36, 8, '1'); convert_element_type_36 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + split_13 = torch.ops.aten.split.Tensor(wait_tensor_16, 2); wait_tensor_16 = None + getitem_113 = split_13[0] + getitem_114 = split_13[1] + getitem_115 = split_13[2] + getitem_116 = split_13[3] + getitem_117 = split_13[4] + getitem_118 = split_13[5] + getitem_119 = split_13[6] + getitem_120 = split_13[7]; split_13 = None + cat_5 = torch.ops.aten.cat.default([getitem_113, getitem_114, getitem_115, getitem_116, getitem_117, getitem_118, getitem_119, getitem_120], 1); getitem_113 = getitem_114 = getitem_115 = getitem_116 = getitem_117 = getitem_118 = getitem_119 = getitem_120 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16) + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 8, '0'); convert_element_type_37 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + view_87 = torch.ops.aten.view.default(cat_5, [16384, 4096]); cat_5 = None + mm_7 = torch.ops.aten.mm.default(view_87, permute_11); permute_11 = None + view_88 = torch.ops.aten.view.default(mm_7, [2, 8192, 512]) + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16) + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 8, '0'); convert_element_type_40 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + mm_8 = torch.ops.aten.mm.default(view_87, permute_12); permute_12 = None + view_95 = torch.ops.aten.view.default(mm_8, [2, 8192, 128]); mm_8 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16) + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 8, '0'); convert_element_type_43 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_19, [1, 0]); wait_tensor_19 = None + mm_9 = torch.ops.aten.mm.default(view_87, permute_13); view_87 = permute_13 = None + view_102 = torch.ops.aten.view.default(mm_9, [2, 8192, 128]) + view_104 = torch.ops.aten.view.default(view_88, [2, 8192, -1, 128]); view_88 = None + view_105 = torch.ops.aten.view.default(view_95, [2, 8192, -1, 128]); view_95 = None + view_106 = torch.ops.aten.view.default(view_102, [2, 8192, -1, 128]); view_102 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_104, torch.float32); view_104 = None + view_107 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 4, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_107); view_107 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_105, torch.float32); view_105 = None + view_108 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 1, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_108); view_108 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_37); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_110 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 4, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_37); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_111 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 1, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_110, torch.bfloat16); view_110 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_111, torch.bfloat16); view_111 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 1, 4, 128]); unsqueeze_2 = None + view_112 = torch.ops.aten.view.default(expand_2, [2, 8192, 4, 128]); expand_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_106, 3); view_106 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 1, 4, 128]); unsqueeze_3 = None + view_113 = torch.ops.aten.view.default(expand_3, [2, 8192, 4, 128]); expand_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_112, [0, 2, 1, 3]); view_112 = None + permute_16 = torch.ops.aten.permute.default(view_113, [0, 2, 1, 3]); view_113 = None + _scaled_dot_product_cudnn_attention_1 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_14, permute_15, permute_16, None, True, 0.0, True); permute_14 = permute_15 = permute_16 = None + getitem_121 = _scaled_dot_product_cudnn_attention_1[0] + getitem_122 = _scaled_dot_product_cudnn_attention_1[1] + getitem_127 = _scaled_dot_product_cudnn_attention_1[6] + getitem_128 = _scaled_dot_product_cudnn_attention_1[7]; _scaled_dot_product_cudnn_attention_1 = None + permute_17 = torch.ops.aten.permute.default(getitem_121, [0, 2, 1, 3]) + view_114 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16) + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 8, '0'); convert_element_type_50 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + view_120 = torch.ops.aten.view.default(view_114, [16384, 512]); view_114 = None + mm_10 = torch.ops.aten.mm.default(view_120, permute_18); view_120 = permute_18 = None + view_121 = torch.ops.aten.view.default(mm_10, [2, 8192, 4096]); mm_10 = None + split_14 = torch.ops.aten.split.Tensor(view_121, 1024, 1); view_121 = None + getitem_130 = split_14[0] + getitem_131 = split_14[1] + getitem_132 = split_14[2] + getitem_133 = split_14[3] + getitem_134 = split_14[4] + getitem_135 = split_14[5] + getitem_136 = split_14[6] + getitem_137 = split_14[7]; split_14 = None + cat_6 = torch.ops.aten.cat.default([getitem_130, getitem_131, getitem_132, getitem_133, getitem_134, getitem_135, getitem_136, getitem_137]); getitem_130 = getitem_131 = getitem_132 = getitem_133 = getitem_134 = getitem_135 = getitem_136 = getitem_137 = None + reduce_scatter_tensor_3 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_6, 'sum', 8, '1'); cat_6 = None + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3) + add_5 = torch.ops.aten.add.Tensor(add_3, wait_tensor_21); wait_tensor_21 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16) + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 8, '0'); convert_element_type_53 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32) + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = rsqrt_3 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_22); mul_12 = wait_tensor_22 = None + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_55, 8, '1'); convert_element_type_55 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + split_15 = torch.ops.aten.split.Tensor(wait_tensor_23, 2); wait_tensor_23 = None + getitem_138 = split_15[0] + getitem_139 = split_15[1] + getitem_140 = split_15[2] + getitem_141 = split_15[3] + getitem_142 = split_15[4] + getitem_143 = split_15[5] + getitem_144 = split_15[6] + getitem_145 = split_15[7]; split_15 = None + cat_7 = torch.ops.aten.cat.default([getitem_138, getitem_139, getitem_140, getitem_141, getitem_142, getitem_143, getitem_144, getitem_145], 1); getitem_138 = getitem_139 = getitem_140 = getitem_141 = getitem_142 = getitem_143 = getitem_144 = getitem_145 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16) + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 8, '0'); convert_element_type_56 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_24, [1, 0]); wait_tensor_24 = None + view_132 = torch.ops.aten.view.default(cat_7, [16384, 4096]); cat_7 = None + mm_11 = torch.ops.aten.mm.default(view_132, permute_19); permute_19 = None + view_133 = torch.ops.aten.view.default(mm_11, [2, 8192, 1792]) + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_133, torch.float32); view_133 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); convert_element_type_59 = sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16) + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 8, '0'); convert_element_type_61 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + mm_12 = torch.ops.aten.mm.default(view_132, permute_20); view_132 = permute_20 = None + view_140 = torch.ops.aten.view.default(mm_12, [2, 8192, 1792]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_140); convert_element_type_60 = view_140 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16) + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 8, '0'); convert_element_type_64 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + view_147 = torch.ops.aten.view.default(mul_15, [16384, 1792]); mul_15 = None + mm_13 = torch.ops.aten.mm.default(view_147, permute_21); view_147 = permute_21 = None + view_148 = torch.ops.aten.view.default(mm_13, [2, 8192, 4096]); mm_13 = None + split_16 = torch.ops.aten.split.Tensor(view_148, 1024, 1); view_148 = None + getitem_146 = split_16[0] + getitem_147 = split_16[1] + getitem_148 = split_16[2] + getitem_149 = split_16[3] + getitem_150 = split_16[4] + getitem_151 = split_16[5] + getitem_152 = split_16[6] + getitem_153 = split_16[7]; split_16 = None + cat_8 = torch.ops.aten.cat.default([getitem_146, getitem_147, getitem_148, getitem_149, getitem_150, getitem_151, getitem_152, getitem_153]); getitem_146 = getitem_147 = getitem_148 = getitem_149 = getitem_150 = getitem_151 = getitem_152 = getitem_153 = None + reduce_scatter_tensor_4 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_8, 'sum', 8, '1'); cat_8 = None + wait_tensor_27 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_4); reduce_scatter_tensor_4 = None + add_7 = torch.ops.aten.add.Tensor(add_5, wait_tensor_27); add_5 = wait_tensor_27 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16) + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 8, '0'); convert_element_type_67 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32) + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = rsqrt_4 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_28); mul_16 = wait_tensor_28 = None + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_69, 8, '1'); convert_element_type_69 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + split_17 = torch.ops.aten.split.Tensor(wait_tensor_29, 2); wait_tensor_29 = None + getitem_154 = split_17[0] + getitem_155 = split_17[1] + getitem_156 = split_17[2] + getitem_157 = split_17[3] + getitem_158 = split_17[4] + getitem_159 = split_17[5] + getitem_160 = split_17[6] + getitem_161 = split_17[7]; split_17 = None + cat_9 = torch.ops.aten.cat.default([getitem_154, getitem_155, getitem_156, getitem_157, getitem_158, getitem_159, getitem_160, getitem_161], 1); getitem_154 = getitem_155 = getitem_156 = getitem_157 = getitem_158 = getitem_159 = getitem_160 = getitem_161 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16) + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 8, '0'); convert_element_type_70 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + view_159 = torch.ops.aten.view.default(cat_9, [16384, 4096]); cat_9 = None + mm_14 = torch.ops.aten.mm.default(view_159, permute_22); permute_22 = None + view_160 = torch.ops.aten.view.default(mm_14, [2, 8192, 512]) + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16) + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 8, '0'); convert_element_type_73 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_15 = torch.ops.aten.mm.default(view_159, permute_23); permute_23 = None + view_167 = torch.ops.aten.view.default(mm_15, [2, 8192, 128]); mm_15 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16) + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 8, '0'); convert_element_type_76 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + mm_16 = torch.ops.aten.mm.default(view_159, permute_24); view_159 = permute_24 = None + view_174 = torch.ops.aten.view.default(mm_16, [2, 8192, 128]) + view_176 = torch.ops.aten.view.default(view_160, [2, 8192, -1, 128]); view_160 = None + view_177 = torch.ops.aten.view.default(view_167, [2, 8192, -1, 128]); view_167 = None + view_178 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_176, torch.float32); view_176 = None + view_179 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 4, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_179); view_179 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_177, torch.float32); view_177 = None + view_180 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 1, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_180); view_180 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_37); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_182 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 4, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_37); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_183 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 1, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_182, torch.bfloat16); view_182 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_183, torch.bfloat16); view_183 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 1, 4, 128]); unsqueeze_4 = None + view_184 = torch.ops.aten.view.default(expand_4, [2, 8192, 4, 128]); expand_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_178, 3); view_178 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 1, 4, 128]); unsqueeze_5 = None + view_185 = torch.ops.aten.view.default(expand_5, [2, 8192, 4, 128]); expand_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_184, [0, 2, 1, 3]); view_184 = None + permute_27 = torch.ops.aten.permute.default(view_185, [0, 2, 1, 3]); view_185 = None + _scaled_dot_product_cudnn_attention_2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_25, permute_26, permute_27, None, True, 0.0, True); permute_25 = permute_26 = permute_27 = None + getitem_162 = _scaled_dot_product_cudnn_attention_2[0] + getitem_163 = _scaled_dot_product_cudnn_attention_2[1] + getitem_168 = _scaled_dot_product_cudnn_attention_2[6] + getitem_169 = _scaled_dot_product_cudnn_attention_2[7]; _scaled_dot_product_cudnn_attention_2 = None + permute_28 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_186 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16) + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 8, '0'); convert_element_type_83 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_33, [1, 0]); wait_tensor_33 = None + view_192 = torch.ops.aten.view.default(view_186, [16384, 512]); view_186 = None + mm_17 = torch.ops.aten.mm.default(view_192, permute_29); view_192 = permute_29 = None + view_193 = torch.ops.aten.view.default(mm_17, [2, 8192, 4096]); mm_17 = None + split_18 = torch.ops.aten.split.Tensor(view_193, 1024, 1); view_193 = None + getitem_171 = split_18[0] + getitem_172 = split_18[1] + getitem_173 = split_18[2] + getitem_174 = split_18[3] + getitem_175 = split_18[4] + getitem_176 = split_18[5] + getitem_177 = split_18[6] + getitem_178 = split_18[7]; split_18 = None + cat_10 = torch.ops.aten.cat.default([getitem_171, getitem_172, getitem_173, getitem_174, getitem_175, getitem_176, getitem_177, getitem_178]); getitem_171 = getitem_172 = getitem_173 = getitem_174 = getitem_175 = getitem_176 = getitem_177 = getitem_178 = None + reduce_scatter_tensor_5 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_10, 'sum', 8, '1'); cat_10 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5) + add_9 = torch.ops.aten.add.Tensor(add_7, wait_tensor_34); wait_tensor_34 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16) + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 8, '0'); convert_element_type_86 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32) + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = rsqrt_5 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_35); mul_20 = wait_tensor_35 = None + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_88, 8, '1'); convert_element_type_88 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + split_19 = torch.ops.aten.split.Tensor(wait_tensor_36, 2); wait_tensor_36 = None + getitem_179 = split_19[0] + getitem_180 = split_19[1] + getitem_181 = split_19[2] + getitem_182 = split_19[3] + getitem_183 = split_19[4] + getitem_184 = split_19[5] + getitem_185 = split_19[6] + getitem_186 = split_19[7]; split_19 = None + cat_11 = torch.ops.aten.cat.default([getitem_179, getitem_180, getitem_181, getitem_182, getitem_183, getitem_184, getitem_185, getitem_186], 1); getitem_179 = getitem_180 = getitem_181 = getitem_182 = getitem_183 = getitem_184 = getitem_185 = getitem_186 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16) + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 8, '0'); convert_element_type_89 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_37, [1, 0]); wait_tensor_37 = None + view_204 = torch.ops.aten.view.default(cat_11, [16384, 4096]); cat_11 = None + mm_18 = torch.ops.aten.mm.default(view_204, permute_30); permute_30 = None + view_205 = torch.ops.aten.view.default(mm_18, [2, 8192, 1792]) + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_205, torch.float32); view_205 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); convert_element_type_92 = sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16) + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 8, '0'); convert_element_type_94 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + mm_19 = torch.ops.aten.mm.default(view_204, permute_31); view_204 = permute_31 = None + view_212 = torch.ops.aten.view.default(mm_19, [2, 8192, 1792]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_212); convert_element_type_93 = view_212 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16) + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 8, '0'); convert_element_type_97 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + view_219 = torch.ops.aten.view.default(mul_23, [16384, 1792]); mul_23 = None + mm_20 = torch.ops.aten.mm.default(view_219, permute_32); view_219 = permute_32 = None + view_220 = torch.ops.aten.view.default(mm_20, [2, 8192, 4096]); mm_20 = None + split_20 = torch.ops.aten.split.Tensor(view_220, 1024, 1); view_220 = None + getitem_187 = split_20[0] + getitem_188 = split_20[1] + getitem_189 = split_20[2] + getitem_190 = split_20[3] + getitem_191 = split_20[4] + getitem_192 = split_20[5] + getitem_193 = split_20[6] + getitem_194 = split_20[7]; split_20 = None + cat_12 = torch.ops.aten.cat.default([getitem_187, getitem_188, getitem_189, getitem_190, getitem_191, getitem_192, getitem_193, getitem_194]); getitem_187 = getitem_188 = getitem_189 = getitem_190 = getitem_191 = getitem_192 = getitem_193 = getitem_194 = None + reduce_scatter_tensor_6 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_12, 'sum', 8, '1'); cat_12 = None + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_6); reduce_scatter_tensor_6 = None + add_11 = torch.ops.aten.add.Tensor(add_9, wait_tensor_40); add_9 = wait_tensor_40 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16) + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 8, '0'); convert_element_type_100 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32) + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = rsqrt_6 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_41); mul_24 = wait_tensor_41 = None + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_102, 8, '1'); convert_element_type_102 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + split_21 = torch.ops.aten.split.Tensor(wait_tensor_42, 2); wait_tensor_42 = None + getitem_195 = split_21[0] + getitem_196 = split_21[1] + getitem_197 = split_21[2] + getitem_198 = split_21[3] + getitem_199 = split_21[4] + getitem_200 = split_21[5] + getitem_201 = split_21[6] + getitem_202 = split_21[7]; split_21 = None + cat_13 = torch.ops.aten.cat.default([getitem_195, getitem_196, getitem_197, getitem_198, getitem_199, getitem_200, getitem_201, getitem_202], 1); getitem_195 = getitem_196 = getitem_197 = getitem_198 = getitem_199 = getitem_200 = getitem_201 = getitem_202 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16) + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 8, '0'); convert_element_type_103 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + view_231 = torch.ops.aten.view.default(cat_13, [16384, 4096]); cat_13 = None + mm_21 = torch.ops.aten.mm.default(view_231, permute_33); permute_33 = None + view_232 = torch.ops.aten.view.default(mm_21, [2, 8192, 512]) + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16) + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 8, '0'); convert_element_type_106 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_22 = torch.ops.aten.mm.default(view_231, permute_34); permute_34 = None + view_239 = torch.ops.aten.view.default(mm_22, [2, 8192, 128]); mm_22 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16) + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 8, '0'); convert_element_type_109 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + mm_23 = torch.ops.aten.mm.default(view_231, permute_35); view_231 = permute_35 = None + view_246 = torch.ops.aten.view.default(mm_23, [2, 8192, 128]) + view_248 = torch.ops.aten.view.default(view_232, [2, 8192, -1, 128]); view_232 = None + view_249 = torch.ops.aten.view.default(view_239, [2, 8192, -1, 128]); view_239 = None + view_250 = torch.ops.aten.view.default(view_246, [2, 8192, -1, 128]); view_246 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_248, torch.float32); view_248 = None + view_251 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 4, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_251); view_251 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 1, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_37); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_254 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 4, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_37); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_255 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 1, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_254, torch.bfloat16); view_254 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 1, 4, 128]); unsqueeze_6 = None + view_256 = torch.ops.aten.view.default(expand_6, [2, 8192, 4, 128]); expand_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_250, 3); view_250 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 1, 4, 128]); unsqueeze_7 = None + view_257 = torch.ops.aten.view.default(expand_7, [2, 8192, 4, 128]); expand_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_256, [0, 2, 1, 3]); view_256 = None + permute_38 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + _scaled_dot_product_cudnn_attention_3 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_36, permute_37, permute_38, None, True, 0.0, True); permute_36 = permute_37 = permute_38 = None + getitem_203 = _scaled_dot_product_cudnn_attention_3[0] + getitem_204 = _scaled_dot_product_cudnn_attention_3[1] + getitem_209 = _scaled_dot_product_cudnn_attention_3[6] + getitem_210 = _scaled_dot_product_cudnn_attention_3[7]; _scaled_dot_product_cudnn_attention_3 = None + permute_39 = torch.ops.aten.permute.default(getitem_203, [0, 2, 1, 3]) + view_258 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16) + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 8, '0'); convert_element_type_116 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_46, [1, 0]); wait_tensor_46 = None + view_264 = torch.ops.aten.view.default(view_258, [16384, 512]); view_258 = None + mm_24 = torch.ops.aten.mm.default(view_264, permute_40); view_264 = permute_40 = None + view_265 = torch.ops.aten.view.default(mm_24, [2, 8192, 4096]); mm_24 = None + split_22 = torch.ops.aten.split.Tensor(view_265, 1024, 1); view_265 = None + getitem_212 = split_22[0] + getitem_213 = split_22[1] + getitem_214 = split_22[2] + getitem_215 = split_22[3] + getitem_216 = split_22[4] + getitem_217 = split_22[5] + getitem_218 = split_22[6] + getitem_219 = split_22[7]; split_22 = None + cat_14 = torch.ops.aten.cat.default([getitem_212, getitem_213, getitem_214, getitem_215, getitem_216, getitem_217, getitem_218, getitem_219]); getitem_212 = getitem_213 = getitem_214 = getitem_215 = getitem_216 = getitem_217 = getitem_218 = getitem_219 = None + reduce_scatter_tensor_7 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_14, 'sum', 8, '1'); cat_14 = None + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7) + add_13 = torch.ops.aten.add.Tensor(add_11, wait_tensor_47); wait_tensor_47 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16) + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 8, '0'); convert_element_type_119 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32) + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = rsqrt_7 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_48); mul_28 = wait_tensor_48 = None + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_121, 8, '1'); convert_element_type_121 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + split_23 = torch.ops.aten.split.Tensor(wait_tensor_49, 2); wait_tensor_49 = None + getitem_220 = split_23[0] + getitem_221 = split_23[1] + getitem_222 = split_23[2] + getitem_223 = split_23[3] + getitem_224 = split_23[4] + getitem_225 = split_23[5] + getitem_226 = split_23[6] + getitem_227 = split_23[7]; split_23 = None + cat_15 = torch.ops.aten.cat.default([getitem_220, getitem_221, getitem_222, getitem_223, getitem_224, getitem_225, getitem_226, getitem_227], 1); getitem_220 = getitem_221 = getitem_222 = getitem_223 = getitem_224 = getitem_225 = getitem_226 = getitem_227 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16) + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 8, '0'); convert_element_type_122 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + view_276 = torch.ops.aten.view.default(cat_15, [16384, 4096]); cat_15 = None + mm_25 = torch.ops.aten.mm.default(view_276, permute_41); permute_41 = None + view_277 = torch.ops.aten.view.default(mm_25, [2, 8192, 1792]) + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_277, torch.float32); view_277 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); convert_element_type_125 = sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16) + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 8, '0'); convert_element_type_127 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_51, [1, 0]); wait_tensor_51 = None + mm_26 = torch.ops.aten.mm.default(view_276, permute_42); view_276 = permute_42 = None + view_284 = torch.ops.aten.view.default(mm_26, [2, 8192, 1792]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_284); convert_element_type_126 = view_284 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16) + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 8, '0'); convert_element_type_130 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + view_291 = torch.ops.aten.view.default(mul_31, [16384, 1792]); mul_31 = None + mm_27 = torch.ops.aten.mm.default(view_291, permute_43); view_291 = permute_43 = None + view_292 = torch.ops.aten.view.default(mm_27, [2, 8192, 4096]); mm_27 = None + split_24 = torch.ops.aten.split.Tensor(view_292, 1024, 1); view_292 = None + getitem_228 = split_24[0] + getitem_229 = split_24[1] + getitem_230 = split_24[2] + getitem_231 = split_24[3] + getitem_232 = split_24[4] + getitem_233 = split_24[5] + getitem_234 = split_24[6] + getitem_235 = split_24[7]; split_24 = None + cat_16 = torch.ops.aten.cat.default([getitem_228, getitem_229, getitem_230, getitem_231, getitem_232, getitem_233, getitem_234, getitem_235]); getitem_228 = getitem_229 = getitem_230 = getitem_231 = getitem_232 = getitem_233 = getitem_234 = getitem_235 = None + reduce_scatter_tensor_8 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_16, 'sum', 8, '1'); cat_16 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_8); reduce_scatter_tensor_8 = None + add_15 = torch.ops.aten.add.Tensor(add_13, wait_tensor_53); add_13 = wait_tensor_53 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16) + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 8, '0'); convert_element_type_133 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32) + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = rsqrt_8 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_54); mul_32 = wait_tensor_54 = None + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_135, 8, '1'); convert_element_type_135 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + split_25 = torch.ops.aten.split.Tensor(wait_tensor_55, 2); wait_tensor_55 = None + getitem_236 = split_25[0] + getitem_237 = split_25[1] + getitem_238 = split_25[2] + getitem_239 = split_25[3] + getitem_240 = split_25[4] + getitem_241 = split_25[5] + getitem_242 = split_25[6] + getitem_243 = split_25[7]; split_25 = None + cat_17 = torch.ops.aten.cat.default([getitem_236, getitem_237, getitem_238, getitem_239, getitem_240, getitem_241, getitem_242, getitem_243], 1); getitem_236 = getitem_237 = getitem_238 = getitem_239 = getitem_240 = getitem_241 = getitem_242 = getitem_243 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16) + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 8, '0'); convert_element_type_136 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + view_303 = torch.ops.aten.view.default(cat_17, [16384, 4096]); cat_17 = None + mm_28 = torch.ops.aten.mm.default(view_303, permute_44); permute_44 = None + view_304 = torch.ops.aten.view.default(mm_28, [2, 8192, 512]) + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16) + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 8, '0'); convert_element_type_139 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_29 = torch.ops.aten.mm.default(view_303, permute_45); permute_45 = None + view_311 = torch.ops.aten.view.default(mm_29, [2, 8192, 128]); mm_29 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16) + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 8, '0'); convert_element_type_142 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + mm_30 = torch.ops.aten.mm.default(view_303, permute_46); view_303 = permute_46 = None + view_318 = torch.ops.aten.view.default(mm_30, [2, 8192, 128]) + view_320 = torch.ops.aten.view.default(view_304, [2, 8192, -1, 128]); view_304 = None + view_321 = torch.ops.aten.view.default(view_311, [2, 8192, -1, 128]); view_311 = None + view_322 = torch.ops.aten.view.default(view_318, [2, 8192, -1, 128]); view_318 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_320, torch.float32); view_320 = None + view_323 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 4, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_323); view_323 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_321, torch.float32); view_321 = None + view_324 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 1, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_324); view_324 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_37); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_326 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 4, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_37); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_327 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 1, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_326, torch.bfloat16); view_326 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_327, torch.bfloat16); view_327 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 1, 4, 128]); unsqueeze_8 = None + view_328 = torch.ops.aten.view.default(expand_8, [2, 8192, 4, 128]); expand_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_322, 3); view_322 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 1, 4, 128]); unsqueeze_9 = None + view_329 = torch.ops.aten.view.default(expand_9, [2, 8192, 4, 128]); expand_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_328, [0, 2, 1, 3]); view_328 = None + permute_49 = torch.ops.aten.permute.default(view_329, [0, 2, 1, 3]); view_329 = None + _scaled_dot_product_cudnn_attention_4 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_47, permute_48, permute_49, None, True, 0.0, True); permute_47 = permute_48 = permute_49 = None + getitem_244 = _scaled_dot_product_cudnn_attention_4[0] + getitem_245 = _scaled_dot_product_cudnn_attention_4[1] + getitem_250 = _scaled_dot_product_cudnn_attention_4[6] + getitem_251 = _scaled_dot_product_cudnn_attention_4[7]; _scaled_dot_product_cudnn_attention_4 = None + permute_50 = torch.ops.aten.permute.default(getitem_244, [0, 2, 1, 3]) + view_330 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16) + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 8, '0'); convert_element_type_149 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + view_336 = torch.ops.aten.view.default(view_330, [16384, 512]); view_330 = None + mm_31 = torch.ops.aten.mm.default(view_336, permute_51); view_336 = permute_51 = None + view_337 = torch.ops.aten.view.default(mm_31, [2, 8192, 4096]); mm_31 = None + split_26 = torch.ops.aten.split.Tensor(view_337, 1024, 1); view_337 = None + getitem_253 = split_26[0] + getitem_254 = split_26[1] + getitem_255 = split_26[2] + getitem_256 = split_26[3] + getitem_257 = split_26[4] + getitem_258 = split_26[5] + getitem_259 = split_26[6] + getitem_260 = split_26[7]; split_26 = None + cat_18 = torch.ops.aten.cat.default([getitem_253, getitem_254, getitem_255, getitem_256, getitem_257, getitem_258, getitem_259, getitem_260]); getitem_253 = getitem_254 = getitem_255 = getitem_256 = getitem_257 = getitem_258 = getitem_259 = getitem_260 = None + reduce_scatter_tensor_9 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_18, 'sum', 8, '1'); cat_18 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9) + add_17 = torch.ops.aten.add.Tensor(add_15, wait_tensor_60); wait_tensor_60 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16) + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 8, '0'); convert_element_type_152 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32) + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = rsqrt_9 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_61); mul_36 = wait_tensor_61 = None + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_154, 8, '1'); convert_element_type_154 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + split_27 = torch.ops.aten.split.Tensor(wait_tensor_62, 2); wait_tensor_62 = None + getitem_261 = split_27[0] + getitem_262 = split_27[1] + getitem_263 = split_27[2] + getitem_264 = split_27[3] + getitem_265 = split_27[4] + getitem_266 = split_27[5] + getitem_267 = split_27[6] + getitem_268 = split_27[7]; split_27 = None + cat_19 = torch.ops.aten.cat.default([getitem_261, getitem_262, getitem_263, getitem_264, getitem_265, getitem_266, getitem_267, getitem_268], 1); getitem_261 = getitem_262 = getitem_263 = getitem_264 = getitem_265 = getitem_266 = getitem_267 = getitem_268 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16) + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 8, '0'); convert_element_type_155 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + view_348 = torch.ops.aten.view.default(cat_19, [16384, 4096]); cat_19 = None + mm_32 = torch.ops.aten.mm.default(view_348, permute_52); permute_52 = None + view_349 = torch.ops.aten.view.default(mm_32, [2, 8192, 1792]) + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_349, torch.float32); view_349 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); convert_element_type_158 = sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16) + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 8, '0'); convert_element_type_160 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_64, [1, 0]); wait_tensor_64 = None + mm_33 = torch.ops.aten.mm.default(view_348, permute_53); view_348 = permute_53 = None + view_356 = torch.ops.aten.view.default(mm_33, [2, 8192, 1792]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_356); convert_element_type_159 = view_356 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16) + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 8, '0'); convert_element_type_163 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + view_363 = torch.ops.aten.view.default(mul_39, [16384, 1792]); mul_39 = None + mm_34 = torch.ops.aten.mm.default(view_363, permute_54); view_363 = permute_54 = None + view_364 = torch.ops.aten.view.default(mm_34, [2, 8192, 4096]); mm_34 = None + split_28 = torch.ops.aten.split.Tensor(view_364, 1024, 1); view_364 = None + getitem_269 = split_28[0] + getitem_270 = split_28[1] + getitem_271 = split_28[2] + getitem_272 = split_28[3] + getitem_273 = split_28[4] + getitem_274 = split_28[5] + getitem_275 = split_28[6] + getitem_276 = split_28[7]; split_28 = None + cat_20 = torch.ops.aten.cat.default([getitem_269, getitem_270, getitem_271, getitem_272, getitem_273, getitem_274, getitem_275, getitem_276]); getitem_269 = getitem_270 = getitem_271 = getitem_272 = getitem_273 = getitem_274 = getitem_275 = getitem_276 = None + reduce_scatter_tensor_10 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_20, 'sum', 8, '1'); cat_20 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_10); reduce_scatter_tensor_10 = None + add_19 = torch.ops.aten.add.Tensor(add_17, wait_tensor_66); add_17 = wait_tensor_66 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16) + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 8, '0'); convert_element_type_166 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32) + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = rsqrt_10 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_67); mul_40 = wait_tensor_67 = None + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_168, 8, '1'); convert_element_type_168 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + split_29 = torch.ops.aten.split.Tensor(wait_tensor_68, 2); wait_tensor_68 = None + getitem_277 = split_29[0] + getitem_278 = split_29[1] + getitem_279 = split_29[2] + getitem_280 = split_29[3] + getitem_281 = split_29[4] + getitem_282 = split_29[5] + getitem_283 = split_29[6] + getitem_284 = split_29[7]; split_29 = None + cat_21 = torch.ops.aten.cat.default([getitem_277, getitem_278, getitem_279, getitem_280, getitem_281, getitem_282, getitem_283, getitem_284], 1); getitem_277 = getitem_278 = getitem_279 = getitem_280 = getitem_281 = getitem_282 = getitem_283 = getitem_284 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16) + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 8, '0'); convert_element_type_169 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_69, [1, 0]); wait_tensor_69 = None + view_375 = torch.ops.aten.view.default(cat_21, [16384, 4096]); cat_21 = None + mm_35 = torch.ops.aten.mm.default(view_375, permute_55); permute_55 = None + view_376 = torch.ops.aten.view.default(mm_35, [2, 8192, 512]) + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16) + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 8, '0'); convert_element_type_172 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + mm_36 = torch.ops.aten.mm.default(view_375, permute_56); permute_56 = None + view_383 = torch.ops.aten.view.default(mm_36, [2, 8192, 128]); mm_36 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16) + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 8, '0'); convert_element_type_175 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_37 = torch.ops.aten.mm.default(view_375, permute_57); view_375 = permute_57 = None + view_390 = torch.ops.aten.view.default(mm_37, [2, 8192, 128]) + view_392 = torch.ops.aten.view.default(view_376, [2, 8192, -1, 128]); view_376 = None + view_393 = torch.ops.aten.view.default(view_383, [2, 8192, -1, 128]); view_383 = None + view_394 = torch.ops.aten.view.default(view_390, [2, 8192, -1, 128]); view_390 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_392, torch.float32); view_392 = None + view_395 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 4, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_395); view_395 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_393, torch.float32); view_393 = None + view_396 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 1, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_396); view_396 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_37); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_398 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 4, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_37); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_399 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 1, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_398, torch.bfloat16); view_398 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_399, torch.bfloat16); view_399 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 1, 4, 128]); unsqueeze_10 = None + view_400 = torch.ops.aten.view.default(expand_10, [2, 8192, 4, 128]); expand_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_394, 3); view_394 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 1, 4, 128]); unsqueeze_11 = None + view_401 = torch.ops.aten.view.default(expand_11, [2, 8192, 4, 128]); expand_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_400, [0, 2, 1, 3]); view_400 = None + permute_60 = torch.ops.aten.permute.default(view_401, [0, 2, 1, 3]); view_401 = None + _scaled_dot_product_cudnn_attention_5 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_58, permute_59, permute_60, None, True, 0.0, True); permute_58 = permute_59 = permute_60 = None + getitem_285 = _scaled_dot_product_cudnn_attention_5[0] + getitem_286 = _scaled_dot_product_cudnn_attention_5[1] + getitem_291 = _scaled_dot_product_cudnn_attention_5[6] + getitem_292 = _scaled_dot_product_cudnn_attention_5[7]; _scaled_dot_product_cudnn_attention_5 = None + permute_61 = torch.ops.aten.permute.default(getitem_285, [0, 2, 1, 3]) + view_402 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16) + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 8, '0'); convert_element_type_182 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + view_408 = torch.ops.aten.view.default(view_402, [16384, 512]); view_402 = None + mm_38 = torch.ops.aten.mm.default(view_408, permute_62); view_408 = permute_62 = None + view_409 = torch.ops.aten.view.default(mm_38, [2, 8192, 4096]); mm_38 = None + split_30 = torch.ops.aten.split.Tensor(view_409, 1024, 1); view_409 = None + getitem_294 = split_30[0] + getitem_295 = split_30[1] + getitem_296 = split_30[2] + getitem_297 = split_30[3] + getitem_298 = split_30[4] + getitem_299 = split_30[5] + getitem_300 = split_30[6] + getitem_301 = split_30[7]; split_30 = None + cat_22 = torch.ops.aten.cat.default([getitem_294, getitem_295, getitem_296, getitem_297, getitem_298, getitem_299, getitem_300, getitem_301]); getitem_294 = getitem_295 = getitem_296 = getitem_297 = getitem_298 = getitem_299 = getitem_300 = getitem_301 = None + reduce_scatter_tensor_11 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_22, 'sum', 8, '1'); cat_22 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11) + add_21 = torch.ops.aten.add.Tensor(add_19, wait_tensor_73); wait_tensor_73 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16) + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 8, '0'); convert_element_type_185 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32) + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = rsqrt_11 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_74); mul_44 = wait_tensor_74 = None + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_187, 8, '1'); convert_element_type_187 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + split_31 = torch.ops.aten.split.Tensor(wait_tensor_75, 2); wait_tensor_75 = None + getitem_302 = split_31[0] + getitem_303 = split_31[1] + getitem_304 = split_31[2] + getitem_305 = split_31[3] + getitem_306 = split_31[4] + getitem_307 = split_31[5] + getitem_308 = split_31[6] + getitem_309 = split_31[7]; split_31 = None + cat_23 = torch.ops.aten.cat.default([getitem_302, getitem_303, getitem_304, getitem_305, getitem_306, getitem_307, getitem_308, getitem_309], 1); getitem_302 = getitem_303 = getitem_304 = getitem_305 = getitem_306 = getitem_307 = getitem_308 = getitem_309 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16) + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 8, '0'); convert_element_type_188 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + view_420 = torch.ops.aten.view.default(cat_23, [16384, 4096]); cat_23 = None + mm_39 = torch.ops.aten.mm.default(view_420, permute_63); permute_63 = None + view_421 = torch.ops.aten.view.default(mm_39, [2, 8192, 1792]) + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_421, torch.float32); view_421 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); convert_element_type_191 = sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16) + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 8, '0'); convert_element_type_193 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + mm_40 = torch.ops.aten.mm.default(view_420, permute_64); view_420 = permute_64 = None + view_428 = torch.ops.aten.view.default(mm_40, [2, 8192, 1792]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_428); convert_element_type_192 = view_428 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16) + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 8, '0'); convert_element_type_196 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_78, [1, 0]); wait_tensor_78 = None + view_435 = torch.ops.aten.view.default(mul_47, [16384, 1792]); mul_47 = None + mm_41 = torch.ops.aten.mm.default(view_435, permute_65); view_435 = permute_65 = None + view_436 = torch.ops.aten.view.default(mm_41, [2, 8192, 4096]); mm_41 = None + split_32 = torch.ops.aten.split.Tensor(view_436, 1024, 1); view_436 = None + getitem_310 = split_32[0] + getitem_311 = split_32[1] + getitem_312 = split_32[2] + getitem_313 = split_32[3] + getitem_314 = split_32[4] + getitem_315 = split_32[5] + getitem_316 = split_32[6] + getitem_317 = split_32[7]; split_32 = None + cat_24 = torch.ops.aten.cat.default([getitem_310, getitem_311, getitem_312, getitem_313, getitem_314, getitem_315, getitem_316, getitem_317]); getitem_310 = getitem_311 = getitem_312 = getitem_313 = getitem_314 = getitem_315 = getitem_316 = getitem_317 = None + reduce_scatter_tensor_12 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_24, 'sum', 8, '1'); cat_24 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_12); reduce_scatter_tensor_12 = None + add_23 = torch.ops.aten.add.Tensor(add_21, wait_tensor_79); add_21 = wait_tensor_79 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16) + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 8, '0'); convert_element_type_199 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32) + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = rsqrt_12 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_80); mul_48 = wait_tensor_80 = None + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_201, 8, '1'); convert_element_type_201 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + split_33 = torch.ops.aten.split.Tensor(wait_tensor_81, 2); wait_tensor_81 = None + getitem_318 = split_33[0] + getitem_319 = split_33[1] + getitem_320 = split_33[2] + getitem_321 = split_33[3] + getitem_322 = split_33[4] + getitem_323 = split_33[5] + getitem_324 = split_33[6] + getitem_325 = split_33[7]; split_33 = None + cat_25 = torch.ops.aten.cat.default([getitem_318, getitem_319, getitem_320, getitem_321, getitem_322, getitem_323, getitem_324, getitem_325], 1); getitem_318 = getitem_319 = getitem_320 = getitem_321 = getitem_322 = getitem_323 = getitem_324 = getitem_325 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16) + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 8, '0'); convert_element_type_202 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_82, [1, 0]); wait_tensor_82 = None + view_447 = torch.ops.aten.view.default(cat_25, [16384, 4096]); cat_25 = None + mm_42 = torch.ops.aten.mm.default(view_447, permute_66); permute_66 = None + view_448 = torch.ops.aten.view.default(mm_42, [2, 8192, 512]) + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16) + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 8, '0'); convert_element_type_205 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + mm_43 = torch.ops.aten.mm.default(view_447, permute_67); permute_67 = None + view_455 = torch.ops.aten.view.default(mm_43, [2, 8192, 128]); mm_43 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16) + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 8, '0'); convert_element_type_208 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + mm_44 = torch.ops.aten.mm.default(view_447, permute_68); view_447 = permute_68 = None + view_462 = torch.ops.aten.view.default(mm_44, [2, 8192, 128]) + view_464 = torch.ops.aten.view.default(view_448, [2, 8192, -1, 128]); view_448 = None + view_465 = torch.ops.aten.view.default(view_455, [2, 8192, -1, 128]); view_455 = None + view_466 = torch.ops.aten.view.default(view_462, [2, 8192, -1, 128]); view_462 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_464, torch.float32); view_464 = None + view_467 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 4, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_467); view_467 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_465, torch.float32); view_465 = None + view_468 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 1, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_468); view_468 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_37); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_470 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 4, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_37); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_471 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 1, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_470, torch.bfloat16); view_470 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_471, torch.bfloat16); view_471 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 1, 4, 128]); unsqueeze_12 = None + view_472 = torch.ops.aten.view.default(expand_12, [2, 8192, 4, 128]); expand_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_466, 3); view_466 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 1, 4, 128]); unsqueeze_13 = None + view_473 = torch.ops.aten.view.default(expand_13, [2, 8192, 4, 128]); expand_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_472, [0, 2, 1, 3]); view_472 = None + permute_71 = torch.ops.aten.permute.default(view_473, [0, 2, 1, 3]); view_473 = None + _scaled_dot_product_cudnn_attention_6 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_69, permute_70, permute_71, None, True, 0.0, True); permute_69 = permute_70 = permute_71 = None + getitem_326 = _scaled_dot_product_cudnn_attention_6[0] + getitem_327 = _scaled_dot_product_cudnn_attention_6[1] + getitem_332 = _scaled_dot_product_cudnn_attention_6[6] + getitem_333 = _scaled_dot_product_cudnn_attention_6[7]; _scaled_dot_product_cudnn_attention_6 = None + permute_72 = torch.ops.aten.permute.default(getitem_326, [0, 2, 1, 3]) + view_474 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16) + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 8, '0'); convert_element_type_215 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + view_480 = torch.ops.aten.view.default(view_474, [16384, 512]); view_474 = None + mm_45 = torch.ops.aten.mm.default(view_480, permute_73); view_480 = permute_73 = None + view_481 = torch.ops.aten.view.default(mm_45, [2, 8192, 4096]); mm_45 = None + split_34 = torch.ops.aten.split.Tensor(view_481, 1024, 1); view_481 = None + getitem_335 = split_34[0] + getitem_336 = split_34[1] + getitem_337 = split_34[2] + getitem_338 = split_34[3] + getitem_339 = split_34[4] + getitem_340 = split_34[5] + getitem_341 = split_34[6] + getitem_342 = split_34[7]; split_34 = None + cat_26 = torch.ops.aten.cat.default([getitem_335, getitem_336, getitem_337, getitem_338, getitem_339, getitem_340, getitem_341, getitem_342]); getitem_335 = getitem_336 = getitem_337 = getitem_338 = getitem_339 = getitem_340 = getitem_341 = getitem_342 = None + reduce_scatter_tensor_13 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_26, 'sum', 8, '1'); cat_26 = None + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13) + add_25 = torch.ops.aten.add.Tensor(add_23, wait_tensor_86); wait_tensor_86 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16) + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 8, '0'); convert_element_type_218 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32) + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = rsqrt_13 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_87); mul_52 = wait_tensor_87 = None + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_220, 8, '1'); convert_element_type_220 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + split_35 = torch.ops.aten.split.Tensor(wait_tensor_88, 2); wait_tensor_88 = None + getitem_343 = split_35[0] + getitem_344 = split_35[1] + getitem_345 = split_35[2] + getitem_346 = split_35[3] + getitem_347 = split_35[4] + getitem_348 = split_35[5] + getitem_349 = split_35[6] + getitem_350 = split_35[7]; split_35 = None + cat_27 = torch.ops.aten.cat.default([getitem_343, getitem_344, getitem_345, getitem_346, getitem_347, getitem_348, getitem_349, getitem_350], 1); getitem_343 = getitem_344 = getitem_345 = getitem_346 = getitem_347 = getitem_348 = getitem_349 = getitem_350 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16) + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 8, '0'); convert_element_type_221 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + view_492 = torch.ops.aten.view.default(cat_27, [16384, 4096]); cat_27 = None + mm_46 = torch.ops.aten.mm.default(view_492, permute_74); permute_74 = None + view_493 = torch.ops.aten.view.default(mm_46, [2, 8192, 1792]) + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_493, torch.float32); view_493 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); convert_element_type_224 = sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16) + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 8, '0'); convert_element_type_226 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + mm_47 = torch.ops.aten.mm.default(view_492, permute_75); view_492 = permute_75 = None + view_500 = torch.ops.aten.view.default(mm_47, [2, 8192, 1792]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_500); convert_element_type_225 = view_500 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16) + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 8, '0'); convert_element_type_229 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_91, [1, 0]); wait_tensor_91 = None + view_507 = torch.ops.aten.view.default(mul_55, [16384, 1792]); mul_55 = None + mm_48 = torch.ops.aten.mm.default(view_507, permute_76); view_507 = permute_76 = None + view_508 = torch.ops.aten.view.default(mm_48, [2, 8192, 4096]); mm_48 = None + split_36 = torch.ops.aten.split.Tensor(view_508, 1024, 1); view_508 = None + getitem_351 = split_36[0] + getitem_352 = split_36[1] + getitem_353 = split_36[2] + getitem_354 = split_36[3] + getitem_355 = split_36[4] + getitem_356 = split_36[5] + getitem_357 = split_36[6] + getitem_358 = split_36[7]; split_36 = None + cat_28 = torch.ops.aten.cat.default([getitem_351, getitem_352, getitem_353, getitem_354, getitem_355, getitem_356, getitem_357, getitem_358]); getitem_351 = getitem_352 = getitem_353 = getitem_354 = getitem_355 = getitem_356 = getitem_357 = getitem_358 = None + reduce_scatter_tensor_14 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_28, 'sum', 8, '1'); cat_28 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_14); reduce_scatter_tensor_14 = None + add_27 = torch.ops.aten.add.Tensor(add_25, wait_tensor_92); add_25 = wait_tensor_92 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16) + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 8, '0'); convert_element_type_232 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32) + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = rsqrt_14 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_93); mul_56 = wait_tensor_93 = None + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_234, 8, '1'); convert_element_type_234 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + split_37 = torch.ops.aten.split.Tensor(wait_tensor_94, 2); wait_tensor_94 = None + getitem_359 = split_37[0] + getitem_360 = split_37[1] + getitem_361 = split_37[2] + getitem_362 = split_37[3] + getitem_363 = split_37[4] + getitem_364 = split_37[5] + getitem_365 = split_37[6] + getitem_366 = split_37[7]; split_37 = None + cat_29 = torch.ops.aten.cat.default([getitem_359, getitem_360, getitem_361, getitem_362, getitem_363, getitem_364, getitem_365, getitem_366], 1); getitem_359 = getitem_360 = getitem_361 = getitem_362 = getitem_363 = getitem_364 = getitem_365 = getitem_366 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16) + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 8, '0'); convert_element_type_235 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + view_519 = torch.ops.aten.view.default(cat_29, [16384, 4096]); cat_29 = None + mm_49 = torch.ops.aten.mm.default(view_519, permute_77); permute_77 = None + view_520 = torch.ops.aten.view.default(mm_49, [2, 8192, 512]) + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16) + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 8, '0'); convert_element_type_238 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_96, [1, 0]); wait_tensor_96 = None + mm_50 = torch.ops.aten.mm.default(view_519, permute_78); permute_78 = None + view_527 = torch.ops.aten.view.default(mm_50, [2, 8192, 128]); mm_50 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16) + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 8, '0'); convert_element_type_241 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + mm_51 = torch.ops.aten.mm.default(view_519, permute_79); view_519 = permute_79 = None + view_534 = torch.ops.aten.view.default(mm_51, [2, 8192, 128]) + view_536 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + view_537 = torch.ops.aten.view.default(view_527, [2, 8192, -1, 128]); view_527 = None + view_538 = torch.ops.aten.view.default(view_534, [2, 8192, -1, 128]); view_534 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_536, torch.float32); view_536 = None + view_539 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 4, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_539); view_539 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_537, torch.float32); view_537 = None + view_540 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 1, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_540); view_540 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_37); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_542 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 4, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_37); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_543 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 1, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_542, torch.bfloat16); view_542 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_543, torch.bfloat16); view_543 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 1, 4, 128]); unsqueeze_14 = None + view_544 = torch.ops.aten.view.default(expand_14, [2, 8192, 4, 128]); expand_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_538, 3); view_538 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 1, 4, 128]); unsqueeze_15 = None + view_545 = torch.ops.aten.view.default(expand_15, [2, 8192, 4, 128]); expand_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_544, [0, 2, 1, 3]); view_544 = None + permute_82 = torch.ops.aten.permute.default(view_545, [0, 2, 1, 3]); view_545 = None + _scaled_dot_product_cudnn_attention_7 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_80, permute_81, permute_82, None, True, 0.0, True); permute_80 = permute_81 = permute_82 = None + getitem_367 = _scaled_dot_product_cudnn_attention_7[0] + getitem_368 = _scaled_dot_product_cudnn_attention_7[1] + getitem_373 = _scaled_dot_product_cudnn_attention_7[6] + getitem_374 = _scaled_dot_product_cudnn_attention_7[7]; _scaled_dot_product_cudnn_attention_7 = None + permute_83 = torch.ops.aten.permute.default(getitem_367, [0, 2, 1, 3]) + view_546 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16) + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 8, '0'); convert_element_type_248 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + view_552 = torch.ops.aten.view.default(view_546, [16384, 512]); view_546 = None + mm_52 = torch.ops.aten.mm.default(view_552, permute_84); view_552 = permute_84 = None + view_553 = torch.ops.aten.view.default(mm_52, [2, 8192, 4096]); mm_52 = None + split_38 = torch.ops.aten.split.Tensor(view_553, 1024, 1); view_553 = None + getitem_376 = split_38[0] + getitem_377 = split_38[1] + getitem_378 = split_38[2] + getitem_379 = split_38[3] + getitem_380 = split_38[4] + getitem_381 = split_38[5] + getitem_382 = split_38[6] + getitem_383 = split_38[7]; split_38 = None + cat_30 = torch.ops.aten.cat.default([getitem_376, getitem_377, getitem_378, getitem_379, getitem_380, getitem_381, getitem_382, getitem_383]); getitem_376 = getitem_377 = getitem_378 = getitem_379 = getitem_380 = getitem_381 = getitem_382 = getitem_383 = None + reduce_scatter_tensor_15 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_30, 'sum', 8, '1'); cat_30 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15) + add_29 = torch.ops.aten.add.Tensor(add_27, wait_tensor_99); wait_tensor_99 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16) + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 8, '0'); convert_element_type_251 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32) + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = rsqrt_15 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_100); mul_60 = wait_tensor_100 = None + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_253, 8, '1'); convert_element_type_253 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + split_39 = torch.ops.aten.split.Tensor(wait_tensor_101, 2); wait_tensor_101 = None + getitem_384 = split_39[0] + getitem_385 = split_39[1] + getitem_386 = split_39[2] + getitem_387 = split_39[3] + getitem_388 = split_39[4] + getitem_389 = split_39[5] + getitem_390 = split_39[6] + getitem_391 = split_39[7]; split_39 = None + cat_31 = torch.ops.aten.cat.default([getitem_384, getitem_385, getitem_386, getitem_387, getitem_388, getitem_389, getitem_390, getitem_391], 1); getitem_384 = getitem_385 = getitem_386 = getitem_387 = getitem_388 = getitem_389 = getitem_390 = getitem_391 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16) + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 8, '0'); convert_element_type_254 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + view_564 = torch.ops.aten.view.default(cat_31, [16384, 4096]); cat_31 = None + mm_53 = torch.ops.aten.mm.default(view_564, permute_85); permute_85 = None + view_565 = torch.ops.aten.view.default(mm_53, [2, 8192, 1792]) + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_565, torch.float32); view_565 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); convert_element_type_257 = sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16) + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 8, '0'); convert_element_type_259 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + mm_54 = torch.ops.aten.mm.default(view_564, permute_86); view_564 = permute_86 = None + view_572 = torch.ops.aten.view.default(mm_54, [2, 8192, 1792]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_572); convert_element_type_258 = view_572 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16) + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 8, '0'); convert_element_type_262 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + view_579 = torch.ops.aten.view.default(mul_63, [16384, 1792]); mul_63 = None + mm_55 = torch.ops.aten.mm.default(view_579, permute_87); view_579 = permute_87 = None + view_580 = torch.ops.aten.view.default(mm_55, [2, 8192, 4096]); mm_55 = None + split_40 = torch.ops.aten.split.Tensor(view_580, 1024, 1); view_580 = None + getitem_392 = split_40[0] + getitem_393 = split_40[1] + getitem_394 = split_40[2] + getitem_395 = split_40[3] + getitem_396 = split_40[4] + getitem_397 = split_40[5] + getitem_398 = split_40[6] + getitem_399 = split_40[7]; split_40 = None + cat_32 = torch.ops.aten.cat.default([getitem_392, getitem_393, getitem_394, getitem_395, getitem_396, getitem_397, getitem_398, getitem_399]); getitem_392 = getitem_393 = getitem_394 = getitem_395 = getitem_396 = getitem_397 = getitem_398 = getitem_399 = None + reduce_scatter_tensor_16 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_32, 'sum', 8, '1'); cat_32 = None + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_16); reduce_scatter_tensor_16 = None + add_31 = torch.ops.aten.add.Tensor(add_29, wait_tensor_105); add_29 = wait_tensor_105 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16) + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 8, '0'); convert_element_type_265 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32) + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = rsqrt_16 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_106); mul_64 = wait_tensor_106 = None + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_267, 8, '1'); convert_element_type_267 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + split_41 = torch.ops.aten.split.Tensor(wait_tensor_107, 2); wait_tensor_107 = None + getitem_400 = split_41[0] + getitem_401 = split_41[1] + getitem_402 = split_41[2] + getitem_403 = split_41[3] + getitem_404 = split_41[4] + getitem_405 = split_41[5] + getitem_406 = split_41[6] + getitem_407 = split_41[7]; split_41 = None + cat_33 = torch.ops.aten.cat.default([getitem_400, getitem_401, getitem_402, getitem_403, getitem_404, getitem_405, getitem_406, getitem_407], 1); getitem_400 = getitem_401 = getitem_402 = getitem_403 = getitem_404 = getitem_405 = getitem_406 = getitem_407 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16) + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 8, '0'); convert_element_type_268 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + view_591 = torch.ops.aten.view.default(cat_33, [16384, 4096]); cat_33 = None + mm_56 = torch.ops.aten.mm.default(view_591, permute_88); permute_88 = None + view_592 = torch.ops.aten.view.default(mm_56, [2, 8192, 512]) + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16) + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 8, '0'); convert_element_type_271 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_109, [1, 0]); wait_tensor_109 = None + mm_57 = torch.ops.aten.mm.default(view_591, permute_89); permute_89 = None + view_599 = torch.ops.aten.view.default(mm_57, [2, 8192, 128]); mm_57 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16) + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 8, '0'); convert_element_type_274 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + mm_58 = torch.ops.aten.mm.default(view_591, permute_90); view_591 = permute_90 = None + view_606 = torch.ops.aten.view.default(mm_58, [2, 8192, 128]) + view_608 = torch.ops.aten.view.default(view_592, [2, 8192, -1, 128]); view_592 = None + view_609 = torch.ops.aten.view.default(view_599, [2, 8192, -1, 128]); view_599 = None + view_610 = torch.ops.aten.view.default(view_606, [2, 8192, -1, 128]); view_606 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_608, torch.float32); view_608 = None + view_611 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 4, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_611); view_611 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_609, torch.float32); view_609 = None + view_612 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 1, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_612); view_612 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_37); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_614 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 4, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_37); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_615 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 1, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_614, torch.bfloat16); view_614 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_615, torch.bfloat16); view_615 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 1, 4, 128]); unsqueeze_16 = None + view_616 = torch.ops.aten.view.default(expand_16, [2, 8192, 4, 128]); expand_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_610, 3); view_610 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 1, 4, 128]); unsqueeze_17 = None + view_617 = torch.ops.aten.view.default(expand_17, [2, 8192, 4, 128]); expand_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_616, [0, 2, 1, 3]); view_616 = None + permute_93 = torch.ops.aten.permute.default(view_617, [0, 2, 1, 3]); view_617 = None + _scaled_dot_product_cudnn_attention_8 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_91, permute_92, permute_93, None, True, 0.0, True); permute_91 = permute_92 = permute_93 = None + getitem_408 = _scaled_dot_product_cudnn_attention_8[0] + getitem_409 = _scaled_dot_product_cudnn_attention_8[1] + getitem_414 = _scaled_dot_product_cudnn_attention_8[6] + getitem_415 = _scaled_dot_product_cudnn_attention_8[7]; _scaled_dot_product_cudnn_attention_8 = None + permute_94 = torch.ops.aten.permute.default(getitem_408, [0, 2, 1, 3]) + view_618 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16) + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 8, '0'); convert_element_type_281 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + view_624 = torch.ops.aten.view.default(view_618, [16384, 512]); view_618 = None + mm_59 = torch.ops.aten.mm.default(view_624, permute_95); view_624 = permute_95 = None + view_625 = torch.ops.aten.view.default(mm_59, [2, 8192, 4096]); mm_59 = None + split_42 = torch.ops.aten.split.Tensor(view_625, 1024, 1); view_625 = None + getitem_417 = split_42[0] + getitem_418 = split_42[1] + getitem_419 = split_42[2] + getitem_420 = split_42[3] + getitem_421 = split_42[4] + getitem_422 = split_42[5] + getitem_423 = split_42[6] + getitem_424 = split_42[7]; split_42 = None + cat_34 = torch.ops.aten.cat.default([getitem_417, getitem_418, getitem_419, getitem_420, getitem_421, getitem_422, getitem_423, getitem_424]); getitem_417 = getitem_418 = getitem_419 = getitem_420 = getitem_421 = getitem_422 = getitem_423 = getitem_424 = None + reduce_scatter_tensor_17 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_34, 'sum', 8, '1'); cat_34 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17) + add_33 = torch.ops.aten.add.Tensor(add_31, wait_tensor_112); wait_tensor_112 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16) + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 8, '0'); convert_element_type_284 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32) + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = rsqrt_17 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_113); mul_68 = wait_tensor_113 = None + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_286, 8, '1'); convert_element_type_286 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + split_43 = torch.ops.aten.split.Tensor(wait_tensor_114, 2); wait_tensor_114 = None + getitem_425 = split_43[0] + getitem_426 = split_43[1] + getitem_427 = split_43[2] + getitem_428 = split_43[3] + getitem_429 = split_43[4] + getitem_430 = split_43[5] + getitem_431 = split_43[6] + getitem_432 = split_43[7]; split_43 = None + cat_35 = torch.ops.aten.cat.default([getitem_425, getitem_426, getitem_427, getitem_428, getitem_429, getitem_430, getitem_431, getitem_432], 1); getitem_425 = getitem_426 = getitem_427 = getitem_428 = getitem_429 = getitem_430 = getitem_431 = getitem_432 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16) + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 8, '0'); convert_element_type_287 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + view_636 = torch.ops.aten.view.default(cat_35, [16384, 4096]); cat_35 = None + mm_60 = torch.ops.aten.mm.default(view_636, permute_96); permute_96 = None + view_637 = torch.ops.aten.view.default(mm_60, [2, 8192, 1792]) + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_637, torch.float32); view_637 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); convert_element_type_290 = sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16) + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 8, '0'); convert_element_type_292 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_61 = torch.ops.aten.mm.default(view_636, permute_97); view_636 = permute_97 = None + view_644 = torch.ops.aten.view.default(mm_61, [2, 8192, 1792]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_644); convert_element_type_291 = view_644 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16) + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 8, '0'); convert_element_type_295 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + view_651 = torch.ops.aten.view.default(mul_71, [16384, 1792]); mul_71 = None + mm_62 = torch.ops.aten.mm.default(view_651, permute_98); view_651 = permute_98 = None + view_652 = torch.ops.aten.view.default(mm_62, [2, 8192, 4096]); mm_62 = None + split_44 = torch.ops.aten.split.Tensor(view_652, 1024, 1); view_652 = None + getitem_433 = split_44[0] + getitem_434 = split_44[1] + getitem_435 = split_44[2] + getitem_436 = split_44[3] + getitem_437 = split_44[4] + getitem_438 = split_44[5] + getitem_439 = split_44[6] + getitem_440 = split_44[7]; split_44 = None + cat_36 = torch.ops.aten.cat.default([getitem_433, getitem_434, getitem_435, getitem_436, getitem_437, getitem_438, getitem_439, getitem_440]); getitem_433 = getitem_434 = getitem_435 = getitem_436 = getitem_437 = getitem_438 = getitem_439 = getitem_440 = None + reduce_scatter_tensor_18 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_36, 'sum', 8, '1'); cat_36 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_18); reduce_scatter_tensor_18 = None + add_35 = torch.ops.aten.add.Tensor(add_33, wait_tensor_118); add_33 = wait_tensor_118 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16) + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 8, '0'); convert_element_type_298 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32) + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = rsqrt_18 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_119); mul_72 = wait_tensor_119 = None + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_300, 8, '1'); convert_element_type_300 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + split_45 = torch.ops.aten.split.Tensor(wait_tensor_120, 2); wait_tensor_120 = None + getitem_441 = split_45[0] + getitem_442 = split_45[1] + getitem_443 = split_45[2] + getitem_444 = split_45[3] + getitem_445 = split_45[4] + getitem_446 = split_45[5] + getitem_447 = split_45[6] + getitem_448 = split_45[7]; split_45 = None + cat_37 = torch.ops.aten.cat.default([getitem_441, getitem_442, getitem_443, getitem_444, getitem_445, getitem_446, getitem_447, getitem_448], 1); getitem_441 = getitem_442 = getitem_443 = getitem_444 = getitem_445 = getitem_446 = getitem_447 = getitem_448 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16) + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 8, '0'); convert_element_type_301 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + view_663 = torch.ops.aten.view.default(cat_37, [16384, 4096]); cat_37 = None + mm_63 = torch.ops.aten.mm.default(view_663, permute_99); permute_99 = None + view_664 = torch.ops.aten.view.default(mm_63, [2, 8192, 512]) + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16) + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 8, '0'); convert_element_type_304 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + mm_64 = torch.ops.aten.mm.default(view_663, permute_100); permute_100 = None + view_671 = torch.ops.aten.view.default(mm_64, [2, 8192, 128]); mm_64 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16) + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 8, '0'); convert_element_type_307 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_123, [1, 0]); wait_tensor_123 = None + mm_65 = torch.ops.aten.mm.default(view_663, permute_101); view_663 = permute_101 = None + view_678 = torch.ops.aten.view.default(mm_65, [2, 8192, 128]) + view_680 = torch.ops.aten.view.default(view_664, [2, 8192, -1, 128]); view_664 = None + view_681 = torch.ops.aten.view.default(view_671, [2, 8192, -1, 128]); view_671 = None + view_682 = torch.ops.aten.view.default(view_678, [2, 8192, -1, 128]); view_678 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_680, torch.float32); view_680 = None + view_683 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 4, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_683); view_683 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_681, torch.float32); view_681 = None + view_684 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 1, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_684); view_684 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_37); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_686 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 4, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_37); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_687 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 1, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_686, torch.bfloat16); view_686 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_687, torch.bfloat16); view_687 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 1, 4, 128]); unsqueeze_18 = None + view_688 = torch.ops.aten.view.default(expand_18, [2, 8192, 4, 128]); expand_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_682, 3); view_682 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 1, 4, 128]); unsqueeze_19 = None + view_689 = torch.ops.aten.view.default(expand_19, [2, 8192, 4, 128]); expand_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_688, [0, 2, 1, 3]); view_688 = None + permute_104 = torch.ops.aten.permute.default(view_689, [0, 2, 1, 3]); view_689 = None + _scaled_dot_product_cudnn_attention_9 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_102, permute_103, permute_104, None, True, 0.0, True); permute_102 = permute_103 = permute_104 = None + getitem_449 = _scaled_dot_product_cudnn_attention_9[0] + getitem_450 = _scaled_dot_product_cudnn_attention_9[1] + getitem_455 = _scaled_dot_product_cudnn_attention_9[6] + getitem_456 = _scaled_dot_product_cudnn_attention_9[7]; _scaled_dot_product_cudnn_attention_9 = None + permute_105 = torch.ops.aten.permute.default(getitem_449, [0, 2, 1, 3]) + view_690 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16) + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 8, '0'); convert_element_type_314 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + view_696 = torch.ops.aten.view.default(view_690, [16384, 512]); view_690 = None + mm_66 = torch.ops.aten.mm.default(view_696, permute_106); view_696 = permute_106 = None + view_697 = torch.ops.aten.view.default(mm_66, [2, 8192, 4096]); mm_66 = None + split_46 = torch.ops.aten.split.Tensor(view_697, 1024, 1); view_697 = None + getitem_458 = split_46[0] + getitem_459 = split_46[1] + getitem_460 = split_46[2] + getitem_461 = split_46[3] + getitem_462 = split_46[4] + getitem_463 = split_46[5] + getitem_464 = split_46[6] + getitem_465 = split_46[7]; split_46 = None + cat_38 = torch.ops.aten.cat.default([getitem_458, getitem_459, getitem_460, getitem_461, getitem_462, getitem_463, getitem_464, getitem_465]); getitem_458 = getitem_459 = getitem_460 = getitem_461 = getitem_462 = getitem_463 = getitem_464 = getitem_465 = None + reduce_scatter_tensor_19 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_38, 'sum', 8, '1'); cat_38 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19) + add_37 = torch.ops.aten.add.Tensor(add_35, wait_tensor_125); wait_tensor_125 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16) + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 8, '0'); convert_element_type_317 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32) + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = rsqrt_19 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_126); mul_76 = wait_tensor_126 = None + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_319, 8, '1'); convert_element_type_319 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + split_47 = torch.ops.aten.split.Tensor(wait_tensor_127, 2); wait_tensor_127 = None + getitem_466 = split_47[0] + getitem_467 = split_47[1] + getitem_468 = split_47[2] + getitem_469 = split_47[3] + getitem_470 = split_47[4] + getitem_471 = split_47[5] + getitem_472 = split_47[6] + getitem_473 = split_47[7]; split_47 = None + cat_39 = torch.ops.aten.cat.default([getitem_466, getitem_467, getitem_468, getitem_469, getitem_470, getitem_471, getitem_472, getitem_473], 1); getitem_466 = getitem_467 = getitem_468 = getitem_469 = getitem_470 = getitem_471 = getitem_472 = getitem_473 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16) + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 8, '0'); convert_element_type_320 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + view_708 = torch.ops.aten.view.default(cat_39, [16384, 4096]); cat_39 = None + mm_67 = torch.ops.aten.mm.default(view_708, permute_107); permute_107 = None + view_709 = torch.ops.aten.view.default(mm_67, [2, 8192, 1792]) + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_709, torch.float32); view_709 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); convert_element_type_323 = sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16) + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 8, '0'); convert_element_type_325 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_68 = torch.ops.aten.mm.default(view_708, permute_108); view_708 = permute_108 = None + view_716 = torch.ops.aten.view.default(mm_68, [2, 8192, 1792]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_716); convert_element_type_324 = view_716 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16) + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 8, '0'); convert_element_type_328 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + view_723 = torch.ops.aten.view.default(mul_79, [16384, 1792]); mul_79 = None + mm_69 = torch.ops.aten.mm.default(view_723, permute_109); view_723 = permute_109 = None + view_724 = torch.ops.aten.view.default(mm_69, [2, 8192, 4096]); mm_69 = None + split_48 = torch.ops.aten.split.Tensor(view_724, 1024, 1); view_724 = None + getitem_474 = split_48[0] + getitem_475 = split_48[1] + getitem_476 = split_48[2] + getitem_477 = split_48[3] + getitem_478 = split_48[4] + getitem_479 = split_48[5] + getitem_480 = split_48[6] + getitem_481 = split_48[7]; split_48 = None + cat_40 = torch.ops.aten.cat.default([getitem_474, getitem_475, getitem_476, getitem_477, getitem_478, getitem_479, getitem_480, getitem_481]); getitem_474 = getitem_475 = getitem_476 = getitem_477 = getitem_478 = getitem_479 = getitem_480 = getitem_481 = None + reduce_scatter_tensor_20 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_40, 'sum', 8, '1'); cat_40 = None + wait_tensor_131 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_20); reduce_scatter_tensor_20 = None + add_39 = torch.ops.aten.add.Tensor(add_37, wait_tensor_131); add_37 = wait_tensor_131 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16) + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 8, '0'); convert_element_type_331 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32) + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = rsqrt_20 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_132); mul_80 = wait_tensor_132 = None + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_333, 8, '1'); convert_element_type_333 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + split_49 = torch.ops.aten.split.Tensor(wait_tensor_133, 2); wait_tensor_133 = None + getitem_482 = split_49[0] + getitem_483 = split_49[1] + getitem_484 = split_49[2] + getitem_485 = split_49[3] + getitem_486 = split_49[4] + getitem_487 = split_49[5] + getitem_488 = split_49[6] + getitem_489 = split_49[7]; split_49 = None + cat_41 = torch.ops.aten.cat.default([getitem_482, getitem_483, getitem_484, getitem_485, getitem_486, getitem_487, getitem_488, getitem_489], 1); getitem_482 = getitem_483 = getitem_484 = getitem_485 = getitem_486 = getitem_487 = getitem_488 = getitem_489 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16) + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 8, '0'); convert_element_type_334 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + view_735 = torch.ops.aten.view.default(cat_41, [16384, 4096]); cat_41 = None + mm_70 = torch.ops.aten.mm.default(view_735, permute_110); permute_110 = None + view_736 = torch.ops.aten.view.default(mm_70, [2, 8192, 512]) + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16) + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 8, '0'); convert_element_type_337 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + mm_71 = torch.ops.aten.mm.default(view_735, permute_111); permute_111 = None + view_743 = torch.ops.aten.view.default(mm_71, [2, 8192, 128]); mm_71 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16) + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 8, '0'); convert_element_type_340 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_136, [1, 0]); wait_tensor_136 = None + mm_72 = torch.ops.aten.mm.default(view_735, permute_112); view_735 = permute_112 = None + view_750 = torch.ops.aten.view.default(mm_72, [2, 8192, 128]) + view_752 = torch.ops.aten.view.default(view_736, [2, 8192, -1, 128]); view_736 = None + view_753 = torch.ops.aten.view.default(view_743, [2, 8192, -1, 128]); view_743 = None + view_754 = torch.ops.aten.view.default(view_750, [2, 8192, -1, 128]); view_750 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_752, torch.float32); view_752 = None + view_755 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 4, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_755); view_755 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_753, torch.float32); view_753 = None + view_756 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 1, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_756); view_756 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_37); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_758 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 4, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_37); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_759 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 1, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_758, torch.bfloat16); view_758 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_759, torch.bfloat16); view_759 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 1, 4, 128]); unsqueeze_20 = None + view_760 = torch.ops.aten.view.default(expand_20, [2, 8192, 4, 128]); expand_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_754, 3); view_754 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 1, 4, 128]); unsqueeze_21 = None + view_761 = torch.ops.aten.view.default(expand_21, [2, 8192, 4, 128]); expand_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_760, [0, 2, 1, 3]); view_760 = None + permute_115 = torch.ops.aten.permute.default(view_761, [0, 2, 1, 3]); view_761 = None + _scaled_dot_product_cudnn_attention_10 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_113, permute_114, permute_115, None, True, 0.0, True); permute_113 = permute_114 = permute_115 = None + getitem_490 = _scaled_dot_product_cudnn_attention_10[0] + getitem_491 = _scaled_dot_product_cudnn_attention_10[1] + getitem_496 = _scaled_dot_product_cudnn_attention_10[6] + getitem_497 = _scaled_dot_product_cudnn_attention_10[7]; _scaled_dot_product_cudnn_attention_10 = None + permute_116 = torch.ops.aten.permute.default(getitem_490, [0, 2, 1, 3]) + view_762 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16) + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 8, '0'); convert_element_type_347 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + view_768 = torch.ops.aten.view.default(view_762, [16384, 512]); view_762 = None + mm_73 = torch.ops.aten.mm.default(view_768, permute_117); view_768 = permute_117 = None + view_769 = torch.ops.aten.view.default(mm_73, [2, 8192, 4096]); mm_73 = None + split_50 = torch.ops.aten.split.Tensor(view_769, 1024, 1); view_769 = None + getitem_499 = split_50[0] + getitem_500 = split_50[1] + getitem_501 = split_50[2] + getitem_502 = split_50[3] + getitem_503 = split_50[4] + getitem_504 = split_50[5] + getitem_505 = split_50[6] + getitem_506 = split_50[7]; split_50 = None + cat_42 = torch.ops.aten.cat.default([getitem_499, getitem_500, getitem_501, getitem_502, getitem_503, getitem_504, getitem_505, getitem_506]); getitem_499 = getitem_500 = getitem_501 = getitem_502 = getitem_503 = getitem_504 = getitem_505 = getitem_506 = None + reduce_scatter_tensor_21 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_42, 'sum', 8, '1'); cat_42 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21) + add_41 = torch.ops.aten.add.Tensor(add_39, wait_tensor_138); wait_tensor_138 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16) + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 8, '0'); convert_element_type_350 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32) + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = rsqrt_21 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_139); mul_84 = wait_tensor_139 = None + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_352, 8, '1'); convert_element_type_352 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + split_51 = torch.ops.aten.split.Tensor(wait_tensor_140, 2); wait_tensor_140 = None + getitem_507 = split_51[0] + getitem_508 = split_51[1] + getitem_509 = split_51[2] + getitem_510 = split_51[3] + getitem_511 = split_51[4] + getitem_512 = split_51[5] + getitem_513 = split_51[6] + getitem_514 = split_51[7]; split_51 = None + cat_43 = torch.ops.aten.cat.default([getitem_507, getitem_508, getitem_509, getitem_510, getitem_511, getitem_512, getitem_513, getitem_514], 1); getitem_507 = getitem_508 = getitem_509 = getitem_510 = getitem_511 = getitem_512 = getitem_513 = getitem_514 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16) + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 8, '0'); convert_element_type_353 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_141, [1, 0]); wait_tensor_141 = None + view_780 = torch.ops.aten.view.default(cat_43, [16384, 4096]); cat_43 = None + mm_74 = torch.ops.aten.mm.default(view_780, permute_118); permute_118 = None + view_781 = torch.ops.aten.view.default(mm_74, [2, 8192, 1792]) + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_781, torch.float32); view_781 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); convert_element_type_356 = sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16) + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 8, '0'); convert_element_type_358 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + mm_75 = torch.ops.aten.mm.default(view_780, permute_119); view_780 = permute_119 = None + view_788 = torch.ops.aten.view.default(mm_75, [2, 8192, 1792]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_788); convert_element_type_357 = view_788 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16) + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 8, '0'); convert_element_type_361 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + view_795 = torch.ops.aten.view.default(mul_87, [16384, 1792]); mul_87 = None + mm_76 = torch.ops.aten.mm.default(view_795, permute_120); view_795 = permute_120 = None + view_796 = torch.ops.aten.view.default(mm_76, [2, 8192, 4096]); mm_76 = None + split_52 = torch.ops.aten.split.Tensor(view_796, 1024, 1); view_796 = None + getitem_515 = split_52[0] + getitem_516 = split_52[1] + getitem_517 = split_52[2] + getitem_518 = split_52[3] + getitem_519 = split_52[4] + getitem_520 = split_52[5] + getitem_521 = split_52[6] + getitem_522 = split_52[7]; split_52 = None + cat_44 = torch.ops.aten.cat.default([getitem_515, getitem_516, getitem_517, getitem_518, getitem_519, getitem_520, getitem_521, getitem_522]); getitem_515 = getitem_516 = getitem_517 = getitem_518 = getitem_519 = getitem_520 = getitem_521 = getitem_522 = None + reduce_scatter_tensor_22 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_44, 'sum', 8, '1'); cat_44 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_22); reduce_scatter_tensor_22 = None + add_43 = torch.ops.aten.add.Tensor(add_41, wait_tensor_144); add_41 = wait_tensor_144 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16) + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 8, '0'); convert_element_type_364 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32) + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = rsqrt_22 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_145); mul_88 = wait_tensor_145 = None + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_366, 8, '1'); convert_element_type_366 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + split_53 = torch.ops.aten.split.Tensor(wait_tensor_146, 2); wait_tensor_146 = None + getitem_523 = split_53[0] + getitem_524 = split_53[1] + getitem_525 = split_53[2] + getitem_526 = split_53[3] + getitem_527 = split_53[4] + getitem_528 = split_53[5] + getitem_529 = split_53[6] + getitem_530 = split_53[7]; split_53 = None + cat_45 = torch.ops.aten.cat.default([getitem_523, getitem_524, getitem_525, getitem_526, getitem_527, getitem_528, getitem_529, getitem_530], 1); getitem_523 = getitem_524 = getitem_525 = getitem_526 = getitem_527 = getitem_528 = getitem_529 = getitem_530 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16) + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 8, '0'); convert_element_type_367 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + view_807 = torch.ops.aten.view.default(cat_45, [16384, 4096]); cat_45 = None + mm_77 = torch.ops.aten.mm.default(view_807, permute_121); permute_121 = None + view_808 = torch.ops.aten.view.default(mm_77, [2, 8192, 512]) + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16) + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 8, '0'); convert_element_type_370 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + mm_78 = torch.ops.aten.mm.default(view_807, permute_122); permute_122 = None + view_815 = torch.ops.aten.view.default(mm_78, [2, 8192, 128]); mm_78 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16) + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 8, '0'); convert_element_type_373 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + mm_79 = torch.ops.aten.mm.default(view_807, permute_123); view_807 = permute_123 = None + view_822 = torch.ops.aten.view.default(mm_79, [2, 8192, 128]) + view_824 = torch.ops.aten.view.default(view_808, [2, 8192, -1, 128]); view_808 = None + view_825 = torch.ops.aten.view.default(view_815, [2, 8192, -1, 128]); view_815 = None + view_826 = torch.ops.aten.view.default(view_822, [2, 8192, -1, 128]); view_822 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_824, torch.float32); view_824 = None + view_827 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 4, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_827); view_827 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_825, torch.float32); view_825 = None + view_828 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 1, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_828); view_828 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_37); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_830 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 4, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_37); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_831 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 1, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_830, torch.bfloat16); view_830 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_831, torch.bfloat16); view_831 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 1, 4, 128]); unsqueeze_22 = None + view_832 = torch.ops.aten.view.default(expand_22, [2, 8192, 4, 128]); expand_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_826, 3); view_826 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 1, 4, 128]); unsqueeze_23 = None + view_833 = torch.ops.aten.view.default(expand_23, [2, 8192, 4, 128]); expand_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_832, [0, 2, 1, 3]); view_832 = None + permute_126 = torch.ops.aten.permute.default(view_833, [0, 2, 1, 3]); view_833 = None + _scaled_dot_product_cudnn_attention_11 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_124, permute_125, permute_126, None, True, 0.0, True); permute_124 = permute_125 = permute_126 = None + getitem_531 = _scaled_dot_product_cudnn_attention_11[0] + getitem_532 = _scaled_dot_product_cudnn_attention_11[1] + getitem_537 = _scaled_dot_product_cudnn_attention_11[6] + getitem_538 = _scaled_dot_product_cudnn_attention_11[7]; _scaled_dot_product_cudnn_attention_11 = None + permute_127 = torch.ops.aten.permute.default(getitem_531, [0, 2, 1, 3]) + view_834 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16) + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 8, '0'); convert_element_type_380 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_150, [1, 0]); wait_tensor_150 = None + view_840 = torch.ops.aten.view.default(view_834, [16384, 512]); view_834 = None + mm_80 = torch.ops.aten.mm.default(view_840, permute_128); view_840 = permute_128 = None + view_841 = torch.ops.aten.view.default(mm_80, [2, 8192, 4096]); mm_80 = None + split_54 = torch.ops.aten.split.Tensor(view_841, 1024, 1); view_841 = None + getitem_540 = split_54[0] + getitem_541 = split_54[1] + getitem_542 = split_54[2] + getitem_543 = split_54[3] + getitem_544 = split_54[4] + getitem_545 = split_54[5] + getitem_546 = split_54[6] + getitem_547 = split_54[7]; split_54 = None + cat_46 = torch.ops.aten.cat.default([getitem_540, getitem_541, getitem_542, getitem_543, getitem_544, getitem_545, getitem_546, getitem_547]); getitem_540 = getitem_541 = getitem_542 = getitem_543 = getitem_544 = getitem_545 = getitem_546 = getitem_547 = None + reduce_scatter_tensor_23 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_46, 'sum', 8, '1'); cat_46 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23) + add_45 = torch.ops.aten.add.Tensor(add_43, wait_tensor_151); wait_tensor_151 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16) + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 8, '0'); convert_element_type_383 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32) + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = rsqrt_23 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_152); mul_92 = wait_tensor_152 = None + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_385, 8, '1'); convert_element_type_385 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + split_55 = torch.ops.aten.split.Tensor(wait_tensor_153, 2); wait_tensor_153 = None + getitem_548 = split_55[0] + getitem_549 = split_55[1] + getitem_550 = split_55[2] + getitem_551 = split_55[3] + getitem_552 = split_55[4] + getitem_553 = split_55[5] + getitem_554 = split_55[6] + getitem_555 = split_55[7]; split_55 = None + cat_47 = torch.ops.aten.cat.default([getitem_548, getitem_549, getitem_550, getitem_551, getitem_552, getitem_553, getitem_554, getitem_555], 1); getitem_548 = getitem_549 = getitem_550 = getitem_551 = getitem_552 = getitem_553 = getitem_554 = getitem_555 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16) + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 8, '0'); convert_element_type_386 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_154, [1, 0]); wait_tensor_154 = None + view_852 = torch.ops.aten.view.default(cat_47, [16384, 4096]); cat_47 = None + mm_81 = torch.ops.aten.mm.default(view_852, permute_129); permute_129 = None + view_853 = torch.ops.aten.view.default(mm_81, [2, 8192, 1792]) + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_853, torch.float32); view_853 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); convert_element_type_389 = sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16) + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 8, '0'); convert_element_type_391 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + mm_82 = torch.ops.aten.mm.default(view_852, permute_130); view_852 = permute_130 = None + view_860 = torch.ops.aten.view.default(mm_82, [2, 8192, 1792]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_860); convert_element_type_390 = view_860 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16) + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 8, '0'); convert_element_type_394 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + view_867 = torch.ops.aten.view.default(mul_95, [16384, 1792]); mul_95 = None + mm_83 = torch.ops.aten.mm.default(view_867, permute_131); view_867 = permute_131 = None + view_868 = torch.ops.aten.view.default(mm_83, [2, 8192, 4096]); mm_83 = None + split_56 = torch.ops.aten.split.Tensor(view_868, 1024, 1); view_868 = None + getitem_556 = split_56[0] + getitem_557 = split_56[1] + getitem_558 = split_56[2] + getitem_559 = split_56[3] + getitem_560 = split_56[4] + getitem_561 = split_56[5] + getitem_562 = split_56[6] + getitem_563 = split_56[7]; split_56 = None + cat_48 = torch.ops.aten.cat.default([getitem_556, getitem_557, getitem_558, getitem_559, getitem_560, getitem_561, getitem_562, getitem_563]); getitem_556 = getitem_557 = getitem_558 = getitem_559 = getitem_560 = getitem_561 = getitem_562 = getitem_563 = None + reduce_scatter_tensor_24 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_48, 'sum', 8, '1'); cat_48 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_24); reduce_scatter_tensor_24 = None + add_47 = torch.ops.aten.add.Tensor(add_45, wait_tensor_157); add_45 = wait_tensor_157 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16) + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 8, '0'); convert_element_type_397 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32) + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = rsqrt_24 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_158); mul_96 = wait_tensor_158 = None + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_399, 8, '1'); convert_element_type_399 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + split_57 = torch.ops.aten.split.Tensor(wait_tensor_159, 2); wait_tensor_159 = None + getitem_564 = split_57[0] + getitem_565 = split_57[1] + getitem_566 = split_57[2] + getitem_567 = split_57[3] + getitem_568 = split_57[4] + getitem_569 = split_57[5] + getitem_570 = split_57[6] + getitem_571 = split_57[7]; split_57 = None + cat_49 = torch.ops.aten.cat.default([getitem_564, getitem_565, getitem_566, getitem_567, getitem_568, getitem_569, getitem_570, getitem_571], 1); getitem_564 = getitem_565 = getitem_566 = getitem_567 = getitem_568 = getitem_569 = getitem_570 = getitem_571 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16) + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 8, '0'); convert_element_type_400 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + view_879 = torch.ops.aten.view.default(cat_49, [16384, 4096]); cat_49 = None + mm_84 = torch.ops.aten.mm.default(view_879, permute_132); permute_132 = None + view_880 = torch.ops.aten.view.default(mm_84, [2, 8192, 512]) + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16) + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 8, '0'); convert_element_type_403 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_85 = torch.ops.aten.mm.default(view_879, permute_133); permute_133 = None + view_887 = torch.ops.aten.view.default(mm_85, [2, 8192, 128]); mm_85 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16) + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 8, '0'); convert_element_type_406 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + mm_86 = torch.ops.aten.mm.default(view_879, permute_134); view_879 = permute_134 = None + view_894 = torch.ops.aten.view.default(mm_86, [2, 8192, 128]) + view_896 = torch.ops.aten.view.default(view_880, [2, 8192, -1, 128]); view_880 = None + view_897 = torch.ops.aten.view.default(view_887, [2, 8192, -1, 128]); view_887 = None + view_898 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 4, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_897, torch.float32); view_897 = None + view_900 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 1, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_900); view_900 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_37); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_902 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 4, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_37); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_903 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 1, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_903, torch.bfloat16); view_903 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 1, 4, 128]); unsqueeze_24 = None + view_904 = torch.ops.aten.view.default(expand_24, [2, 8192, 4, 128]); expand_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_898, 3); view_898 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 1, 4, 128]); unsqueeze_25 = None + view_905 = torch.ops.aten.view.default(expand_25, [2, 8192, 4, 128]); expand_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + permute_137 = torch.ops.aten.permute.default(view_905, [0, 2, 1, 3]); view_905 = None + _scaled_dot_product_cudnn_attention_12 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_135, permute_136, permute_137, None, True, 0.0, True); permute_135 = permute_136 = permute_137 = None + getitem_572 = _scaled_dot_product_cudnn_attention_12[0] + getitem_573 = _scaled_dot_product_cudnn_attention_12[1] + getitem_578 = _scaled_dot_product_cudnn_attention_12[6] + getitem_579 = _scaled_dot_product_cudnn_attention_12[7]; _scaled_dot_product_cudnn_attention_12 = None + permute_138 = torch.ops.aten.permute.default(getitem_572, [0, 2, 1, 3]) + view_906 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16) + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 8, '0'); convert_element_type_413 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_163, [1, 0]); wait_tensor_163 = None + view_912 = torch.ops.aten.view.default(view_906, [16384, 512]); view_906 = None + mm_87 = torch.ops.aten.mm.default(view_912, permute_139); view_912 = permute_139 = None + view_913 = torch.ops.aten.view.default(mm_87, [2, 8192, 4096]); mm_87 = None + split_58 = torch.ops.aten.split.Tensor(view_913, 1024, 1); view_913 = None + getitem_581 = split_58[0] + getitem_582 = split_58[1] + getitem_583 = split_58[2] + getitem_584 = split_58[3] + getitem_585 = split_58[4] + getitem_586 = split_58[5] + getitem_587 = split_58[6] + getitem_588 = split_58[7]; split_58 = None + cat_50 = torch.ops.aten.cat.default([getitem_581, getitem_582, getitem_583, getitem_584, getitem_585, getitem_586, getitem_587, getitem_588]); getitem_581 = getitem_582 = getitem_583 = getitem_584 = getitem_585 = getitem_586 = getitem_587 = getitem_588 = None + reduce_scatter_tensor_25 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_50, 'sum', 8, '1'); cat_50 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25) + add_49 = torch.ops.aten.add.Tensor(add_47, wait_tensor_164); wait_tensor_164 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16) + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 8, '0'); convert_element_type_416 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32) + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = rsqrt_25 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_165); mul_100 = wait_tensor_165 = None + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_418, 8, '1'); convert_element_type_418 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + split_59 = torch.ops.aten.split.Tensor(wait_tensor_166, 2); wait_tensor_166 = None + getitem_589 = split_59[0] + getitem_590 = split_59[1] + getitem_591 = split_59[2] + getitem_592 = split_59[3] + getitem_593 = split_59[4] + getitem_594 = split_59[5] + getitem_595 = split_59[6] + getitem_596 = split_59[7]; split_59 = None + cat_51 = torch.ops.aten.cat.default([getitem_589, getitem_590, getitem_591, getitem_592, getitem_593, getitem_594, getitem_595, getitem_596], 1); getitem_589 = getitem_590 = getitem_591 = getitem_592 = getitem_593 = getitem_594 = getitem_595 = getitem_596 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16) + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 8, '0'); convert_element_type_419 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + view_924 = torch.ops.aten.view.default(cat_51, [16384, 4096]); cat_51 = None + mm_88 = torch.ops.aten.mm.default(view_924, permute_140); permute_140 = None + view_925 = torch.ops.aten.view.default(mm_88, [2, 8192, 1792]) + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_925, torch.float32); view_925 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); convert_element_type_422 = sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16) + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 8, '0'); convert_element_type_424 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_168, [1, 0]); wait_tensor_168 = None + mm_89 = torch.ops.aten.mm.default(view_924, permute_141); view_924 = permute_141 = None + view_932 = torch.ops.aten.view.default(mm_89, [2, 8192, 1792]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_932); convert_element_type_423 = view_932 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16) + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 8, '0'); convert_element_type_427 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + view_939 = torch.ops.aten.view.default(mul_103, [16384, 1792]); mul_103 = None + mm_90 = torch.ops.aten.mm.default(view_939, permute_142); view_939 = permute_142 = None + view_940 = torch.ops.aten.view.default(mm_90, [2, 8192, 4096]); mm_90 = None + split_60 = torch.ops.aten.split.Tensor(view_940, 1024, 1); view_940 = None + getitem_597 = split_60[0] + getitem_598 = split_60[1] + getitem_599 = split_60[2] + getitem_600 = split_60[3] + getitem_601 = split_60[4] + getitem_602 = split_60[5] + getitem_603 = split_60[6] + getitem_604 = split_60[7]; split_60 = None + cat_52 = torch.ops.aten.cat.default([getitem_597, getitem_598, getitem_599, getitem_600, getitem_601, getitem_602, getitem_603, getitem_604]); getitem_597 = getitem_598 = getitem_599 = getitem_600 = getitem_601 = getitem_602 = getitem_603 = getitem_604 = None + reduce_scatter_tensor_26 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_52, 'sum', 8, '1'); cat_52 = None + wait_tensor_170 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_26); reduce_scatter_tensor_26 = None + add_51 = torch.ops.aten.add.Tensor(add_49, wait_tensor_170); add_49 = wait_tensor_170 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16) + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 8, '0'); convert_element_type_430 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32) + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = rsqrt_26 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_171); mul_104 = wait_tensor_171 = None + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_432, 8, '1'); convert_element_type_432 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + split_61 = torch.ops.aten.split.Tensor(wait_tensor_172, 2); wait_tensor_172 = None + getitem_605 = split_61[0] + getitem_606 = split_61[1] + getitem_607 = split_61[2] + getitem_608 = split_61[3] + getitem_609 = split_61[4] + getitem_610 = split_61[5] + getitem_611 = split_61[6] + getitem_612 = split_61[7]; split_61 = None + cat_53 = torch.ops.aten.cat.default([getitem_605, getitem_606, getitem_607, getitem_608, getitem_609, getitem_610, getitem_611, getitem_612], 1); getitem_605 = getitem_606 = getitem_607 = getitem_608 = getitem_609 = getitem_610 = getitem_611 = getitem_612 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16) + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 8, '0'); convert_element_type_433 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + view_951 = torch.ops.aten.view.default(cat_53, [16384, 4096]); cat_53 = None + mm_91 = torch.ops.aten.mm.default(view_951, permute_143); permute_143 = None + view_952 = torch.ops.aten.view.default(mm_91, [2, 8192, 512]) + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16) + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 8, '0'); convert_element_type_436 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_92 = torch.ops.aten.mm.default(view_951, permute_144); permute_144 = None + view_959 = torch.ops.aten.view.default(mm_92, [2, 8192, 128]); mm_92 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16) + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 8, '0'); convert_element_type_439 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + mm_93 = torch.ops.aten.mm.default(view_951, permute_145); view_951 = permute_145 = None + view_966 = torch.ops.aten.view.default(mm_93, [2, 8192, 128]) + view_968 = torch.ops.aten.view.default(view_952, [2, 8192, -1, 128]); view_952 = None + view_969 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_970 = torch.ops.aten.view.default(view_966, [2, 8192, -1, 128]); view_966 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_968, torch.float32); view_968 = None + view_971 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 4, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_971); view_971 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_969, torch.float32); view_969 = None + view_972 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 1, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_972); view_972 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_37); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_974 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 4, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_37); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_975 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 1, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_974, torch.bfloat16); view_974 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_975, torch.bfloat16); view_975 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 1, 4, 128]); unsqueeze_26 = None + view_976 = torch.ops.aten.view.default(expand_26, [2, 8192, 4, 128]); expand_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_970, 3); view_970 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 1, 4, 128]); unsqueeze_27 = None + view_977 = torch.ops.aten.view.default(expand_27, [2, 8192, 4, 128]); expand_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_976, [0, 2, 1, 3]); view_976 = None + permute_148 = torch.ops.aten.permute.default(view_977, [0, 2, 1, 3]); view_977 = None + _scaled_dot_product_cudnn_attention_13 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_146, permute_147, permute_148, None, True, 0.0, True); permute_146 = permute_147 = permute_148 = None + getitem_613 = _scaled_dot_product_cudnn_attention_13[0] + getitem_614 = _scaled_dot_product_cudnn_attention_13[1] + getitem_619 = _scaled_dot_product_cudnn_attention_13[6] + getitem_620 = _scaled_dot_product_cudnn_attention_13[7]; _scaled_dot_product_cudnn_attention_13 = None + permute_149 = torch.ops.aten.permute.default(getitem_613, [0, 2, 1, 3]) + view_978 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16) + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 8, '0'); convert_element_type_446 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + view_984 = torch.ops.aten.view.default(view_978, [16384, 512]); view_978 = None + mm_94 = torch.ops.aten.mm.default(view_984, permute_150); view_984 = permute_150 = None + view_985 = torch.ops.aten.view.default(mm_94, [2, 8192, 4096]); mm_94 = None + split_62 = torch.ops.aten.split.Tensor(view_985, 1024, 1); view_985 = None + getitem_622 = split_62[0] + getitem_623 = split_62[1] + getitem_624 = split_62[2] + getitem_625 = split_62[3] + getitem_626 = split_62[4] + getitem_627 = split_62[5] + getitem_628 = split_62[6] + getitem_629 = split_62[7]; split_62 = None + cat_54 = torch.ops.aten.cat.default([getitem_622, getitem_623, getitem_624, getitem_625, getitem_626, getitem_627, getitem_628, getitem_629]); getitem_622 = getitem_623 = getitem_624 = getitem_625 = getitem_626 = getitem_627 = getitem_628 = getitem_629 = None + reduce_scatter_tensor_27 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_54, 'sum', 8, '1'); cat_54 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27) + add_53 = torch.ops.aten.add.Tensor(add_51, wait_tensor_177); wait_tensor_177 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16) + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 8, '0'); convert_element_type_449 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32) + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = rsqrt_27 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_178); mul_108 = wait_tensor_178 = None + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_451, 8, '1'); convert_element_type_451 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + split_63 = torch.ops.aten.split.Tensor(wait_tensor_179, 2); wait_tensor_179 = None + getitem_630 = split_63[0] + getitem_631 = split_63[1] + getitem_632 = split_63[2] + getitem_633 = split_63[3] + getitem_634 = split_63[4] + getitem_635 = split_63[5] + getitem_636 = split_63[6] + getitem_637 = split_63[7]; split_63 = None + cat_55 = torch.ops.aten.cat.default([getitem_630, getitem_631, getitem_632, getitem_633, getitem_634, getitem_635, getitem_636, getitem_637], 1); getitem_630 = getitem_631 = getitem_632 = getitem_633 = getitem_634 = getitem_635 = getitem_636 = getitem_637 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16) + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 8, '0'); convert_element_type_452 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + view_996 = torch.ops.aten.view.default(cat_55, [16384, 4096]); cat_55 = None + mm_95 = torch.ops.aten.mm.default(view_996, permute_151); permute_151 = None + view_997 = torch.ops.aten.view.default(mm_95, [2, 8192, 1792]) + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); convert_element_type_455 = sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16) + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 8, '0'); convert_element_type_457 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_181, [1, 0]); wait_tensor_181 = None + mm_96 = torch.ops.aten.mm.default(view_996, permute_152); view_996 = permute_152 = None + view_1004 = torch.ops.aten.view.default(mm_96, [2, 8192, 1792]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_1004); convert_element_type_456 = view_1004 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16) + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 8, '0'); convert_element_type_460 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + view_1011 = torch.ops.aten.view.default(mul_111, [16384, 1792]); mul_111 = None + mm_97 = torch.ops.aten.mm.default(view_1011, permute_153); view_1011 = permute_153 = None + view_1012 = torch.ops.aten.view.default(mm_97, [2, 8192, 4096]); mm_97 = None + split_64 = torch.ops.aten.split.Tensor(view_1012, 1024, 1); view_1012 = None + getitem_638 = split_64[0] + getitem_639 = split_64[1] + getitem_640 = split_64[2] + getitem_641 = split_64[3] + getitem_642 = split_64[4] + getitem_643 = split_64[5] + getitem_644 = split_64[6] + getitem_645 = split_64[7]; split_64 = None + cat_56 = torch.ops.aten.cat.default([getitem_638, getitem_639, getitem_640, getitem_641, getitem_642, getitem_643, getitem_644, getitem_645]); getitem_638 = getitem_639 = getitem_640 = getitem_641 = getitem_642 = getitem_643 = getitem_644 = getitem_645 = None + reduce_scatter_tensor_28 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_56, 'sum', 8, '1'); cat_56 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_28); reduce_scatter_tensor_28 = None + add_55 = torch.ops.aten.add.Tensor(add_53, wait_tensor_183); add_53 = wait_tensor_183 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16) + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 8, '0'); convert_element_type_463 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32) + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = rsqrt_28 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_184); mul_112 = wait_tensor_184 = None + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_465, 8, '1'); convert_element_type_465 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + split_65 = torch.ops.aten.split.Tensor(wait_tensor_185, 2); wait_tensor_185 = None + getitem_646 = split_65[0] + getitem_647 = split_65[1] + getitem_648 = split_65[2] + getitem_649 = split_65[3] + getitem_650 = split_65[4] + getitem_651 = split_65[5] + getitem_652 = split_65[6] + getitem_653 = split_65[7]; split_65 = None + cat_57 = torch.ops.aten.cat.default([getitem_646, getitem_647, getitem_648, getitem_649, getitem_650, getitem_651, getitem_652, getitem_653], 1); getitem_646 = getitem_647 = getitem_648 = getitem_649 = getitem_650 = getitem_651 = getitem_652 = getitem_653 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16) + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 8, '0'); convert_element_type_466 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_186, [1, 0]); wait_tensor_186 = None + view_1023 = torch.ops.aten.view.default(cat_57, [16384, 4096]); cat_57 = None + mm_98 = torch.ops.aten.mm.default(view_1023, permute_154); permute_154 = None + view_1024 = torch.ops.aten.view.default(mm_98, [2, 8192, 512]) + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16) + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 8, '0'); convert_element_type_469 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + mm_99 = torch.ops.aten.mm.default(view_1023, permute_155); permute_155 = None + view_1031 = torch.ops.aten.view.default(mm_99, [2, 8192, 128]); mm_99 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16) + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 8, '0'); convert_element_type_472 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + mm_100 = torch.ops.aten.mm.default(view_1023, permute_156); view_1023 = permute_156 = None + view_1038 = torch.ops.aten.view.default(mm_100, [2, 8192, 128]) + view_1040 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1041 = torch.ops.aten.view.default(view_1031, [2, 8192, -1, 128]); view_1031 = None + view_1042 = torch.ops.aten.view.default(view_1038, [2, 8192, -1, 128]); view_1038 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_1040, torch.float32); view_1040 = None + view_1043 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 4, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_1043); view_1043 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_1041, torch.float32); view_1041 = None + view_1044 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 1, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_1044); view_1044 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_37); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_1046 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 4, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_37); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_1047 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 1, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_1046, torch.bfloat16); view_1046 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_1047, torch.bfloat16); view_1047 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 1, 4, 128]); unsqueeze_28 = None + view_1048 = torch.ops.aten.view.default(expand_28, [2, 8192, 4, 128]); expand_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_1042, 3); view_1042 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 1, 4, 128]); unsqueeze_29 = None + view_1049 = torch.ops.aten.view.default(expand_29, [2, 8192, 4, 128]); expand_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_1048, [0, 2, 1, 3]); view_1048 = None + permute_159 = torch.ops.aten.permute.default(view_1049, [0, 2, 1, 3]); view_1049 = None + _scaled_dot_product_cudnn_attention_14 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_157, permute_158, permute_159, None, True, 0.0, True); permute_157 = permute_158 = permute_159 = None + getitem_654 = _scaled_dot_product_cudnn_attention_14[0] + getitem_655 = _scaled_dot_product_cudnn_attention_14[1] + getitem_660 = _scaled_dot_product_cudnn_attention_14[6] + getitem_661 = _scaled_dot_product_cudnn_attention_14[7]; _scaled_dot_product_cudnn_attention_14 = None + permute_160 = torch.ops.aten.permute.default(getitem_654, [0, 2, 1, 3]) + view_1050 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16) + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 8, '0'); convert_element_type_479 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + view_1056 = torch.ops.aten.view.default(view_1050, [16384, 512]); view_1050 = None + mm_101 = torch.ops.aten.mm.default(view_1056, permute_161); view_1056 = permute_161 = None + view_1057 = torch.ops.aten.view.default(mm_101, [2, 8192, 4096]); mm_101 = None + split_66 = torch.ops.aten.split.Tensor(view_1057, 1024, 1); view_1057 = None + getitem_663 = split_66[0] + getitem_664 = split_66[1] + getitem_665 = split_66[2] + getitem_666 = split_66[3] + getitem_667 = split_66[4] + getitem_668 = split_66[5] + getitem_669 = split_66[6] + getitem_670 = split_66[7]; split_66 = None + cat_58 = torch.ops.aten.cat.default([getitem_663, getitem_664, getitem_665, getitem_666, getitem_667, getitem_668, getitem_669, getitem_670]); getitem_663 = getitem_664 = getitem_665 = getitem_666 = getitem_667 = getitem_668 = getitem_669 = getitem_670 = None + reduce_scatter_tensor_29 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_58, 'sum', 8, '1'); cat_58 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29) + add_57 = torch.ops.aten.add.Tensor(add_55, wait_tensor_190); wait_tensor_190 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16) + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 8, '0'); convert_element_type_482 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32) + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = rsqrt_29 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_191); mul_116 = wait_tensor_191 = None + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_484, 8, '1'); convert_element_type_484 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + split_67 = torch.ops.aten.split.Tensor(wait_tensor_192, 2); wait_tensor_192 = None + getitem_671 = split_67[0] + getitem_672 = split_67[1] + getitem_673 = split_67[2] + getitem_674 = split_67[3] + getitem_675 = split_67[4] + getitem_676 = split_67[5] + getitem_677 = split_67[6] + getitem_678 = split_67[7]; split_67 = None + cat_59 = torch.ops.aten.cat.default([getitem_671, getitem_672, getitem_673, getitem_674, getitem_675, getitem_676, getitem_677, getitem_678], 1); getitem_671 = getitem_672 = getitem_673 = getitem_674 = getitem_675 = getitem_676 = getitem_677 = getitem_678 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16) + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 8, '0'); convert_element_type_485 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + view_1068 = torch.ops.aten.view.default(cat_59, [16384, 4096]); cat_59 = None + mm_102 = torch.ops.aten.mm.default(view_1068, permute_162); permute_162 = None + view_1069 = torch.ops.aten.view.default(mm_102, [2, 8192, 1792]) + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_1069, torch.float32); view_1069 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); convert_element_type_488 = sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16) + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 8, '0'); convert_element_type_490 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + mm_103 = torch.ops.aten.mm.default(view_1068, permute_163); view_1068 = permute_163 = None + view_1076 = torch.ops.aten.view.default(mm_103, [2, 8192, 1792]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_1076); convert_element_type_489 = view_1076 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16) + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 8, '0'); convert_element_type_493 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_195, [1, 0]); wait_tensor_195 = None + view_1083 = torch.ops.aten.view.default(mul_119, [16384, 1792]); mul_119 = None + mm_104 = torch.ops.aten.mm.default(view_1083, permute_164); view_1083 = permute_164 = None + view_1084 = torch.ops.aten.view.default(mm_104, [2, 8192, 4096]); mm_104 = None + split_68 = torch.ops.aten.split.Tensor(view_1084, 1024, 1); view_1084 = None + getitem_679 = split_68[0] + getitem_680 = split_68[1] + getitem_681 = split_68[2] + getitem_682 = split_68[3] + getitem_683 = split_68[4] + getitem_684 = split_68[5] + getitem_685 = split_68[6] + getitem_686 = split_68[7]; split_68 = None + cat_60 = torch.ops.aten.cat.default([getitem_679, getitem_680, getitem_681, getitem_682, getitem_683, getitem_684, getitem_685, getitem_686]); getitem_679 = getitem_680 = getitem_681 = getitem_682 = getitem_683 = getitem_684 = getitem_685 = getitem_686 = None + reduce_scatter_tensor_30 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_60, 'sum', 8, '1'); cat_60 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_30); reduce_scatter_tensor_30 = None + add_59 = torch.ops.aten.add.Tensor(add_57, wait_tensor_196); add_57 = wait_tensor_196 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16) + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 8, '0'); convert_element_type_496 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32) + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = rsqrt_30 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_197); mul_120 = wait_tensor_197 = None + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_498, 8, '1'); convert_element_type_498 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + split_69 = torch.ops.aten.split.Tensor(wait_tensor_198, 2); wait_tensor_198 = None + getitem_687 = split_69[0] + getitem_688 = split_69[1] + getitem_689 = split_69[2] + getitem_690 = split_69[3] + getitem_691 = split_69[4] + getitem_692 = split_69[5] + getitem_693 = split_69[6] + getitem_694 = split_69[7]; split_69 = None + cat_61 = torch.ops.aten.cat.default([getitem_687, getitem_688, getitem_689, getitem_690, getitem_691, getitem_692, getitem_693, getitem_694], 1); getitem_687 = getitem_688 = getitem_689 = getitem_690 = getitem_691 = getitem_692 = getitem_693 = getitem_694 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16) + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 8, '0'); convert_element_type_499 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_199, [1, 0]); wait_tensor_199 = None + view_1095 = torch.ops.aten.view.default(cat_61, [16384, 4096]); cat_61 = None + mm_105 = torch.ops.aten.mm.default(view_1095, permute_165); permute_165 = None + view_1096 = torch.ops.aten.view.default(mm_105, [2, 8192, 512]) + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16) + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 8, '0'); convert_element_type_502 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + mm_106 = torch.ops.aten.mm.default(view_1095, permute_166); permute_166 = None + view_1103 = torch.ops.aten.view.default(mm_106, [2, 8192, 128]); mm_106 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16) + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 8, '0'); convert_element_type_505 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + mm_107 = torch.ops.aten.mm.default(view_1095, permute_167); view_1095 = permute_167 = None + view_1110 = torch.ops.aten.view.default(mm_107, [2, 8192, 128]) + view_1112 = torch.ops.aten.view.default(view_1096, [2, 8192, -1, 128]); view_1096 = None + view_1113 = torch.ops.aten.view.default(view_1103, [2, 8192, -1, 128]); view_1103 = None + view_1114 = torch.ops.aten.view.default(view_1110, [2, 8192, -1, 128]); view_1110 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_1112, torch.float32); view_1112 = None + view_1115 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 4, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_1115); view_1115 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_1113, torch.float32); view_1113 = None + view_1116 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 1, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_1116); view_1116 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_37); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_1118 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 4, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_37); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_1119 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 1, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_1118, torch.bfloat16); view_1118 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_1119, torch.bfloat16); view_1119 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 1, 4, 128]); unsqueeze_30 = None + view_1120 = torch.ops.aten.view.default(expand_30, [2, 8192, 4, 128]); expand_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_1114, 3); view_1114 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 1, 4, 128]); unsqueeze_31 = None + view_1121 = torch.ops.aten.view.default(expand_31, [2, 8192, 4, 128]); expand_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_1120, [0, 2, 1, 3]); view_1120 = None + permute_170 = torch.ops.aten.permute.default(view_1121, [0, 2, 1, 3]); view_1121 = None + _scaled_dot_product_cudnn_attention_15 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_168, permute_169, permute_170, None, True, 0.0, True); permute_168 = permute_169 = permute_170 = None + getitem_695 = _scaled_dot_product_cudnn_attention_15[0] + getitem_696 = _scaled_dot_product_cudnn_attention_15[1] + getitem_701 = _scaled_dot_product_cudnn_attention_15[6] + getitem_702 = _scaled_dot_product_cudnn_attention_15[7]; _scaled_dot_product_cudnn_attention_15 = None + permute_171 = torch.ops.aten.permute.default(getitem_695, [0, 2, 1, 3]) + view_1122 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16) + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 8, '0'); convert_element_type_512 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + view_1128 = torch.ops.aten.view.default(view_1122, [16384, 512]); view_1122 = None + mm_108 = torch.ops.aten.mm.default(view_1128, permute_172); view_1128 = permute_172 = None + view_1129 = torch.ops.aten.view.default(mm_108, [2, 8192, 4096]); mm_108 = None + split_70 = torch.ops.aten.split.Tensor(view_1129, 1024, 1); view_1129 = None + getitem_704 = split_70[0] + getitem_705 = split_70[1] + getitem_706 = split_70[2] + getitem_707 = split_70[3] + getitem_708 = split_70[4] + getitem_709 = split_70[5] + getitem_710 = split_70[6] + getitem_711 = split_70[7]; split_70 = None + cat_62 = torch.ops.aten.cat.default([getitem_704, getitem_705, getitem_706, getitem_707, getitem_708, getitem_709, getitem_710, getitem_711]); getitem_704 = getitem_705 = getitem_706 = getitem_707 = getitem_708 = getitem_709 = getitem_710 = getitem_711 = None + reduce_scatter_tensor_31 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_62, 'sum', 8, '1'); cat_62 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31) + add_61 = torch.ops.aten.add.Tensor(add_59, wait_tensor_203); wait_tensor_203 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16) + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 8, '0'); convert_element_type_515 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32) + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = rsqrt_31 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_204); mul_124 = wait_tensor_204 = None + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_517, 8, '1'); convert_element_type_517 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + split_71 = torch.ops.aten.split.Tensor(wait_tensor_205, 2); wait_tensor_205 = None + getitem_712 = split_71[0] + getitem_713 = split_71[1] + getitem_714 = split_71[2] + getitem_715 = split_71[3] + getitem_716 = split_71[4] + getitem_717 = split_71[5] + getitem_718 = split_71[6] + getitem_719 = split_71[7]; split_71 = None + cat_63 = torch.ops.aten.cat.default([getitem_712, getitem_713, getitem_714, getitem_715, getitem_716, getitem_717, getitem_718, getitem_719], 1); getitem_712 = getitem_713 = getitem_714 = getitem_715 = getitem_716 = getitem_717 = getitem_718 = getitem_719 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16) + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 8, '0'); convert_element_type_518 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + view_1140 = torch.ops.aten.view.default(cat_63, [16384, 4096]); cat_63 = None + mm_109 = torch.ops.aten.mm.default(view_1140, permute_173); permute_173 = None + view_1141 = torch.ops.aten.view.default(mm_109, [2, 8192, 1792]) + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_1141, torch.float32); view_1141 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); convert_element_type_521 = sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16) + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 8, '0'); convert_element_type_523 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + mm_110 = torch.ops.aten.mm.default(view_1140, permute_174); view_1140 = permute_174 = None + view_1148 = torch.ops.aten.view.default(mm_110, [2, 8192, 1792]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_1148); convert_element_type_522 = view_1148 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16) + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 8, '0'); convert_element_type_526 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_208, [1, 0]); wait_tensor_208 = None + view_1155 = torch.ops.aten.view.default(mul_127, [16384, 1792]); mul_127 = None + mm_111 = torch.ops.aten.mm.default(view_1155, permute_175); view_1155 = permute_175 = None + view_1156 = torch.ops.aten.view.default(mm_111, [2, 8192, 4096]); mm_111 = None + split_72 = torch.ops.aten.split.Tensor(view_1156, 1024, 1); view_1156 = None + getitem_720 = split_72[0] + getitem_721 = split_72[1] + getitem_722 = split_72[2] + getitem_723 = split_72[3] + getitem_724 = split_72[4] + getitem_725 = split_72[5] + getitem_726 = split_72[6] + getitem_727 = split_72[7]; split_72 = None + cat_64 = torch.ops.aten.cat.default([getitem_720, getitem_721, getitem_722, getitem_723, getitem_724, getitem_725, getitem_726, getitem_727]); getitem_720 = getitem_721 = getitem_722 = getitem_723 = getitem_724 = getitem_725 = getitem_726 = getitem_727 = None + reduce_scatter_tensor_32 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_64, 'sum', 8, '1'); cat_64 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_32); reduce_scatter_tensor_32 = None + add_63 = torch.ops.aten.add.Tensor(add_61, wait_tensor_209); add_61 = wait_tensor_209 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16) + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 8, '0'); convert_element_type_529 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32) + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = rsqrt_32 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_210); mul_128 = wait_tensor_210 = None + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_531, 8, '1'); convert_element_type_531 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + split_73 = torch.ops.aten.split.Tensor(wait_tensor_211, 2); wait_tensor_211 = None + getitem_728 = split_73[0] + getitem_729 = split_73[1] + getitem_730 = split_73[2] + getitem_731 = split_73[3] + getitem_732 = split_73[4] + getitem_733 = split_73[5] + getitem_734 = split_73[6] + getitem_735 = split_73[7]; split_73 = None + cat_65 = torch.ops.aten.cat.default([getitem_728, getitem_729, getitem_730, getitem_731, getitem_732, getitem_733, getitem_734, getitem_735], 1); getitem_728 = getitem_729 = getitem_730 = getitem_731 = getitem_732 = getitem_733 = getitem_734 = getitem_735 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16) + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 8, '0'); convert_element_type_532 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + view_1167 = torch.ops.aten.view.default(cat_65, [16384, 4096]); cat_65 = None + mm_112 = torch.ops.aten.mm.default(view_1167, permute_176); permute_176 = None + view_1168 = torch.ops.aten.view.default(mm_112, [2, 8192, 512]) + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16) + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 8, '0'); convert_element_type_535 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_213, [1, 0]); wait_tensor_213 = None + mm_113 = torch.ops.aten.mm.default(view_1167, permute_177); permute_177 = None + view_1175 = torch.ops.aten.view.default(mm_113, [2, 8192, 128]); mm_113 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16) + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 8, '0'); convert_element_type_538 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + mm_114 = torch.ops.aten.mm.default(view_1167, permute_178); view_1167 = permute_178 = None + view_1182 = torch.ops.aten.view.default(mm_114, [2, 8192, 128]) + view_1184 = torch.ops.aten.view.default(view_1168, [2, 8192, -1, 128]); view_1168 = None + view_1185 = torch.ops.aten.view.default(view_1175, [2, 8192, -1, 128]); view_1175 = None + view_1186 = torch.ops.aten.view.default(view_1182, [2, 8192, -1, 128]); view_1182 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_1184, torch.float32); view_1184 = None + view_1187 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 4, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_1187); view_1187 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_1185, torch.float32); view_1185 = None + view_1188 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 1, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_1188); view_1188 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_37); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_1190 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 4, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_37); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_1191 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 1, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_1190, torch.bfloat16); view_1190 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_1191, torch.bfloat16); view_1191 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 1, 4, 128]); unsqueeze_32 = None + view_1192 = torch.ops.aten.view.default(expand_32, [2, 8192, 4, 128]); expand_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_1186, 3); view_1186 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 1, 4, 128]); unsqueeze_33 = None + view_1193 = torch.ops.aten.view.default(expand_33, [2, 8192, 4, 128]); expand_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_1192, [0, 2, 1, 3]); view_1192 = None + permute_181 = torch.ops.aten.permute.default(view_1193, [0, 2, 1, 3]); view_1193 = None + _scaled_dot_product_cudnn_attention_16 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_179, permute_180, permute_181, None, True, 0.0, True); permute_179 = permute_180 = permute_181 = None + getitem_736 = _scaled_dot_product_cudnn_attention_16[0] + getitem_737 = _scaled_dot_product_cudnn_attention_16[1] + getitem_742 = _scaled_dot_product_cudnn_attention_16[6] + getitem_743 = _scaled_dot_product_cudnn_attention_16[7]; _scaled_dot_product_cudnn_attention_16 = None + permute_182 = torch.ops.aten.permute.default(getitem_736, [0, 2, 1, 3]) + view_1194 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16) + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 8, '0'); convert_element_type_545 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + view_1200 = torch.ops.aten.view.default(view_1194, [16384, 512]); view_1194 = None + mm_115 = torch.ops.aten.mm.default(view_1200, permute_183); view_1200 = permute_183 = None + view_1201 = torch.ops.aten.view.default(mm_115, [2, 8192, 4096]); mm_115 = None + split_74 = torch.ops.aten.split.Tensor(view_1201, 1024, 1); view_1201 = None + getitem_745 = split_74[0] + getitem_746 = split_74[1] + getitem_747 = split_74[2] + getitem_748 = split_74[3] + getitem_749 = split_74[4] + getitem_750 = split_74[5] + getitem_751 = split_74[6] + getitem_752 = split_74[7]; split_74 = None + cat_66 = torch.ops.aten.cat.default([getitem_745, getitem_746, getitem_747, getitem_748, getitem_749, getitem_750, getitem_751, getitem_752]); getitem_745 = getitem_746 = getitem_747 = getitem_748 = getitem_749 = getitem_750 = getitem_751 = getitem_752 = None + reduce_scatter_tensor_33 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_66, 'sum', 8, '1'); cat_66 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33) + add_65 = torch.ops.aten.add.Tensor(add_63, wait_tensor_216); wait_tensor_216 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16) + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 8, '0'); convert_element_type_548 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32) + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = rsqrt_33 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_217); mul_132 = wait_tensor_217 = None + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_550, 8, '1'); convert_element_type_550 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + split_75 = torch.ops.aten.split.Tensor(wait_tensor_218, 2); wait_tensor_218 = None + getitem_753 = split_75[0] + getitem_754 = split_75[1] + getitem_755 = split_75[2] + getitem_756 = split_75[3] + getitem_757 = split_75[4] + getitem_758 = split_75[5] + getitem_759 = split_75[6] + getitem_760 = split_75[7]; split_75 = None + cat_67 = torch.ops.aten.cat.default([getitem_753, getitem_754, getitem_755, getitem_756, getitem_757, getitem_758, getitem_759, getitem_760], 1); getitem_753 = getitem_754 = getitem_755 = getitem_756 = getitem_757 = getitem_758 = getitem_759 = getitem_760 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16) + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 8, '0'); convert_element_type_551 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + view_1212 = torch.ops.aten.view.default(cat_67, [16384, 4096]); cat_67 = None + mm_116 = torch.ops.aten.mm.default(view_1212, permute_184); permute_184 = None + view_1213 = torch.ops.aten.view.default(mm_116, [2, 8192, 1792]) + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_1213, torch.float32); view_1213 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); convert_element_type_554 = sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16) + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 8, '0'); convert_element_type_556 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + mm_117 = torch.ops.aten.mm.default(view_1212, permute_185); view_1212 = permute_185 = None + view_1220 = torch.ops.aten.view.default(mm_117, [2, 8192, 1792]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_1220); convert_element_type_555 = view_1220 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16) + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 8, '0'); convert_element_type_559 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + view_1227 = torch.ops.aten.view.default(mul_135, [16384, 1792]); mul_135 = None + mm_118 = torch.ops.aten.mm.default(view_1227, permute_186); view_1227 = permute_186 = None + view_1228 = torch.ops.aten.view.default(mm_118, [2, 8192, 4096]); mm_118 = None + split_76 = torch.ops.aten.split.Tensor(view_1228, 1024, 1); view_1228 = None + getitem_761 = split_76[0] + getitem_762 = split_76[1] + getitem_763 = split_76[2] + getitem_764 = split_76[3] + getitem_765 = split_76[4] + getitem_766 = split_76[5] + getitem_767 = split_76[6] + getitem_768 = split_76[7]; split_76 = None + cat_68 = torch.ops.aten.cat.default([getitem_761, getitem_762, getitem_763, getitem_764, getitem_765, getitem_766, getitem_767, getitem_768]); getitem_761 = getitem_762 = getitem_763 = getitem_764 = getitem_765 = getitem_766 = getitem_767 = getitem_768 = None + reduce_scatter_tensor_34 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_68, 'sum', 8, '1'); cat_68 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_34); reduce_scatter_tensor_34 = None + add_67 = torch.ops.aten.add.Tensor(add_65, wait_tensor_222); add_65 = wait_tensor_222 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16) + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 8, '0'); convert_element_type_562 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32) + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = rsqrt_34 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_223); mul_136 = wait_tensor_223 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_564, 8, '1'); convert_element_type_564 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + split_77 = torch.ops.aten.split.Tensor(wait_tensor_224, 2); wait_tensor_224 = None + getitem_769 = split_77[0] + getitem_770 = split_77[1] + getitem_771 = split_77[2] + getitem_772 = split_77[3] + getitem_773 = split_77[4] + getitem_774 = split_77[5] + getitem_775 = split_77[6] + getitem_776 = split_77[7]; split_77 = None + cat_69 = torch.ops.aten.cat.default([getitem_769, getitem_770, getitem_771, getitem_772, getitem_773, getitem_774, getitem_775, getitem_776], 1); getitem_769 = getitem_770 = getitem_771 = getitem_772 = getitem_773 = getitem_774 = getitem_775 = getitem_776 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16) + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 8, '0'); convert_element_type_565 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + view_1239 = torch.ops.aten.view.default(cat_69, [16384, 4096]); cat_69 = None + mm_119 = torch.ops.aten.mm.default(view_1239, permute_187); permute_187 = None + view_1240 = torch.ops.aten.view.default(mm_119, [2, 8192, 512]) + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16) + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 8, '0'); convert_element_type_568 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_226, [1, 0]); wait_tensor_226 = None + mm_120 = torch.ops.aten.mm.default(view_1239, permute_188); permute_188 = None + view_1247 = torch.ops.aten.view.default(mm_120, [2, 8192, 128]); mm_120 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16) + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 8, '0'); convert_element_type_571 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + mm_121 = torch.ops.aten.mm.default(view_1239, permute_189); view_1239 = permute_189 = None + view_1254 = torch.ops.aten.view.default(mm_121, [2, 8192, 128]) + view_1256 = torch.ops.aten.view.default(view_1240, [2, 8192, -1, 128]); view_1240 = None + view_1257 = torch.ops.aten.view.default(view_1247, [2, 8192, -1, 128]); view_1247 = None + view_1258 = torch.ops.aten.view.default(view_1254, [2, 8192, -1, 128]); view_1254 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_1256, torch.float32); view_1256 = None + view_1259 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 4, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_1259); view_1259 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_1257, torch.float32); view_1257 = None + view_1260 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 1, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_1260); view_1260 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_37); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_1262 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 4, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_37); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_1263 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 1, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_1262, torch.bfloat16); view_1262 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_1263, torch.bfloat16); view_1263 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 1, 4, 128]); unsqueeze_34 = None + view_1264 = torch.ops.aten.view.default(expand_34, [2, 8192, 4, 128]); expand_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_1258, 3); view_1258 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 1, 4, 128]); unsqueeze_35 = None + view_1265 = torch.ops.aten.view.default(expand_35, [2, 8192, 4, 128]); expand_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_1264, [0, 2, 1, 3]); view_1264 = None + permute_192 = torch.ops.aten.permute.default(view_1265, [0, 2, 1, 3]); view_1265 = None + _scaled_dot_product_cudnn_attention_17 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_190, permute_191, permute_192, None, True, 0.0, True); permute_190 = permute_191 = permute_192 = None + getitem_777 = _scaled_dot_product_cudnn_attention_17[0] + getitem_778 = _scaled_dot_product_cudnn_attention_17[1] + getitem_783 = _scaled_dot_product_cudnn_attention_17[6] + getitem_784 = _scaled_dot_product_cudnn_attention_17[7]; _scaled_dot_product_cudnn_attention_17 = None + permute_193 = torch.ops.aten.permute.default(getitem_777, [0, 2, 1, 3]) + view_1266 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16) + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 8, '0'); convert_element_type_578 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + view_1272 = torch.ops.aten.view.default(view_1266, [16384, 512]); view_1266 = None + mm_122 = torch.ops.aten.mm.default(view_1272, permute_194); view_1272 = permute_194 = None + view_1273 = torch.ops.aten.view.default(mm_122, [2, 8192, 4096]); mm_122 = None + split_78 = torch.ops.aten.split.Tensor(view_1273, 1024, 1); view_1273 = None + getitem_786 = split_78[0] + getitem_787 = split_78[1] + getitem_788 = split_78[2] + getitem_789 = split_78[3] + getitem_790 = split_78[4] + getitem_791 = split_78[5] + getitem_792 = split_78[6] + getitem_793 = split_78[7]; split_78 = None + cat_70 = torch.ops.aten.cat.default([getitem_786, getitem_787, getitem_788, getitem_789, getitem_790, getitem_791, getitem_792, getitem_793]); getitem_786 = getitem_787 = getitem_788 = getitem_789 = getitem_790 = getitem_791 = getitem_792 = getitem_793 = None + reduce_scatter_tensor_35 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_70, 'sum', 8, '1'); cat_70 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35) + add_69 = torch.ops.aten.add.Tensor(add_67, wait_tensor_229); wait_tensor_229 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16) + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 8, '0'); convert_element_type_581 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32) + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = rsqrt_35 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_230); mul_140 = wait_tensor_230 = None + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_583, 8, '1'); convert_element_type_583 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + split_79 = torch.ops.aten.split.Tensor(wait_tensor_231, 2); wait_tensor_231 = None + getitem_794 = split_79[0] + getitem_795 = split_79[1] + getitem_796 = split_79[2] + getitem_797 = split_79[3] + getitem_798 = split_79[4] + getitem_799 = split_79[5] + getitem_800 = split_79[6] + getitem_801 = split_79[7]; split_79 = None + cat_71 = torch.ops.aten.cat.default([getitem_794, getitem_795, getitem_796, getitem_797, getitem_798, getitem_799, getitem_800, getitem_801], 1); getitem_794 = getitem_795 = getitem_796 = getitem_797 = getitem_798 = getitem_799 = getitem_800 = getitem_801 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16) + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 8, '0'); convert_element_type_584 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + view_1284 = torch.ops.aten.view.default(cat_71, [16384, 4096]); cat_71 = None + mm_123 = torch.ops.aten.mm.default(view_1284, permute_195); permute_195 = None + view_1285 = torch.ops.aten.view.default(mm_123, [2, 8192, 1792]) + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_1285, torch.float32); view_1285 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); convert_element_type_587 = sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16) + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 8, '0'); convert_element_type_589 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_124 = torch.ops.aten.mm.default(view_1284, permute_196); view_1284 = permute_196 = None + view_1292 = torch.ops.aten.view.default(mm_124, [2, 8192, 1792]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_1292); convert_element_type_588 = view_1292 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16) + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 8, '0'); convert_element_type_592 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + view_1299 = torch.ops.aten.view.default(mul_143, [16384, 1792]); mul_143 = None + mm_125 = torch.ops.aten.mm.default(view_1299, permute_197); view_1299 = permute_197 = None + view_1300 = torch.ops.aten.view.default(mm_125, [2, 8192, 4096]); mm_125 = None + split_80 = torch.ops.aten.split.Tensor(view_1300, 1024, 1); view_1300 = None + getitem_802 = split_80[0] + getitem_803 = split_80[1] + getitem_804 = split_80[2] + getitem_805 = split_80[3] + getitem_806 = split_80[4] + getitem_807 = split_80[5] + getitem_808 = split_80[6] + getitem_809 = split_80[7]; split_80 = None + cat_72 = torch.ops.aten.cat.default([getitem_802, getitem_803, getitem_804, getitem_805, getitem_806, getitem_807, getitem_808, getitem_809]); getitem_802 = getitem_803 = getitem_804 = getitem_805 = getitem_806 = getitem_807 = getitem_808 = getitem_809 = None + reduce_scatter_tensor_36 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_72, 'sum', 8, '1'); cat_72 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_36); reduce_scatter_tensor_36 = None + add_71 = torch.ops.aten.add.Tensor(add_69, wait_tensor_235); add_69 = wait_tensor_235 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16) + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 8, '0'); convert_element_type_595 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32) + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = rsqrt_36 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_236); mul_144 = wait_tensor_236 = None + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_597, 8, '1'); convert_element_type_597 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + split_81 = torch.ops.aten.split.Tensor(wait_tensor_237, 2); wait_tensor_237 = None + getitem_810 = split_81[0] + getitem_811 = split_81[1] + getitem_812 = split_81[2] + getitem_813 = split_81[3] + getitem_814 = split_81[4] + getitem_815 = split_81[5] + getitem_816 = split_81[6] + getitem_817 = split_81[7]; split_81 = None + cat_73 = torch.ops.aten.cat.default([getitem_810, getitem_811, getitem_812, getitem_813, getitem_814, getitem_815, getitem_816, getitem_817], 1); getitem_810 = getitem_811 = getitem_812 = getitem_813 = getitem_814 = getitem_815 = getitem_816 = getitem_817 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16) + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 8, '0'); convert_element_type_598 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + view_1311 = torch.ops.aten.view.default(cat_73, [16384, 4096]); cat_73 = None + mm_126 = torch.ops.aten.mm.default(view_1311, permute_198); permute_198 = None + view_1312 = torch.ops.aten.view.default(mm_126, [2, 8192, 512]) + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16) + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 8, '0'); convert_element_type_601 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + mm_127 = torch.ops.aten.mm.default(view_1311, permute_199); permute_199 = None + view_1319 = torch.ops.aten.view.default(mm_127, [2, 8192, 128]); mm_127 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16) + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 8, '0'); convert_element_type_604 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_240, [1, 0]); wait_tensor_240 = None + mm_128 = torch.ops.aten.mm.default(view_1311, permute_200); view_1311 = permute_200 = None + view_1326 = torch.ops.aten.view.default(mm_128, [2, 8192, 128]) + view_1328 = torch.ops.aten.view.default(view_1312, [2, 8192, -1, 128]); view_1312 = None + view_1329 = torch.ops.aten.view.default(view_1319, [2, 8192, -1, 128]); view_1319 = None + view_1330 = torch.ops.aten.view.default(view_1326, [2, 8192, -1, 128]); view_1326 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_1328, torch.float32); view_1328 = None + view_1331 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 4, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_1331); view_1331 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_1329, torch.float32); view_1329 = None + view_1332 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 1, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_1332); view_1332 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_37); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_1334 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 4, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_37); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_1335 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 1, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_1334, torch.bfloat16); view_1334 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_1335, torch.bfloat16); view_1335 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 1, 4, 128]); unsqueeze_36 = None + view_1336 = torch.ops.aten.view.default(expand_36, [2, 8192, 4, 128]); expand_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_1330, 3); view_1330 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 1, 4, 128]); unsqueeze_37 = None + view_1337 = torch.ops.aten.view.default(expand_37, [2, 8192, 4, 128]); expand_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_1336, [0, 2, 1, 3]); view_1336 = None + permute_203 = torch.ops.aten.permute.default(view_1337, [0, 2, 1, 3]); view_1337 = None + _scaled_dot_product_cudnn_attention_18 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_201, permute_202, permute_203, None, True, 0.0, True); permute_201 = permute_202 = permute_203 = None + getitem_818 = _scaled_dot_product_cudnn_attention_18[0] + getitem_819 = _scaled_dot_product_cudnn_attention_18[1] + getitem_824 = _scaled_dot_product_cudnn_attention_18[6] + getitem_825 = _scaled_dot_product_cudnn_attention_18[7]; _scaled_dot_product_cudnn_attention_18 = None + permute_204 = torch.ops.aten.permute.default(getitem_818, [0, 2, 1, 3]) + view_1338 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16) + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 8, '0'); convert_element_type_611 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + view_1344 = torch.ops.aten.view.default(view_1338, [16384, 512]); view_1338 = None + mm_129 = torch.ops.aten.mm.default(view_1344, permute_205); view_1344 = permute_205 = None + view_1345 = torch.ops.aten.view.default(mm_129, [2, 8192, 4096]); mm_129 = None + split_82 = torch.ops.aten.split.Tensor(view_1345, 1024, 1); view_1345 = None + getitem_827 = split_82[0] + getitem_828 = split_82[1] + getitem_829 = split_82[2] + getitem_830 = split_82[3] + getitem_831 = split_82[4] + getitem_832 = split_82[5] + getitem_833 = split_82[6] + getitem_834 = split_82[7]; split_82 = None + cat_74 = torch.ops.aten.cat.default([getitem_827, getitem_828, getitem_829, getitem_830, getitem_831, getitem_832, getitem_833, getitem_834]); getitem_827 = getitem_828 = getitem_829 = getitem_830 = getitem_831 = getitem_832 = getitem_833 = getitem_834 = None + reduce_scatter_tensor_37 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_74, 'sum', 8, '1'); cat_74 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37) + add_73 = torch.ops.aten.add.Tensor(add_71, wait_tensor_242); wait_tensor_242 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16) + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 8, '0'); convert_element_type_614 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32) + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = rsqrt_37 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_243); mul_148 = wait_tensor_243 = None + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_616, 8, '1'); convert_element_type_616 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + split_83 = torch.ops.aten.split.Tensor(wait_tensor_244, 2); wait_tensor_244 = None + getitem_835 = split_83[0] + getitem_836 = split_83[1] + getitem_837 = split_83[2] + getitem_838 = split_83[3] + getitem_839 = split_83[4] + getitem_840 = split_83[5] + getitem_841 = split_83[6] + getitem_842 = split_83[7]; split_83 = None + cat_75 = torch.ops.aten.cat.default([getitem_835, getitem_836, getitem_837, getitem_838, getitem_839, getitem_840, getitem_841, getitem_842], 1); getitem_835 = getitem_836 = getitem_837 = getitem_838 = getitem_839 = getitem_840 = getitem_841 = getitem_842 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16) + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 8, '0'); convert_element_type_617 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + view_1356 = torch.ops.aten.view.default(cat_75, [16384, 4096]); cat_75 = None + mm_130 = torch.ops.aten.mm.default(view_1356, permute_206); permute_206 = None + view_1357 = torch.ops.aten.view.default(mm_130, [2, 8192, 1792]) + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_1357, torch.float32); view_1357 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); convert_element_type_620 = sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16) + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 8, '0'); convert_element_type_622 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_131 = torch.ops.aten.mm.default(view_1356, permute_207); view_1356 = permute_207 = None + view_1364 = torch.ops.aten.view.default(mm_131, [2, 8192, 1792]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_1364); convert_element_type_621 = view_1364 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16) + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 8, '0'); convert_element_type_625 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + view_1371 = torch.ops.aten.view.default(mul_151, [16384, 1792]); mul_151 = None + mm_132 = torch.ops.aten.mm.default(view_1371, permute_208); view_1371 = permute_208 = None + view_1372 = torch.ops.aten.view.default(mm_132, [2, 8192, 4096]); mm_132 = None + split_84 = torch.ops.aten.split.Tensor(view_1372, 1024, 1); view_1372 = None + getitem_843 = split_84[0] + getitem_844 = split_84[1] + getitem_845 = split_84[2] + getitem_846 = split_84[3] + getitem_847 = split_84[4] + getitem_848 = split_84[5] + getitem_849 = split_84[6] + getitem_850 = split_84[7]; split_84 = None + cat_76 = torch.ops.aten.cat.default([getitem_843, getitem_844, getitem_845, getitem_846, getitem_847, getitem_848, getitem_849, getitem_850]); getitem_843 = getitem_844 = getitem_845 = getitem_846 = getitem_847 = getitem_848 = getitem_849 = getitem_850 = None + reduce_scatter_tensor_38 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_76, 'sum', 8, '1'); cat_76 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_38); reduce_scatter_tensor_38 = None + add_75 = torch.ops.aten.add.Tensor(add_73, wait_tensor_248); add_73 = wait_tensor_248 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16) + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 8, '0'); convert_element_type_628 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32) + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = rsqrt_38 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_249); mul_152 = wait_tensor_249 = None + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_630, 8, '1'); convert_element_type_630 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + split_85 = torch.ops.aten.split.Tensor(wait_tensor_250, 2); wait_tensor_250 = None + getitem_851 = split_85[0] + getitem_852 = split_85[1] + getitem_853 = split_85[2] + getitem_854 = split_85[3] + getitem_855 = split_85[4] + getitem_856 = split_85[5] + getitem_857 = split_85[6] + getitem_858 = split_85[7]; split_85 = None + cat_77 = torch.ops.aten.cat.default([getitem_851, getitem_852, getitem_853, getitem_854, getitem_855, getitem_856, getitem_857, getitem_858], 1); getitem_851 = getitem_852 = getitem_853 = getitem_854 = getitem_855 = getitem_856 = getitem_857 = getitem_858 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16) + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 8, '0'); convert_element_type_631 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + view_1383 = torch.ops.aten.view.default(cat_77, [16384, 4096]); cat_77 = None + mm_133 = torch.ops.aten.mm.default(view_1383, permute_209); permute_209 = None + view_1384 = torch.ops.aten.view.default(mm_133, [2, 8192, 512]) + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16) + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 8, '0'); convert_element_type_634 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + mm_134 = torch.ops.aten.mm.default(view_1383, permute_210); permute_210 = None + view_1391 = torch.ops.aten.view.default(mm_134, [2, 8192, 128]); mm_134 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16) + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 8, '0'); convert_element_type_637 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_253, [1, 0]); wait_tensor_253 = None + mm_135 = torch.ops.aten.mm.default(view_1383, permute_211); view_1383 = permute_211 = None + view_1398 = torch.ops.aten.view.default(mm_135, [2, 8192, 128]) + view_1400 = torch.ops.aten.view.default(view_1384, [2, 8192, -1, 128]); view_1384 = None + view_1401 = torch.ops.aten.view.default(view_1391, [2, 8192, -1, 128]); view_1391 = None + view_1402 = torch.ops.aten.view.default(view_1398, [2, 8192, -1, 128]); view_1398 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_1400, torch.float32); view_1400 = None + view_1403 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 4, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_1403); view_1403 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_1401, torch.float32); view_1401 = None + view_1404 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 1, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_1404); view_1404 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_37); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_1406 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 4, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_37); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_1407 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 1, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_1406, torch.bfloat16); view_1406 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_1407, torch.bfloat16); view_1407 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 1, 4, 128]); unsqueeze_38 = None + view_1408 = torch.ops.aten.view.default(expand_38, [2, 8192, 4, 128]); expand_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_1402, 3); view_1402 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 1, 4, 128]); unsqueeze_39 = None + view_1409 = torch.ops.aten.view.default(expand_39, [2, 8192, 4, 128]); expand_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_1408, [0, 2, 1, 3]); view_1408 = None + permute_214 = torch.ops.aten.permute.default(view_1409, [0, 2, 1, 3]); view_1409 = None + _scaled_dot_product_cudnn_attention_19 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_212, permute_213, permute_214, None, True, 0.0, True); permute_212 = permute_213 = permute_214 = None + getitem_859 = _scaled_dot_product_cudnn_attention_19[0] + getitem_860 = _scaled_dot_product_cudnn_attention_19[1] + getitem_865 = _scaled_dot_product_cudnn_attention_19[6] + getitem_866 = _scaled_dot_product_cudnn_attention_19[7]; _scaled_dot_product_cudnn_attention_19 = None + permute_215 = torch.ops.aten.permute.default(getitem_859, [0, 2, 1, 3]) + view_1410 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16) + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 8, '0'); convert_element_type_644 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + view_1416 = torch.ops.aten.view.default(view_1410, [16384, 512]); view_1410 = None + mm_136 = torch.ops.aten.mm.default(view_1416, permute_216); view_1416 = permute_216 = None + view_1417 = torch.ops.aten.view.default(mm_136, [2, 8192, 4096]); mm_136 = None + split_86 = torch.ops.aten.split.Tensor(view_1417, 1024, 1); view_1417 = None + getitem_868 = split_86[0] + getitem_869 = split_86[1] + getitem_870 = split_86[2] + getitem_871 = split_86[3] + getitem_872 = split_86[4] + getitem_873 = split_86[5] + getitem_874 = split_86[6] + getitem_875 = split_86[7]; split_86 = None + cat_78 = torch.ops.aten.cat.default([getitem_868, getitem_869, getitem_870, getitem_871, getitem_872, getitem_873, getitem_874, getitem_875]); getitem_868 = getitem_869 = getitem_870 = getitem_871 = getitem_872 = getitem_873 = getitem_874 = getitem_875 = None + reduce_scatter_tensor_39 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_78, 'sum', 8, '1'); cat_78 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39) + add_77 = torch.ops.aten.add.Tensor(add_75, wait_tensor_255); wait_tensor_255 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16) + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 8, '0'); convert_element_type_647 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32) + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = rsqrt_39 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_256); mul_156 = wait_tensor_256 = None + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_649, 8, '1'); convert_element_type_649 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + split_87 = torch.ops.aten.split.Tensor(wait_tensor_257, 2); wait_tensor_257 = None + getitem_876 = split_87[0] + getitem_877 = split_87[1] + getitem_878 = split_87[2] + getitem_879 = split_87[3] + getitem_880 = split_87[4] + getitem_881 = split_87[5] + getitem_882 = split_87[6] + getitem_883 = split_87[7]; split_87 = None + cat_79 = torch.ops.aten.cat.default([getitem_876, getitem_877, getitem_878, getitem_879, getitem_880, getitem_881, getitem_882, getitem_883], 1); getitem_876 = getitem_877 = getitem_878 = getitem_879 = getitem_880 = getitem_881 = getitem_882 = getitem_883 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16) + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 8, '0'); convert_element_type_650 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_258, [1, 0]); wait_tensor_258 = None + view_1428 = torch.ops.aten.view.default(cat_79, [16384, 4096]); cat_79 = None + mm_137 = torch.ops.aten.mm.default(view_1428, permute_217); permute_217 = None + view_1429 = torch.ops.aten.view.default(mm_137, [2, 8192, 1792]) + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_1429, torch.float32); view_1429 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); convert_element_type_653 = sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16) + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 8, '0'); convert_element_type_655 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + mm_138 = torch.ops.aten.mm.default(view_1428, permute_218); view_1428 = permute_218 = None + view_1436 = torch.ops.aten.view.default(mm_138, [2, 8192, 1792]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_1436); convert_element_type_654 = view_1436 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16) + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 8, '0'); convert_element_type_658 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + view_1443 = torch.ops.aten.view.default(mul_159, [16384, 1792]); mul_159 = None + mm_139 = torch.ops.aten.mm.default(view_1443, permute_219); view_1443 = permute_219 = None + view_1444 = torch.ops.aten.view.default(mm_139, [2, 8192, 4096]); mm_139 = None + split_88 = torch.ops.aten.split.Tensor(view_1444, 1024, 1); view_1444 = None + getitem_884 = split_88[0] + getitem_885 = split_88[1] + getitem_886 = split_88[2] + getitem_887 = split_88[3] + getitem_888 = split_88[4] + getitem_889 = split_88[5] + getitem_890 = split_88[6] + getitem_891 = split_88[7]; split_88 = None + cat_80 = torch.ops.aten.cat.default([getitem_884, getitem_885, getitem_886, getitem_887, getitem_888, getitem_889, getitem_890, getitem_891]); getitem_884 = getitem_885 = getitem_886 = getitem_887 = getitem_888 = getitem_889 = getitem_890 = getitem_891 = None + reduce_scatter_tensor_40 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_80, 'sum', 8, '1'); cat_80 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_40); reduce_scatter_tensor_40 = None + add_79 = torch.ops.aten.add.Tensor(add_77, wait_tensor_261); add_77 = wait_tensor_261 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16) + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 8, '0'); convert_element_type_661 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32) + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = rsqrt_40 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_262); mul_160 = wait_tensor_262 = None + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_663, 8, '1'); convert_element_type_663 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + split_89 = torch.ops.aten.split.Tensor(wait_tensor_263, 2); wait_tensor_263 = None + getitem_892 = split_89[0] + getitem_893 = split_89[1] + getitem_894 = split_89[2] + getitem_895 = split_89[3] + getitem_896 = split_89[4] + getitem_897 = split_89[5] + getitem_898 = split_89[6] + getitem_899 = split_89[7]; split_89 = None + cat_81 = torch.ops.aten.cat.default([getitem_892, getitem_893, getitem_894, getitem_895, getitem_896, getitem_897, getitem_898, getitem_899], 1); getitem_892 = getitem_893 = getitem_894 = getitem_895 = getitem_896 = getitem_897 = getitem_898 = getitem_899 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16) + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 8, '0'); convert_element_type_664 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + view_1455 = torch.ops.aten.view.default(cat_81, [16384, 4096]); cat_81 = None + mm_140 = torch.ops.aten.mm.default(view_1455, permute_220); permute_220 = None + view_1456 = torch.ops.aten.view.default(mm_140, [2, 8192, 512]) + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16) + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 8, '0'); convert_element_type_667 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + mm_141 = torch.ops.aten.mm.default(view_1455, permute_221); permute_221 = None + view_1463 = torch.ops.aten.view.default(mm_141, [2, 8192, 128]); mm_141 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16) + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 8, '0'); convert_element_type_670 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + mm_142 = torch.ops.aten.mm.default(view_1455, permute_222); view_1455 = permute_222 = None + view_1470 = torch.ops.aten.view.default(mm_142, [2, 8192, 128]) + view_1472 = torch.ops.aten.view.default(view_1456, [2, 8192, -1, 128]); view_1456 = None + view_1473 = torch.ops.aten.view.default(view_1463, [2, 8192, -1, 128]); view_1463 = None + view_1474 = torch.ops.aten.view.default(view_1470, [2, 8192, -1, 128]); view_1470 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_1472, torch.float32); view_1472 = None + view_1475 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 4, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_1475); view_1475 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_1473, torch.float32); view_1473 = None + view_1476 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 1, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_1476); view_1476 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_37); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_1478 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 4, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_37); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_1479 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 1, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_1478, torch.bfloat16); view_1478 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_1479, torch.bfloat16); view_1479 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 1, 4, 128]); unsqueeze_40 = None + view_1480 = torch.ops.aten.view.default(expand_40, [2, 8192, 4, 128]); expand_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_1474, 3); view_1474 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 1, 4, 128]); unsqueeze_41 = None + view_1481 = torch.ops.aten.view.default(expand_41, [2, 8192, 4, 128]); expand_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_1480, [0, 2, 1, 3]); view_1480 = None + permute_225 = torch.ops.aten.permute.default(view_1481, [0, 2, 1, 3]); view_1481 = None + _scaled_dot_product_cudnn_attention_20 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_223, permute_224, permute_225, None, True, 0.0, True); permute_223 = permute_224 = permute_225 = None + getitem_900 = _scaled_dot_product_cudnn_attention_20[0] + getitem_901 = _scaled_dot_product_cudnn_attention_20[1] + getitem_906 = _scaled_dot_product_cudnn_attention_20[6] + getitem_907 = _scaled_dot_product_cudnn_attention_20[7]; _scaled_dot_product_cudnn_attention_20 = None + permute_226 = torch.ops.aten.permute.default(getitem_900, [0, 2, 1, 3]) + view_1482 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16) + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 8, '0'); convert_element_type_677 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_267, [1, 0]); wait_tensor_267 = None + view_1488 = torch.ops.aten.view.default(view_1482, [16384, 512]); view_1482 = None + mm_143 = torch.ops.aten.mm.default(view_1488, permute_227); view_1488 = permute_227 = None + view_1489 = torch.ops.aten.view.default(mm_143, [2, 8192, 4096]); mm_143 = None + split_90 = torch.ops.aten.split.Tensor(view_1489, 1024, 1); view_1489 = None + getitem_909 = split_90[0] + getitem_910 = split_90[1] + getitem_911 = split_90[2] + getitem_912 = split_90[3] + getitem_913 = split_90[4] + getitem_914 = split_90[5] + getitem_915 = split_90[6] + getitem_916 = split_90[7]; split_90 = None + cat_82 = torch.ops.aten.cat.default([getitem_909, getitem_910, getitem_911, getitem_912, getitem_913, getitem_914, getitem_915, getitem_916]); getitem_909 = getitem_910 = getitem_911 = getitem_912 = getitem_913 = getitem_914 = getitem_915 = getitem_916 = None + reduce_scatter_tensor_41 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_82, 'sum', 8, '1'); cat_82 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41) + add_81 = torch.ops.aten.add.Tensor(add_79, wait_tensor_268); wait_tensor_268 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16) + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 8, '0'); convert_element_type_680 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32) + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = rsqrt_41 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_269); mul_164 = wait_tensor_269 = None + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_682, 8, '1'); convert_element_type_682 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + split_91 = torch.ops.aten.split.Tensor(wait_tensor_270, 2); wait_tensor_270 = None + getitem_917 = split_91[0] + getitem_918 = split_91[1] + getitem_919 = split_91[2] + getitem_920 = split_91[3] + getitem_921 = split_91[4] + getitem_922 = split_91[5] + getitem_923 = split_91[6] + getitem_924 = split_91[7]; split_91 = None + cat_83 = torch.ops.aten.cat.default([getitem_917, getitem_918, getitem_919, getitem_920, getitem_921, getitem_922, getitem_923, getitem_924], 1); getitem_917 = getitem_918 = getitem_919 = getitem_920 = getitem_921 = getitem_922 = getitem_923 = getitem_924 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16) + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 8, '0'); convert_element_type_683 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_271, [1, 0]); wait_tensor_271 = None + view_1500 = torch.ops.aten.view.default(cat_83, [16384, 4096]); cat_83 = None + mm_144 = torch.ops.aten.mm.default(view_1500, permute_228); permute_228 = None + view_1501 = torch.ops.aten.view.default(mm_144, [2, 8192, 1792]) + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_1501, torch.float32); view_1501 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); convert_element_type_686 = sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16) + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 8, '0'); convert_element_type_688 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + mm_145 = torch.ops.aten.mm.default(view_1500, permute_229); view_1500 = permute_229 = None + view_1508 = torch.ops.aten.view.default(mm_145, [2, 8192, 1792]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_1508); convert_element_type_687 = view_1508 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16) + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 8, '0'); convert_element_type_691 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + view_1515 = torch.ops.aten.view.default(mul_167, [16384, 1792]); mul_167 = None + mm_146 = torch.ops.aten.mm.default(view_1515, permute_230); view_1515 = permute_230 = None + view_1516 = torch.ops.aten.view.default(mm_146, [2, 8192, 4096]); mm_146 = None + split_92 = torch.ops.aten.split.Tensor(view_1516, 1024, 1); view_1516 = None + getitem_925 = split_92[0] + getitem_926 = split_92[1] + getitem_927 = split_92[2] + getitem_928 = split_92[3] + getitem_929 = split_92[4] + getitem_930 = split_92[5] + getitem_931 = split_92[6] + getitem_932 = split_92[7]; split_92 = None + cat_84 = torch.ops.aten.cat.default([getitem_925, getitem_926, getitem_927, getitem_928, getitem_929, getitem_930, getitem_931, getitem_932]); getitem_925 = getitem_926 = getitem_927 = getitem_928 = getitem_929 = getitem_930 = getitem_931 = getitem_932 = None + reduce_scatter_tensor_42 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_84, 'sum', 8, '1'); cat_84 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_42); reduce_scatter_tensor_42 = None + add_83 = torch.ops.aten.add.Tensor(add_81, wait_tensor_274); add_81 = wait_tensor_274 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16) + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 8, '0'); convert_element_type_694 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32) + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = rsqrt_42 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_275); mul_168 = wait_tensor_275 = None + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_696, 8, '1'); convert_element_type_696 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + split_93 = torch.ops.aten.split.Tensor(wait_tensor_276, 2); wait_tensor_276 = None + getitem_933 = split_93[0] + getitem_934 = split_93[1] + getitem_935 = split_93[2] + getitem_936 = split_93[3] + getitem_937 = split_93[4] + getitem_938 = split_93[5] + getitem_939 = split_93[6] + getitem_940 = split_93[7]; split_93 = None + cat_85 = torch.ops.aten.cat.default([getitem_933, getitem_934, getitem_935, getitem_936, getitem_937, getitem_938, getitem_939, getitem_940], 1); getitem_933 = getitem_934 = getitem_935 = getitem_936 = getitem_937 = getitem_938 = getitem_939 = getitem_940 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16) + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 8, '0'); convert_element_type_697 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + view_1527 = torch.ops.aten.view.default(cat_85, [16384, 4096]); cat_85 = None + mm_147 = torch.ops.aten.mm.default(view_1527, permute_231); permute_231 = None + view_1528 = torch.ops.aten.view.default(mm_147, [2, 8192, 512]) + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16) + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 8, '0'); convert_element_type_700 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_148 = torch.ops.aten.mm.default(view_1527, permute_232); permute_232 = None + view_1535 = torch.ops.aten.view.default(mm_148, [2, 8192, 128]); mm_148 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16) + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 8, '0'); convert_element_type_703 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + mm_149 = torch.ops.aten.mm.default(view_1527, permute_233); view_1527 = permute_233 = None + view_1542 = torch.ops.aten.view.default(mm_149, [2, 8192, 128]) + view_1544 = torch.ops.aten.view.default(view_1528, [2, 8192, -1, 128]); view_1528 = None + view_1545 = torch.ops.aten.view.default(view_1535, [2, 8192, -1, 128]); view_1535 = None + view_1546 = torch.ops.aten.view.default(view_1542, [2, 8192, -1, 128]); view_1542 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_1544, torch.float32); view_1544 = None + view_1547 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 4, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_1547); view_1547 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_1545, torch.float32); view_1545 = None + view_1548 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 1, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_1548); view_1548 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_37); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_1550 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 4, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_37); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_1551 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 1, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_1550, torch.bfloat16); view_1550 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_1551, torch.bfloat16); view_1551 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 1, 4, 128]); unsqueeze_42 = None + view_1552 = torch.ops.aten.view.default(expand_42, [2, 8192, 4, 128]); expand_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_1546, 3); view_1546 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 1, 4, 128]); unsqueeze_43 = None + view_1553 = torch.ops.aten.view.default(expand_43, [2, 8192, 4, 128]); expand_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_1552, [0, 2, 1, 3]); view_1552 = None + permute_236 = torch.ops.aten.permute.default(view_1553, [0, 2, 1, 3]); view_1553 = None + _scaled_dot_product_cudnn_attention_21 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_234, permute_235, permute_236, None, True, 0.0, True); permute_234 = permute_235 = permute_236 = None + getitem_941 = _scaled_dot_product_cudnn_attention_21[0] + getitem_942 = _scaled_dot_product_cudnn_attention_21[1] + getitem_947 = _scaled_dot_product_cudnn_attention_21[6] + getitem_948 = _scaled_dot_product_cudnn_attention_21[7]; _scaled_dot_product_cudnn_attention_21 = None + permute_237 = torch.ops.aten.permute.default(getitem_941, [0, 2, 1, 3]) + view_1554 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16) + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 8, '0'); convert_element_type_710 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_280, [1, 0]); wait_tensor_280 = None + view_1560 = torch.ops.aten.view.default(view_1554, [16384, 512]); view_1554 = None + mm_150 = torch.ops.aten.mm.default(view_1560, permute_238); view_1560 = permute_238 = None + view_1561 = torch.ops.aten.view.default(mm_150, [2, 8192, 4096]); mm_150 = None + split_94 = torch.ops.aten.split.Tensor(view_1561, 1024, 1); view_1561 = None + getitem_950 = split_94[0] + getitem_951 = split_94[1] + getitem_952 = split_94[2] + getitem_953 = split_94[3] + getitem_954 = split_94[4] + getitem_955 = split_94[5] + getitem_956 = split_94[6] + getitem_957 = split_94[7]; split_94 = None + cat_86 = torch.ops.aten.cat.default([getitem_950, getitem_951, getitem_952, getitem_953, getitem_954, getitem_955, getitem_956, getitem_957]); getitem_950 = getitem_951 = getitem_952 = getitem_953 = getitem_954 = getitem_955 = getitem_956 = getitem_957 = None + reduce_scatter_tensor_43 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_86, 'sum', 8, '1'); cat_86 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43) + add_85 = torch.ops.aten.add.Tensor(add_83, wait_tensor_281); wait_tensor_281 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16) + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 8, '0'); convert_element_type_713 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32) + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = rsqrt_43 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_282); mul_172 = wait_tensor_282 = None + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_715, 8, '1'); convert_element_type_715 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + split_95 = torch.ops.aten.split.Tensor(wait_tensor_283, 2); wait_tensor_283 = None + getitem_958 = split_95[0] + getitem_959 = split_95[1] + getitem_960 = split_95[2] + getitem_961 = split_95[3] + getitem_962 = split_95[4] + getitem_963 = split_95[5] + getitem_964 = split_95[6] + getitem_965 = split_95[7]; split_95 = None + cat_87 = torch.ops.aten.cat.default([getitem_958, getitem_959, getitem_960, getitem_961, getitem_962, getitem_963, getitem_964, getitem_965], 1); getitem_958 = getitem_959 = getitem_960 = getitem_961 = getitem_962 = getitem_963 = getitem_964 = getitem_965 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16) + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 8, '0'); convert_element_type_716 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + view_1572 = torch.ops.aten.view.default(cat_87, [16384, 4096]); cat_87 = None + mm_151 = torch.ops.aten.mm.default(view_1572, permute_239); permute_239 = None + view_1573 = torch.ops.aten.view.default(mm_151, [2, 8192, 1792]) + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_1573, torch.float32); view_1573 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); convert_element_type_719 = sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16) + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 8, '0'); convert_element_type_721 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_285, [1, 0]); wait_tensor_285 = None + mm_152 = torch.ops.aten.mm.default(view_1572, permute_240); view_1572 = permute_240 = None + view_1580 = torch.ops.aten.view.default(mm_152, [2, 8192, 1792]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_1580); convert_element_type_720 = view_1580 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16) + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 8, '0'); convert_element_type_724 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + view_1587 = torch.ops.aten.view.default(mul_175, [16384, 1792]); mul_175 = None + mm_153 = torch.ops.aten.mm.default(view_1587, permute_241); view_1587 = permute_241 = None + view_1588 = torch.ops.aten.view.default(mm_153, [2, 8192, 4096]); mm_153 = None + split_96 = torch.ops.aten.split.Tensor(view_1588, 1024, 1); view_1588 = None + getitem_966 = split_96[0] + getitem_967 = split_96[1] + getitem_968 = split_96[2] + getitem_969 = split_96[3] + getitem_970 = split_96[4] + getitem_971 = split_96[5] + getitem_972 = split_96[6] + getitem_973 = split_96[7]; split_96 = None + cat_88 = torch.ops.aten.cat.default([getitem_966, getitem_967, getitem_968, getitem_969, getitem_970, getitem_971, getitem_972, getitem_973]); getitem_966 = getitem_967 = getitem_968 = getitem_969 = getitem_970 = getitem_971 = getitem_972 = getitem_973 = None + reduce_scatter_tensor_44 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_88, 'sum', 8, '1'); cat_88 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_44); reduce_scatter_tensor_44 = None + add_87 = torch.ops.aten.add.Tensor(add_85, wait_tensor_287); add_85 = wait_tensor_287 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16) + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 8, '0'); convert_element_type_727 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32) + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = rsqrt_44 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_288); mul_176 = wait_tensor_288 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_729, 8, '1'); convert_element_type_729 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + split_97 = torch.ops.aten.split.Tensor(wait_tensor_289, 2); wait_tensor_289 = None + getitem_974 = split_97[0] + getitem_975 = split_97[1] + getitem_976 = split_97[2] + getitem_977 = split_97[3] + getitem_978 = split_97[4] + getitem_979 = split_97[5] + getitem_980 = split_97[6] + getitem_981 = split_97[7]; split_97 = None + cat_89 = torch.ops.aten.cat.default([getitem_974, getitem_975, getitem_976, getitem_977, getitem_978, getitem_979, getitem_980, getitem_981], 1); getitem_974 = getitem_975 = getitem_976 = getitem_977 = getitem_978 = getitem_979 = getitem_980 = getitem_981 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16) + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 8, '0'); convert_element_type_730 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + view_1599 = torch.ops.aten.view.default(cat_89, [16384, 4096]); cat_89 = None + mm_154 = torch.ops.aten.mm.default(view_1599, permute_242); permute_242 = None + view_1600 = torch.ops.aten.view.default(mm_154, [2, 8192, 512]) + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16) + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 8, '0'); convert_element_type_733 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_291, [1, 0]); wait_tensor_291 = None + mm_155 = torch.ops.aten.mm.default(view_1599, permute_243); permute_243 = None + view_1607 = torch.ops.aten.view.default(mm_155, [2, 8192, 128]); mm_155 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16) + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 8, '0'); convert_element_type_736 = None + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_292, [1, 0]); wait_tensor_292 = None + mm_156 = torch.ops.aten.mm.default(view_1599, permute_244); view_1599 = permute_244 = None + view_1614 = torch.ops.aten.view.default(mm_156, [2, 8192, 128]) + view_1616 = torch.ops.aten.view.default(view_1600, [2, 8192, -1, 128]); view_1600 = None + view_1617 = torch.ops.aten.view.default(view_1607, [2, 8192, -1, 128]); view_1607 = None + view_1618 = torch.ops.aten.view.default(view_1614, [2, 8192, -1, 128]); view_1614 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_1616, torch.float32); view_1616 = None + view_1619 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 4, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_1619); view_1619 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_1617, torch.float32); view_1617 = None + view_1620 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 1, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_1620); view_1620 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_37); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_1622 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 4, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_37); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_1623 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 1, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_1622, torch.bfloat16); view_1622 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_1623, torch.bfloat16); view_1623 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 1, 4, 128]); unsqueeze_44 = None + view_1624 = torch.ops.aten.view.default(expand_44, [2, 8192, 4, 128]); expand_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_1618, 3); view_1618 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 1, 4, 128]); unsqueeze_45 = None + view_1625 = torch.ops.aten.view.default(expand_45, [2, 8192, 4, 128]); expand_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_1624, [0, 2, 1, 3]); view_1624 = None + permute_247 = torch.ops.aten.permute.default(view_1625, [0, 2, 1, 3]); view_1625 = None + _scaled_dot_product_cudnn_attention_22 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_245, permute_246, permute_247, None, True, 0.0, True); permute_245 = permute_246 = permute_247 = None + getitem_982 = _scaled_dot_product_cudnn_attention_22[0] + getitem_983 = _scaled_dot_product_cudnn_attention_22[1] + getitem_988 = _scaled_dot_product_cudnn_attention_22[6] + getitem_989 = _scaled_dot_product_cudnn_attention_22[7]; _scaled_dot_product_cudnn_attention_22 = None + permute_248 = torch.ops.aten.permute.default(getitem_982, [0, 2, 1, 3]) + view_1626 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16) + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 8, '0'); convert_element_type_743 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_293, [1, 0]); wait_tensor_293 = None + view_1632 = torch.ops.aten.view.default(view_1626, [16384, 512]); view_1626 = None + mm_157 = torch.ops.aten.mm.default(view_1632, permute_249); view_1632 = permute_249 = None + view_1633 = torch.ops.aten.view.default(mm_157, [2, 8192, 4096]); mm_157 = None + split_98 = torch.ops.aten.split.Tensor(view_1633, 1024, 1); view_1633 = None + getitem_991 = split_98[0] + getitem_992 = split_98[1] + getitem_993 = split_98[2] + getitem_994 = split_98[3] + getitem_995 = split_98[4] + getitem_996 = split_98[5] + getitem_997 = split_98[6] + getitem_998 = split_98[7]; split_98 = None + cat_90 = torch.ops.aten.cat.default([getitem_991, getitem_992, getitem_993, getitem_994, getitem_995, getitem_996, getitem_997, getitem_998]); getitem_991 = getitem_992 = getitem_993 = getitem_994 = getitem_995 = getitem_996 = getitem_997 = getitem_998 = None + reduce_scatter_tensor_45 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_90, 'sum', 8, '1'); cat_90 = None + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45) + add_89 = torch.ops.aten.add.Tensor(add_87, wait_tensor_294); wait_tensor_294 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16) + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 8, '0'); convert_element_type_746 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32) + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = rsqrt_45 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_295); mul_180 = wait_tensor_295 = None + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_748, 8, '1'); convert_element_type_748 = None + wait_tensor_296 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + split_99 = torch.ops.aten.split.Tensor(wait_tensor_296, 2); wait_tensor_296 = None + getitem_999 = split_99[0] + getitem_1000 = split_99[1] + getitem_1001 = split_99[2] + getitem_1002 = split_99[3] + getitem_1003 = split_99[4] + getitem_1004 = split_99[5] + getitem_1005 = split_99[6] + getitem_1006 = split_99[7]; split_99 = None + cat_91 = torch.ops.aten.cat.default([getitem_999, getitem_1000, getitem_1001, getitem_1002, getitem_1003, getitem_1004, getitem_1005, getitem_1006], 1); getitem_999 = getitem_1000 = getitem_1001 = getitem_1002 = getitem_1003 = getitem_1004 = getitem_1005 = getitem_1006 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16) + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 8, '0'); convert_element_type_749 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_297, [1, 0]); wait_tensor_297 = None + view_1644 = torch.ops.aten.view.default(cat_91, [16384, 4096]); cat_91 = None + mm_158 = torch.ops.aten.mm.default(view_1644, permute_250); permute_250 = None + view_1645 = torch.ops.aten.view.default(mm_158, [2, 8192, 1792]) + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_1645, torch.float32); view_1645 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); convert_element_type_752 = sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16) + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 8, '0'); convert_element_type_754 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_298, [1, 0]); wait_tensor_298 = None + mm_159 = torch.ops.aten.mm.default(view_1644, permute_251); view_1644 = permute_251 = None + view_1652 = torch.ops.aten.view.default(mm_159, [2, 8192, 1792]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_1652); convert_element_type_753 = view_1652 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16) + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 8, '0'); convert_element_type_757 = None + wait_tensor_299 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_299, [1, 0]); wait_tensor_299 = None + view_1659 = torch.ops.aten.view.default(mul_183, [16384, 1792]); mul_183 = None + mm_160 = torch.ops.aten.mm.default(view_1659, permute_252); view_1659 = permute_252 = None + view_1660 = torch.ops.aten.view.default(mm_160, [2, 8192, 4096]); mm_160 = None + split_100 = torch.ops.aten.split.Tensor(view_1660, 1024, 1); view_1660 = None + getitem_1007 = split_100[0] + getitem_1008 = split_100[1] + getitem_1009 = split_100[2] + getitem_1010 = split_100[3] + getitem_1011 = split_100[4] + getitem_1012 = split_100[5] + getitem_1013 = split_100[6] + getitem_1014 = split_100[7]; split_100 = None + cat_92 = torch.ops.aten.cat.default([getitem_1007, getitem_1008, getitem_1009, getitem_1010, getitem_1011, getitem_1012, getitem_1013, getitem_1014]); getitem_1007 = getitem_1008 = getitem_1009 = getitem_1010 = getitem_1011 = getitem_1012 = getitem_1013 = getitem_1014 = None + reduce_scatter_tensor_46 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_92, 'sum', 8, '1'); cat_92 = None + wait_tensor_300 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_46); reduce_scatter_tensor_46 = None + add_91 = torch.ops.aten.add.Tensor(add_89, wait_tensor_300); add_89 = wait_tensor_300 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16) + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 8, '0'); convert_element_type_760 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32) + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = rsqrt_46 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_301); mul_184 = wait_tensor_301 = None + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_762, 8, '1'); convert_element_type_762 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + split_101 = torch.ops.aten.split.Tensor(wait_tensor_302, 2); wait_tensor_302 = None + getitem_1015 = split_101[0] + getitem_1016 = split_101[1] + getitem_1017 = split_101[2] + getitem_1018 = split_101[3] + getitem_1019 = split_101[4] + getitem_1020 = split_101[5] + getitem_1021 = split_101[6] + getitem_1022 = split_101[7]; split_101 = None + cat_93 = torch.ops.aten.cat.default([getitem_1015, getitem_1016, getitem_1017, getitem_1018, getitem_1019, getitem_1020, getitem_1021, getitem_1022], 1); getitem_1015 = getitem_1016 = getitem_1017 = getitem_1018 = getitem_1019 = getitem_1020 = getitem_1021 = getitem_1022 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16) + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 8, '0'); convert_element_type_763 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_303, [1, 0]); wait_tensor_303 = None + view_1671 = torch.ops.aten.view.default(cat_93, [16384, 4096]); cat_93 = None + mm_161 = torch.ops.aten.mm.default(view_1671, permute_253); permute_253 = None + view_1672 = torch.ops.aten.view.default(mm_161, [2, 8192, 512]) + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16) + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 8, '0'); convert_element_type_766 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_304, [1, 0]); wait_tensor_304 = None + mm_162 = torch.ops.aten.mm.default(view_1671, permute_254); permute_254 = None + view_1679 = torch.ops.aten.view.default(mm_162, [2, 8192, 128]); mm_162 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16) + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 8, '0'); convert_element_type_769 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_305, [1, 0]); wait_tensor_305 = None + mm_163 = torch.ops.aten.mm.default(view_1671, permute_255); view_1671 = permute_255 = None + view_1686 = torch.ops.aten.view.default(mm_163, [2, 8192, 128]) + view_1688 = torch.ops.aten.view.default(view_1672, [2, 8192, -1, 128]); view_1672 = None + view_1689 = torch.ops.aten.view.default(view_1679, [2, 8192, -1, 128]); view_1679 = None + view_1690 = torch.ops.aten.view.default(view_1686, [2, 8192, -1, 128]); view_1686 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_1688, torch.float32); view_1688 = None + view_1691 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 4, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_1691); view_1691 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_1689, torch.float32); view_1689 = None + view_1692 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 1, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_1692); view_1692 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_37); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_1694 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 4, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_37); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_1695 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 1, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_1694, torch.bfloat16); view_1694 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_1695, torch.bfloat16); view_1695 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 1, 4, 128]); unsqueeze_46 = None + view_1696 = torch.ops.aten.view.default(expand_46, [2, 8192, 4, 128]); expand_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_1690, 3); view_1690 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 1, 4, 128]); unsqueeze_47 = None + view_1697 = torch.ops.aten.view.default(expand_47, [2, 8192, 4, 128]); expand_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_1696, [0, 2, 1, 3]); view_1696 = None + permute_258 = torch.ops.aten.permute.default(view_1697, [0, 2, 1, 3]); view_1697 = None + _scaled_dot_product_cudnn_attention_23 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_256, permute_257, permute_258, None, True, 0.0, True); permute_256 = permute_257 = permute_258 = None + getitem_1023 = _scaled_dot_product_cudnn_attention_23[0] + getitem_1024 = _scaled_dot_product_cudnn_attention_23[1] + getitem_1029 = _scaled_dot_product_cudnn_attention_23[6] + getitem_1030 = _scaled_dot_product_cudnn_attention_23[7]; _scaled_dot_product_cudnn_attention_23 = None + permute_259 = torch.ops.aten.permute.default(getitem_1023, [0, 2, 1, 3]) + view_1698 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16) + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 8, '0'); convert_element_type_776 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_306, [1, 0]); wait_tensor_306 = None + view_1704 = torch.ops.aten.view.default(view_1698, [16384, 512]); view_1698 = None + mm_164 = torch.ops.aten.mm.default(view_1704, permute_260); view_1704 = permute_260 = None + view_1705 = torch.ops.aten.view.default(mm_164, [2, 8192, 4096]); mm_164 = None + split_102 = torch.ops.aten.split.Tensor(view_1705, 1024, 1); view_1705 = None + getitem_1032 = split_102[0] + getitem_1033 = split_102[1] + getitem_1034 = split_102[2] + getitem_1035 = split_102[3] + getitem_1036 = split_102[4] + getitem_1037 = split_102[5] + getitem_1038 = split_102[6] + getitem_1039 = split_102[7]; split_102 = None + cat_94 = torch.ops.aten.cat.default([getitem_1032, getitem_1033, getitem_1034, getitem_1035, getitem_1036, getitem_1037, getitem_1038, getitem_1039]); getitem_1032 = getitem_1033 = getitem_1034 = getitem_1035 = getitem_1036 = getitem_1037 = getitem_1038 = getitem_1039 = None + reduce_scatter_tensor_47 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_94, 'sum', 8, '1'); cat_94 = None + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47) + add_93 = torch.ops.aten.add.Tensor(add_91, wait_tensor_307); wait_tensor_307 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16) + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 8, '0'); convert_element_type_779 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32) + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = rsqrt_47 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_308); mul_188 = wait_tensor_308 = None + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_781, 8, '1'); convert_element_type_781 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + split_103 = torch.ops.aten.split.Tensor(wait_tensor_309, 2); wait_tensor_309 = None + getitem_1040 = split_103[0] + getitem_1041 = split_103[1] + getitem_1042 = split_103[2] + getitem_1043 = split_103[3] + getitem_1044 = split_103[4] + getitem_1045 = split_103[5] + getitem_1046 = split_103[6] + getitem_1047 = split_103[7]; split_103 = None + cat_95 = torch.ops.aten.cat.default([getitem_1040, getitem_1041, getitem_1042, getitem_1043, getitem_1044, getitem_1045, getitem_1046, getitem_1047], 1); getitem_1040 = getitem_1041 = getitem_1042 = getitem_1043 = getitem_1044 = getitem_1045 = getitem_1046 = getitem_1047 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16) + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 8, '0'); convert_element_type_782 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_310, [1, 0]); wait_tensor_310 = None + view_1716 = torch.ops.aten.view.default(cat_95, [16384, 4096]); cat_95 = None + mm_165 = torch.ops.aten.mm.default(view_1716, permute_261); permute_261 = None + view_1717 = torch.ops.aten.view.default(mm_165, [2, 8192, 1792]) + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_1717, torch.float32); view_1717 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); convert_element_type_785 = sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16) + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 8, '0'); convert_element_type_787 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_311, [1, 0]); wait_tensor_311 = None + mm_166 = torch.ops.aten.mm.default(view_1716, permute_262); view_1716 = permute_262 = None + view_1724 = torch.ops.aten.view.default(mm_166, [2, 8192, 1792]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_1724); convert_element_type_786 = view_1724 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16) + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 8, '0'); convert_element_type_790 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_312, [1, 0]); wait_tensor_312 = None + view_1731 = torch.ops.aten.view.default(mul_191, [16384, 1792]); mul_191 = None + mm_167 = torch.ops.aten.mm.default(view_1731, permute_263); view_1731 = permute_263 = None + view_1732 = torch.ops.aten.view.default(mm_167, [2, 8192, 4096]); mm_167 = None + split_104 = torch.ops.aten.split.Tensor(view_1732, 1024, 1); view_1732 = None + getitem_1048 = split_104[0] + getitem_1049 = split_104[1] + getitem_1050 = split_104[2] + getitem_1051 = split_104[3] + getitem_1052 = split_104[4] + getitem_1053 = split_104[5] + getitem_1054 = split_104[6] + getitem_1055 = split_104[7]; split_104 = None + cat_96 = torch.ops.aten.cat.default([getitem_1048, getitem_1049, getitem_1050, getitem_1051, getitem_1052, getitem_1053, getitem_1054, getitem_1055]); getitem_1048 = getitem_1049 = getitem_1050 = getitem_1051 = getitem_1052 = getitem_1053 = getitem_1054 = getitem_1055 = None + reduce_scatter_tensor_48 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_96, 'sum', 8, '1'); cat_96 = None + wait_tensor_313 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_48); reduce_scatter_tensor_48 = None + add_95 = torch.ops.aten.add.Tensor(add_93, wait_tensor_313); add_93 = wait_tensor_313 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16) + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 8, '0'); convert_element_type_793 = None + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32) + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = rsqrt_48 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_314); mul_192 = wait_tensor_314 = None + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_795, 8, '1'); convert_element_type_795 = None + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + split_105 = torch.ops.aten.split.Tensor(wait_tensor_315, 2); wait_tensor_315 = None + getitem_1056 = split_105[0] + getitem_1057 = split_105[1] + getitem_1058 = split_105[2] + getitem_1059 = split_105[3] + getitem_1060 = split_105[4] + getitem_1061 = split_105[5] + getitem_1062 = split_105[6] + getitem_1063 = split_105[7]; split_105 = None + cat_97 = torch.ops.aten.cat.default([getitem_1056, getitem_1057, getitem_1058, getitem_1059, getitem_1060, getitem_1061, getitem_1062, getitem_1063], 1); getitem_1056 = getitem_1057 = getitem_1058 = getitem_1059 = getitem_1060 = getitem_1061 = getitem_1062 = getitem_1063 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16) + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 8, '0'); convert_element_type_796 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_316, [1, 0]); wait_tensor_316 = None + view_1743 = torch.ops.aten.view.default(cat_97, [16384, 4096]); cat_97 = None + mm_168 = torch.ops.aten.mm.default(view_1743, permute_264); permute_264 = None + view_1744 = torch.ops.aten.view.default(mm_168, [2, 8192, 512]) + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16) + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 8, '0'); convert_element_type_799 = None + wait_tensor_317 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_317, [1, 0]); wait_tensor_317 = None + mm_169 = torch.ops.aten.mm.default(view_1743, permute_265); permute_265 = None + view_1751 = torch.ops.aten.view.default(mm_169, [2, 8192, 128]); mm_169 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16) + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 8, '0'); convert_element_type_802 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_318, [1, 0]); wait_tensor_318 = None + mm_170 = torch.ops.aten.mm.default(view_1743, permute_266); view_1743 = permute_266 = None + view_1758 = torch.ops.aten.view.default(mm_170, [2, 8192, 128]) + view_1760 = torch.ops.aten.view.default(view_1744, [2, 8192, -1, 128]); view_1744 = None + view_1761 = torch.ops.aten.view.default(view_1751, [2, 8192, -1, 128]); view_1751 = None + view_1762 = torch.ops.aten.view.default(view_1758, [2, 8192, -1, 128]); view_1758 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_1760, torch.float32); view_1760 = None + view_1763 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 4, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_1763); view_1763 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_1761, torch.float32); view_1761 = None + view_1764 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 1, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_1764); view_1764 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_37); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_1766 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 4, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_37); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_1767 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 1, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_1766, torch.bfloat16); view_1766 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_1767, torch.bfloat16); view_1767 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 1, 4, 128]); unsqueeze_48 = None + view_1768 = torch.ops.aten.view.default(expand_48, [2, 8192, 4, 128]); expand_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_1762, 3); view_1762 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 1, 4, 128]); unsqueeze_49 = None + view_1769 = torch.ops.aten.view.default(expand_49, [2, 8192, 4, 128]); expand_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_1768, [0, 2, 1, 3]); view_1768 = None + permute_269 = torch.ops.aten.permute.default(view_1769, [0, 2, 1, 3]); view_1769 = None + _scaled_dot_product_cudnn_attention_24 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_267, permute_268, permute_269, None, True, 0.0, True); permute_267 = permute_268 = permute_269 = None + getitem_1064 = _scaled_dot_product_cudnn_attention_24[0] + getitem_1065 = _scaled_dot_product_cudnn_attention_24[1] + getitem_1070 = _scaled_dot_product_cudnn_attention_24[6] + getitem_1071 = _scaled_dot_product_cudnn_attention_24[7]; _scaled_dot_product_cudnn_attention_24 = None + permute_270 = torch.ops.aten.permute.default(getitem_1064, [0, 2, 1, 3]) + view_1770 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16) + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 8, '0'); convert_element_type_809 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_319, [1, 0]); wait_tensor_319 = None + view_1776 = torch.ops.aten.view.default(view_1770, [16384, 512]); view_1770 = None + mm_171 = torch.ops.aten.mm.default(view_1776, permute_271); view_1776 = permute_271 = None + view_1777 = torch.ops.aten.view.default(mm_171, [2, 8192, 4096]); mm_171 = None + split_106 = torch.ops.aten.split.Tensor(view_1777, 1024, 1); view_1777 = None + getitem_1073 = split_106[0] + getitem_1074 = split_106[1] + getitem_1075 = split_106[2] + getitem_1076 = split_106[3] + getitem_1077 = split_106[4] + getitem_1078 = split_106[5] + getitem_1079 = split_106[6] + getitem_1080 = split_106[7]; split_106 = None + cat_98 = torch.ops.aten.cat.default([getitem_1073, getitem_1074, getitem_1075, getitem_1076, getitem_1077, getitem_1078, getitem_1079, getitem_1080]); getitem_1073 = getitem_1074 = getitem_1075 = getitem_1076 = getitem_1077 = getitem_1078 = getitem_1079 = getitem_1080 = None + reduce_scatter_tensor_49 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_98, 'sum', 8, '1'); cat_98 = None + wait_tensor_320 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49) + add_97 = torch.ops.aten.add.Tensor(add_95, wait_tensor_320); wait_tensor_320 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16) + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 8, '0'); convert_element_type_812 = None + wait_tensor_321 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32) + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = rsqrt_49 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_321); mul_196 = wait_tensor_321 = None + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_814, 8, '1'); convert_element_type_814 = None + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + split_107 = torch.ops.aten.split.Tensor(wait_tensor_322, 2); wait_tensor_322 = None + getitem_1081 = split_107[0] + getitem_1082 = split_107[1] + getitem_1083 = split_107[2] + getitem_1084 = split_107[3] + getitem_1085 = split_107[4] + getitem_1086 = split_107[5] + getitem_1087 = split_107[6] + getitem_1088 = split_107[7]; split_107 = None + cat_99 = torch.ops.aten.cat.default([getitem_1081, getitem_1082, getitem_1083, getitem_1084, getitem_1085, getitem_1086, getitem_1087, getitem_1088], 1); getitem_1081 = getitem_1082 = getitem_1083 = getitem_1084 = getitem_1085 = getitem_1086 = getitem_1087 = getitem_1088 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16) + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 8, '0'); convert_element_type_815 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_323, [1, 0]); wait_tensor_323 = None + view_1788 = torch.ops.aten.view.default(cat_99, [16384, 4096]); cat_99 = None + mm_172 = torch.ops.aten.mm.default(view_1788, permute_272); permute_272 = None + view_1789 = torch.ops.aten.view.default(mm_172, [2, 8192, 1792]) + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_1789, torch.float32); view_1789 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); convert_element_type_818 = sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16) + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 8, '0'); convert_element_type_820 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_324, [1, 0]); wait_tensor_324 = None + mm_173 = torch.ops.aten.mm.default(view_1788, permute_273); view_1788 = permute_273 = None + view_1796 = torch.ops.aten.view.default(mm_173, [2, 8192, 1792]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_1796); convert_element_type_819 = view_1796 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16) + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 8, '0'); convert_element_type_823 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_325, [1, 0]); wait_tensor_325 = None + view_1803 = torch.ops.aten.view.default(mul_199, [16384, 1792]); mul_199 = None + mm_174 = torch.ops.aten.mm.default(view_1803, permute_274); view_1803 = permute_274 = None + view_1804 = torch.ops.aten.view.default(mm_174, [2, 8192, 4096]); mm_174 = None + split_108 = torch.ops.aten.split.Tensor(view_1804, 1024, 1); view_1804 = None + getitem_1089 = split_108[0] + getitem_1090 = split_108[1] + getitem_1091 = split_108[2] + getitem_1092 = split_108[3] + getitem_1093 = split_108[4] + getitem_1094 = split_108[5] + getitem_1095 = split_108[6] + getitem_1096 = split_108[7]; split_108 = None + cat_100 = torch.ops.aten.cat.default([getitem_1089, getitem_1090, getitem_1091, getitem_1092, getitem_1093, getitem_1094, getitem_1095, getitem_1096]); getitem_1089 = getitem_1090 = getitem_1091 = getitem_1092 = getitem_1093 = getitem_1094 = getitem_1095 = getitem_1096 = None + reduce_scatter_tensor_50 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_100, 'sum', 8, '1'); cat_100 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_50); reduce_scatter_tensor_50 = None + add_99 = torch.ops.aten.add.Tensor(add_97, wait_tensor_326); add_97 = wait_tensor_326 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16) + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 8, '0'); convert_element_type_826 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32) + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = rsqrt_50 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_327); mul_200 = wait_tensor_327 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_828, 8, '1'); convert_element_type_828 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + split_109 = torch.ops.aten.split.Tensor(wait_tensor_328, 2); wait_tensor_328 = None + getitem_1097 = split_109[0] + getitem_1098 = split_109[1] + getitem_1099 = split_109[2] + getitem_1100 = split_109[3] + getitem_1101 = split_109[4] + getitem_1102 = split_109[5] + getitem_1103 = split_109[6] + getitem_1104 = split_109[7]; split_109 = None + cat_101 = torch.ops.aten.cat.default([getitem_1097, getitem_1098, getitem_1099, getitem_1100, getitem_1101, getitem_1102, getitem_1103, getitem_1104], 1); getitem_1097 = getitem_1098 = getitem_1099 = getitem_1100 = getitem_1101 = getitem_1102 = getitem_1103 = getitem_1104 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16) + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 8, '0'); convert_element_type_829 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_329, [1, 0]); wait_tensor_329 = None + view_1815 = torch.ops.aten.view.default(cat_101, [16384, 4096]); cat_101 = None + mm_175 = torch.ops.aten.mm.default(view_1815, permute_275); permute_275 = None + view_1816 = torch.ops.aten.view.default(mm_175, [2, 8192, 512]) + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16) + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 8, '0'); convert_element_type_832 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_330, [1, 0]); wait_tensor_330 = None + mm_176 = torch.ops.aten.mm.default(view_1815, permute_276); permute_276 = None + view_1823 = torch.ops.aten.view.default(mm_176, [2, 8192, 128]); mm_176 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16) + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 8, '0'); convert_element_type_835 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_331, [1, 0]); wait_tensor_331 = None + mm_177 = torch.ops.aten.mm.default(view_1815, permute_277); view_1815 = permute_277 = None + view_1830 = torch.ops.aten.view.default(mm_177, [2, 8192, 128]) + view_1832 = torch.ops.aten.view.default(view_1816, [2, 8192, -1, 128]); view_1816 = None + view_1833 = torch.ops.aten.view.default(view_1823, [2, 8192, -1, 128]); view_1823 = None + view_1834 = torch.ops.aten.view.default(view_1830, [2, 8192, -1, 128]); view_1830 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_1832, torch.float32); view_1832 = None + view_1835 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 4, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_1835); view_1835 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_1833, torch.float32); view_1833 = None + view_1836 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 1, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_1836); view_1836 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_37); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_1838 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 4, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_37); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_1839 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 1, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_1838, torch.bfloat16); view_1838 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_1839, torch.bfloat16); view_1839 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 1, 4, 128]); unsqueeze_50 = None + view_1840 = torch.ops.aten.view.default(expand_50, [2, 8192, 4, 128]); expand_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_1834, 3); view_1834 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 1, 4, 128]); unsqueeze_51 = None + view_1841 = torch.ops.aten.view.default(expand_51, [2, 8192, 4, 128]); expand_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_1840, [0, 2, 1, 3]); view_1840 = None + permute_280 = torch.ops.aten.permute.default(view_1841, [0, 2, 1, 3]); view_1841 = None + _scaled_dot_product_cudnn_attention_25 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_278, permute_279, permute_280, None, True, 0.0, True); permute_278 = permute_279 = permute_280 = None + getitem_1105 = _scaled_dot_product_cudnn_attention_25[0] + getitem_1106 = _scaled_dot_product_cudnn_attention_25[1] + getitem_1111 = _scaled_dot_product_cudnn_attention_25[6] + getitem_1112 = _scaled_dot_product_cudnn_attention_25[7]; _scaled_dot_product_cudnn_attention_25 = None + permute_281 = torch.ops.aten.permute.default(getitem_1105, [0, 2, 1, 3]) + view_1842 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16) + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 8, '0'); convert_element_type_842 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_332, [1, 0]); wait_tensor_332 = None + view_1848 = torch.ops.aten.view.default(view_1842, [16384, 512]); view_1842 = None + mm_178 = torch.ops.aten.mm.default(view_1848, permute_282); view_1848 = permute_282 = None + view_1849 = torch.ops.aten.view.default(mm_178, [2, 8192, 4096]); mm_178 = None + split_110 = torch.ops.aten.split.Tensor(view_1849, 1024, 1); view_1849 = None + getitem_1114 = split_110[0] + getitem_1115 = split_110[1] + getitem_1116 = split_110[2] + getitem_1117 = split_110[3] + getitem_1118 = split_110[4] + getitem_1119 = split_110[5] + getitem_1120 = split_110[6] + getitem_1121 = split_110[7]; split_110 = None + cat_102 = torch.ops.aten.cat.default([getitem_1114, getitem_1115, getitem_1116, getitem_1117, getitem_1118, getitem_1119, getitem_1120, getitem_1121]); getitem_1114 = getitem_1115 = getitem_1116 = getitem_1117 = getitem_1118 = getitem_1119 = getitem_1120 = getitem_1121 = None + reduce_scatter_tensor_51 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_102, 'sum', 8, '1'); cat_102 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51) + add_101 = torch.ops.aten.add.Tensor(add_99, wait_tensor_333); wait_tensor_333 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16) + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 8, '0'); convert_element_type_845 = None + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32) + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = rsqrt_51 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_334); mul_204 = wait_tensor_334 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_847, 8, '1'); convert_element_type_847 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + split_111 = torch.ops.aten.split.Tensor(wait_tensor_335, 2); wait_tensor_335 = None + getitem_1122 = split_111[0] + getitem_1123 = split_111[1] + getitem_1124 = split_111[2] + getitem_1125 = split_111[3] + getitem_1126 = split_111[4] + getitem_1127 = split_111[5] + getitem_1128 = split_111[6] + getitem_1129 = split_111[7]; split_111 = None + cat_103 = torch.ops.aten.cat.default([getitem_1122, getitem_1123, getitem_1124, getitem_1125, getitem_1126, getitem_1127, getitem_1128, getitem_1129], 1); getitem_1122 = getitem_1123 = getitem_1124 = getitem_1125 = getitem_1126 = getitem_1127 = getitem_1128 = getitem_1129 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16) + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 8, '0'); convert_element_type_848 = None + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_336, [1, 0]); wait_tensor_336 = None + view_1860 = torch.ops.aten.view.default(cat_103, [16384, 4096]); cat_103 = None + mm_179 = torch.ops.aten.mm.default(view_1860, permute_283); permute_283 = None + view_1861 = torch.ops.aten.view.default(mm_179, [2, 8192, 1792]) + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_1861, torch.float32); view_1861 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); convert_element_type_851 = sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16) + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 8, '0'); convert_element_type_853 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_337, [1, 0]); wait_tensor_337 = None + mm_180 = torch.ops.aten.mm.default(view_1860, permute_284); view_1860 = permute_284 = None + view_1868 = torch.ops.aten.view.default(mm_180, [2, 8192, 1792]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_1868); convert_element_type_852 = view_1868 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16) + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 8, '0'); convert_element_type_856 = None + wait_tensor_338 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_338, [1, 0]); wait_tensor_338 = None + view_1875 = torch.ops.aten.view.default(mul_207, [16384, 1792]); mul_207 = None + mm_181 = torch.ops.aten.mm.default(view_1875, permute_285); view_1875 = permute_285 = None + view_1876 = torch.ops.aten.view.default(mm_181, [2, 8192, 4096]); mm_181 = None + split_112 = torch.ops.aten.split.Tensor(view_1876, 1024, 1); view_1876 = None + getitem_1130 = split_112[0] + getitem_1131 = split_112[1] + getitem_1132 = split_112[2] + getitem_1133 = split_112[3] + getitem_1134 = split_112[4] + getitem_1135 = split_112[5] + getitem_1136 = split_112[6] + getitem_1137 = split_112[7]; split_112 = None + cat_104 = torch.ops.aten.cat.default([getitem_1130, getitem_1131, getitem_1132, getitem_1133, getitem_1134, getitem_1135, getitem_1136, getitem_1137]); getitem_1130 = getitem_1131 = getitem_1132 = getitem_1133 = getitem_1134 = getitem_1135 = getitem_1136 = getitem_1137 = None + reduce_scatter_tensor_52 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_104, 'sum', 8, '1'); cat_104 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_52); reduce_scatter_tensor_52 = None + add_103 = torch.ops.aten.add.Tensor(add_101, wait_tensor_339); add_101 = wait_tensor_339 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16) + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 8, '0'); convert_element_type_859 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32) + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = rsqrt_52 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_340); mul_208 = wait_tensor_340 = None + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_861, 8, '1'); convert_element_type_861 = None + wait_tensor_341 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + split_113 = torch.ops.aten.split.Tensor(wait_tensor_341, 2); wait_tensor_341 = None + getitem_1138 = split_113[0] + getitem_1139 = split_113[1] + getitem_1140 = split_113[2] + getitem_1141 = split_113[3] + getitem_1142 = split_113[4] + getitem_1143 = split_113[5] + getitem_1144 = split_113[6] + getitem_1145 = split_113[7]; split_113 = None + cat_105 = torch.ops.aten.cat.default([getitem_1138, getitem_1139, getitem_1140, getitem_1141, getitem_1142, getitem_1143, getitem_1144, getitem_1145], 1); getitem_1138 = getitem_1139 = getitem_1140 = getitem_1141 = getitem_1142 = getitem_1143 = getitem_1144 = getitem_1145 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16) + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 8, '0'); convert_element_type_862 = None + wait_tensor_342 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_342, [1, 0]); wait_tensor_342 = None + view_1887 = torch.ops.aten.view.default(cat_105, [16384, 4096]); cat_105 = None + mm_182 = torch.ops.aten.mm.default(view_1887, permute_286); permute_286 = None + view_1888 = torch.ops.aten.view.default(mm_182, [2, 8192, 512]) + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16) + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 8, '0'); convert_element_type_865 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_343, [1, 0]); wait_tensor_343 = None + mm_183 = torch.ops.aten.mm.default(view_1887, permute_287); permute_287 = None + view_1895 = torch.ops.aten.view.default(mm_183, [2, 8192, 128]); mm_183 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16) + all_gather_into_tensor_291 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 8, '0'); convert_element_type_868 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_291); all_gather_into_tensor_291 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_344, [1, 0]); wait_tensor_344 = None + mm_184 = torch.ops.aten.mm.default(view_1887, permute_288); view_1887 = permute_288 = None + view_1902 = torch.ops.aten.view.default(mm_184, [2, 8192, 128]) + view_1904 = torch.ops.aten.view.default(view_1888, [2, 8192, -1, 128]); view_1888 = None + view_1905 = torch.ops.aten.view.default(view_1895, [2, 8192, -1, 128]); view_1895 = None + view_1906 = torch.ops.aten.view.default(view_1902, [2, 8192, -1, 128]); view_1902 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_1904, torch.float32); view_1904 = None + view_1907 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 4, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_1907); view_1907 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_1905, torch.float32); view_1905 = None + view_1908 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 1, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_1908); view_1908 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_37); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_1910 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 4, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_37); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_1911 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 1, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_1910, torch.bfloat16); view_1910 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_1911, torch.bfloat16); view_1911 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 1, 4, 128]); unsqueeze_52 = None + view_1912 = torch.ops.aten.view.default(expand_52, [2, 8192, 4, 128]); expand_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_1906, 3); view_1906 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 1, 4, 128]); unsqueeze_53 = None + view_1913 = torch.ops.aten.view.default(expand_53, [2, 8192, 4, 128]); expand_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_1912, [0, 2, 1, 3]); view_1912 = None + permute_291 = torch.ops.aten.permute.default(view_1913, [0, 2, 1, 3]); view_1913 = None + _scaled_dot_product_cudnn_attention_26 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_289, permute_290, permute_291, None, True, 0.0, True); permute_289 = permute_290 = permute_291 = None + getitem_1146 = _scaled_dot_product_cudnn_attention_26[0] + getitem_1147 = _scaled_dot_product_cudnn_attention_26[1] + getitem_1152 = _scaled_dot_product_cudnn_attention_26[6] + getitem_1153 = _scaled_dot_product_cudnn_attention_26[7]; _scaled_dot_product_cudnn_attention_26 = None + permute_292 = torch.ops.aten.permute.default(getitem_1146, [0, 2, 1, 3]) + view_1914 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16) + all_gather_into_tensor_292 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 8, '0'); convert_element_type_875 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_292); all_gather_into_tensor_292 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_345, [1, 0]); wait_tensor_345 = None + view_1920 = torch.ops.aten.view.default(view_1914, [16384, 512]); view_1914 = None + mm_185 = torch.ops.aten.mm.default(view_1920, permute_293); view_1920 = permute_293 = None + view_1921 = torch.ops.aten.view.default(mm_185, [2, 8192, 4096]); mm_185 = None + split_114 = torch.ops.aten.split.Tensor(view_1921, 1024, 1); view_1921 = None + getitem_1155 = split_114[0] + getitem_1156 = split_114[1] + getitem_1157 = split_114[2] + getitem_1158 = split_114[3] + getitem_1159 = split_114[4] + getitem_1160 = split_114[5] + getitem_1161 = split_114[6] + getitem_1162 = split_114[7]; split_114 = None + cat_106 = torch.ops.aten.cat.default([getitem_1155, getitem_1156, getitem_1157, getitem_1158, getitem_1159, getitem_1160, getitem_1161, getitem_1162]); getitem_1155 = getitem_1156 = getitem_1157 = getitem_1158 = getitem_1159 = getitem_1160 = getitem_1161 = getitem_1162 = None + reduce_scatter_tensor_53 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_106, 'sum', 8, '1'); cat_106 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53) + add_105 = torch.ops.aten.add.Tensor(add_103, wait_tensor_346); wait_tensor_346 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16) + all_gather_into_tensor_293 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 8, '0'); convert_element_type_878 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_293); all_gather_into_tensor_293 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32) + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = rsqrt_53 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_347); mul_212 = wait_tensor_347 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + all_gather_into_tensor_294 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_880, 8, '1'); convert_element_type_880 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_294); all_gather_into_tensor_294 = None + split_115 = torch.ops.aten.split.Tensor(wait_tensor_348, 2); wait_tensor_348 = None + getitem_1163 = split_115[0] + getitem_1164 = split_115[1] + getitem_1165 = split_115[2] + getitem_1166 = split_115[3] + getitem_1167 = split_115[4] + getitem_1168 = split_115[5] + getitem_1169 = split_115[6] + getitem_1170 = split_115[7]; split_115 = None + cat_107 = torch.ops.aten.cat.default([getitem_1163, getitem_1164, getitem_1165, getitem_1166, getitem_1167, getitem_1168, getitem_1169, getitem_1170], 1); getitem_1163 = getitem_1164 = getitem_1165 = getitem_1166 = getitem_1167 = getitem_1168 = getitem_1169 = getitem_1170 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16) + all_gather_into_tensor_295 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 8, '0'); convert_element_type_881 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_295); all_gather_into_tensor_295 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_349, [1, 0]); wait_tensor_349 = None + view_1932 = torch.ops.aten.view.default(cat_107, [16384, 4096]); cat_107 = None + mm_186 = torch.ops.aten.mm.default(view_1932, permute_294); permute_294 = None + view_1933 = torch.ops.aten.view.default(mm_186, [2, 8192, 1792]) + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_1933, torch.float32); view_1933 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); convert_element_type_884 = sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16) + all_gather_into_tensor_296 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 8, '0'); convert_element_type_886 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_296); all_gather_into_tensor_296 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_350, [1, 0]); wait_tensor_350 = None + mm_187 = torch.ops.aten.mm.default(view_1932, permute_295); view_1932 = permute_295 = None + view_1940 = torch.ops.aten.view.default(mm_187, [2, 8192, 1792]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_1940); convert_element_type_885 = view_1940 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16) + all_gather_into_tensor_297 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 8, '0'); convert_element_type_889 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_297); all_gather_into_tensor_297 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_351, [1, 0]); wait_tensor_351 = None + view_1947 = torch.ops.aten.view.default(mul_215, [16384, 1792]); mul_215 = None + mm_188 = torch.ops.aten.mm.default(view_1947, permute_296); view_1947 = permute_296 = None + view_1948 = torch.ops.aten.view.default(mm_188, [2, 8192, 4096]); mm_188 = None + split_116 = torch.ops.aten.split.Tensor(view_1948, 1024, 1); view_1948 = None + getitem_1171 = split_116[0] + getitem_1172 = split_116[1] + getitem_1173 = split_116[2] + getitem_1174 = split_116[3] + getitem_1175 = split_116[4] + getitem_1176 = split_116[5] + getitem_1177 = split_116[6] + getitem_1178 = split_116[7]; split_116 = None + cat_108 = torch.ops.aten.cat.default([getitem_1171, getitem_1172, getitem_1173, getitem_1174, getitem_1175, getitem_1176, getitem_1177, getitem_1178]); getitem_1171 = getitem_1172 = getitem_1173 = getitem_1174 = getitem_1175 = getitem_1176 = getitem_1177 = getitem_1178 = None + reduce_scatter_tensor_54 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_108, 'sum', 8, '1'); cat_108 = None + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_54); reduce_scatter_tensor_54 = None + add_107 = torch.ops.aten.add.Tensor(add_105, wait_tensor_352); add_105 = wait_tensor_352 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16) + all_gather_into_tensor_298 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 8, '0'); convert_element_type_892 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_298); all_gather_into_tensor_298 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32) + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = rsqrt_54 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_353); mul_216 = wait_tensor_353 = None + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + all_gather_into_tensor_299 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_894, 8, '1'); convert_element_type_894 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_299); all_gather_into_tensor_299 = None + split_117 = torch.ops.aten.split.Tensor(wait_tensor_354, 2); wait_tensor_354 = None + getitem_1179 = split_117[0] + getitem_1180 = split_117[1] + getitem_1181 = split_117[2] + getitem_1182 = split_117[3] + getitem_1183 = split_117[4] + getitem_1184 = split_117[5] + getitem_1185 = split_117[6] + getitem_1186 = split_117[7]; split_117 = None + cat_109 = torch.ops.aten.cat.default([getitem_1179, getitem_1180, getitem_1181, getitem_1182, getitem_1183, getitem_1184, getitem_1185, getitem_1186], 1); getitem_1179 = getitem_1180 = getitem_1181 = getitem_1182 = getitem_1183 = getitem_1184 = getitem_1185 = getitem_1186 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16) + all_gather_into_tensor_300 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 8, '0'); convert_element_type_895 = None + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_300); all_gather_into_tensor_300 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_355, [1, 0]); wait_tensor_355 = None + view_1959 = torch.ops.aten.view.default(cat_109, [16384, 4096]); cat_109 = None + mm_189 = torch.ops.aten.mm.default(view_1959, permute_297); permute_297 = None + view_1960 = torch.ops.aten.view.default(mm_189, [2, 8192, 512]) + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16) + all_gather_into_tensor_301 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 8, '0'); convert_element_type_898 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_301); all_gather_into_tensor_301 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_356, [1, 0]); wait_tensor_356 = None + mm_190 = torch.ops.aten.mm.default(view_1959, permute_298); permute_298 = None + view_1967 = torch.ops.aten.view.default(mm_190, [2, 8192, 128]); mm_190 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16) + all_gather_into_tensor_302 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 8, '0'); convert_element_type_901 = None + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_302); all_gather_into_tensor_302 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_357, [1, 0]); wait_tensor_357 = None + mm_191 = torch.ops.aten.mm.default(view_1959, permute_299); view_1959 = permute_299 = None + view_1974 = torch.ops.aten.view.default(mm_191, [2, 8192, 128]) + view_1976 = torch.ops.aten.view.default(view_1960, [2, 8192, -1, 128]); view_1960 = None + view_1977 = torch.ops.aten.view.default(view_1967, [2, 8192, -1, 128]); view_1967 = None + view_1978 = torch.ops.aten.view.default(view_1974, [2, 8192, -1, 128]); view_1974 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_1976, torch.float32); view_1976 = None + view_1979 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 4, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_1979); view_1979 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_1977, torch.float32); view_1977 = None + view_1980 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 1, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_1980); view_1980 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_37); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_1982 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 4, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_37); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_1983 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 1, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_1982, torch.bfloat16); view_1982 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_1983, torch.bfloat16); view_1983 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 1, 4, 128]); unsqueeze_54 = None + view_1984 = torch.ops.aten.view.default(expand_54, [2, 8192, 4, 128]); expand_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_1978, 3); view_1978 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 1, 4, 128]); unsqueeze_55 = None + view_1985 = torch.ops.aten.view.default(expand_55, [2, 8192, 4, 128]); expand_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_1984, [0, 2, 1, 3]); view_1984 = None + permute_302 = torch.ops.aten.permute.default(view_1985, [0, 2, 1, 3]); view_1985 = None + _scaled_dot_product_cudnn_attention_27 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_300, permute_301, permute_302, None, True, 0.0, True); permute_300 = permute_301 = permute_302 = None + getitem_1187 = _scaled_dot_product_cudnn_attention_27[0] + getitem_1188 = _scaled_dot_product_cudnn_attention_27[1] + getitem_1193 = _scaled_dot_product_cudnn_attention_27[6] + getitem_1194 = _scaled_dot_product_cudnn_attention_27[7]; _scaled_dot_product_cudnn_attention_27 = None + permute_303 = torch.ops.aten.permute.default(getitem_1187, [0, 2, 1, 3]) + view_1986 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16) + all_gather_into_tensor_303 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 8, '0'); convert_element_type_908 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_303); all_gather_into_tensor_303 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_358, [1, 0]); wait_tensor_358 = None + view_1992 = torch.ops.aten.view.default(view_1986, [16384, 512]); view_1986 = None + mm_192 = torch.ops.aten.mm.default(view_1992, permute_304); view_1992 = permute_304 = None + view_1993 = torch.ops.aten.view.default(mm_192, [2, 8192, 4096]); mm_192 = None + split_118 = torch.ops.aten.split.Tensor(view_1993, 1024, 1); view_1993 = None + getitem_1196 = split_118[0] + getitem_1197 = split_118[1] + getitem_1198 = split_118[2] + getitem_1199 = split_118[3] + getitem_1200 = split_118[4] + getitem_1201 = split_118[5] + getitem_1202 = split_118[6] + getitem_1203 = split_118[7]; split_118 = None + cat_110 = torch.ops.aten.cat.default([getitem_1196, getitem_1197, getitem_1198, getitem_1199, getitem_1200, getitem_1201, getitem_1202, getitem_1203]); getitem_1196 = getitem_1197 = getitem_1198 = getitem_1199 = getitem_1200 = getitem_1201 = getitem_1202 = getitem_1203 = None + reduce_scatter_tensor_55 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_110, 'sum', 8, '1'); cat_110 = None + wait_tensor_359 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55) + add_109 = torch.ops.aten.add.Tensor(add_107, wait_tensor_359); wait_tensor_359 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16) + all_gather_into_tensor_304 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 8, '0'); convert_element_type_911 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_304); all_gather_into_tensor_304 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32) + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = rsqrt_55 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_360); mul_220 = wait_tensor_360 = None + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + all_gather_into_tensor_305 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_913, 8, '1'); convert_element_type_913 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_305); all_gather_into_tensor_305 = None + split_119 = torch.ops.aten.split.Tensor(wait_tensor_361, 2); wait_tensor_361 = None + getitem_1204 = split_119[0] + getitem_1205 = split_119[1] + getitem_1206 = split_119[2] + getitem_1207 = split_119[3] + getitem_1208 = split_119[4] + getitem_1209 = split_119[5] + getitem_1210 = split_119[6] + getitem_1211 = split_119[7]; split_119 = None + cat_111 = torch.ops.aten.cat.default([getitem_1204, getitem_1205, getitem_1206, getitem_1207, getitem_1208, getitem_1209, getitem_1210, getitem_1211], 1); getitem_1204 = getitem_1205 = getitem_1206 = getitem_1207 = getitem_1208 = getitem_1209 = getitem_1210 = getitem_1211 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16) + all_gather_into_tensor_306 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 8, '0'); convert_element_type_914 = None + wait_tensor_362 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_306); all_gather_into_tensor_306 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_362, [1, 0]); wait_tensor_362 = None + view_2004 = torch.ops.aten.view.default(cat_111, [16384, 4096]); cat_111 = None + mm_193 = torch.ops.aten.mm.default(view_2004, permute_305); permute_305 = None + view_2005 = torch.ops.aten.view.default(mm_193, [2, 8192, 1792]) + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_2005, torch.float32); view_2005 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); convert_element_type_917 = sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16) + all_gather_into_tensor_307 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 8, '0'); convert_element_type_919 = None + wait_tensor_363 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_307); all_gather_into_tensor_307 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_363, [1, 0]); wait_tensor_363 = None + mm_194 = torch.ops.aten.mm.default(view_2004, permute_306); view_2004 = permute_306 = None + view_2012 = torch.ops.aten.view.default(mm_194, [2, 8192, 1792]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_2012); convert_element_type_918 = view_2012 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16) + all_gather_into_tensor_308 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 8, '0'); convert_element_type_922 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_308); all_gather_into_tensor_308 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_364, [1, 0]); wait_tensor_364 = None + view_2019 = torch.ops.aten.view.default(mul_223, [16384, 1792]); mul_223 = None + mm_195 = torch.ops.aten.mm.default(view_2019, permute_307); view_2019 = permute_307 = None + view_2020 = torch.ops.aten.view.default(mm_195, [2, 8192, 4096]); mm_195 = None + split_120 = torch.ops.aten.split.Tensor(view_2020, 1024, 1); view_2020 = None + getitem_1212 = split_120[0] + getitem_1213 = split_120[1] + getitem_1214 = split_120[2] + getitem_1215 = split_120[3] + getitem_1216 = split_120[4] + getitem_1217 = split_120[5] + getitem_1218 = split_120[6] + getitem_1219 = split_120[7]; split_120 = None + cat_112 = torch.ops.aten.cat.default([getitem_1212, getitem_1213, getitem_1214, getitem_1215, getitem_1216, getitem_1217, getitem_1218, getitem_1219]); getitem_1212 = getitem_1213 = getitem_1214 = getitem_1215 = getitem_1216 = getitem_1217 = getitem_1218 = getitem_1219 = None + reduce_scatter_tensor_56 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_112, 'sum', 8, '1'); cat_112 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_56); reduce_scatter_tensor_56 = None + add_111 = torch.ops.aten.add.Tensor(add_109, wait_tensor_365); add_109 = wait_tensor_365 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16) + all_gather_into_tensor_309 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 8, '0'); convert_element_type_925 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_309); all_gather_into_tensor_309 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32) + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = rsqrt_56 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_366); mul_224 = wait_tensor_366 = None + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + all_gather_into_tensor_310 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_927, 8, '1'); convert_element_type_927 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_310); all_gather_into_tensor_310 = None + split_121 = torch.ops.aten.split.Tensor(wait_tensor_367, 2); wait_tensor_367 = None + getitem_1220 = split_121[0] + getitem_1221 = split_121[1] + getitem_1222 = split_121[2] + getitem_1223 = split_121[3] + getitem_1224 = split_121[4] + getitem_1225 = split_121[5] + getitem_1226 = split_121[6] + getitem_1227 = split_121[7]; split_121 = None + cat_113 = torch.ops.aten.cat.default([getitem_1220, getitem_1221, getitem_1222, getitem_1223, getitem_1224, getitem_1225, getitem_1226, getitem_1227], 1); getitem_1220 = getitem_1221 = getitem_1222 = getitem_1223 = getitem_1224 = getitem_1225 = getitem_1226 = getitem_1227 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16) + all_gather_into_tensor_311 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 8, '0'); convert_element_type_928 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_311); all_gather_into_tensor_311 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_368, [1, 0]); wait_tensor_368 = None + view_2031 = torch.ops.aten.view.default(cat_113, [16384, 4096]); cat_113 = None + mm_196 = torch.ops.aten.mm.default(view_2031, permute_308); permute_308 = None + view_2032 = torch.ops.aten.view.default(mm_196, [2, 8192, 512]) + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16) + all_gather_into_tensor_312 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 8, '0'); convert_element_type_931 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_312); all_gather_into_tensor_312 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_369, [1, 0]); wait_tensor_369 = None + mm_197 = torch.ops.aten.mm.default(view_2031, permute_309); permute_309 = None + view_2039 = torch.ops.aten.view.default(mm_197, [2, 8192, 128]); mm_197 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16) + all_gather_into_tensor_313 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 8, '0'); convert_element_type_934 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_313); all_gather_into_tensor_313 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_370, [1, 0]); wait_tensor_370 = None + mm_198 = torch.ops.aten.mm.default(view_2031, permute_310); view_2031 = permute_310 = None + view_2046 = torch.ops.aten.view.default(mm_198, [2, 8192, 128]) + view_2048 = torch.ops.aten.view.default(view_2032, [2, 8192, -1, 128]); view_2032 = None + view_2049 = torch.ops.aten.view.default(view_2039, [2, 8192, -1, 128]); view_2039 = None + view_2050 = torch.ops.aten.view.default(view_2046, [2, 8192, -1, 128]); view_2046 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_2048, torch.float32); view_2048 = None + view_2051 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 4, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_2051); view_2051 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_2049, torch.float32); view_2049 = None + view_2052 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 1, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_2052); view_2052 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_37); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_2054 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 4, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_37); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_2055 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 1, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_2054, torch.bfloat16); view_2054 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_2055, torch.bfloat16); view_2055 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 1, 4, 128]); unsqueeze_56 = None + view_2056 = torch.ops.aten.view.default(expand_56, [2, 8192, 4, 128]); expand_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_2050, 3); view_2050 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 1, 4, 128]); unsqueeze_57 = None + view_2057 = torch.ops.aten.view.default(expand_57, [2, 8192, 4, 128]); expand_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_2056, [0, 2, 1, 3]); view_2056 = None + permute_313 = torch.ops.aten.permute.default(view_2057, [0, 2, 1, 3]); view_2057 = None + _scaled_dot_product_cudnn_attention_28 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_311, permute_312, permute_313, None, True, 0.0, True); permute_311 = permute_312 = permute_313 = None + getitem_1228 = _scaled_dot_product_cudnn_attention_28[0] + getitem_1229 = _scaled_dot_product_cudnn_attention_28[1] + getitem_1234 = _scaled_dot_product_cudnn_attention_28[6] + getitem_1235 = _scaled_dot_product_cudnn_attention_28[7]; _scaled_dot_product_cudnn_attention_28 = None + permute_314 = torch.ops.aten.permute.default(getitem_1228, [0, 2, 1, 3]) + view_2058 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16) + all_gather_into_tensor_314 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 8, '0'); convert_element_type_941 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_314); all_gather_into_tensor_314 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_371, [1, 0]); wait_tensor_371 = None + view_2064 = torch.ops.aten.view.default(view_2058, [16384, 512]); view_2058 = None + mm_199 = torch.ops.aten.mm.default(view_2064, permute_315); view_2064 = permute_315 = None + view_2065 = torch.ops.aten.view.default(mm_199, [2, 8192, 4096]); mm_199 = None + split_122 = torch.ops.aten.split.Tensor(view_2065, 1024, 1); view_2065 = None + getitem_1237 = split_122[0] + getitem_1238 = split_122[1] + getitem_1239 = split_122[2] + getitem_1240 = split_122[3] + getitem_1241 = split_122[4] + getitem_1242 = split_122[5] + getitem_1243 = split_122[6] + getitem_1244 = split_122[7]; split_122 = None + cat_114 = torch.ops.aten.cat.default([getitem_1237, getitem_1238, getitem_1239, getitem_1240, getitem_1241, getitem_1242, getitem_1243, getitem_1244]); getitem_1237 = getitem_1238 = getitem_1239 = getitem_1240 = getitem_1241 = getitem_1242 = getitem_1243 = getitem_1244 = None + reduce_scatter_tensor_57 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_114, 'sum', 8, '1'); cat_114 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57) + add_113 = torch.ops.aten.add.Tensor(add_111, wait_tensor_372); wait_tensor_372 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16) + all_gather_into_tensor_315 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 8, '0'); convert_element_type_944 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_315); all_gather_into_tensor_315 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32) + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = rsqrt_57 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_373); mul_228 = wait_tensor_373 = None + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + all_gather_into_tensor_316 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_946, 8, '1'); convert_element_type_946 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_316); all_gather_into_tensor_316 = None + split_123 = torch.ops.aten.split.Tensor(wait_tensor_374, 2); wait_tensor_374 = None + getitem_1245 = split_123[0] + getitem_1246 = split_123[1] + getitem_1247 = split_123[2] + getitem_1248 = split_123[3] + getitem_1249 = split_123[4] + getitem_1250 = split_123[5] + getitem_1251 = split_123[6] + getitem_1252 = split_123[7]; split_123 = None + cat_115 = torch.ops.aten.cat.default([getitem_1245, getitem_1246, getitem_1247, getitem_1248, getitem_1249, getitem_1250, getitem_1251, getitem_1252], 1); getitem_1245 = getitem_1246 = getitem_1247 = getitem_1248 = getitem_1249 = getitem_1250 = getitem_1251 = getitem_1252 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16) + all_gather_into_tensor_317 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 8, '0'); convert_element_type_947 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_317); all_gather_into_tensor_317 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_375, [1, 0]); wait_tensor_375 = None + view_2076 = torch.ops.aten.view.default(cat_115, [16384, 4096]); cat_115 = None + mm_200 = torch.ops.aten.mm.default(view_2076, permute_316); permute_316 = None + view_2077 = torch.ops.aten.view.default(mm_200, [2, 8192, 1792]) + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_2077, torch.float32); view_2077 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); convert_element_type_950 = sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16) + all_gather_into_tensor_318 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 8, '0'); convert_element_type_952 = None + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_318); all_gather_into_tensor_318 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_376, [1, 0]); wait_tensor_376 = None + mm_201 = torch.ops.aten.mm.default(view_2076, permute_317); view_2076 = permute_317 = None + view_2084 = torch.ops.aten.view.default(mm_201, [2, 8192, 1792]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_2084); convert_element_type_951 = view_2084 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16) + all_gather_into_tensor_319 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 8, '0'); convert_element_type_955 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_319); all_gather_into_tensor_319 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_377, [1, 0]); wait_tensor_377 = None + view_2091 = torch.ops.aten.view.default(mul_231, [16384, 1792]); mul_231 = None + mm_202 = torch.ops.aten.mm.default(view_2091, permute_318); view_2091 = permute_318 = None + view_2092 = torch.ops.aten.view.default(mm_202, [2, 8192, 4096]); mm_202 = None + split_124 = torch.ops.aten.split.Tensor(view_2092, 1024, 1); view_2092 = None + getitem_1253 = split_124[0] + getitem_1254 = split_124[1] + getitem_1255 = split_124[2] + getitem_1256 = split_124[3] + getitem_1257 = split_124[4] + getitem_1258 = split_124[5] + getitem_1259 = split_124[6] + getitem_1260 = split_124[7]; split_124 = None + cat_116 = torch.ops.aten.cat.default([getitem_1253, getitem_1254, getitem_1255, getitem_1256, getitem_1257, getitem_1258, getitem_1259, getitem_1260]); getitem_1253 = getitem_1254 = getitem_1255 = getitem_1256 = getitem_1257 = getitem_1258 = getitem_1259 = getitem_1260 = None + reduce_scatter_tensor_58 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_116, 'sum', 8, '1'); cat_116 = None + wait_tensor_378 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_58); reduce_scatter_tensor_58 = None + add_115 = torch.ops.aten.add.Tensor(add_113, wait_tensor_378); add_113 = wait_tensor_378 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16) + all_gather_into_tensor_320 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 8, '0'); convert_element_type_958 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_320); all_gather_into_tensor_320 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32) + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = rsqrt_58 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_379); mul_232 = wait_tensor_379 = None + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + all_gather_into_tensor_321 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_960, 8, '1'); convert_element_type_960 = None + wait_tensor_380 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_321); all_gather_into_tensor_321 = None + split_125 = torch.ops.aten.split.Tensor(wait_tensor_380, 2); wait_tensor_380 = None + getitem_1261 = split_125[0] + getitem_1262 = split_125[1] + getitem_1263 = split_125[2] + getitem_1264 = split_125[3] + getitem_1265 = split_125[4] + getitem_1266 = split_125[5] + getitem_1267 = split_125[6] + getitem_1268 = split_125[7]; split_125 = None + cat_117 = torch.ops.aten.cat.default([getitem_1261, getitem_1262, getitem_1263, getitem_1264, getitem_1265, getitem_1266, getitem_1267, getitem_1268], 1); getitem_1261 = getitem_1262 = getitem_1263 = getitem_1264 = getitem_1265 = getitem_1266 = getitem_1267 = getitem_1268 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16) + all_gather_into_tensor_322 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 8, '0'); convert_element_type_961 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_322); all_gather_into_tensor_322 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_381, [1, 0]); wait_tensor_381 = None + view_2103 = torch.ops.aten.view.default(cat_117, [16384, 4096]); cat_117 = None + mm_203 = torch.ops.aten.mm.default(view_2103, permute_319); permute_319 = None + view_2104 = torch.ops.aten.view.default(mm_203, [2, 8192, 512]) + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16) + all_gather_into_tensor_323 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 8, '0'); convert_element_type_964 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_323); all_gather_into_tensor_323 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_382, [1, 0]); wait_tensor_382 = None + mm_204 = torch.ops.aten.mm.default(view_2103, permute_320); permute_320 = None + view_2111 = torch.ops.aten.view.default(mm_204, [2, 8192, 128]); mm_204 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16) + all_gather_into_tensor_324 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 8, '0'); convert_element_type_967 = None + wait_tensor_383 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_324); all_gather_into_tensor_324 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_383, [1, 0]); wait_tensor_383 = None + mm_205 = torch.ops.aten.mm.default(view_2103, permute_321); view_2103 = permute_321 = None + view_2118 = torch.ops.aten.view.default(mm_205, [2, 8192, 128]) + view_2120 = torch.ops.aten.view.default(view_2104, [2, 8192, -1, 128]); view_2104 = None + view_2121 = torch.ops.aten.view.default(view_2111, [2, 8192, -1, 128]); view_2111 = None + view_2122 = torch.ops.aten.view.default(view_2118, [2, 8192, -1, 128]); view_2118 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_2120, torch.float32); view_2120 = None + view_2123 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 4, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_2123); view_2123 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_2121, torch.float32); view_2121 = None + view_2124 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 1, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_2124); view_2124 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_37); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_2126 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 4, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_37); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_2127 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 1, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_2126, torch.bfloat16); view_2126 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_2127, torch.bfloat16); view_2127 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 1, 4, 128]); unsqueeze_58 = None + view_2128 = torch.ops.aten.view.default(expand_58, [2, 8192, 4, 128]); expand_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_2122, 3); view_2122 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 1, 4, 128]); unsqueeze_59 = None + view_2129 = torch.ops.aten.view.default(expand_59, [2, 8192, 4, 128]); expand_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_2128, [0, 2, 1, 3]); view_2128 = None + permute_324 = torch.ops.aten.permute.default(view_2129, [0, 2, 1, 3]); view_2129 = None + _scaled_dot_product_cudnn_attention_29 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_322, permute_323, permute_324, None, True, 0.0, True); permute_322 = permute_323 = permute_324 = None + getitem_1269 = _scaled_dot_product_cudnn_attention_29[0] + getitem_1270 = _scaled_dot_product_cudnn_attention_29[1] + getitem_1275 = _scaled_dot_product_cudnn_attention_29[6] + getitem_1276 = _scaled_dot_product_cudnn_attention_29[7]; _scaled_dot_product_cudnn_attention_29 = None + permute_325 = torch.ops.aten.permute.default(getitem_1269, [0, 2, 1, 3]) + view_2130 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16) + all_gather_into_tensor_325 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 8, '0'); convert_element_type_974 = None + wait_tensor_384 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_325); all_gather_into_tensor_325 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_384, [1, 0]); wait_tensor_384 = None + view_2136 = torch.ops.aten.view.default(view_2130, [16384, 512]); view_2130 = None + mm_206 = torch.ops.aten.mm.default(view_2136, permute_326); view_2136 = permute_326 = None + view_2137 = torch.ops.aten.view.default(mm_206, [2, 8192, 4096]); mm_206 = None + split_126 = torch.ops.aten.split.Tensor(view_2137, 1024, 1); view_2137 = None + getitem_1278 = split_126[0] + getitem_1279 = split_126[1] + getitem_1280 = split_126[2] + getitem_1281 = split_126[3] + getitem_1282 = split_126[4] + getitem_1283 = split_126[5] + getitem_1284 = split_126[6] + getitem_1285 = split_126[7]; split_126 = None + cat_118 = torch.ops.aten.cat.default([getitem_1278, getitem_1279, getitem_1280, getitem_1281, getitem_1282, getitem_1283, getitem_1284, getitem_1285]); getitem_1278 = getitem_1279 = getitem_1280 = getitem_1281 = getitem_1282 = getitem_1283 = getitem_1284 = getitem_1285 = None + reduce_scatter_tensor_59 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_118, 'sum', 8, '1'); cat_118 = None + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59) + add_117 = torch.ops.aten.add.Tensor(add_115, wait_tensor_385); wait_tensor_385 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16) + all_gather_into_tensor_326 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 8, '0'); convert_element_type_977 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_326); all_gather_into_tensor_326 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32) + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = rsqrt_59 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_386); mul_236 = wait_tensor_386 = None + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + all_gather_into_tensor_327 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_979, 8, '1'); convert_element_type_979 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_327); all_gather_into_tensor_327 = None + split_127 = torch.ops.aten.split.Tensor(wait_tensor_387, 2); wait_tensor_387 = None + getitem_1286 = split_127[0] + getitem_1287 = split_127[1] + getitem_1288 = split_127[2] + getitem_1289 = split_127[3] + getitem_1290 = split_127[4] + getitem_1291 = split_127[5] + getitem_1292 = split_127[6] + getitem_1293 = split_127[7]; split_127 = None + cat_119 = torch.ops.aten.cat.default([getitem_1286, getitem_1287, getitem_1288, getitem_1289, getitem_1290, getitem_1291, getitem_1292, getitem_1293], 1); getitem_1286 = getitem_1287 = getitem_1288 = getitem_1289 = getitem_1290 = getitem_1291 = getitem_1292 = getitem_1293 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16) + all_gather_into_tensor_328 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 8, '0'); convert_element_type_980 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_328); all_gather_into_tensor_328 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_388, [1, 0]); wait_tensor_388 = None + view_2148 = torch.ops.aten.view.default(cat_119, [16384, 4096]); cat_119 = None + mm_207 = torch.ops.aten.mm.default(view_2148, permute_327); permute_327 = None + view_2149 = torch.ops.aten.view.default(mm_207, [2, 8192, 1792]) + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_2149, torch.float32); view_2149 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); convert_element_type_983 = sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16) + all_gather_into_tensor_329 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 8, '0'); convert_element_type_985 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_329); all_gather_into_tensor_329 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_389, [1, 0]); wait_tensor_389 = None + mm_208 = torch.ops.aten.mm.default(view_2148, permute_328); view_2148 = permute_328 = None + view_2156 = torch.ops.aten.view.default(mm_208, [2, 8192, 1792]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_2156); convert_element_type_984 = view_2156 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16) + all_gather_into_tensor_330 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 8, '0'); convert_element_type_988 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_330); all_gather_into_tensor_330 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_390, [1, 0]); wait_tensor_390 = None + view_2163 = torch.ops.aten.view.default(mul_239, [16384, 1792]); mul_239 = None + mm_209 = torch.ops.aten.mm.default(view_2163, permute_329); view_2163 = permute_329 = None + view_2164 = torch.ops.aten.view.default(mm_209, [2, 8192, 4096]); mm_209 = None + split_128 = torch.ops.aten.split.Tensor(view_2164, 1024, 1); view_2164 = None + getitem_1294 = split_128[0] + getitem_1295 = split_128[1] + getitem_1296 = split_128[2] + getitem_1297 = split_128[3] + getitem_1298 = split_128[4] + getitem_1299 = split_128[5] + getitem_1300 = split_128[6] + getitem_1301 = split_128[7]; split_128 = None + cat_120 = torch.ops.aten.cat.default([getitem_1294, getitem_1295, getitem_1296, getitem_1297, getitem_1298, getitem_1299, getitem_1300, getitem_1301]); getitem_1294 = getitem_1295 = getitem_1296 = getitem_1297 = getitem_1298 = getitem_1299 = getitem_1300 = getitem_1301 = None + reduce_scatter_tensor_60 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_120, 'sum', 8, '1'); cat_120 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_60); reduce_scatter_tensor_60 = None + add_119 = torch.ops.aten.add.Tensor(add_117, wait_tensor_391); add_117 = wait_tensor_391 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16) + all_gather_into_tensor_331 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 8, '0'); convert_element_type_991 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_331); all_gather_into_tensor_331 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32) + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = rsqrt_60 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_392); mul_240 = wait_tensor_392 = None + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + all_gather_into_tensor_332 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_993, 8, '1'); convert_element_type_993 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_332); all_gather_into_tensor_332 = None + split_129 = torch.ops.aten.split.Tensor(wait_tensor_393, 2); wait_tensor_393 = None + getitem_1302 = split_129[0] + getitem_1303 = split_129[1] + getitem_1304 = split_129[2] + getitem_1305 = split_129[3] + getitem_1306 = split_129[4] + getitem_1307 = split_129[5] + getitem_1308 = split_129[6] + getitem_1309 = split_129[7]; split_129 = None + cat_121 = torch.ops.aten.cat.default([getitem_1302, getitem_1303, getitem_1304, getitem_1305, getitem_1306, getitem_1307, getitem_1308, getitem_1309], 1); getitem_1302 = getitem_1303 = getitem_1304 = getitem_1305 = getitem_1306 = getitem_1307 = getitem_1308 = getitem_1309 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16) + all_gather_into_tensor_333 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 8, '0'); convert_element_type_994 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_333); all_gather_into_tensor_333 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_394, [1, 0]); wait_tensor_394 = None + view_2175 = torch.ops.aten.view.default(cat_121, [16384, 4096]); cat_121 = None + mm_210 = torch.ops.aten.mm.default(view_2175, permute_330); permute_330 = None + view_2176 = torch.ops.aten.view.default(mm_210, [2, 8192, 512]) + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16) + all_gather_into_tensor_334 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 8, '0'); convert_element_type_997 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_334); all_gather_into_tensor_334 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_395, [1, 0]); wait_tensor_395 = None + mm_211 = torch.ops.aten.mm.default(view_2175, permute_331); permute_331 = None + view_2183 = torch.ops.aten.view.default(mm_211, [2, 8192, 128]); mm_211 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16) + all_gather_into_tensor_335 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 8, '0'); convert_element_type_1000 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_335); all_gather_into_tensor_335 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_396, [1, 0]); wait_tensor_396 = None + mm_212 = torch.ops.aten.mm.default(view_2175, permute_332); view_2175 = permute_332 = None + view_2190 = torch.ops.aten.view.default(mm_212, [2, 8192, 128]) + view_2192 = torch.ops.aten.view.default(view_2176, [2, 8192, -1, 128]); view_2176 = None + view_2193 = torch.ops.aten.view.default(view_2183, [2, 8192, -1, 128]); view_2183 = None + view_2194 = torch.ops.aten.view.default(view_2190, [2, 8192, -1, 128]); view_2190 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_2192, torch.float32); view_2192 = None + view_2195 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 4, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_2195); view_2195 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_2193, torch.float32); view_2193 = None + view_2196 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 1, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_2196); view_2196 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_37); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_2198 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 4, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_37); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_2199 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 1, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_2198, torch.bfloat16); view_2198 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_2199, torch.bfloat16); view_2199 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 1, 4, 128]); unsqueeze_60 = None + view_2200 = torch.ops.aten.view.default(expand_60, [2, 8192, 4, 128]); expand_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_2194, 3); view_2194 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 1, 4, 128]); unsqueeze_61 = None + view_2201 = torch.ops.aten.view.default(expand_61, [2, 8192, 4, 128]); expand_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_2200, [0, 2, 1, 3]); view_2200 = None + permute_335 = torch.ops.aten.permute.default(view_2201, [0, 2, 1, 3]); view_2201 = None + _scaled_dot_product_cudnn_attention_30 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_333, permute_334, permute_335, None, True, 0.0, True); permute_333 = permute_334 = permute_335 = None + getitem_1310 = _scaled_dot_product_cudnn_attention_30[0] + getitem_1311 = _scaled_dot_product_cudnn_attention_30[1] + getitem_1316 = _scaled_dot_product_cudnn_attention_30[6] + getitem_1317 = _scaled_dot_product_cudnn_attention_30[7]; _scaled_dot_product_cudnn_attention_30 = None + permute_336 = torch.ops.aten.permute.default(getitem_1310, [0, 2, 1, 3]) + view_2202 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16) + all_gather_into_tensor_336 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 8, '0'); convert_element_type_1007 = None + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_336); all_gather_into_tensor_336 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_397, [1, 0]); wait_tensor_397 = None + view_2208 = torch.ops.aten.view.default(view_2202, [16384, 512]); view_2202 = None + mm_213 = torch.ops.aten.mm.default(view_2208, permute_337); view_2208 = permute_337 = None + view_2209 = torch.ops.aten.view.default(mm_213, [2, 8192, 4096]); mm_213 = None + split_130 = torch.ops.aten.split.Tensor(view_2209, 1024, 1); view_2209 = None + getitem_1319 = split_130[0] + getitem_1320 = split_130[1] + getitem_1321 = split_130[2] + getitem_1322 = split_130[3] + getitem_1323 = split_130[4] + getitem_1324 = split_130[5] + getitem_1325 = split_130[6] + getitem_1326 = split_130[7]; split_130 = None + cat_122 = torch.ops.aten.cat.default([getitem_1319, getitem_1320, getitem_1321, getitem_1322, getitem_1323, getitem_1324, getitem_1325, getitem_1326]); getitem_1319 = getitem_1320 = getitem_1321 = getitem_1322 = getitem_1323 = getitem_1324 = getitem_1325 = getitem_1326 = None + reduce_scatter_tensor_61 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_122, 'sum', 8, '1'); cat_122 = None + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61) + add_121 = torch.ops.aten.add.Tensor(add_119, wait_tensor_398); wait_tensor_398 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16) + all_gather_into_tensor_337 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 8, '0'); convert_element_type_1010 = None + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_337); all_gather_into_tensor_337 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32) + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = rsqrt_61 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_399); mul_244 = wait_tensor_399 = None + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + all_gather_into_tensor_338 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1012, 8, '1'); convert_element_type_1012 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_338); all_gather_into_tensor_338 = None + split_131 = torch.ops.aten.split.Tensor(wait_tensor_400, 2); wait_tensor_400 = None + getitem_1327 = split_131[0] + getitem_1328 = split_131[1] + getitem_1329 = split_131[2] + getitem_1330 = split_131[3] + getitem_1331 = split_131[4] + getitem_1332 = split_131[5] + getitem_1333 = split_131[6] + getitem_1334 = split_131[7]; split_131 = None + cat_123 = torch.ops.aten.cat.default([getitem_1327, getitem_1328, getitem_1329, getitem_1330, getitem_1331, getitem_1332, getitem_1333, getitem_1334], 1); getitem_1327 = getitem_1328 = getitem_1329 = getitem_1330 = getitem_1331 = getitem_1332 = getitem_1333 = getitem_1334 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16) + all_gather_into_tensor_339 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 8, '0'); convert_element_type_1013 = None + wait_tensor_401 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_339); all_gather_into_tensor_339 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_401, [1, 0]); wait_tensor_401 = None + view_2220 = torch.ops.aten.view.default(cat_123, [16384, 4096]); cat_123 = None + mm_214 = torch.ops.aten.mm.default(view_2220, permute_338); permute_338 = None + view_2221 = torch.ops.aten.view.default(mm_214, [2, 8192, 1792]) + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_2221, torch.float32); view_2221 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); convert_element_type_1016 = sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16) + all_gather_into_tensor_340 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 8, '0'); convert_element_type_1018 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_340); all_gather_into_tensor_340 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_402, [1, 0]); wait_tensor_402 = None + mm_215 = torch.ops.aten.mm.default(view_2220, permute_339); view_2220 = permute_339 = None + view_2228 = torch.ops.aten.view.default(mm_215, [2, 8192, 1792]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_2228); convert_element_type_1017 = view_2228 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16) + all_gather_into_tensor_341 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 8, '0'); convert_element_type_1021 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_341); all_gather_into_tensor_341 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_403, [1, 0]); wait_tensor_403 = None + view_2235 = torch.ops.aten.view.default(mul_247, [16384, 1792]); mul_247 = None + mm_216 = torch.ops.aten.mm.default(view_2235, permute_340); view_2235 = permute_340 = None + view_2236 = torch.ops.aten.view.default(mm_216, [2, 8192, 4096]); mm_216 = None + split_132 = torch.ops.aten.split.Tensor(view_2236, 1024, 1); view_2236 = None + getitem_1335 = split_132[0] + getitem_1336 = split_132[1] + getitem_1337 = split_132[2] + getitem_1338 = split_132[3] + getitem_1339 = split_132[4] + getitem_1340 = split_132[5] + getitem_1341 = split_132[6] + getitem_1342 = split_132[7]; split_132 = None + cat_124 = torch.ops.aten.cat.default([getitem_1335, getitem_1336, getitem_1337, getitem_1338, getitem_1339, getitem_1340, getitem_1341, getitem_1342]); getitem_1335 = getitem_1336 = getitem_1337 = getitem_1338 = getitem_1339 = getitem_1340 = getitem_1341 = getitem_1342 = None + reduce_scatter_tensor_62 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_124, 'sum', 8, '1'); cat_124 = None + wait_tensor_404 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_62); reduce_scatter_tensor_62 = None + add_123 = torch.ops.aten.add.Tensor(add_121, wait_tensor_404); add_121 = wait_tensor_404 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16) + all_gather_into_tensor_342 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 8, '0'); convert_element_type_1024 = None + wait_tensor_405 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_342); all_gather_into_tensor_342 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32) + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = rsqrt_62 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_405); mul_248 = wait_tensor_405 = None + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + all_gather_into_tensor_343 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1026, 8, '1'); convert_element_type_1026 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_343); all_gather_into_tensor_343 = None + split_133 = torch.ops.aten.split.Tensor(wait_tensor_406, 2); wait_tensor_406 = None + getitem_1343 = split_133[0] + getitem_1344 = split_133[1] + getitem_1345 = split_133[2] + getitem_1346 = split_133[3] + getitem_1347 = split_133[4] + getitem_1348 = split_133[5] + getitem_1349 = split_133[6] + getitem_1350 = split_133[7]; split_133 = None + cat_125 = torch.ops.aten.cat.default([getitem_1343, getitem_1344, getitem_1345, getitem_1346, getitem_1347, getitem_1348, getitem_1349, getitem_1350], 1); getitem_1343 = getitem_1344 = getitem_1345 = getitem_1346 = getitem_1347 = getitem_1348 = getitem_1349 = getitem_1350 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16) + all_gather_into_tensor_344 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 8, '0'); convert_element_type_1027 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_344); all_gather_into_tensor_344 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_407, [1, 0]); wait_tensor_407 = None + view_2247 = torch.ops.aten.view.default(cat_125, [16384, 4096]); cat_125 = None + mm_217 = torch.ops.aten.mm.default(view_2247, permute_341); permute_341 = None + view_2248 = torch.ops.aten.view.default(mm_217, [2, 8192, 512]) + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16) + all_gather_into_tensor_345 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 8, '0'); convert_element_type_1030 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_345); all_gather_into_tensor_345 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_408, [1, 0]); wait_tensor_408 = None + mm_218 = torch.ops.aten.mm.default(view_2247, permute_342); permute_342 = None + view_2255 = torch.ops.aten.view.default(mm_218, [2, 8192, 128]); mm_218 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16) + all_gather_into_tensor_346 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 8, '0'); convert_element_type_1033 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_346); all_gather_into_tensor_346 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_409, [1, 0]); wait_tensor_409 = None + mm_219 = torch.ops.aten.mm.default(view_2247, permute_343); view_2247 = permute_343 = None + view_2262 = torch.ops.aten.view.default(mm_219, [2, 8192, 128]) + view_2264 = torch.ops.aten.view.default(view_2248, [2, 8192, -1, 128]); view_2248 = None + view_2265 = torch.ops.aten.view.default(view_2255, [2, 8192, -1, 128]); view_2255 = None + view_2266 = torch.ops.aten.view.default(view_2262, [2, 8192, -1, 128]); view_2262 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_2264, torch.float32); view_2264 = None + view_2267 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 4, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_2267); view_2267 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_2265, torch.float32); view_2265 = None + view_2268 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 1, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_2268); view_2268 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_37); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_2270 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 4, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_37); view_as_complex_63 = view_37 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_2271 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 1, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_2270, torch.bfloat16); view_2270 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_2271, torch.bfloat16); view_2271 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 1, 4, 128]); unsqueeze_62 = None + view_2272 = torch.ops.aten.view.default(expand_62, [2, 8192, 4, 128]); expand_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_2266, 3); view_2266 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 1, 4, 128]); unsqueeze_63 = None + view_2273 = torch.ops.aten.view.default(expand_63, [2, 8192, 4, 128]); expand_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_2272, [0, 2, 1, 3]); view_2272 = None + permute_346 = torch.ops.aten.permute.default(view_2273, [0, 2, 1, 3]); view_2273 = None + _scaled_dot_product_cudnn_attention_31 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_344, permute_345, permute_346, None, True, 0.0, True); permute_344 = permute_345 = permute_346 = None + getitem_1351 = _scaled_dot_product_cudnn_attention_31[0] + getitem_1352 = _scaled_dot_product_cudnn_attention_31[1] + getitem_1357 = _scaled_dot_product_cudnn_attention_31[6] + getitem_1358 = _scaled_dot_product_cudnn_attention_31[7]; _scaled_dot_product_cudnn_attention_31 = None + permute_347 = torch.ops.aten.permute.default(getitem_1351, [0, 2, 1, 3]) + view_2274 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16) + all_gather_into_tensor_347 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 8, '0'); convert_element_type_1040 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_347); all_gather_into_tensor_347 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_410, [1, 0]); wait_tensor_410 = None + view_2280 = torch.ops.aten.view.default(view_2274, [16384, 512]); view_2274 = None + mm_220 = torch.ops.aten.mm.default(view_2280, permute_348); view_2280 = permute_348 = None + view_2281 = torch.ops.aten.view.default(mm_220, [2, 8192, 4096]); mm_220 = None + split_134 = torch.ops.aten.split.Tensor(view_2281, 1024, 1); view_2281 = None + getitem_1360 = split_134[0] + getitem_1361 = split_134[1] + getitem_1362 = split_134[2] + getitem_1363 = split_134[3] + getitem_1364 = split_134[4] + getitem_1365 = split_134[5] + getitem_1366 = split_134[6] + getitem_1367 = split_134[7]; split_134 = None + cat_126 = torch.ops.aten.cat.default([getitem_1360, getitem_1361, getitem_1362, getitem_1363, getitem_1364, getitem_1365, getitem_1366, getitem_1367]); getitem_1360 = getitem_1361 = getitem_1362 = getitem_1363 = getitem_1364 = getitem_1365 = getitem_1366 = getitem_1367 = None + reduce_scatter_tensor_63 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_126, 'sum', 8, '1'); cat_126 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63) + add_125 = torch.ops.aten.add.Tensor(add_123, wait_tensor_411); wait_tensor_411 = None + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16) + all_gather_into_tensor_348 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 8, '0'); convert_element_type_1043 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_348); all_gather_into_tensor_348 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32) + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = rsqrt_63 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_412); mul_252 = wait_tensor_412 = None + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + all_gather_into_tensor_349 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1045, 8, '1'); convert_element_type_1045 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_349); all_gather_into_tensor_349 = None + split_135 = torch.ops.aten.split.Tensor(wait_tensor_413, 2); wait_tensor_413 = None + getitem_1368 = split_135[0] + getitem_1369 = split_135[1] + getitem_1370 = split_135[2] + getitem_1371 = split_135[3] + getitem_1372 = split_135[4] + getitem_1373 = split_135[5] + getitem_1374 = split_135[6] + getitem_1375 = split_135[7]; split_135 = None + cat_127 = torch.ops.aten.cat.default([getitem_1368, getitem_1369, getitem_1370, getitem_1371, getitem_1372, getitem_1373, getitem_1374, getitem_1375], 1); getitem_1368 = getitem_1369 = getitem_1370 = getitem_1371 = getitem_1372 = getitem_1373 = getitem_1374 = getitem_1375 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16) + all_gather_into_tensor_350 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 8, '0'); convert_element_type_1046 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_350); all_gather_into_tensor_350 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_414, [1, 0]); wait_tensor_414 = None + view_2292 = torch.ops.aten.view.default(cat_127, [16384, 4096]); cat_127 = None + mm_221 = torch.ops.aten.mm.default(view_2292, permute_349); permute_349 = None + view_2293 = torch.ops.aten.view.default(mm_221, [2, 8192, 1792]) + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_2293, torch.float32); view_2293 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); convert_element_type_1049 = sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16) + all_gather_into_tensor_351 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 8, '0'); convert_element_type_1051 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_351); all_gather_into_tensor_351 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_415, [1, 0]); wait_tensor_415 = None + mm_222 = torch.ops.aten.mm.default(view_2292, permute_350); view_2292 = permute_350 = None + view_2300 = torch.ops.aten.view.default(mm_222, [2, 8192, 1792]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_2300); convert_element_type_1050 = view_2300 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16) + all_gather_into_tensor_352 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 8, '0'); convert_element_type_1054 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_352); all_gather_into_tensor_352 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_416, [1, 0]); wait_tensor_416 = None + view_2307 = torch.ops.aten.view.default(mul_255, [16384, 1792]); mul_255 = None + mm_223 = torch.ops.aten.mm.default(view_2307, permute_351); view_2307 = permute_351 = None + view_2308 = torch.ops.aten.view.default(mm_223, [2, 8192, 4096]); mm_223 = None + split_136 = torch.ops.aten.split.Tensor(view_2308, 1024, 1); view_2308 = None + getitem_1376 = split_136[0] + getitem_1377 = split_136[1] + getitem_1378 = split_136[2] + getitem_1379 = split_136[3] + getitem_1380 = split_136[4] + getitem_1381 = split_136[5] + getitem_1382 = split_136[6] + getitem_1383 = split_136[7]; split_136 = None + cat_128 = torch.ops.aten.cat.default([getitem_1376, getitem_1377, getitem_1378, getitem_1379, getitem_1380, getitem_1381, getitem_1382, getitem_1383]); getitem_1376 = getitem_1377 = getitem_1378 = getitem_1379 = getitem_1380 = getitem_1381 = getitem_1382 = getitem_1383 = None + reduce_scatter_tensor_64 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_128, 'sum', 8, '1'); cat_128 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64) + add_127 = torch.ops.aten.add.Tensor(add_125, wait_tensor_417); add_125 = wait_tensor_417 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16) + all_gather_into_tensor_353 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 8, '0'); convert_element_type_1057 = None + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_353); all_gather_into_tensor_353 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + pow_65 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1058, 2) + mean_64 = torch.ops.aten.mean.dim(pow_65, [2], True); pow_65 = None + add_128 = torch.ops.aten.add.Scalar(mean_64, 1e-05); mean_64 = None + rsqrt_64 = torch.ops.aten.rsqrt.default(add_128); add_128 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_257 = torch.ops.aten.mul.Tensor(mul_256, wait_tensor_418); mul_256 = wait_tensor_418 = None + convert_element_type_1059 = torch.ops.prims.convert_element_type.default(mul_257, torch.bfloat16); mul_257 = None + all_gather_into_tensor_354 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1059, 8, '1'); convert_element_type_1059 = None + wait_tensor_419 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_354); all_gather_into_tensor_354 = None + split_137 = torch.ops.aten.split.Tensor(wait_tensor_419, 2); wait_tensor_419 = None + getitem_1384 = split_137[0] + getitem_1385 = split_137[1] + getitem_1386 = split_137[2] + getitem_1387 = split_137[3] + getitem_1388 = split_137[4] + getitem_1389 = split_137[5] + getitem_1390 = split_137[6] + getitem_1391 = split_137[7]; split_137 = None + cat_129 = torch.ops.aten.cat.default([getitem_1384, getitem_1385, getitem_1386, getitem_1387, getitem_1388, getitem_1389, getitem_1390, getitem_1391], 1); getitem_1384 = getitem_1385 = getitem_1386 = getitem_1387 = getitem_1388 = getitem_1389 = getitem_1390 = getitem_1391 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16) + all_gather_into_tensor_355 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 8, '0'); convert_element_type_1060 = None + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_355); all_gather_into_tensor_355 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_420, [1, 0]); wait_tensor_420 = None + view_2319 = torch.ops.aten.view.default(cat_129, [16384, 4096]); cat_129 = None + mm_224 = torch.ops.aten.mm.default(view_2319, permute_352); permute_352 = None + view_2320 = torch.ops.aten.view.default(mm_224, [2, 8192, 16032]); mm_224 = None + return (view_2320, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, wait_tensor_1, mm, mm_2, getitem_80, getitem_81, getitem_86, getitem_87, reduce_scatter_tensor_1, mm_4, add_3, mm_7, mm_9, getitem_121, getitem_122, getitem_127, getitem_128, reduce_scatter_tensor_3, mm_11, add_7, mm_14, mm_16, getitem_162, getitem_163, getitem_168, getitem_169, reduce_scatter_tensor_5, mm_18, add_11, mm_21, mm_23, getitem_203, getitem_204, getitem_209, getitem_210, reduce_scatter_tensor_7, mm_25, add_15, mm_28, mm_30, getitem_244, getitem_245, getitem_250, getitem_251, reduce_scatter_tensor_9, mm_32, add_19, mm_35, mm_37, getitem_285, getitem_286, getitem_291, getitem_292, reduce_scatter_tensor_11, mm_39, add_23, mm_42, mm_44, getitem_326, getitem_327, getitem_332, getitem_333, reduce_scatter_tensor_13, mm_46, add_27, mm_49, mm_51, getitem_367, getitem_368, getitem_373, getitem_374, reduce_scatter_tensor_15, mm_53, add_31, mm_56, mm_58, getitem_408, getitem_409, getitem_414, getitem_415, reduce_scatter_tensor_17, mm_60, add_35, mm_63, mm_65, getitem_449, getitem_450, getitem_455, getitem_456, reduce_scatter_tensor_19, mm_67, add_39, mm_70, mm_72, getitem_490, getitem_491, getitem_496, getitem_497, reduce_scatter_tensor_21, mm_74, add_43, mm_77, mm_79, getitem_531, getitem_532, getitem_537, getitem_538, reduce_scatter_tensor_23, mm_81, add_47, mm_84, mm_86, getitem_572, getitem_573, getitem_578, getitem_579, reduce_scatter_tensor_25, mm_88, add_51, mm_91, mm_93, getitem_613, getitem_614, getitem_619, getitem_620, reduce_scatter_tensor_27, mm_95, add_55, mm_98, mm_100, getitem_654, getitem_655, getitem_660, getitem_661, reduce_scatter_tensor_29, mm_102, add_59, mm_105, mm_107, getitem_695, getitem_696, getitem_701, getitem_702, reduce_scatter_tensor_31, mm_109, add_63, mm_112, mm_114, getitem_736, getitem_737, getitem_742, getitem_743, reduce_scatter_tensor_33, mm_116, add_67, mm_119, mm_121, getitem_777, getitem_778, getitem_783, getitem_784, reduce_scatter_tensor_35, mm_123, add_71, mm_126, mm_128, getitem_818, getitem_819, getitem_824, getitem_825, reduce_scatter_tensor_37, mm_130, add_75, mm_133, mm_135, getitem_859, getitem_860, getitem_865, getitem_866, reduce_scatter_tensor_39, mm_137, add_79, mm_140, mm_142, getitem_900, getitem_901, getitem_906, getitem_907, reduce_scatter_tensor_41, mm_144, add_83, mm_147, mm_149, getitem_941, getitem_942, getitem_947, getitem_948, reduce_scatter_tensor_43, mm_151, add_87, mm_154, mm_156, getitem_982, getitem_983, getitem_988, getitem_989, reduce_scatter_tensor_45, mm_158, add_91, mm_161, mm_163, getitem_1023, getitem_1024, getitem_1029, getitem_1030, reduce_scatter_tensor_47, mm_165, add_95, mm_168, mm_170, getitem_1064, getitem_1065, getitem_1070, getitem_1071, reduce_scatter_tensor_49, mm_172, add_99, mm_175, mm_177, getitem_1105, getitem_1106, getitem_1111, getitem_1112, reduce_scatter_tensor_51, mm_179, add_103, mm_182, mm_184, getitem_1146, getitem_1147, getitem_1152, getitem_1153, reduce_scatter_tensor_53, mm_186, add_107, mm_189, mm_191, getitem_1187, getitem_1188, getitem_1193, getitem_1194, reduce_scatter_tensor_55, mm_193, add_111, mm_196, mm_198, getitem_1228, getitem_1229, getitem_1234, getitem_1235, reduce_scatter_tensor_57, mm_200, add_115, mm_203, mm_205, getitem_1269, getitem_1270, getitem_1275, getitem_1276, reduce_scatter_tensor_59, mm_207, add_119, mm_210, mm_212, getitem_1310, getitem_1311, getitem_1316, getitem_1317, reduce_scatter_tensor_61, mm_214, add_123, mm_217, mm_219, getitem_1351, getitem_1352, getitem_1357, getitem_1358, reduce_scatter_tensor_63, mm_221, reduce_scatter_tensor_64, rsqrt_64, view_2319) + +def load_args(reader): + buf0 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf0, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_1 + buf1 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf1, (2004, 4096), is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf3, (512,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf4, (64, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf5, (16, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf6, (16, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf7, (512, 512), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf8, (512,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf9, (224, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf10, (224, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf11, (512, 1792), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf12, (512,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf13, (64, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf14, (16, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf15, (16, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf16, (512, 512), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf17, (512,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf18, (224, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf19, (224, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf20, (512, 1792), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf21, (512,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf22, (64, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf23, (16, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf24, (16, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf25, (512, 512), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf26, (512,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf27, (224, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf28, (224, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf29, (512, 1792), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf30, (512,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf31, (64, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf32, (16, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf33, (16, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf34, (512, 512), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf35, (512,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf36, (224, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf37, (224, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf38, (512, 1792), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf39, (512,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf40, (64, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf41, (16, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf42, (16, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf43, (512, 512), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf44, (512,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf45, (224, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf46, (224, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf47, (512, 1792), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf48, (512,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf49, (64, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf50, (16, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf51, (16, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf52, (512, 512), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf53, (512,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf54, (224, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf55, (224, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf56, (512, 1792), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf57, (512,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf58, (64, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf59, (16, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf60, (16, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf61, (512, 512), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf62, (512,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf63, (224, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf64, (224, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf65, (512, 1792), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf66, (512,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf67, (64, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf68, (16, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf69, (16, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf70, (512, 512), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf71, (512,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf72, (224, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf73, (224, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf74, (512, 1792), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf75, (512,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf76, (64, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf77, (16, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf78, (16, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf79, (512, 512), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf80, (512,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf81, (224, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf82, (224, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf83, (512, 1792), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf84, (512,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf85, (64, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf86, (16, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf87, (16, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf88, (512, 512), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf89, (512,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf90, (224, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf91, (224, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf92, (512, 1792), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf93, (512,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf94, (64, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf95, (16, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf96, (16, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf97, (512, 512), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf98, (512,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf99, (224, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf100, (224, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf101, (512, 1792), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf102, (512,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf103, (64, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf104, (16, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf105, (16, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf106, (512, 512), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf107, (512,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf108, (224, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf109, (224, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf110, (512, 1792), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf111, (512,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf112, (64, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf113, (16, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf114, (16, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf115, (512, 512), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf116, (512,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf117, (224, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf118, (224, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf119, (512, 1792), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf120, (512,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf121, (64, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf122, (16, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf123, (16, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf124, (512, 512), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf125, (512,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf126, (224, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf127, (224, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf128, (512, 1792), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf129, (512,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf130, (64, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf131, (16, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf132, (16, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf133, (512, 512), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf134, (512,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf135, (224, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf136, (224, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf137, (512, 1792), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf138, (512,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf139, (64, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf140, (16, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf141, (16, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf142, (512, 512), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf143, (512,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf144, (224, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf145, (224, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf146, (512, 1792), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf147, (512,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf148, (64, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf149, (16, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf150, (16, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf151, (512, 512), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf152, (512,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf153, (224, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf154, (224, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf155, (512, 1792), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf156, (512,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf157, (64, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf158, (16, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf159, (16, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf160, (512, 512), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf161, (512,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf162, (224, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf163, (224, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf164, (512, 1792), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf165, (512,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf166, (64, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf167, (16, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf168, (16, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf169, (512, 512), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf170, (512,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf171, (224, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf172, (224, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf173, (512, 1792), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf174, (512,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf175, (64, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf176, (16, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf177, (16, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf178, (512, 512), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf179, (512,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf180, (224, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf181, (224, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf182, (512, 1792), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf183, (512,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf184, (64, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf185, (16, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf186, (16, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf187, (512, 512), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf188, (512,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf189, (224, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf190, (224, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf191, (512, 1792), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf192, (512,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf193, (64, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf194, (16, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf195, (16, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf196, (512, 512), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf197, (512,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf198, (224, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf199, (224, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf200, (512, 1792), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf201, (512,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf202, (64, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf203, (16, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf204, (16, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf205, (512, 512), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf206, (512,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf207, (224, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf208, (224, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf209, (512, 1792), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf210, (512,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf211, (64, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf212, (16, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf213, (16, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf214, (512, 512), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf215, (512,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf216, (224, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf217, (224, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf218, (512, 1792), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf219, (512,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf220, (64, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf221, (16, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf222, (16, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf223, (512, 512), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf224, (512,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf225, (224, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf226, (224, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf227, (512, 1792), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf228, (512,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf229, (64, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf230, (16, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf231, (16, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf232, (512, 512), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf233, (512,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf234, (224, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf235, (224, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf236, (512, 1792), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf237, (512,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf238, (64, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf239, (16, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf240, (16, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf241, (512, 512), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf242, (512,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf243, (224, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf244, (224, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf245, (512, 1792), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf246, (512,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf247, (64, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf248, (16, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf249, (16, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf250, (512, 512), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf251, (512,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf252, (224, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf253, (224, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf254, (512, 1792), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf255, (512,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf256, (64, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf257, (16, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf258, (16, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf259, (512, 512), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf260, (512,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf261, (224, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf262, (224, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf263, (512, 1792), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf264, (512,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf265, (64, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf266, (16, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf267, (16, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf268, (512, 512), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf269, (512,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf270, (224, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf271, (224, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf272, (512, 1792), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf273, (512,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf274, (64, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf275, (16, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf276, (16, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf277, (512, 512), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf278, (512,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf279, (224, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf280, (224, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf281, (512, 1792), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf282, (512,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf283, (64, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf284, (16, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf285, (16, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf286, (512, 512), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf287, (512,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf288, (224, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf289, (224, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf290, (512, 1792), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf291, (512,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf292, (2004, 4096), is_leaf=True) # primals_293 + +load_args._version = 0 + +def get_mesh_sizes(): + return 8, 8 + +def get_colls_estimations_file(): + return "colls8_8.table" + +def get_pg_names(): + return "0" diff --git a/autoparallel/tools/overlap_simulator/run.py b/autoparallel/tools/overlap_simulator/run.py new file mode 100644 index 00000000..b2e84a04 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/run.py @@ -0,0 +1,809 @@ +#!/usr/bin/env python3 +""" +Overlap Scheduling Experiments Runner + +This script runs overlap scheduling experiments with various bucketing strategies +on different model variants and configurations. +""" + +# Standard library imports +import argparse +import copy +import dataclasses +import json +import logging +import os +from math import inf +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +# Third-party imports +import torch +import torch.distributed as dist +import torch.fx as fx +from torch import device, tensor +from torch._dynamo.testing import rand_strided +from torch._inductor.fx_passes.bucketing import is_wait_tensor +from torch.fx.operator_schemas import normalize_function +from torch.testing._internal.distributed.fake_pg import FakeStore +from torch.utils._dtype_abbrs import dtype_abbrs + +# Local imports +import torch._dynamo.config +import torch._functorch.config +import torch._inductor.config +import torch._inductor.inductor_prims +import torch.fx.experimental._config + +# Constants +DEFAULT_VARIANT = "llama3_8b_bw_256_2d_32" + +# Launch overhead in microseconds +LAUNCH_OVERHEAD_US = 1 +BYTES_PER_MB = 1024 * 1024 +MS_TO_US_MULTIPLIER = 1000 + +# Logging configuration +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def configure_torch() -> None: + """Configure torch settings for overlap scheduling experiments.""" + torch._inductor.config.allow_buffer_reuse = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.max_autotune = False + torch._inductor.config.coordinate_descent_tuning = False + torch._inductor.config.deterministic = False + torch._inductor.config.aten_distributed_optimizations.collective_bucketing = True + torch._inductor.config.triton.store_cubin = False + torch._inductor.config.test_configs.runtime_triton_dtype_assert = False + torch._functorch.config.functionalize_rng_ops = False + torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = True + torch._functorch.config.unlift_effect_tokens = False + torch._functorch.config.selective_decompose = False + + +@dataclasses.dataclass +class Stats: + """Statistics for graph collective operations.""" + num_ag: int # Number of all-gather operations + num_rs: int # Number of reduce-scatter operations + num_ar: int # Number of all-reduce operations + runtime: float # Total runtime in milliseconds + + def __str__(self) -> str: + return f"AG:{self.num_ag}, RS:{self.num_rs}, AR:{self.num_ar}, Runtime:{self.runtime:.2f}ms" + + +@dataclasses.dataclass +class VariantConfig: + """Configuration for a model variant.""" + repro_class: type + load_args_func: Callable + get_mesh_sizes_func: Callable + get_colls_file_func: Callable + get_pg_names_func: Callable + + +class CollectiveEstimationParser: + """Parser for collective estimation table files.""" + + @staticmethod + def parse_table(file_path: str) -> Dict[Tuple[str, int, str], Dict[int, float]]: + """ + Parse the collectives estimations table file. + + Args: + file_path: Path to the table file + + Returns: + Dict mapping (group_name, group_size, collective_name) -> {size_mb: time_ms} + """ + result: Dict[Tuple[str, int, str], Dict[int, float]] = {} + + try: + with open(file_path, 'r') as f: + lines = f.readlines() + except FileNotFoundError: + logger.error(f"Collective estimation file not found: {file_path}") + return result + except Exception as e: + logger.error(f"Error reading collective estimation file {file_path}: {e}") + return result + + if len(lines) < 2: + logger.warning(f"Collective estimation file {file_path} has insufficient data") + return result + + # Parse header to get size columns + header = lines[0] + size_columns: List[int] = [] + for part in header.split(): + if part.endswith("MB"): + try: + size_mb = int(part.replace("MB", "")) + size_columns.append(size_mb) + except ValueError: + continue + + # Process data lines (skip separator line) + for line_num, line in enumerate(lines[2:], start=3): + line = line.strip() + if not line: + continue + + parts = line.split() + if len(parts) < 3 + len(size_columns): + logger.warning(f"Insufficient data in line {line_num} of {file_path}") + continue + + try: + group_name = parts[0] + group_size = int(parts[1]) + collective = parts[2] + + size_to_time: Dict[int, float] = {} + for i, size_mb in enumerate(size_columns): + time_ms = float(parts[3 + i]) + size_to_time[size_mb] = time_ms + + result[(group_name, group_size, collective)] = size_to_time + + except (ValueError, IndexError) as e: + logger.warning(f"Error parsing line {line_num} in {file_path}: {e}") + continue + + logger.info(f"Parsed {len(result)} collective entries from {file_path}") + return result + + @staticmethod + def interpolate_time(size_to_time: Dict[int, float], size_mb: float) -> float: + """ + Interpolate or extrapolate time for a given size in MB. + + Args: + size_to_time: Mapping of size (MB) to time (ms) + size_mb: Target size in MB + + Returns: + Estimated time in milliseconds + """ + if not size_to_time: + return 0.0 + + sorted_sizes = sorted(size_to_time.keys()) + + # For sizes less than 1MB, use 1MB value + if size_mb < 1.0: + return size_to_time.get(1, size_to_time[sorted_sizes[0]]) + + # Exact match + size_int = int(size_mb) + if size_int in size_to_time: + return size_to_time[size_int] + + # Find surrounding points for interpolation + lower_size = None + upper_size = None + + for s in sorted_sizes: + if s <= size_mb: + lower_size = s + if s >= size_mb and upper_size is None: + upper_size = s + + # Extrapolation cases + if lower_size is None: + # Below minimum - use first two points + if len(sorted_sizes) >= 2: + s1, s2 = sorted_sizes[0], sorted_sizes[1] + t1, t2 = size_to_time[s1], size_to_time[s2] + slope = (t2 - t1) / (s2 - s1) + return max(0.0, t1 + slope * (size_mb - s1)) + return size_to_time[sorted_sizes[0]] + + if upper_size is None: + # Above maximum - use last two points + if len(sorted_sizes) >= 2: + s1, s2 = sorted_sizes[-2], sorted_sizes[-1] + t1, t2 = size_to_time[s1], size_to_time[s2] + slope = (t2 - t1) / (s2 - s1) + return max(0.0, t2 + slope * (size_mb - s2)) + return size_to_time[sorted_sizes[-1]] + + # Interpolation between two points + if lower_size == upper_size: + return size_to_time[lower_size] + + t1, t2 = size_to_time[lower_size], size_to_time[upper_size] + fraction = (size_mb - lower_size) / (upper_size - lower_size) + return t1 + fraction * (t2 - t1) + + +class NodeEstimator: + """Handles runtime estimation for nodes in the computation graph.""" + + def __init__(self, + nodes_estimations_dict: Dict[fx.Node, float], + collective_table: Dict[Tuple[str, int, str], Dict[int, float]]): + self.node_names_ests = {n.name: est for n, est in nodes_estimations_dict.items()} + self.collective_table = collective_table + + @staticmethod + def get_hint(x: Union[int, torch.SymInt]) -> Optional[int]: + """Extract concrete int from SymInt if needed.""" + if isinstance(x, int): + return x + if hasattr(x, 'node') and hasattr(x.node, 'hint'): + return x.node.hint + return None + + @staticmethod + def get_tensor_bytes(node: fx.Node) -> Optional[int]: + """Get the size in bytes of the tensor produced by this node.""" + if "val" not in node.meta: + return None + + t = node.meta["val"] + if not isinstance(t, torch.Tensor): + return None + + shape = [NodeEstimator.get_hint(dim) for dim in t.shape] + if any(s is None for s in shape): + return None + + numel = 1 + for dim in shape: + numel *= dim + return numel * t.dtype.itemsize + + def get_collective_info(self, node: fx.Node) -> Optional[Tuple[str, int, int, str]]: + """ + Extract collective type, group_size, tensor bytes, and group_name. + + Returns: + (collective_name, group_size, tensor_bytes, group_name) or None + """ + if node.op != "call_function": + return None + + target_str = str(node.target) + collective_name = None + group_size = None + + # Determine collective type and extract group_size + if "all_gather_into_tensor" in target_str: + collective_name = "all_gather_into_tensor" + if len(node.args) >= 2: + group_size = self.get_hint(node.args[1]) if hasattr(node.args[1], 'node') else node.args[1] + if isinstance(node.args[1], int): + group_size = node.args[1] + + elif "reduce_scatter_tensor" in target_str: + collective_name = "reduce_scatter_tensor" + if len(node.args) >= 3: + group_size = self.get_hint(node.args[2]) if hasattr(node.args[2], 'node') else node.args[2] + if isinstance(node.args[2], int): + group_size = node.args[2] + + elif "all_reduce" in target_str: + collective_name = "all_reduce" + # No explicit group_size in args for all_reduce + + else: + return None + + # Get tensor bytes from input tensor + input_node = node.args[0] if node.args else None + if not isinstance(input_node, fx.Node): + return None + + tensor_bytes = self.get_tensor_bytes(input_node) + if tensor_bytes is None: + return None + + # Extract group_name + try: + group_name = get_group_name(node) + except Exception as e: + logger.warning(f"Failed to extract group name from node {node.name}: {e}") + group_name = "" + + return (collective_name, group_size, tensor_bytes, group_name) + + def estimate(self, node: fx.Node) -> float: + """ + Estimate execution time for a node in milliseconds. + + Args: + node: FX node to estimate + + Returns: + Estimated runtime in milliseconds + """ + # Check pre-computed estimation (matmul, etc.) + if node.name in self.node_names_ests: + return self.node_names_ests[node.name] + + # Check if this is a collective operation + coll_info = self.get_collective_info(node) + if coll_info is not None: + collective_name, group_size, tensor_bytes, node_group_name = coll_info + size_mb = tensor_bytes / BYTES_PER_MB + + # Look up in the table + if group_size is not None: + for (table_group, gs, cn), size_to_time in self.collective_table.items(): + if gs == group_size and cn == collective_name and table_group in node_group_name: + return CollectiveEstimationParser.interpolate_time(size_to_time, size_mb) + else: + # For all_reduce without explicit group_size + for (table_group, gs, cn), size_to_time in self.collective_table.items(): + if cn == collective_name and table_group in node_group_name: + return CollectiveEstimationParser.interpolate_time(size_to_time, size_mb) + + return 0.0 + + +class TraceGenerator: + """Generates execution traces for visualization.""" + + @staticmethod + def is_communication_node(node: fx.Node) -> bool: + """Check if node is a communication operation.""" + return (node.op == "call_function" and + isinstance(node.target, torch._ops.OpOverload) and + node.target.namespace == "_c10d_functional") + + @staticmethod + def get_tid(node: fx.Node) -> Union[int, str]: + """Get thread ID for trace visualization.""" + if TraceGenerator.is_communication_node(node): + if node.target == torch.ops._c10d_functional.wait_tensor.default: + return 0 + return f"group-{node.args[-1]}" + return 0 + + @staticmethod + def get_repr(arg: Any, mode: str = "full") -> Any: + """Get representation of argument for trace.""" + def get_dtype_repr(dtype): + return dtype_abbrs[dtype] + + if isinstance(arg, torch.Tensor): + return { + "shape": tuple(arg.shape), + "dtype": get_dtype_repr(arg.dtype) + } + + if isinstance(arg, (int, float, str)): + return arg + + if isinstance(arg, torch.dtype): + return get_dtype_repr(arg) + + if isinstance(arg, torch.fx.Node): + if mode == "name_only" or "val" not in arg.meta: + return f"fx node {arg.name}" + elif mode == "full": + return {"name": arg.name, "data": TraceGenerator.get_repr(arg.meta["val"])} + elif mode == "content_only": + return TraceGenerator.get_repr(arg.meta["val"]) + else: + raise ValueError(f"Unknown mode {mode}") + + if isinstance(arg, (list, tuple)): + return [TraceGenerator.get_repr(x, mode="name_only") for x in arg] + + if isinstance(arg, dict): + return {k: TraceGenerator.get_repr(v, mode="name_only") for k, v in arg.items()} + + return f"arg {type(arg)}" + + @classmethod + def generate_trace(cls, + gm: fx.GraphModule, + runtime_estimator: Callable[[fx.Node], float], + name: str) -> Dict[str, Any]: + """ + Generate execution trace for visualization. + + Args: + gm: Graph module to trace + runtime_estimator: Function to estimate node runtime + name: Name for the trace + + Returns: + Trace dictionary for perfetto visualization + """ + trace_events = [] + curr_time = {0: 0} + global_time: Dict[fx.Node, int] = {} + + for node_idx, node in enumerate(gm.graph.nodes): + dur_ms = runtime_estimator(node) + dur = dur_ms * MS_TO_US_MULTIPLIER # Convert to microseconds + tid = cls.get_tid(node) + + if tid not in curr_time: + curr_time[tid] = curr_time[0] + + event = { + "ph": "X", + "cat": "kernel", + "name": str(node), + "pid": 0, + "tid": tid + } + + if cls.is_communication_node(node): + if tid == 0 and is_wait_tensor(node) and node.args[0].op != "placeholder": + # Sync with compute stream for wait tensor + comm_end_time = global_time.pop(node.args[0]) + curr_time[tid] = max(curr_time[tid], comm_end_time) + else: + curr_time[tid] = max(curr_time[0], curr_time[tid]) + + event["ts"] = curr_time[tid] + event["dur"] = dur + curr_time[tid] += dur + LAUNCH_OVERHEAD_US + + if tid != 0: + curr_time[0] += LAUNCH_OVERHEAD_US + global_time[node] = curr_time[tid] + + # Add metadata + args = { + "order": node_idx, + "output": cls.get_repr(node, mode="content_only"), + "inputs": [cls.get_repr(arg) for arg in node.args] + } + event["args"] = args + + if dur > 0.0: + trace_events.append(event) + + return { + "traceEvents": trace_events, + "traceName": f"{name}_trace.json" + } + + +class ExperimentRunner: + """Main experiment runner for overlap scheduling.""" + + def __init__(self, variant_config: VariantConfig): + self.variant_config = variant_config + self.setup_process_groups() + + def setup_process_groups(self) -> None: + """Set up fake process groups for simulation.""" + from torch.distributed.device_mesh import DeviceMesh, init_device_mesh + + store = FakeStore() + mesh_sizes = self.variant_config.get_mesh_sizes_func() + world_size = 1 + for size in mesh_sizes: + world_size *= size + + self.pg = dist.init_process_group( + backend="fake", + rank=0, + world_size=world_size, + store=store + ) + + mesh = DeviceMesh("fake", torch.arange(world_size).view(*mesh_sizes)) + pgs = [] + pg_names = self.variant_config.get_pg_names_func() + + for i, size in enumerate(mesh_sizes): + pg = mesh.get_group(i) + pgs.append(pg) + + torch._C._distributed_c10d._unregister_all_process_groups() + for pg, pg_name in zip(pgs, pg_names): + torch._C._distributed_c10d._register_process_group(pg_name, pg) + + def run_experiment(self, variant_name: str) -> Tuple[Stats, Stats]: + """ + Run overlap scheduling experiment. + + Args: + variant_name: Name of the variant being tested + + Returns: + Tuple of (stats_before, stats_after) optimization + """ + try: + # Setup model and graph + mod = self.variant_config.repro_class() + + with torch.no_grad(): + from torch.fx.experimental.proxy_tensor import make_fx + from torch._subclasses.fake_tensor import FakeTensorMode + from torch._dynamo.debug_utils import InputReader + + mode = FakeTensorMode() + reader = InputReader() + self.variant_config.load_args_func(reader) + args = reader.args + + gm = make_fx(mod, tracing_mode="fake")(*args) + + # Import scheduling functions + from torch._inductor.fx_passes.overlap_scheduling import ( + schedule_overlap_bucketing_with_estimations + ) + + gm_before = copy.deepcopy(gm) + + # Run optimization to get estimations + gm_copy = copy.deepcopy(gm) + gm_after, nodes_estimations_dict_before = schedule_overlap_bucketing_with_estimations( + gm_copy, + collective_bucketing=True, + insert_overlap_deps=False, + max_memory_increase_ratio=0.0, + collective_estimator="analytical", + ) + + # Create estimators + colls_file_path = resolve_colls_file_path( + self.variant_config.get_colls_file_func() + ) + estimator_before = self._create_estimator( + nodes_estimations_dict_before, colls_file_path + ) + + # Get stats before optimization + stats_before = self.calculate_stats( + gm_before, estimator_before, f"{variant_name}_before" + ) + + # Run optimization again for after stats + _, nodes_estimations_dict_after = schedule_overlap_bucketing_with_estimations( + gm, + collective_bucketing=True, + insert_overlap_deps=False, + max_memory_increase_ratio=0.0, + collective_estimator="analytical", + ) + + estimator_after = self._create_estimator( + nodes_estimations_dict_after, colls_file_path + ) + + stats_after = self.calculate_stats( + gm_after, estimator_after, f"{variant_name}_after" + ) + + return stats_before, stats_after + + except Exception as e: + logger.error(f"Error running experiment for {variant_name}: {e}") + raise + + def _create_estimator(self, + nodes_estimations_dict: Dict[fx.Node, float], + colls_file_path: str) -> NodeEstimator: + """Create a node estimator with collective table.""" + collective_table = CollectiveEstimationParser.parse_table(colls_file_path) + return NodeEstimator(nodes_estimations_dict, collective_table) + + def calculate_stats(self, + gm: fx.GraphModule, + estimator: NodeEstimator, + name: str) -> Stats: + """Calculate statistics for a graph module.""" + num_ag = num_rs = num_ar = 0 + + for node in gm.graph.nodes: + if node.op == 'call_function': + target_str = str(node.target) + if 'all_gather_into_tensor' in target_str: + num_ag += 1 + elif 'reduce_scatter_tensor' in target_str: + num_rs += 1 + elif 'all_reduce' in target_str: + num_ar += 1 + + trace = TraceGenerator.generate_trace(gm, estimator.estimate, name) + + # Calculate total runtime + max_end_time = 0.0 + for event in trace.get("traceEvents", []): + ts = event.get("ts", 0) + dur = event.get("dur", 0) + end_time = ts + dur + max_end_time = max(max_end_time, end_time) + + runtime_ms = max_end_time / MS_TO_US_MULTIPLIER # Convert back to ms + + return Stats(num_ag=num_ag, num_rs=num_rs, num_ar=num_ar, runtime=runtime_ms) + + def cleanup(self) -> None: + """Clean up process groups.""" + dist.destroy_process_group() + + +def get_group_name(n: fx.Node) -> str: + """Extract the group name from a collective operation node.""" + opt_args_kwargs = normalize_function( + n.target, # type: ignore[arg-type] + args=n.args, + kwargs=n.kwargs, + normalize_to_only_use_kwargs=True, + ) + assert opt_args_kwargs is not None + _, kwargs = opt_args_kwargs + return kwargs["group_name"] + + +def resolve_colls_file_path(filename: str) -> str: + """Resolve collective estimations filename to full path relative to run.py.""" + script_dir = os.path.dirname(os.path.abspath(__file__)) + return os.path.join(script_dir, filename) + + +def get_variant_configs() -> Dict[str, VariantConfig]: + """Get all available variant configurations.""" + # Import all repro modules + from repro_llama3_8b_bw_256_2d_32layers import ( + Repro as Repro_bw_256_2d_32, load_args as load_args_bw_256_2d_32, + get_mesh_sizes as get_mesh_sizes_bw_256_2d_32, + get_colls_estimations_file as get_colls_file_bw_256_2d_32, + get_pg_names as get_pg_names_bw_256_2d_32 + ) + from repro_llama3_8b_bw_256_1d_32layers import ( + Repro as Repro_bw_256_1d_32, load_args as load_args_bw_256_1d_32, + get_mesh_sizes as get_mesh_sizes_bw_256_1d_32, + get_colls_estimations_file as get_colls_file_bw_256_1d_32, + get_pg_names as get_pg_names_bw_256_1d_32 + ) + from repro_llama3_8b_bw_64_2d_32layers import ( + Repro as Repro_bw_64_2d_32, load_args as load_args_bw_64_2d_32, + get_mesh_sizes as get_mesh_sizes_bw_64_2d_32, + get_colls_estimations_file as get_colls_file_bw_64_2d_32, + get_pg_names as get_pg_names_bw_64_2d_32 + ) + from repro_llama3_8b_bw_64_1d_32layers import ( + Repro as Repro_bw_64_1d_32, load_args as load_args_bw_64_1d_32, + get_mesh_sizes as get_mesh_sizes_bw_64_1d_32, + get_colls_estimations_file as get_colls_file_bw_64_1d_32, + get_pg_names as get_pg_names_bw_64_1d_32 + ) + from repro_llama3_8b_fw_256_2d_32layers import ( + Repro as Repro_fw_256_2d_32, load_args as load_args_fw_256_2d_32, + get_mesh_sizes as get_mesh_sizes_fw_256_2d_32, + get_colls_estimations_file as get_colls_file_fw_256_2d_32, + get_pg_names as get_pg_names_fw_256_2d_32 + ) + from repro_llama3_8b_fw_256_1d_32layers import ( + Repro as Repro_fw_256_1d_32, load_args as load_args_fw_256_1d_32, + get_mesh_sizes as get_mesh_sizes_fw_256_1d_32, + get_colls_estimations_file as get_colls_file_fw_256_1d_32, + get_pg_names as get_pg_names_fw_256_1d_32 + ) + from repro_llama3_8b_fw_64_2d_32layers import ( + Repro as Repro_fw_64_2d_32, load_args as load_args_fw_64_2d_32, + get_mesh_sizes as get_mesh_sizes_fw_64_2d_32, + get_colls_estimations_file as get_colls_file_fw_64_2d_32, + get_pg_names as get_pg_names_fw_64_2d_32 + ) + from repro_llama3_8b_fw_64_1d_32layers import ( + Repro as Repro_fw_64_1d_32, load_args as load_args_fw_64_1d_32, + get_mesh_sizes as get_mesh_sizes_fw_64_1d_32, + get_colls_estimations_file as get_colls_file_fw_64_1d_32, + get_pg_names as get_pg_names_fw_64_1d_32 + ) + + return { + "llama3_8b_bw_256_2d_32": VariantConfig( + Repro_bw_256_2d_32, load_args_bw_256_2d_32, get_mesh_sizes_bw_256_2d_32, + get_colls_file_bw_256_2d_32, get_pg_names_bw_256_2d_32 + ), + "llama3_8b_bw_256_1d_32": VariantConfig( + Repro_bw_256_1d_32, load_args_bw_256_1d_32, get_mesh_sizes_bw_256_1d_32, + get_colls_file_bw_256_1d_32, get_pg_names_bw_256_1d_32 + ), + "llama3_8b_bw_64_2d_32": VariantConfig( + Repro_bw_64_2d_32, load_args_bw_64_2d_32, get_mesh_sizes_bw_64_2d_32, + get_colls_file_bw_64_2d_32, get_pg_names_bw_64_2d_32 + ), + "llama3_8b_bw_64_1d_32": VariantConfig( + Repro_bw_64_1d_32, load_args_bw_64_1d_32, get_mesh_sizes_bw_64_1d_32, + get_colls_file_bw_64_1d_32, get_pg_names_bw_64_1d_32 + ), + "llama3_8b_fw_256_2d_32": VariantConfig( + Repro_fw_256_2d_32, load_args_fw_256_2d_32, get_mesh_sizes_fw_256_2d_32, + get_colls_file_fw_256_2d_32, get_pg_names_fw_256_2d_32 + ), + "llama3_8b_fw_256_1d_32": VariantConfig( + Repro_fw_256_1d_32, load_args_fw_256_1d_32, get_mesh_sizes_fw_256_1d_32, + get_colls_file_fw_256_1d_32, get_pg_names_fw_256_1d_32 + ), + "llama3_8b_fw_64_2d_32": VariantConfig( + Repro_fw_64_2d_32, load_args_fw_64_2d_32, get_mesh_sizes_fw_64_2d_32, + get_colls_file_fw_64_2d_32, get_pg_names_fw_64_2d_32 + ), + "llama3_8b_fw_64_1d_32": VariantConfig( + Repro_fw_64_1d_32, load_args_fw_64_1d_32, get_mesh_sizes_fw_64_1d_32, + get_colls_file_fw_64_1d_32, get_pg_names_fw_64_1d_32 + ), + } + + +def create_argument_parser() -> argparse.ArgumentParser: + """Create and configure the argument parser.""" + parser = argparse.ArgumentParser( + description="Run overlap scheduling experiments", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python run_improved.py --variant llama3_8b_bw_256_2d_32 + python run_improved.py --variant llama3_8b_fw_256_2d_32 + python run_improved.py --variant llama3_8b_fw_64_1d_32 +""" + ) + + variant_choices = list(get_variant_configs().keys()) + parser.add_argument( + "--variant", + type=str, + default=DEFAULT_VARIANT, + choices=variant_choices, + help=f"Model variant (default: {DEFAULT_VARIANT})" + ) + + return parser + + +def main() -> None: + """Main entry point.""" + # Configure torch before any experiments + configure_torch() + + # Parse arguments + parser = create_argument_parser() + args = parser.parse_args() + + # Get variant configuration + variant_configs = get_variant_configs() + if args.variant not in variant_configs: + logger.error(f"Unknown variant: {args.variant}") + return + + variant_config = variant_configs[args.variant] + + # Run experiment + try: + logger.info(f"Running overlap scheduling experiment for variant: {args.variant}") + + runner = ExperimentRunner(variant_config) + stats_before, stats_after = runner.run_experiment(args.variant) + + # Print results + logger.info("Experiment completed successfully") + print(f"\nResults for {args.variant}:") + print(f"BEFORE: {stats_before}") + print(f"AFTER: {stats_after}") + + # Calculate improvement + if stats_before.runtime > 0: + improvement = ((stats_before.runtime - stats_after.runtime) / stats_before.runtime) * 100 + print(f"Runtime improvement: {improvement:.2f}%") + + runner.cleanup() + + except Exception as e: + logger.error(f"Experiment failed: {e}") + raise + + +if __name__ == '__main__': + main() \ No newline at end of file