From e316e499b8a348a1a5a79a56571cf072fa0753f4 Mon Sep 17 00:00:00 2001 From: Brian Coutinho Date: Mon, 4 Nov 2024 14:36:33 -0800 Subject: [PATCH] Add kernel hash parsing and suport local minor updates for fields (#199) Summary: Pull Request resolved: https://github.com/facebookresearch/HolisticTraceAnalysis/pull/199 Enable support for minor version updates to the trace config (parser). This makes it easy to not have to create new files in case only fields are added to the trace as well as parsing logic. Add new field for Triton kernel hash and test it https://github.com/pytorch/pytorch/pull/139531 Differential Revision: D65434312 --- hta/common/trace_parser.py | 18 +++++++++-- hta/configs/default_values.py | 22 ++++++++++++++ .../event_args_formats/event_args_1.0.0.yaml | 16 ++++++++++ hta/configs/event_args_yaml_parser.py | 28 +++++------------- hta/configs/parser_config.py | 7 +++++ .../triton_example/triton_example.json.gz | Bin 2138 -> 2162 bytes tests/test_trace_parse.py | 5 ++++ 7 files changed, 73 insertions(+), 23 deletions(-) diff --git a/hta/common/trace_parser.py b/hta/common/trace_parser.py index 147df25..2cea3f4 100644 --- a/hta/common/trace_parser.py +++ b/hta/common/trace_parser.py @@ -23,7 +23,12 @@ from hta.configs.config import logger from hta.configs.default_values import ValueType -from hta.configs.parser_config import AttributeSpec, ParserBackend, ParserConfig +from hta.configs.parser_config import ( + AttributeSpec, + DEFAULT_PARSE_VERSION, + ParserBackend, + ParserConfig, +) from hta.utils.utils import normalize_gpu_stream_numbers # from memory_profiler import profile @@ -260,7 +265,16 @@ def _compress_df( ) # args_to_keep = args_to_keep.union(counter_names) cfg.add_args( - [AttributeSpec(name, name, ValueType.Int, -1) for name in counter_names] + [ + AttributeSpec( + name, + name, + ValueType.Int, + -1, + min_supported_version=DEFAULT_PARSE_VERSION, + ) + for name in counter_names + ] ) logger.info(f"counter_names={counter_names}") logger.info(f"args={cfg.get_args()}") diff --git a/hta/configs/default_values.py b/hta/configs/default_values.py index 1600c45..3b483d8 100644 --- a/hta/configs/default_values.py +++ b/hta/configs/default_values.py @@ -2,6 +2,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import re from enum import Enum from typing import Dict, List, NamedTuple, Union @@ -17,6 +18,24 @@ MAX_NUM_PROCESSES: int = 32 +class YamlVersion(NamedTuple): + major: int + minor: int + patch: int + + def get_version_str(self) -> str: + return f"{self.major}.{self.minor}.{self.patch}" + + @staticmethod + def from_string(version_str: str) -> "YamlVersion": + pattern = r"^(\d+)\.(\d+)\.(\d+)$" + match = re.match(pattern, version_str) + if not match: + raise ValueError(f"Invalid version string: {version_str}") + major, minor, patch = map(int, match.groups()) + return YamlVersion(major, minor, patch) + + class ValueType(Enum): """ValueType enumerates the possible data types for the attribute values.""" @@ -40,6 +59,9 @@ class AttributeSpec(NamedTuple): raw_name: str value_type: ValueType default_value: Union[int, float, str, object] + # This property allows for addition of fields that do + # not break backwards compatibility with other fields (minor version bumps). + min_supported_version: YamlVersion class EventArgs(NamedTuple): diff --git a/hta/configs/event_args_formats/event_args_1.0.0.yaml b/hta/configs/event_args_formats/event_args_1.0.0.yaml index 5c2c14b..216bb54 100644 --- a/hta/configs/event_args_formats/event_args_1.0.0.yaml +++ b/hta/configs/event_args_formats/event_args_1.0.0.yaml @@ -46,6 +46,12 @@ AVAILABLE_ARGS: raw_name: kernel_backend value_type: String default_value: "" + cpu_op::kernel_hash: + name: kernel_hash + raw_name: kernel_hash + value_type: String + default_value: "" + min_supported_version: "1.1.0" correlation::cbid: name: cbid raw_name: cbid @@ -166,6 +172,16 @@ AVAILABLE_ARGS: raw_name: Out msg nelems value_type: Int default_value: 0 + nccl::in_msg_size: + name: in_msg_size + raw_name: In msg size + value_type: Int + default_value: 0 + nccl::out_msg_size: + name: out_msg_size + raw_name: Out msg size + value_type: Int + default_value: 0 nccl::group_size: name: group_size raw_name: Group size diff --git a/hta/configs/event_args_yaml_parser.py b/hta/configs/event_args_yaml_parser.py index 4552b6b..2a902cf 100644 --- a/hta/configs/event_args_yaml_parser.py +++ b/hta/configs/event_args_yaml_parser.py @@ -3,32 +3,13 @@ # pyre-strict -import re from functools import lru_cache from pathlib import Path -from typing import Callable, Dict, List, NamedTuple +from typing import Callable, Dict, List import yaml -from hta.configs.default_values import AttributeSpec, EventArgs, ValueType - - -class YamlVersion(NamedTuple): - major: int - minor: int - patch: int - - def get_version_str(self) -> str: - return f"{self.major}.{self.minor}.{self.patch}" - - @staticmethod - def from_string(version_str: str) -> "YamlVersion": - pattern = r"^(\d+)\.(\d+)\.(\d+)$" - match = re.match(pattern, version_str) - if not match: - raise ValueError(f"Invalid version string: {version_str}") - major, minor, patch = map(int, match.groups()) - return YamlVersion(major, minor, patch) +from hta.configs.default_values import AttributeSpec, EventArgs, ValueType, YamlVersion # Yaml version will be mapped to the yaml files defined under the "event_args_formats" folder @@ -110,6 +91,11 @@ def parse_value_type(value: str) -> ValueType: raw_name=value["raw_name"], value_type=parse_value_type(value["value_type"]), default_value=value["default_value"], + min_supported_version=( + YamlVersion.from_string(value["min_supported_version"]) + if "min_supported_version" in value + else version + ), ) for k, value in yaml_content["AVAILABLE_ARGS"].items() } diff --git a/hta/configs/parser_config.py b/hta/configs/parser_config.py index b9b3212..dbcf0dd 100644 --- a/hta/configs/parser_config.py +++ b/hta/configs/parser_config.py @@ -149,6 +149,13 @@ def set_global_parser_config_version(version: YamlVersion) -> None: ).ARGS_COMMUNICATION ParserConfig.ARGS_DEFAULT = parse_event_args_yaml(version).ARGS_DEFAULT + @staticmethod + def show_available_args(): + from pprint import pprint + + global AVAILABLE_ARGS + pprint(AVAILABLE_ARGS) + # Define a global ParserConfig variable for internal use. To access this variable, # Clients should use ParserConfig.get_default_cfg and ParserConfig.set_default_cfg. diff --git a/tests/data/triton_example/triton_example.json.gz b/tests/data/triton_example/triton_example.json.gz index 0935c73dc9f18d1e66886f22cf39cfc4a83c487e..36347a15ee2e267c2a951c41cdac38ab1fa18157 100644 GIT binary patch literal 2162 zcmV-&2#xn2iwFqd7%67}18sRfx6h$xejR4++{td}!vI z`G&)X=?{N>KuAg4IbD!HXq@;gRE7xmM6E}czDq|ji)hS!nn1&Ms7L58`5Sup9@MEl zq=aNa;ZeEyknlFI(>RS<*UiRmUW+B-4bBu0NHrklhV@5FSI81tQ& zZW3d?6Vpv%>~~`R|HQ!W2f4>_Q;;!Uu6I6zyxJq5w`#1MA=|)v+Mv`q*%wvzH3@=w{z^aNd{1Gts+@ z+|SRP#vut%zOhigN&C<=4P6elqjD$sd5NyPXHh6UEr-FehMrwODEk5R&^I~*4IuDb zJ=IhL_ie{0O~tyY^&QJ-WUSW2SbVavJoY&Y#}V_#(=?&pn0VgUWyv`4!|{m*+q!A= zRdY}p?3n!l)*6|s_moE{(qKF%iNx>zAnfsJs!gJ~L1%3ClP+CNTLvjf{8gIzhHn}_ zBf5yJ$8eI&X46TU3f!|;3Y338LbUaM5w!<>nQ%x!i@b2bLN}(IqN@<{nH(&Nc)R!y z;*;CoW4NOU&8%4|0BrDg9`rtTj!D`$$7?&75eH zAQ4|hR05p_4A_SH{fAqZ1SFQwvzKl`ikBU&e8*rphNg;#hxHqm{$iuMT6Wucl5dpp zsWptx%*s>ZQ!OLYu8v8=Fd)Hb>p|GJEwy1t=G8w^n$kj9b~=?8To!Wr%w=vWW}hOS zY}p>iG$1@n)4Ex7f+@#OKv#ky8Y3=eK{>ceDRM?Z3%X$ODpQs@#o8v3g0 zXr@!Niz#utS-0U|(VP#C?*JEP9B+t~4ccI*i6*>{An*Hh35b{rQx^%{75Ya_vYiT& zGDn(~bRDEBlahx9akbfIU=-y^BuZ4N&wcADb5o9+{!Snn%zrCG@05h+IDJ?Jl)+`# z5M6`gqRX9RWwk0`Uhr>H<=;c-4JJR7+R@v7#g@76SErd@o#Xvl06p5TMO(gf9LsXA0ClbGsWNR#>lk5#489lMAEGND1z4yIf-NWB={H-iJKYd(hgmOi%wX+wVbQX^o;GpBZn zvwAdZw!Vz+Q9~%F@>LIKiz!o;KP@qZJH=Eznr)_V%OmR)rmD9vuV<=TdGMGFD45dP zW$MW6moe2=_R~AXR6Uw)ru251s)tj-l-?3kMyHsnN3+e8(JoW8Yf8nnce(C@kG9@m=i9AAiUx3Y_P#A2 zQ#1gy&lYa&oT6g~6-S>`w0%@@%2@?&>#PEI>#(8$oE^4WKCWm0sDiDw&MR=Y4lEkL z*EId=dJY(vv5r>{Djw<=A;G5UB=TsF*bDz^~{We&`6Dx;OW z^@v6>6PMXyrMK;u-FEa5JbA#ZuMu)t8Va1vwlz(>y=mM_mNd(TI>=L# zf^Ip9`4H_psX*(CN16O*$~HVn)TXCpMWv`}YY%|S?=k)dOkz*ybLcezvjJm2m~BV1 zA(+MbYkfcY>gL~H6>+IhJPb96&@-%_X#)t?p=u3vwTRncG1e}`jpk1-_m@6j73THb zC>mnb96SJ2G70=}hQu@}E+mWRC!f#Jf;^*(`_bJMnxxYy#8~40q_E~6T>lEO_7(`7 zf&PHLmofZLbmO`~Q3Pm95~RSm05qq3F0SM;6sI9_=YHU!sF3wWXdLS$G&hV3-JyWf zj{aS5$gbZ=IHTS1NfoBnJ%cgU9?)GLoQW9y;=vD*mAsh-d#&Ng4+2C(PiU}T(14H7 zpsT>{D9f-y2Syh#;IfhRRw{I-R6#wY*>X&0Up+|BN37Z!nyffYUKqYd>2u5nlqWtp z!z4rWHRWe)dKQZafpCGPpZsb@Ag^0#X0iR_8 zL5GY#Xt@vRb3|Ps@8bUQ6Qc496!}{+UnQJH5j0a02djUHmbWn&GP9aR2~BdtsW*Hx z*&;x8<{?;zAitRR5+$1iWp)40L|j4_-8?^u057Wq4t2aBO8RYIARH0PS0CZ`(K${@!0f2soe% zI0{L<2RIzMNv{ZQlLk$)yFd{bvc$w{Wl1Mdj-AE+_Zw1{{3b`TW8w56L7PZ?nRz(R za7d2n4}X0?NXfi0o4CKTGz;QLIYVS1N_{pBJa&`D2}^mvGN||teTp%KjeI5JTRH=_fJtotSUEd#iaW>3?XowFofsJj`Qf*DIOkfe!F^yIR&<4s=Dxc)v zj}T7=zX%vR&607L@Dphj(e6Bsyp(Zb=mHAD%bSW*X1tU_Achr;oW7&L-aTSf7SI{XwVh zXco`C(ag+fYUXM(SC*PauJ@Bg{?h28PJ%>yo-hgYF^qvIDBpg#_1w@+CG?`xElBZ; zRK|A#h9FQ?^YF0jaS=>bn&Yb3#*^$(xfX5dTE!@9Tx;n=xbhyW-ftXDR}G?7enVBN z{&gA8s2h{ zDuO^AuqhylJugq($eW{oG(^WoK~iBzqn56nR25PdKHz%vY1ISD_faC=n9>}da^D^L zX)tB!svIx>&4qbC{!is&2LvSt6V)`yGAg>A<{Z8b~ zwr*$OG>>=yz1A%3yBBh{+|I$!X&i|?&_RYK*l<*;tt|84DOKABLpvLmCZ@4%X{Vi~ zoxLqxLEqcbW!++zHk9v!64aHlsmih}VyT+78{@9HS81B_aU5wX!Ah+;HW7vHh=o}< zm=>}@>iNxmJ2r9bunQDep!&Lpj;|S5PS4=gdS!arN90Lzws(v5;cBZsmG1ZJ_ohhwb zrkdf@Fr{_Gl-?_*n$fH?rFY9zGn^Wx^p2P^dc{;Tnsug(ZkcL^Q^QpK;q2{9^=rSr z&rV%1Wp>Ndp1bsdsjhbERvJah({UvpOkfYlBQKl z%Ys@_^`qSZZgGY24;WJXj6O$x8!#&{wiMcgTDGcD5et_WfiJKB{Z$dS3B|)1#Z*7R z_6cYH^;PY~Ls+S~Z1Y`XHpt_?lh@OyGB*l$l@%$X{ zF`BqDbbfzxcZo*%<0JT37Q8Un@!MCwg0CF|fkT}K^u6-of1+#83(Hr4Hg&@s7#D!X zjE}_?y%?o=guHPO`Y0)79SMz2Uqf>pbD=vDaJtdIX${%*n;Si`{@A1rQ|Fc;*u)R$ zZsDAnnEm1*h>)2bO@p!2@En98Vv#R2SVz)8EtqUKMT8YP2s(!eSIn#&DKYY z+lf#CEA(Xqz8v_7%{W@Q1(+4xgmnb`xDW_Br2j$9eZ*!7^Mt(f`-@MA$!|a`ddYZ| z@i|Cgval9y1^DZGSbi-9x$>{zlbm};Br5S%MeP_BEKQY-@Udmq<}g=aD! zE-si-$#Mmq)f7+YUE~2lwZ41oWVFEUo1VuB93inwyopx Q=fmIs1D*fVe0x6t0Osl=9RL6T diff --git a/tests/test_trace_parse.py b/tests/test_trace_parse.py index e8a915e..ef73bda 100644 --- a/tests/test_trace_parse.py +++ b/tests/test_trace_parse.py @@ -356,6 +356,7 @@ def test_triton_trace(self) -> None: trace_df = self.triton_t.get_trace(0) self.assertGreater(len(trace_df), 0) self.assertTrue("kernel_backend" in trace_df.columns) + self.assertTrue("kernel_hash" in trace_df.columns) triton_cpu_ops = trace_df[trace_df.kernel_backend.ne("")] # We have one triton cpu op @@ -365,6 +366,10 @@ def test_triton_trace(self) -> None: self.assertEqual(triton_op["s_name"], "triton_poi_fused_add_cos_sin_0") self.assertEqual(triton_op["s_cat"], "cpu_op") self.assertEqual(triton_op["kernel_backend"], "triton") + self.assertEqual( + triton_op["kernel_hash"], + "cqaokwf2bph4egogzevc22vluasiyuui4i54zpemp6knbsggfbuu", + ) if __name__ == "__main__": # pragma: no cover