Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions fastdeploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,3 +1285,61 @@ def decorator(func):
register_op = do_nothing

register_custom_python_op = register_op

from functools import wraps

try:
import nvtx
except ImportError:
nvtx = None


def nvtx_annotate(message, color="blue"):
"""A decorator to add NVTX annotations for profiling."""

def decorator(func):
if nvtx is None:
return func

@wraps(func)
def wrapper(*args, **kwargs):
with nvtx.annotate(message, color=color):
return func(*args, **kwargs)

return wrapper

return decorator


def nvtx_class_annotate(color="blue"):
"""
A class decorator that automatically adds NVTX annotations to all public
methods of a class. The annotation message will be in the format
'ClassName.method_name'.
"""

def class_decorator(cls):
if nvtx is None:
return cls

class_name = cls.__name__
# Iterate over all attributes of the class
for attr_name in dir(cls):
# Filter out private (starting with '_') and special methods
if attr_name.startswith("_"):
continue
attr = getattr(cls, attr_name)
# Check if it is a callable method
if callable(attr):
# Generate annotation message
message = f"{class_name}.{attr_name}"

# Apply NVTX decorator
decorated_method = nvtx_annotate(message, color=color)(attr)

# Set the decorated method back to the class
setattr(cls, attr_name, decorated_method)

return cls

return class_decorator