diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 2796b931c2f..3fd12555cef 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -1285,3 +1285,61 @@ def decorator(func): register_op = do_nothing register_custom_python_op = register_op + +from functools import wraps + +try: + import nvtx +except ImportError: + nvtx = None + + +def nvtx_annotate(message, color="blue"): + """A decorator to add NVTX annotations for profiling.""" + + def decorator(func): + if nvtx is None: + return func + + @wraps(func) + def wrapper(*args, **kwargs): + with nvtx.annotate(message, color=color): + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def nvtx_class_annotate(color="blue"): + """ + A class decorator that automatically adds NVTX annotations to all public + methods of a class. The annotation message will be in the format + 'ClassName.method_name'. + """ + + def class_decorator(cls): + if nvtx is None: + return cls + + class_name = cls.__name__ + # Iterate over all attributes of the class + for attr_name in dir(cls): + # Filter out private (starting with '_') and special methods + if attr_name.startswith("_"): + continue + attr = getattr(cls, attr_name) + # Check if it is a callable method + if callable(attr): + # Generate annotation message + message = f"{class_name}.{attr_name}" + + # Apply NVTX decorator + decorated_method = nvtx_annotate(message, color=color)(attr) + + # Set the decorated method back to the class + setattr(cls, attr_name, decorated_method) + + return cls + + return class_decorator