Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
.DS_Store
.vscode
__pycache__

# build results
build/
dist/
*.egg-info
*.so

# checkpoints
venv/
__pycache__/
.vscode/
checkpoints/

# outputs
output*/
*.mp4
output/
*.egg-info/
172 changes: 172 additions & 0 deletions README_AMD_WINDOWS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# TurboDiffusion - AMD ROCm on Windows Setup Guide

This guide explains how to build and run TurboDiffusion on Windows with AMD GPUs using ROCm.

> **Note:** These steps should also work on Linux with minor modifications (use bash commands instead of PowerShell, `source venv/bin/activate` instead of `.\venv\Scripts\Activate.ps1`, and skip the Visual Studio environment setup). However, Linux support has not been tested yet and may have issues.

## Supported Hardware

TurboDiffusion on Windows has been tested with RDNA3/RDNA3.5 GPUs (gfx1100, gfx1101, gfx1102, gfx1151).

## Prerequisites

- Windows 10/11
- Python 3.11, 3.12, or 3.13
- Visual Studio 2022 with C++ build tools
- AMD Adrenaline driver (latest recommended)

## Installation

### 1. Install ROCm and PyTorch from TheRock

Follow the instructions at [ROCm/TheRock RELEASES.md](https://github.com/ROCm/TheRock/blob/main/RELEASES.md) to install ROCm and PyTorch wheels for your GPU architecture.

#### Create a Virtual Environment

```powershell
python -m venv venv
.\venv\Scripts\Activate.ps1
```

#### Install PyTorch (includes ROCm SDK as dependency)

For **gfx1151** (AMD Strix Halo iGPU):
```powershell
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ --pre torch torchaudio torchvision
```

For **gfx110X** (RX 7900 XTX, RX 7800 XT, RX 7700S, Radeon 780M):
```powershell
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/ --pre torch torchaudio torchvision
```

For **gfx120X** (RX 9060, RX 9070):
```powershell
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ --pre torch torchaudio torchvision
```

#### Initialize ROCm SDK

```powershell
rocm-sdk init
```

#### Install Triton with AMD Windows Support

```powershell
pip install triton-windows
```

### 2. Set Environment Variables

Open a PowerShell terminal and run:

```powershell
# Activate Visual Studio environment
cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } }

# Activate the virtual environment
.\venv\Scripts\Activate.ps1

# Set ROCm paths using rocm-sdk
$ROCM_ROOT = (rocm-sdk path --root).Trim()
$ROCM_BIN = (rocm-sdk path --bin).Trim()
$env:ROCM_HOME = $ROCM_ROOT
$env:PATH = "$ROCM_ROOT\lib\llvm\bin;$ROCM_BIN;$env:PATH"

# Set compiler and build settings
$env:CC = "clang-cl"
$env:CXX = "clang-cl"
$env:DISTUTILS_USE_SDK = "1"

# Enable experimental features
$env:FLASH_ATTENTION_TRITON_AMD_ENABLE = "TRUE"
$env:TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL = "1"

# Set PYTHONPATH for TurboDiffusion
$env:PYTHONPATH = "turbodiffusion"
```

### 3. Build and Install TurboDiffusion

```powershell
cd <path_to_turbodiffusion>
pip install --no-build-isolation -e .
```

### 4. Install SpargeAttn (Optional, for sparse attention)

If you want to use sparse attention with TurboDiffusion, clone the AMD Windows fork:

```powershell
git clone --branch jam/amd_windows https://github.com/jammm/SpargeAttn.git
cd SpargeAttn
pip install --no-build-isolation -v .
```

## Running Inference

### Text-to-Video with Wan2.1

```powershell
# Make sure environment variables are set (see step 2)

python turbodiffusion/inference/wan2.1_t2v_infer.py `
--model Wan2.1-1.3B `
--dit_path checkpoints/TurboWan2.1-T2V-1.3B-480P-quant.pth `
--resolution 480p `
--prompt "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage." `
--num_samples 1 `
--num_steps 4 `
--quant_linear `
--attention_type sagesla `
--sla_topk 0.1
```

### Available Attention Types

- `sdpa` - PyTorch Scaled Dot Product Attention
- `sagesla` - SageAttention with Sparse Linear Attention (requires SpargeAttn)

## Environment Variable Summary

| Variable | Value | Description |
|----------|-------|-------------|
| `CC` | `clang-cl` | C compiler |
| `CXX` | `clang-cl` | C++ compiler |
| `DISTUTILS_USE_SDK` | `1` | Use SDK for distutils |
| `ROCM_HOME` | `<rocm-sdk path --root>` | ROCm SDK root path |
| `PATH` | Include LLVM and ROCm bin | Required for hipcc, clang, lld-link |
| `FLASH_ATTENTION_TRITON_AMD_ENABLE` | `TRUE` | Enable Triton Flash Attention on AMD |
| `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL` | `1` | Enable experimental aotriton kernels |
| `PYTHONPATH` | `turbodiffusion` | Include turbodiffusion module |

## Known Issues

1. **Triton compiler warnings** - You may see `clang-cl: warning: unknown argument ignored` warnings during first run. These are harmless.

2. **First run is slow** - Triton and MIOpen kernels are compiled on first use and cached. Subsequent runs will be faster.

3. **No FP8 support on RDNA3** - RDNA3 GPUs don't support FP8, so FP16/BF16 kernels are used.

## Troubleshooting

### "LoadLibrary failed" or "cannot find amdhip64.dll"

Make sure you ran `rocm-sdk init` after installing the ROCm SDK packages.

### "LINK : fatal error LNK1104: cannot open file 'python312.lib'"

Ensure Visual Studio environment is activated before building:
```powershell
cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } }
```

### "PermissionError" when compiling Triton kernels

This is a known Windows issue with temp file handling. Make sure you're using the latest `triton-windows` package (`pip install --upgrade triton-windows`).

### "flash_attn is not installed" warning

This warning is expected. Flash Attention is not available on AMD GPUs, but Triton-based attention is used instead when `FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE` is set.

Binary file added build_ext_log.txt
Binary file not shown.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
dependencies = [
"torch>=2.7.0",
"torchvision",
"triton>=3.3.0",
"triton-windows>=3.3.0",
"flash-attn",
"einops",
"numpy",
Expand Down
1 change: 1 addition & 0 deletions rocwmma_lib
Submodule rocwmma_lib added at c360d5
155 changes: 106 additions & 49 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Copyright (c) 2025 by TurboDiffusion team.

Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License")

Citation (please cite if you use this code):

Expand All @@ -16,60 +16,117 @@
from pathlib import Path
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os
import sys

import torch

is_rocm = torch.version.hip is not None

# On Windows, deduplicate INCLUDE/LIB/LIBPATH to avoid "command line too long" errors
if sys.platform == 'win32':
for var in ['INCLUDE', 'LIB', 'LIBPATH']:
val = os.environ.get(var, '')
if val:
unique = []
seen = set()
for p in val.split(';'):
if p.lower() not in seen and p:
seen.add(p.lower())
unique.append(p)
os.environ[var] = ';'.join(unique)

ops_dir = Path(__file__).parent / "turbodiffusion" / "ops"
cutlass_dir = ops_dir / "cutlass"
rocwmma_dir = Path(__file__).parent / "rocwmma_lib" / "projects" / "rocwmma" / "library" / "include"

nvcc_flags = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=--verbose,--warn-on-local-memory-usage",
"-lineinfo",
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
"-DNDEBUG",
"-Xcompiler",
"-fPIC"
]
if is_rocm:
# HIP/ROCm build with rocWMMA
hip_flags = [
"-O3",
"-std=c++17",
"-D__HIP_PLATFORM_AMD__",
"-DNDEBUG",
# Undefine PyTorch's half conversion restrictions - rocWMMA needs these
"-U__HIP_NO_HALF_OPERATORS__",
"-U__HIP_NO_HALF_CONVERSIONS__",
]

# Windows-specific: add C/C++ runtime libraries for clang-cl
extra_libraries = []
extra_link_args = []
if sys.platform == 'win32':
extra_libraries = ["msvcrt", "vcruntime", "ucrt"]
# Force linking with MSVC C++ runtime
extra_link_args = ["/DEFAULTLIB:msvcprt"]

ext_modules = [
CUDAExtension(
name="turbo_diffusion_ops",
sources=[
"turbodiffusion/ops/bindings.cpp",
"turbodiffusion/ops/quant/quant.hip",
"turbodiffusion/ops/norm/rmsnorm.hip",
"turbodiffusion/ops/norm/layernorm.hip",
"turbodiffusion/ops/gemm/gemm_rocwmma.hip",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17", "-D__HIP_PLATFORM_AMD__"],
"nvcc": hip_flags,
},
include_dirs=[
str(rocwmma_dir),
str(ops_dir),
],
libraries=extra_libraries,
extra_link_args=extra_link_args,
)
]
else:
# CUDA build with CUTLASS
nvcc_flags = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"-lineinfo",
"-DNDEBUG",
]

cc_flag = [
"-gencode", "arch=compute_120a,code=sm_120a",
"-gencode", "arch=compute_100,code=sm_100",
"-gencode", "arch=compute_90,code=sm_90",
"-gencode", "arch=compute_89,code=sm_89",
"-gencode", "arch=compute_80,code=sm_80"
]
cc_flag = [
"-gencode", "arch=compute_120a,code=sm_120a",
"-gencode", "arch=compute_100,code=sm_100",
"-gencode", "arch=compute_90,code=sm_90",
"-gencode", "arch=compute_89,code=sm_89",
"-gencode", "arch=compute_80,code=sm_80"
]

ext_modules = [
CUDAExtension(
name="turbo_diffusion_ops",
sources=[
"turbodiffusion/ops/bindings.cpp",
"turbodiffusion/ops/quant/quant.cu",
"turbodiffusion/ops/norm/rmsnorm.cu",
"turbodiffusion/ops/norm/layernorm.cu",
"turbodiffusion/ops/gemm/gemm.cu"
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"],
"nvcc": nvcc_flags + ["-DEXECMODE=0"] + cc_flag + ["--threads", "4"],
},
include_dirs=[
cutlass_dir / "include",
cutlass_dir / "tools" / "util" / "include",
ops_dir
],
libraries=["cuda"],
)
]
ext_modules = [
CUDAExtension(
name="turbo_diffusion_ops",
sources=[
"turbodiffusion/ops/bindings.cpp",
"turbodiffusion/ops/quant/quant.cu",
"turbodiffusion/ops/norm/rmsnorm.cu",
"turbodiffusion/ops/norm/layernorm.cu",
"turbodiffusion/ops/gemm/gemm.cu"
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"],
"nvcc": nvcc_flags + ["-DEXECMODE=0"] + cc_flag,
},
include_dirs=[
str(cutlass_dir / "include"),
str(cutlass_dir / "tools" / "util" / "include"),
str(ops_dir),
],
libraries=["cuda"],
)
]

setup(
packages=find_packages(
Expand Down
Loading