Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable fp8 on sm89 #3624

Merged
merged 4 commits into from
Dec 21, 2024
Merged

Enable fp8 on sm89 #3624

merged 4 commits into from
Dec 21, 2024

Conversation

jjsjann123
Copy link
Collaborator

fp8's supported has been lifted to sm89 since PTX ISA 8.1 and later per https://docs.nvidia.com/cuda/parallel-thread-execution/

@jjsjann123
Copy link
Collaborator Author

!test

Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you! now I can see the same traces as in Lightning-AI/lightning-thunder#1551 on my environment with RTX6000 Ada, with a diff in thunder

@jacobhinkle
Copy link
Collaborator

fp8's supported has been lifted to sm89 since PTX ISA 8.1 and later per https://docs.nvidia.com/cuda/parallel-thread-execution/

Does that technically mean we only support CUDA 12+ for this feature?

@jjsjann123
Copy link
Collaborator Author

fp8's supported has been lifted to sm89 since PTX ISA 8.1 and later per https://docs.nvidia.com/cuda/parallel-thread-execution/

Does that technically mean we only support CUDA 12+ for this feature?

good call. I think I should conditionally relax this one, depending on the build time CUDA version.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123 jjsjann123 merged commit 410e48f into main Dec 21, 2024
48 checks passed
@jjsjann123 jjsjann123 deleted the fp8_enable_on_sm89 branch December 21, 2024 13:58
jjsjann123 added a commit that referenced this pull request Dec 24, 2024
Fixing a version check for fp8 support.

bump nvfuser version for PR #3624, Framework integration needs to guard
against versions in order to decide whether to send fp8 operations to
nvfuser
@naoyam
Copy link
Collaborator

naoyam commented Dec 31, 2024

@jjsjann123 I'm seeing an error on RTX 6000 (sm_89):

[ RUN      ] NVFuserTest.FusionFp8CastOps_CUDA
unknown file: Failure
C++ exception with description " INTERNAL ASSERT FAILED at "/home/nmaruyama/nvfuser/debug3/csrc/runtime/executor_utils.cpp":859, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues.
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<__bfloat, 2, 2> T0, Tensor<__bfloat, 2, 2> T2) {
  __e4m3 T1[1LL];
  T1[0LL]
     = __bfloat2e4m3(T0[((T0.alloc_stride[1LL] * ((nvfuser_index_t)threadIdx.x)) + (T0.alloc_stride[0LL] * ((nvfuser_index_t)blockIdx.x)))]);
  T2[(((nvfuser_index_t)threadIdx.x) + (T0.logical_size[1LL] * ((nvfuser_index_t)blockIdx.x)))]
     = __e4m32bfloat(T1[0LL]);
}
}

CUDA NVRTC compile error: ptxas application ptx input, line 47; error   : Feature 'cvt with .f16.bf16' requires .target sm_90 or higher
ptxas application ptx input, line 58; error   : Feature 'cvt with .bf16.f16' requires .target sm_90 or higher
ptxas fatal   : Ptx assembly aborted due to errors

Exception raised from invoke at /home/nmaruyama/nvfuser/debug3/csrc/runtime/executor_utils.cpp:859 (most recent call first):

@jjsjann123
Copy link
Collaborator Author

Did I get the cuda TK check wrong?! I thought CUDA TK version would determine PTX ISA version...

Are you running this in a container? I'm curious how the setup is like.
sm_89 should have fp8 support since cuda 12.1.

@naoyam
Copy link
Collaborator

naoyam commented Dec 31, 2024

This is on my own container with 12.6.

@jjsjann123
Copy link
Collaborator Author

wait, it's not complaining about fp8 though..

cvt with .bf16.f16

@jjsjann123
Copy link
Collaborator Author

looks like cvt to/from bf16 does require sm_90. https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt

I wonder why our check is only requiring sm_80.

if (val->dtype() == DataType::BFloat16) {
ensureVersion(
{8, 0},
"Fusion contains BFloat16 values which was introduced in Ampere (8.0)");
}

Looks like this is just a test thing. I'll update that along with the checks. Thanks for raising the issue @naoyam

@jjsjann123
Copy link
Collaborator Author

.relu modifier and {.f16x2, .bf16, .bf16x2, .tf32} destination formats require sm_80 or higher.
cvt.bf16.{u8/s8/u16/s16/u32/s32/u64/s64/f16/f64/bf16}, cvt.{u8/s8/u16/s16/u32/s32/u64/s64/f16/f64}.bf16, and cvt.tf32.f32.{relu}.{rn/rz} require sm_90 or higher.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants