Skip to content

Commit

Permalink
Add kernel hash parsing and suport local minor updates for fields (fa…
Browse files Browse the repository at this point in the history
…cebookresearch#199)

Summary:
Pull Request resolved: facebookresearch#199

Enable support for minor version updates to the trace config (parser). This makes it easy to not have to create new files in case only fields are added to the trace as well as parsing logic.

Add new field for Triton kernel hash and test it
pytorch/pytorch#139531

Differential Revision: D65434312
  • Loading branch information
briancoutinho authored and facebook-github-bot committed Nov 4, 2024
1 parent 69c1637 commit 4854d7f
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 21 deletions.
22 changes: 22 additions & 0 deletions hta/configs/default_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import re
from enum import Enum
from typing import Dict, List, NamedTuple, Union

Expand All @@ -17,6 +18,24 @@
MAX_NUM_PROCESSES: int = 32


class YamlVersion(NamedTuple):
major: int
minor: int
patch: int

def get_version_str(self) -> str:
return f"{self.major}.{self.minor}.{self.patch}"

@staticmethod
def from_string(version_str: str) -> "YamlVersion":
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
match = re.match(pattern, version_str)
if not match:
raise ValueError(f"Invalid version string: {version_str}")
major, minor, patch = map(int, match.groups())
return YamlVersion(major, minor, patch)


class ValueType(Enum):
"""ValueType enumerates the possible data types for the attribute values."""

Expand All @@ -40,6 +59,9 @@ class AttributeSpec(NamedTuple):
raw_name: str
value_type: ValueType
default_value: Union[int, float, str, object]
# This property allows for addition of fields that do
# not break backwards compatibility with other fields (minor version bumps).
min_supported_version: YamlVersion


class EventArgs(NamedTuple):
Expand Down
16 changes: 16 additions & 0 deletions hta/configs/event_args_formats/event_args_1.0.0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ AVAILABLE_ARGS:
raw_name: kernel_backend
value_type: String
default_value: ""
cpu_op::kernel_hash:
name: kernel_hash
raw_name: kernel_hash
value_type: String
default_value: ""
min_supported_version: "1.1.0"
correlation::cbid:
name: cbid
raw_name: cbid
Expand Down Expand Up @@ -166,6 +172,16 @@ AVAILABLE_ARGS:
raw_name: Out msg nelems
value_type: Int
default_value: 0
nccl::in_msg_size:
name: in_msg_size
raw_name: In msg size
value_type: Int
default_value: 0
nccl::out_msg_size:
name: out_msg_size
raw_name: Out msg size
value_type: Int
default_value: 0
nccl::group_size:
name: group_size
raw_name: Group size
Expand Down
28 changes: 7 additions & 21 deletions hta/configs/event_args_yaml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,13 @@
# pyre-strict


import re
from functools import lru_cache
from pathlib import Path
from typing import Callable, Dict, List, NamedTuple
from typing import Callable, Dict, List

import yaml

from hta.configs.default_values import AttributeSpec, EventArgs, ValueType


class YamlVersion(NamedTuple):
major: int
minor: int
patch: int

def get_version_str(self) -> str:
return f"{self.major}.{self.minor}.{self.patch}"

@staticmethod
def from_string(version_str: str) -> "YamlVersion":
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
match = re.match(pattern, version_str)
if not match:
raise ValueError(f"Invalid version string: {version_str}")
major, minor, patch = map(int, match.groups())
return YamlVersion(major, minor, patch)
from hta.configs.default_values import AttributeSpec, EventArgs, ValueType, YamlVersion


# Yaml version will be mapped to the yaml files defined under the "event_args_formats" folder
Expand Down Expand Up @@ -110,6 +91,11 @@ def parse_value_type(value: str) -> ValueType:
raw_name=value["raw_name"],
value_type=parse_value_type(value["value_type"]),
default_value=value["default_value"],
min_supported_version=(
YamlVersion.from_string(value["min_supported_version"])
if "min_supported_version" in value
else version
),
)
for k, value in yaml_content["AVAILABLE_ARGS"].items()
}
Expand Down
7 changes: 7 additions & 0 deletions hta/configs/parser_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ def set_global_parser_config_version(version: YamlVersion) -> None:
).ARGS_COMMUNICATION
ParserConfig.ARGS_DEFAULT = parse_event_args_yaml(version).ARGS_DEFAULT

@staticmethod
def show_available_args():
from pprint import pprint

global AVAILABLE_ARGS
pprint(AVAILABLE_ARGS)


# Define a global ParserConfig variable for internal use. To access this variable,
# Clients should use ParserConfig.get_default_cfg and ParserConfig.set_default_cfg.
Expand Down
Binary file modified tests/data/triton_example/triton_example.json.gz
Binary file not shown.
5 changes: 5 additions & 0 deletions tests/test_trace_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def test_triton_trace(self) -> None:
trace_df = self.triton_t.get_trace(0)
self.assertGreater(len(trace_df), 0)
self.assertTrue("kernel_backend" in trace_df.columns)
self.assertTrue("kernel_hash" in trace_df.columns)

triton_cpu_ops = trace_df[trace_df.kernel_backend.ne("")]
# We have one triton cpu op
Expand All @@ -365,6 +366,10 @@ def test_triton_trace(self) -> None:
self.assertEqual(triton_op["s_name"], "triton_poi_fused_add_cos_sin_0")
self.assertEqual(triton_op["s_cat"], "cpu_op")
self.assertEqual(triton_op["kernel_backend"], "triton")
self.assertEqual(
triton_op["kernel_hash"],
"cqaokwf2bph4egogzevc22vluasiyuui4i54zpemp6knbsggfbuu",
)


if __name__ == "__main__": # pragma: no cover
Expand Down

0 comments on commit 4854d7f

Please sign in to comment.