diff --git a/hta/common/trace_parser.py b/hta/common/trace_parser.py index 147df25..2cea3f4 100644 --- a/hta/common/trace_parser.py +++ b/hta/common/trace_parser.py @@ -23,7 +23,12 @@ from hta.configs.config import logger from hta.configs.default_values import ValueType -from hta.configs.parser_config import AttributeSpec, ParserBackend, ParserConfig +from hta.configs.parser_config import ( + AttributeSpec, + DEFAULT_PARSE_VERSION, + ParserBackend, + ParserConfig, +) from hta.utils.utils import normalize_gpu_stream_numbers # from memory_profiler import profile @@ -260,7 +265,16 @@ def _compress_df( ) # args_to_keep = args_to_keep.union(counter_names) cfg.add_args( - [AttributeSpec(name, name, ValueType.Int, -1) for name in counter_names] + [ + AttributeSpec( + name, + name, + ValueType.Int, + -1, + min_supported_version=DEFAULT_PARSE_VERSION, + ) + for name in counter_names + ] ) logger.info(f"counter_names={counter_names}") logger.info(f"args={cfg.get_args()}") diff --git a/hta/configs/default_values.py b/hta/configs/default_values.py index 1600c45..3b483d8 100644 --- a/hta/configs/default_values.py +++ b/hta/configs/default_values.py @@ -2,6 +2,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import re from enum import Enum from typing import Dict, List, NamedTuple, Union @@ -17,6 +18,24 @@ MAX_NUM_PROCESSES: int = 32 +class YamlVersion(NamedTuple): + major: int + minor: int + patch: int + + def get_version_str(self) -> str: + return f"{self.major}.{self.minor}.{self.patch}" + + @staticmethod + def from_string(version_str: str) -> "YamlVersion": + pattern = r"^(\d+)\.(\d+)\.(\d+)$" + match = re.match(pattern, version_str) + if not match: + raise ValueError(f"Invalid version string: {version_str}") + major, minor, patch = map(int, match.groups()) + return YamlVersion(major, minor, patch) + + class ValueType(Enum): """ValueType enumerates the possible data types for the attribute values.""" @@ -40,6 +59,9 @@ class AttributeSpec(NamedTuple): raw_name: str value_type: ValueType default_value: Union[int, float, str, object] + # This property allows for addition of fields that do + # not break backwards compatibility with other fields (minor version bumps). + min_supported_version: YamlVersion class EventArgs(NamedTuple): diff --git a/hta/configs/event_args_formats/event_args_1.0.0.yaml b/hta/configs/event_args_formats/event_args_1.0.0.yaml index 5c2c14b..216bb54 100644 --- a/hta/configs/event_args_formats/event_args_1.0.0.yaml +++ b/hta/configs/event_args_formats/event_args_1.0.0.yaml @@ -46,6 +46,12 @@ AVAILABLE_ARGS: raw_name: kernel_backend value_type: String default_value: "" + cpu_op::kernel_hash: + name: kernel_hash + raw_name: kernel_hash + value_type: String + default_value: "" + min_supported_version: "1.1.0" correlation::cbid: name: cbid raw_name: cbid @@ -166,6 +172,16 @@ AVAILABLE_ARGS: raw_name: Out msg nelems value_type: Int default_value: 0 + nccl::in_msg_size: + name: in_msg_size + raw_name: In msg size + value_type: Int + default_value: 0 + nccl::out_msg_size: + name: out_msg_size + raw_name: Out msg size + value_type: Int + default_value: 0 nccl::group_size: name: group_size raw_name: Group size diff --git a/hta/configs/event_args_yaml_parser.py b/hta/configs/event_args_yaml_parser.py index 4552b6b..2a902cf 100644 --- a/hta/configs/event_args_yaml_parser.py +++ b/hta/configs/event_args_yaml_parser.py @@ -3,32 +3,13 @@ # pyre-strict -import re from functools import lru_cache from pathlib import Path -from typing import Callable, Dict, List, NamedTuple +from typing import Callable, Dict, List import yaml -from hta.configs.default_values import AttributeSpec, EventArgs, ValueType - - -class YamlVersion(NamedTuple): - major: int - minor: int - patch: int - - def get_version_str(self) -> str: - return f"{self.major}.{self.minor}.{self.patch}" - - @staticmethod - def from_string(version_str: str) -> "YamlVersion": - pattern = r"^(\d+)\.(\d+)\.(\d+)$" - match = re.match(pattern, version_str) - if not match: - raise ValueError(f"Invalid version string: {version_str}") - major, minor, patch = map(int, match.groups()) - return YamlVersion(major, minor, patch) +from hta.configs.default_values import AttributeSpec, EventArgs, ValueType, YamlVersion # Yaml version will be mapped to the yaml files defined under the "event_args_formats" folder @@ -110,6 +91,11 @@ def parse_value_type(value: str) -> ValueType: raw_name=value["raw_name"], value_type=parse_value_type(value["value_type"]), default_value=value["default_value"], + min_supported_version=( + YamlVersion.from_string(value["min_supported_version"]) + if "min_supported_version" in value + else version + ), ) for k, value in yaml_content["AVAILABLE_ARGS"].items() } diff --git a/hta/configs/parser_config.py b/hta/configs/parser_config.py index b9b3212..dbcf0dd 100644 --- a/hta/configs/parser_config.py +++ b/hta/configs/parser_config.py @@ -149,6 +149,13 @@ def set_global_parser_config_version(version: YamlVersion) -> None: ).ARGS_COMMUNICATION ParserConfig.ARGS_DEFAULT = parse_event_args_yaml(version).ARGS_DEFAULT + @staticmethod + def show_available_args(): + from pprint import pprint + + global AVAILABLE_ARGS + pprint(AVAILABLE_ARGS) + # Define a global ParserConfig variable for internal use. To access this variable, # Clients should use ParserConfig.get_default_cfg and ParserConfig.set_default_cfg. diff --git a/tests/data/triton_example/triton_example.json.gz b/tests/data/triton_example/triton_example.json.gz index 0935c73..36347a1 100644 Binary files a/tests/data/triton_example/triton_example.json.gz and b/tests/data/triton_example/triton_example.json.gz differ diff --git a/tests/test_trace_parse.py b/tests/test_trace_parse.py index e8a915e..ef73bda 100644 --- a/tests/test_trace_parse.py +++ b/tests/test_trace_parse.py @@ -356,6 +356,7 @@ def test_triton_trace(self) -> None: trace_df = self.triton_t.get_trace(0) self.assertGreater(len(trace_df), 0) self.assertTrue("kernel_backend" in trace_df.columns) + self.assertTrue("kernel_hash" in trace_df.columns) triton_cpu_ops = trace_df[trace_df.kernel_backend.ne("")] # We have one triton cpu op @@ -365,6 +366,10 @@ def test_triton_trace(self) -> None: self.assertEqual(triton_op["s_name"], "triton_poi_fused_add_cos_sin_0") self.assertEqual(triton_op["s_cat"], "cpu_op") self.assertEqual(triton_op["kernel_backend"], "triton") + self.assertEqual( + triton_op["kernel_hash"], + "cqaokwf2bph4egogzevc22vluasiyuui4i54zpemp6knbsggfbuu", + ) if __name__ == "__main__": # pragma: no cover