diff --git a/ENVs.md b/ENVs.md index bb196c68b0..67153b16bc 100644 --- a/ENVs.md +++ b/ENVs.md @@ -7,3 +7,7 @@ | `FLA_USE_FAST_OPS` | `0` | `0` or `1` | Enable faster, but potentially less accurate, operations when set to `1`. | | `FLA_CACHE_RESULTS` | `1` | `0` or `1` | Whether to cache autotune timings to disk. Defaults to `1` (enabled). | | `FLA_TRIL_PRECISION` | `ieee` | `ieee`, `tf32`, `tf32x3` | Controls the precision for triangular operations. `tf32x3` is only available on NV GPUs. | +| `FLA_CONFIG_DIR` | - | Any path | Override the default config directory for Triton kernel configs. When set, loads configs from `$FLA_CONFIG_DIR/{GPU}/` instead of `fla/configs/{GPU}/`. | +| `FLA_DISABLE_CACHE` | `1` | `0` or `1` | When set to '1', skip loading cached Triton kernel configurations and force fallback to autotune. Useful for debugging or when cache may be outdated. | +| `FLA_GPU_NAME` | - | Any string | Override the detected GPU name for config directory naming. When set, configs will be stored in `configs/{FLA_GPU_NAME}/` instead of auto-detecting from hardware (CUDA/ROCm). Useful for custom or unsupported devices. | +| `TRITON_CACHE_DIR` | - | Any path | Override Triton's default cache directory. When set, Triton will use this directory for kernel compilation cache instead of `~/.triton`. | diff --git a/fla/ops/utils/cache.py b/fla/ops/utils/cache.py new file mode 100644 index 0000000000..97c5a4914b --- /dev/null +++ b/fla/ops/utils/cache.py @@ -0,0 +1,223 @@ +import json +import os +import warnings +from functools import lru_cache +from pathlib import Path +from typing import Any + +import torch +import triton +from packaging import version +from triton.runtime.autotuner import Autotuner + +TRITON_ABOVE_3_5_1 = version.parse(triton.__version__) >= version.parse("3.5.1") +TRITON_ABOVE_3_4_0 = version.parse(triton.__version__) >= version.parse("3.4.0") +FLA_ALWAYS_CHECK_CACHE = os.environ.get("FLA_ALWAYS_CHECK_CACHE") == "1" +FLA_DISABLE_CACHE = os.environ.get("FLA_DISABLE_CACHE", "1") == "1" + + +@lru_cache(maxsize=1) +def get_gpu_info(): + """Get GPU model information. + + This function detects the GPU model and returns a sanitized string identifier. + It prioritizes FLA_GPU_NAME environment variable if set, then detects from + available hardware (CUDA, ROCm, Intel GPU, or CPU). + """ + # Check if GPU name is overridden via environment variable + gpu_name = None + # Check if GPU name is overridden via environment variable + if "FLA_GPU_NAME" in os.environ: + gpu_name = os.environ["FLA_GPU_NAME"] + # Try to get device name based on availability + elif torch.cuda.is_available(): + # Works for both NVIDIA and AMD GPUs (ROCm) + gpu_name = torch.cuda.get_device_name(0) + elif hasattr(torch, 'xpu') and torch.xpu.is_available(): + gpu_name = torch.xpu.get_device_name(0) + + if gpu_name: + return gpu_name.replace(" ", "_").replace("(", "_").replace(")", "_").replace("-", "_") + + # Default to CPU if no GPU available + return "cpu" + + +def get_fla_config_dir() -> Path: + """Get FLA's configs directory. + + The directory can be overridden by setting the FLA_CONFIG_DIR environment variable. + If set, configs will be loaded from $FLA_CONFIG_DIR/{GPU}/ instead of the default + fla/configs/{GPU}/ in the project. + """ + # Check if custom config dir is set via environment variable + if "FLA_CONFIG_DIR" in os.environ: + base_dir = Path(os.environ["FLA_CONFIG_DIR"]) + else: + # Default: project_dir/fla/configs/ + project_dir = Path(__file__).parent.parent.parent + base_dir = project_dir / "configs" + + gpu_name = get_gpu_info() + config_dir = base_dir / gpu_name + return config_dir + + +def load_cached_config(kernel_name: str) -> dict[str, Any] | None: + """ + Load cached best config for a kernel from FLA configs directory. + + This function loads the cached best configuration for a given kernel name + from fla/configs/{GPU}/{kernel_name}.json. The file should contain only the + best_config dictionary. + + If the config file is not found or cannot be loaded, a warning is printed + and None is returned, allowing fallback to Triton's autotune. + + The lookup can be disabled by setting the FLA_DISABLE_CACHE environment variable. + + Args: + kernel_name: Name of the kernel (e.g., "causal_conv1d_fwd_kernel") + + Returns: + Best config dictionary or None if not found or disabled + """ + # Check if cache is disabled via environment variable + if os.environ.get("FLA_DISABLE_CACHE") == "1": + return None + + config_dir = get_fla_config_dir() + config_file = config_dir / f"{kernel_name}.json" + + if not config_file.exists(): + return None + + try: + with open(config_file) as f: + config = json.load(f) + return config + except Exception as e: + warnings.warn(f"Error reading config file {config_file}: {e}") + return None + + +class CachedAutotuner(Autotuner): + """ + A modified autotuner that loads best config from FLA's config directory. + + This class extends Triton's Autotuner but overrides the run method to + try loading cached configuration first before falling back to autotune. + """ + + def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, **kwargs): + super().__init__(fn, arg_names, configs, key, reset_to_zero, restore_value, **kwargs) + self.kernel_name = fn.fn.__name__ if hasattr(fn, 'fn') else fn.__name__ + self._fla_cache_checked = bool(FLA_DISABLE_CACHE) + + def run(self, *args, **kwargs): + if not self._fla_cache_checked: + self.first_run_hook() + self._fla_cache_checked = bool(not FLA_ALWAYS_CHECK_CACHE) + return super().run(*args, **kwargs) + + def first_run_hook(self): + best_config = load_cached_config(self.kernel_name) + + if best_config is not None: + kw = best_config.get("kwargs", {}) + num_warps = best_config.get("num_warps", 4) + num_stages = best_config.get("num_stages", 2) + + if TRITON_ABOVE_3_5_1: + cfg = triton.Config( + kw, + num_warps=num_warps, + num_stages=num_stages, + num_ctas=best_config.get("num_ctas", 1), + maxnreg=best_config.get("maxnreg", None), + pre_hook=None, + ir_override=best_config.get("ir_override", None), + ) + else: + cfg = triton.Config( + kw, + num_warps=num_warps, + num_stages=num_stages, + ) + + self.configs = [cfg] + else: + # No cached config found. + warnings.warn( + f"No cached config found for kernel '{self.kernel_name}', " + "falling back to Triton autotune", + stacklevel=2 + ) + + +def cache_autotune(configs, key=None, prune_configs_by=None, reset_to_zero=None, restore_value=None, + pre_hook=None, post_hook=None, warmup=None, rep=None, use_cuda_graph=False, + do_bench=None, cache_results=False): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function with FLA config support. + + This decorator extends Triton's autotune to support loading best configurations + from FLA's config directory (fla/configs/{GPU}/). It searches for cached + configs by kernel name from files named {kernel_name}.json. + + If a cached best config is found, it will be used directly and autotuning will be + skipped. If no cache is found, a warning is issued and the decorator falls back + to normal autotuning. + + .. highlight:: python + .. code-block:: python + + @fla.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # key is used for fallback autotune + ) + @triton.jit + def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr): + ... + + :param configs: a list of :code:`triton.Config` objects (used for fallback autotune) + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger autotune + :type key: list[str] or None + :param prune_configs_by: a dict of functions that are used to prune configs + :param reset_to_zero: a list of argument names whose value will be reset to zero + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after running + :type restore_value: list[str] + :param pre_hook: a function to call before the kernel + :type pre_hook: lambda args, reset_only + :param post_hook: a function to call after the kernel + :type post_hook: lambda args, exception + :param warmup: warmup time for benchmarking (deprecated) + :type warmup: int + :param rep: repetition time for benchmarking (deprecated) + :type rep: int + :param do_bench: a benchmark function + :type do_bench: lambda fn, quantiles + :param cache_results: whether to cache autotune timings to disk (passed to Triton) + :type cache_results: bool + """ + # key can be None when we want to use cache only (no fallback autotune) + if key is None: + key = [] + + def decorator(fn): + kwargs = {} + if TRITON_ABOVE_3_4_0: + kwargs = {"cache_results": cache_results} + + return CachedAutotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, + pre_hook=pre_hook, post_hook=post_hook, + prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph, do_bench=do_bench, + **kwargs, + ) + + return decorator diff --git a/scripts/demo_extract_configs.py b/scripts/demo_extract_configs.py new file mode 100644 index 0000000000..6c456e4b53 --- /dev/null +++ b/scripts/demo_extract_configs.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python +""" +Extract best configs from Triton's autotune cache. + +This script searches Triton's cache directory (~/.triton/cache) for .autotune.json files, +extracts the best configuration for each kernel, and saves them to a human-readable format. + +Usage: + python extract_triton_autotune_cache.py [--output-dir DIR] + +The output files are saved to fla/configs/{GPU}/{kernel_name}.json +Each file contains only the best_config for direct cache lookup. +""" + +import argparse +import json +import os +import shutil +import sys +from pathlib import Path +from typing import Any + +os.environ['FLA_DISABLE_CACHE'] = '1' + + +def get_gpu_info(): + """Get GPU model information. + + This function detects the GPU model and returns a sanitized string identifier. + It prioritizes FLA_GPU_NAME environment variable if set, then detects from + available hardware (CUDA, ROCm, Intel GPU, or CPU). + """ + import torch + + # Check if GPU name is overridden via environment variable + if "FLA_GPU_NAME" in os.environ: + gpu_name = os.environ["FLA_GPU_NAME"] + return gpu_name.replace(" ", "_").replace("(", "_").replace(")", "_").replace("-", "_") + + # Try to get device name based on availability + if torch.cuda.is_available(): + # Works for both NVIDIA and AMD GPUs (ROCm) + gpu_name = torch.cuda.get_device_name(0) + gpu_name = gpu_name.replace(" ", "_").replace("(", "_").replace(")", "_").replace("-", "_") + return gpu_name + + # Default to CPU if no GPU available + return "cpu" + + +def get_triton_cache_dir() -> Path: + """Get Triton's cache directory.""" + cache_dir = os.environ.get("TRITON_CACHE_DIR", "~/.triton") + return Path(cache_dir).expanduser() / "cache" + + +def get_fla_config_dir() -> Path: + """Get FLA's configs directory. + + The directory can be overridden by setting the FLA_CONFIG_DIR environment variable. + If set, configs will be saved to $FLA_CONFIG_DIR/{GPU}/ instead of the default + fla/configs/{GPU}/ in the project. + """ + # Check if custom config dir is set via environment variable + if "FLA_CONFIG_DIR" in os.environ: + base_dir = Path(os.environ["FLA_CONFIG_DIR"]) + else: + # Default: project_dir/fla/configs/ + project_dir = Path(__file__).parent.parent + base_dir = project_dir / "fla" / "configs" + + gpu_name = get_gpu_info() + config_dir = base_dir / gpu_name + config_dir.mkdir(parents=True, exist_ok=True) + return config_dir + + +def process_autotune_file(autotune_file: Path) -> dict[str, Any]: + """ + Process a single Triton autotune.json file and extract best config. + + Returns: + Dictionary with kernel info and best config, or None if invalid + """ + try: + with open(autotune_file) as f: + data = json.load(f) + + if not isinstance(data, dict) or "configs_timings" not in data: + return None + + # Find the best config (minimum timing) + configs_timings = data["configs_timings"] + if not configs_timings: + return None + + # configs_timings is a list of [config_dict, timing] + best_entry = min(configs_timings, key=lambda x: x[1] if isinstance(x[1], (int, float)) else x[1][0]) + best_config_dict = best_entry[0] + best_timing = best_entry[1] + + # Extract kernel info from the file path or content + # Example path: ~/.triton/cache/a1b2c3d4/fused_recurrent_fwd_jit_functionn_12345.autotune.json + parts = autotune_file.stem.split('.') + if len(parts) >= 2: + kernel_name = parts[0] + else: + kernel_name = "unknown_kernel" + + # Build output data structure + result = { + "kernel_name": kernel_name, + "source_file": str(autotune_file), + "cache_key": data.get("key", "unknown"), + "timestamp": data.get("timestamp", 0), + "best_config": best_config_dict, + "best_timing": best_timing, + "total_configs_tested": len(configs_timings), + } + + return result + + except Exception as e: + print(f"Error processing {autotune_file}: {e}") + return None + + +def sync_hopper_configs(updated_dir: Path): + """ + Sync Hopper architecture GPU configs (NVIDIA_H100, NVIDIA_H200, NVIDIA_H20). + + When one of these GPU configs is updated, copy its contents to the other directories. + Creates target directories if they don't exist. + + Args: + updated_dir: The directory that was just updated (source directory) + """ + # Define the GPU groups that share the same configs (with NVIDIA_ prefix) + hopper_gpus = ["NVIDIA_H100", "NVIDIA_H800", "NVIDIA_H20"] + + # Get the parent directory (e.g., fla/configs) + parent_dir = updated_dir.parent + + # Check if the updated_dir name matches any of the hopper GPUs + updated_gpu = updated_dir.name + if updated_gpu not in hopper_gpus: + return # Not a hopper GPU, skip sync + + print(f"\nSyncing Hopper configs from {updated_gpu} to other GPUs: {hopper_gpus}") + print("-" * 60) + + # Copy to other hopper GPU directories + sync_count = 0 + for target_gpu in hopper_gpus: + if target_gpu == updated_gpu: + continue # Skip self + + target_dir = parent_dir / target_gpu + + try: + # Create target directory if it doesn't exist + target_dir.mkdir(parents=True, exist_ok=True) + + # Copy all files from updated_dir to target_dir + for config_file in updated_dir.glob("*.json"): + if config_file.is_file(): + shutil.copy2(config_file, target_dir / config_file.name) + + sync_count += 1 + print(f" ✓ Synced to {target_gpu} ({target_dir})") + + except Exception as e: + print(f" ✗ Failed to sync to {target_gpu}: {e}") + + print(f"\nSuccessfully synced configs to {sync_count} GPU directories") + print("=" * 60) + + +def extract_configs(triton_cache_dir: Path, output_dir: Path): + """ + Extract all autotune configs from Triton cache. + + Args: + triton_cache_dir: Triton's cache directory + output_dir: Output directory for extracted configs + """ + if not triton_cache_dir.exists(): + print(f"Triton cache directory not found: {triton_cache_dir}") + return + + # Find all .autotune.json files + autotune_files = list(triton_cache_dir.rglob("*.autotune.json")) + + if not autotune_files: + print(f"No .autotune.json files found in {triton_cache_dir}") + return + + print(f"Found {len(autotune_files)} autotune cache files") + print(f"GPU: {get_gpu_info()}") + print(f"Output directory: {output_dir}") + print("-" * 60) + + # Process each file + exported_count = 0 + for autotune_file in autotune_files: + result = process_autotune_file(autotune_file) + if result is None: + continue + + # Save to output directory + kernel_name = result["kernel_name"] + output_file = output_dir / f"{kernel_name}.json" + + try: + with open(output_file, 'w') as f: + # Only save the best_config for direct lookup + output_data = result["best_config"] + json.dump(output_data, f, indent=2) + + exported_count += 1 + + print(f"\n[{exported_count}] {kernel_name}") + print(f" Source: {autotune_file}") + print(f" Output: {output_file}") + print(f" Best config: {result['best_config']}") + print(f" Timing: {result['best_timing']}") + + except Exception as e: + print(f"Error saving {output_file}: {e}") + + print("\n" + "=" * 60) + print(f"Successfully exported {exported_count} configs to {output_dir}") + print("=" * 60) + sync_hopper_configs(output_dir) + + # Remove Custom Triton Cache + try: + shutil.rmtree(triton_cache_dir) + print(f"Removed Triton cache directory: {triton_cache_dir}") + except Exception as e: + print(f"Warning: Failed to remove {triton_cache_dir}: {e}") + + +def main(): + parser = argparse.ArgumentParser(description='Extract Triton autotune configs') + parser.add_argument( + '--output-dir', '-o', + type=str, + help='Output directory (default: fla/configs/autotune/{GPU})' + ) + parser.add_argument( + '--triton-cache-dir', + type=str, + help='Triton cache directory (default: ~/.triton/cache)' + ) + parser.add_argument( + '--list-only', '-l', + action='store_true', + help='Only list the cache files without extracting' + ) + parser.add_argument( + '--generate-cache', '-g', + action='store_true', + help='Generate new cache with custom temporary directory' + ) + + args = parser.parse_args() + + # Determine directories + if args.generate_cache: + # Generate cache with temporary directory + triton_cache_dir = Path(generate_triton_cache()) + else: + triton_cache_dir = Path(args.triton_cache_dir) if args.triton_cache_dir else get_triton_cache_dir() + + output_dir = Path(args.output_dir) if args.output_dir else get_fla_config_dir() + + if args.list_only: + # Just list the files + if not triton_cache_dir.exists(): + print(f"Triton cache directory not found: {triton_cache_dir}") + return + + autotune_files = list(triton_cache_dir.rglob("*.autotune.json")) + print(f"Found {len(autotune_files)} .autotune.json files in {triton_cache_dir}:\n") + + for i, file in enumerate(autotune_files, 1): + print(f"{i}. {file}") + return + + # Extract configs + extract_configs(triton_cache_dir, output_dir) + + +def generate_triton_cache(): + """Generate Triton cache with custom directory.""" + import torch + + from fla.ops.kda import chunk_kda + from fla.utils import device + + # Create a custom directory in the project for Triton cache + project_dir = Path(__file__).parent.parent + custom_cache_dir = project_dir / "tmp_triton_cache" + cache_path = custom_cache_dir / "triton" / "cache" + + # Clear and create the directory + if custom_cache_dir.exists(): + shutil.rmtree(custom_cache_dir) + cache_path.mkdir(parents=True, exist_ok=True) + + print(f"Using custom Triton cache directory: {cache_path}") + os.environ["TRITON_CACHE_DIR"] = str(cache_path) + + # Generate cache by running the kernels + torch.manual_seed(42) + dtype = torch.bfloat16 + # Just for DEMO. + B, T, H, D = 1, 8192, 32, 128 + + q = torch.rand(B, T, H, D, dtype=dtype) + k = torch.rand(B, T, H, D, dtype=dtype) + v = torch.rand(B, T, H, D, dtype=dtype) + g = torch.randn(B, T, H, D, dtype=dtype) + A_log = torch.randn(H, dtype=torch.float) + dt_bias = torch.randn(H * D, dtype=torch.float) + beta = torch.randn(B, T, H, dtype=dtype).sigmoid() + h0 = torch.randn(B, H, D, D, dtype=torch.float32) + A_log, dt_bias = map(lambda x: x.to(device).requires_grad_(True), (A_log, dt_bias)) + q, k, v, g, beta, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, g, beta, h0)) + + do = torch.randn_like(v) + dht = torch.randn_like(h0) + + tri, tri_ht = chunk_kda( + q=q.clone(), + k=k.clone(), + v=v.clone(), + g=g.clone().float(), + beta=beta.clone(), + A_log=A_log.clone(), + dt_bias=dt_bias.clone(), + scale=None, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=True, + use_gate_in_kernel=False, + safe_gate=True, + lower_bound=-5, + ) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri0, tri_ht0 = chunk_kda( + q=q.clone(), + k=k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=A_log.clone(), + dt_bias=dt_bias.clone(), + scale=None, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=True, + use_gate_in_kernel=True, + safe_gate=True, + lower_bound=-5, + ) + ((tri0 * do).sum() + (tri_ht0 * dht).sum()).backward() + + W = 4 + x = torch.randn(B, T, H*D).to(device, dtype).requires_grad_(True) + weight = torch.randn(H*D, W).to(device, dtype).requires_grad_(True) + bias = None + + dy = torch.randn(B, T, H*D).to(device, dtype) + + from fla.modules.convolution import causal_conv1d + tri, _ = causal_conv1d(x, weight, bias, residual=None, activation="silu") + tri.backward(dy) + + return str(cache_path) + + +if __name__ == "__main__": + # Check if we should extract to fla/configs + if "--extract-to-fla-configs" in sys.argv: + # Generate cache with temporary directory + triton_cache_dir = Path(generate_triton_cache()) + + # Compute output directory in fla/configs (relative to project root) + project_dir = Path(__file__).parent.parent + gpu_name = get_gpu_info() + output_dir = project_dir / "fla" / "configs" / gpu_name + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\nExtracting configs to: {output_dir}") + extract_configs(triton_cache_dir, output_dir) + else: + main()