Skip to content

Commit

Permalink
Update gpu user annotation breakdown (#217)
Browse files Browse the repository at this point in the history
Summary:
## Summay
Fixes #180. Some additional improvements to past change #209

1. Add use_gpu_time= option that allows the feature to aggregate both CPU user annotations and GPU user annotations.
2. Visualization and API were improved as some options are not needed. For example kernel_type does not make sense since user annotation names are user provided.
3. Added unit tests.

## Before submitting

- [y] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  - [ ] N/A
- [y] Did you write any new necessary tests?
  - [ ] N/A
- [ ] Did you make sure to update the docs?
  - [y] N/A
- [ ] Did you update the [changelog](https://github.com/facebookresearch/HolisticTraceAnalysis/blob/main/CHANGELOG.md)?
  - [y] N/A

Testplan:
## Run feature

![Screenshot 2025-02-10 at 4 23 44 PM](https://github.com/user-attachments/assets/8ed6bb84-529d-4d8b-8091-87a8bf9726fd)
![Screenshot 2025-02-10 at 4 27 23 PM](https://github.com/user-attachments/assets/24110454-e57b-42cf-ab50-041e7de40b1f)

## unit test

Pull Request resolved: #217

Reviewed By: fengxizhou

Differential Revision: D69430645

Pulled By: briancoutinho

fbshipit-source-id: c386d66431917fefcd3717335080ca169afec399
  • Loading branch information
briancoutinho authored and facebook-github-bot committed Feb 12, 2025
1 parent 62e28ac commit 267ea8c
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 157 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Versioning](https://semver.org/spec/v2.0.0.html).
- Add nccl collective fields to parser config
- Queue length analysis: Add feature to compute time blocked on a stream hitting max queue length.
- Add `kernel_backend` to parser config for Triton / torch.compile() support.
- Add analyses features for GPU user annotation attribution at trace and kernel level.

#### Changed
- Change test data path in unittests from relative path to real path to support running test within IDEs.
Expand Down
225 changes: 91 additions & 134 deletions hta/analyzers/breakdown_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,17 +346,41 @@ def get_gpu_kernels_with_user_annotations(
def get_gpu_user_annotation_breakdown(
cls,
t: "Trace",
use_gpu_annotation: bool = True,
visualize: bool = True,
duration_ratio: float = 0.8,
num_kernels: int = 10,
include_memory_kernels: bool = False,
image_renderer="notebook",
) -> Tuple[pd.DataFrame, pd.DataFrame]:
num_kernels: int = 1000,
image_renderer: Optional[str] = None,
) -> Optional[pd.DataFrame]:
"""
GPU kernel breakdown implementation. See `get_gpu_kernel_breakdown` in `trace_analysis.py` for details.
Summarizes the time spent by each GPU user annotation. Outputs the following graphs:
1. Pie charts showing the most time consuming user annotations for each rank.
2. Bar graphs showing the average duration for the most time user annotations for each rank.
Args:
use_gpu_annotation (boolean): Use time on GPU for each user annotation, if false use the time on CPU instead. Default = True,
visualize (boolean): Set to True to display the graphs. Default = True.
duration_ratio (float): Floating point value between 0 and 1 specifying the ratio of time taken
by top user annotations. Default = 0.8.
num_kernels (int): Maximum number of user annotations to show. Default = 1000. Rest get grouped into "other".
image_renderer (str): Set to ``notebook`` when using jupyter and ``jupyterlab`` when using jupyter-lab.
To see all available options execute: ``import plotly; plotly.io.renderers`` in a python shell.
Returns:
Optional[pd.DataFrame]
Returns a dataframe that shows the min, max, mean, standard deviation, total time taken by each
user annotation on each rank. This dataframe will be summarized based on values of ``duration_ratio``
and ``num_kernels``. If both ``duration_ratio`` and ``num_kernels`` are specified,
``num_kernels`` takes precedence.
If user_annotations are not present on CPU or GPU (according to use_gpu_annotation flag), return None.
"""
sym_table = t.symbol_table.get_sym_table()
idx = sym_table.index("gpu_user_annotation")
annotation = "gpu_user_annotation" if use_gpu_annotation else "user_annotation"
image_renderer = image_renderer or ""

if (idx := t.symbol_table.sym_index.get(annotation, None)) is None:
logger.warning(f"Trace does not contain any {annotation}")
return None

all_kernel_df = pd.DataFrame(
{
Expand All @@ -366,81 +390,33 @@ def get_gpu_user_annotation_breakdown(
"min": pd.Series(dtype="int"),
"std": pd.Series(dtype="float"),
"mean": pd.Series(dtype="int"),
"kernel_type": pd.Series(dtype="str"),
"rank": pd.Series(dtype="int"),
}
)
kernel_type_df = pd.DataFrame(
{
"kernel_type": pd.Series(dtype="str"),
"sum": pd.Series(dtype="int"),
}
)

kernel_type_to_analysis: List[str] = [
KernelType.COMPUTATION.name,
KernelType.COMMUNICATION.name,
]
if include_memory_kernels:
kernel_type_to_analysis.append(KernelType.MEMORY.name)
kernel_per_rank: Dict[int, pd.DataFrame] = {}

kernel_per_rank: Dict[str, Dict] = defaultdict(dict)
for rank, trace_df in t.traces.items():
gpu_user_annotation_kernels = trace_df[trace_df["cat"].eq(idx)].copy()
gpu_user_annotation_kernels["kernel_type"] = gpu_user_annotation_kernels[
["name"]
].apply(lambda x: get_kernel_type(sym_table[x["name"]]), axis=1)
gpu_user_annotation_kernels["name"] = gpu_user_annotation_kernels[
"name"
].apply(lambda x: sym_table[x])
t.symbol_table.add_symbols_to_trace_df(gpu_user_annotation_kernels, "name")
logger.info(
f"rank = {rank}, num {annotation}s = {len(gpu_user_annotation_kernels)}"
)

# Create kernel type dataframe
kernel_type_df = pd.concat(
[
kernel_type_df,
cls._get_gpu_kernel_type_time(
gpu_user_annotation_kernels, kernel_type_to_analysis
),
],
ignore_index=True,
gpu_kernel_time = cls._aggr_gpu_kernel_time(
gpu_user_annotation_kernels,
duration_ratio=duration_ratio,
num_kernels=num_kernels,
)
gpu_kernel_time["rank"] = int(rank)
kernel_per_rank[rank] = gpu_kernel_time

# Create all kernel info dataframe
for kernel_type in kernel_type_to_analysis:
gpu_kernel_time = gpu_user_annotation_kernels[
gpu_user_annotation_kernels["kernel_type"] == kernel_type
]

if kernel_type not in kernel_per_rank:
kernel_per_rank[kernel_type] = {}

gpu_kernel_time = cls._aggr_gpu_kernel_time(
gpu_kernel_time,
duration_ratio=duration_ratio,
num_kernels=num_kernels,
)

kernel_per_rank[kernel_type][rank] = gpu_kernel_time

gpu_kernel_time["kernel_type"] = kernel_type
gpu_kernel_time["rank"] = int(rank)
all_kernel_df = pd.concat(
[all_kernel_df, gpu_kernel_time], ignore_index=True
)

kernel_type_df = kernel_type_df.groupby(by=["kernel_type"])["sum"].agg(["sum"])
kernel_type_df.reset_index(inplace=True)
kernel_type_df.sort_values(
by=["sum"], ignore_index=True, inplace=True, ascending=False
)
kernel_type_df["percentage"] = (
kernel_type_df["sum"] / kernel_type_df["sum"].sum()
) * 100
kernel_type_df = kernel_type_df.round({"percentage": 1})
all_kernel_df = pd.concat(
[all_kernel_df, gpu_kernel_time], ignore_index=True
)

all_kernel_df.sort_values(
by=["kernel_type", "name", "rank"], ignore_index=True, inplace=True
)
all_kernel_df.sort_values(by=["rank", "name"], ignore_index=True, inplace=True)
all_kernel_df.rename(
columns={
"sum": "sum (us)",
Expand All @@ -453,80 +429,61 @@ def get_gpu_user_annotation_breakdown(
)

if visualize: # pragma: no cover
non_zero_kernel_df = kernel_type_df[(kernel_type_df["percentage"] > 0)]

fig = px.pie(
non_zero_kernel_df,
values="percentage",
names="kernel_type",
height=500,
title="Kernel Type Percentage Across All Ranks",
specs = []
for count, rank in enumerate(kernel_per_rank):
if count % 2 == 0:
specs.append([{"type": "domain"}, {"type": "domain"}])
fig = make_subplots(
rows=int((len(kernel_per_rank) + 1) / 2),
cols=2,
specs=specs,
)
for rank in kernel_per_rank:
fig.add_trace(
go.Pie(
labels=kernel_per_rank[rank]["name"],
values=kernel_per_rank[rank]["sum"],
title=f"Rank {rank}",
automargin=False,
),
int(rank / 2) + 1,
int(rank % 2) + 1,
)
image_size_multiplier = 1 + (len(t.traces.keys())) / 2
fig.update_layout(
title_text="User annotation distribution on each rank",
margin=dict(l=50, r=50, b=50, t=50),
showlegend=True,
legend=dict(yanchor="bottom", y=-0.4, xanchor="left", x=0),
height=400 * image_size_multiplier,
legend=dict(yanchor="bottom", y=-0.1, xanchor="left", x=0),
)
fig.show(renderer=image_renderer)

for kernel in kernel_per_rank:
specs = []
for count, rank in enumerate(kernel_per_rank[kernel]):
if count % 2 == 0:
specs.append([{"type": "domain"}, {"type": "domain"}])
fig = make_subplots(
rows=int((len(kernel_per_rank[kernel]) + 1) / 2),
cols=2,
specs=specs,
kernel_name = all_kernel_df["name"].unique()
for name in kernel_name:
if name == "others":
continue
kernel_name_df = all_kernel_df[all_kernel_df["name"].eq(name)]
fig = px.bar(
kernel_name_df,
x="rank",
y="mean (us)",
title=name,
labels={
"rank": "Rank",
"mean (us)": "Mean Duration (us)",
},
error_y=kernel_name_df["max (us)"] - kernel_name_df["mean (us)"],
error_y_minus=kernel_name_df["mean (us)"]
- kernel_name_df["min (us)"],
)
for rank in kernel_per_rank[kernel]:
fig.add_trace(
go.Pie(
labels=kernel_per_rank[kernel][rank]["name"],
values=kernel_per_rank[kernel][rank]["sum"],
title=f"Rank {rank}",
automargin=False,
),
int(rank / 2) + 1,
int(rank % 2) + 1,
)
image_size_multiplier = 1 + (len(t.traces.keys())) / 2
fig.update_layout(
title_text=f'Kernel type "{kernel}" - kernel distribution on each rank',
margin=dict(l=50, r=50, b=50, t=50),
showlegend=True,
height=400 * image_size_multiplier,
legend=dict(yanchor="bottom", y=-0.1, xanchor="left", x=0),
title_text=f"User annotation = {name}",
xaxis=dict(tickmode="linear", tick0=0, dtick=1),
)
fig.show(renderer=image_renderer)

kernel_df = all_kernel_df[all_kernel_df["kernel_type"].eq(kernel)]

kernel_name = kernel_df["name"].unique()
for name in kernel_name:
if name != "others":
kernel_name_df = kernel_df[kernel_df["name"].eq(name)]
fig = px.bar(
kernel_name_df,
x="rank",
y="mean (us)",
title=name,
labels={
"rank": "Rank",
"mean (us)": "Mean Duration (us)",
},
error_y=kernel_name_df["max (us)"]
- kernel_name_df["mean (us)"],
error_y_minus=kernel_name_df["mean (us)"]
- kernel_name_df["min (us)"],
)
fig.update_layout(
title_text=f'Kernel type "{kernel}" - {name}',
xaxis=dict(tickmode="linear", tick0=0, dtick=1),
)
fig.show(renderer=image_renderer)

return kernel_type_df, all_kernel_df
return all_kernel_df

@classmethod
def _get_gpu_kernel_type_time(
Expand Down Expand Up @@ -611,7 +568,7 @@ def _aggr_gpu_kernel_time(
gpu_kernel_time = gpu_kernel_time.sort_values(
by=["sum"], ascending=False, ignore_index=True
)
gpu_kernel_time["std"].fillna(0, inplace=True)
gpu_kernel_time.fillna({"std": 0}, inplace=True)

# if there are more than num_kernels kernels, starting to aggregate kernels
if gpu_kernel_time.shape[0] > num_kernels:
Expand All @@ -628,7 +585,7 @@ def _aggr_gpu_kernel_time(
["sum", "max", "min", "mean", "std"]
)
gpu_kernel_time.reset_index(inplace=True)
gpu_kernel_time["std"].fillna(0, inplace=True)
gpu_kernel_time.fillna({"std": 0}, inplace=True)

return gpu_kernel_time

Expand Down
33 changes: 16 additions & 17 deletions hta/trace_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def get_gpu_kernel_breakdown(
duration_ratio: float = 0.8,
num_kernels: int = 10,
include_memory_kernels: bool = True,
image_renderer: str = "notebook",
image_renderer: str = "",
) -> Tuple[pd.DataFrame, pd.DataFrame]:
r"""
Summarizes the time spent by each kernel and by kernel type. Outputs the following graphs:
Expand Down Expand Up @@ -187,43 +187,42 @@ def get_gpu_kernels_with_user_annotations(

def get_gpu_user_annotation_breakdown(
self,
use_gpu_annotation: bool = True,
visualize: bool = True,
duration_ratio: float = 0.8,
num_kernels: int = 10,
include_memory_kernels: bool = True,
image_renderer: str = "notebook",
) -> Tuple[pd.DataFrame, pd.DataFrame]:
num_kernels: int = 1000,
image_renderer: Optional[str] = None,
) -> Optional[pd.DataFrame]:
r"""
Summarizes the time spent by each kernel and by kernel type. Outputs the following graphs:
Summarizes the time spent by each GPU user annotation. Outputs the following graphs:
1. Pie chart indicating the percentage of time taken by each kernel type.
2. Pie charts showing the most time consuming kernels for each rank for each kernel type.
3. Bar graphs showing the average duration for the most time consuming kernels for each rank and each kernel type.
1. Pie charts showing the most time consuming user annotations for each rank.
2. Bar graphs showing the average duration for the most time user annotations for each rank.
Args:
use_gpu_annotation (boolean): Use time on GPU for each user annotation, if false use the time on CPU instead. Default = True,
visualize (boolean): Set to True to display the graphs. Default = True.
duration_ratio (float): Floating point value between 0 and 1 specifying the ratio of time taken
by top COMM/COMP/MEMORY kernels. Default = 0.8.
num_kernels (int): Maximum number of COMM/COMP/MEMORY kernels to show. Default = 10.
include_memory_kernels (bool): Whether to include MEMORY kernels in the analysis. Default = True.
by top user annotations. Default = 0.8.
num_kernels (int): Maximum number of user annotations to show. Default = 1000. Rest get grouped into "other".
image_renderer (str): Set to ``notebook`` when using jupyter and ``jupyterlab`` when using jupyter-lab.
To see all available options execute: ``import plotly; plotly.io.renderers`` in a python shell.
Returns:
Tuple[pd.DataFrame, pd.DataFrame]
Returns two dataframes. The first dataframe shows the percentage of time spent by kernel type.
The second dataframe shows the min, max, mean, standard deviation, total time taken by each
kernel on each rank. This dataframe will be summarized based on values of ``duration_ratio``
Optional[pd.DataFrame]
Returns a dataframe that shows the min, max, mean, standard deviation, total time taken by each
user annotation on each rank. This dataframe will be summarized based on values of ``duration_ratio``
and ``num_kernels``. If both ``duration_ratio`` and ``num_kernels`` are specified,
``num_kernels`` takes precedence.
If user_annotations are not present on CPU or GPU (according to use_gpu_annotation flag), return None.
"""

return BreakdownAnalysis.get_gpu_user_annotation_breakdown(
self.t,
use_gpu_annotation,
visualize,
duration_ratio,
num_kernels,
include_memory_kernels,
image_renderer,
)

Expand Down
Loading

0 comments on commit 267ea8c

Please sign in to comment.