Skip to content

Have fp8 GEMMs go to hipblaslt only#3990

Merged
causten merged 15 commits intodevelopfrom
fp8_hipblaslt_only
May 16, 2025
Merged

Have fp8 GEMMs go to hipblaslt only#3990
causten merged 15 commits intodevelopfrom
fp8_hipblaslt_only

Conversation

@CharlieL7
Copy link
Collaborator

@CharlieL7 CharlieL7 commented May 5, 2025

  • Will always use hipblaslt for fp8 GEMMs with this PR.
  • FP8 will be eliminated if hipblaslt is not avaliable.
  • I tested on MI300 with the different flags and env variables
  • Depends on Remove rocblas fp8 #3985

@CharlieL7 CharlieL7 self-assigned this May 5, 2025
@CharlieL7 CharlieL7 requested a review from causten as a code owner May 5, 2025 21:25
@CharlieL7 CharlieL7 added the FP8 issues related to FP8 implemenation label May 5, 2025
@CharlieL7 CharlieL7 requested a review from pfultz2 May 5, 2025 21:25
@CharlieL7 CharlieL7 linked an issue May 5, 2025 that may be closed by this pull request
@CharlieL7 CharlieL7 changed the base branch from develop to rocblas_fp8_remove May 5, 2025 21:29
@CharlieL7 CharlieL7 requested a review from a team as a code owner May 5, 2025 21:29
@CharlieL7 CharlieL7 changed the base branch from rocblas_fp8_remove to develop May 5, 2025 21:29
@CharlieL7 CharlieL7 removed the request for review from a team May 5, 2025 21:29
Copy link
Contributor

@spolifroni-amd spolifroni-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing to review for docs

if(string_value_of(MIGRAPHX_SET_GEMM_PROVIDER{}) == "rocblas" or gpu::gfx_default_rocblas())

// disable dot & quant_dot if no hipblaslt
if(not hipblaslt_supported())
Copy link
Contributor

@ahsan-ca ahsan-ca May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we still not want to check for gpu::gfx_default_rocblas() for the case when hipblaslt is supported for the arch but we default to use rocblas for it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With next rocm release rocblas is removing all fp8 support so only need to check we if have hipblaslt for fp8.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering for the case of gfx90a for which hipblaslt is supported but we default to use rocblas. So hipblaslt_supported() will return true, but we may be using rocblas for it (if default has not been overridden).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lowering has been changed in this PR to move all FP8 gemm instructions to hipblaslt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering for the case of gfx90a for which hipblaslt is supported but we default to use rocblas. So hipblaslt_supported() will return true, but we may be using rocblas for it (if default has not been overridden).

For posterity: For gfx90a even though it would default to rocblas, for fp8 it would still use hipblaslt. That is intentional.

Copy link
Contributor

@ahsan-ca ahsan-ca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for making the change.

@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
cf02e0
Rate old
03922b
Diff Compare
torchvision-resnet50 64 3,256.91 3,257.40 -0.01%
torchvision-resnet50_fp16 64 6,930.25 6,937.63 -0.11%
torchvision-densenet121 32 2,457.23 2,454.67 0.10%
torchvision-densenet121_fp16 32 4,229.96 4,234.91 -0.12%
torchvision-inceptionv3 32 1,630.19 1,627.74 0.15%
torchvision-inceptionv3_fp16 32 2,717.58 2,720.93 -0.12%
cadene-inceptionv4 16 761.68 761.38 0.04%
cadene-resnext64x4 16 819.42 819.11 0.04%
slim-mobilenet 64 7,476.15 7,472.34 0.05%
slim-nasnetalarge 64 217.23 217.12 0.05%
slim-resnet50v2 64 3,352.51 3,349.86 0.08%
bert-mrpc-onnx 8 1,153.45 1,151.17 0.20%
bert-mrpc-tf 1 457.20 455.79 0.31%
pytorch-examples-wlang-gru 1 365.97 344.28 6.30% 🔆
pytorch-examples-wlang-lstm 1 473.55 481.34 -1.62%
torchvision-resnet50_1 1 817.66 814.48 0.39%
cadene-dpn92_1 1 432.35 431.59 0.18%
cadene-resnext101_1 1 393.69 393.45 0.06%
onnx-taau-downsample 1 395.87 395.60 0.07%
dlrm-criteoterabyte 1 32.32 32.32 -0.00%
dlrm-criteoterabyte_fp16 1 51.23 51.26 -0.06%
agentmodel 1 10,309.83 10,714.81 -3.78% 🔴
unet_fp16 2 59.56 59.52 0.06%
resnet50v1_fp16 1 1,081.21 1,076.03 0.48%
resnet50v1_int8 1 1,058.48 1,049.50 0.86%
bert_base_cased_fp16 64 1,170.78 1,170.45 0.03%
bert_large_uncased_fp16 32 357.87 357.91 -0.01%
bert_large_fp16 1 200.06 200.20 -0.07%
distilgpt2_fp16 16 2,239.62 2,240.98 -0.06%
yolov5s 1 542.37 542.66 -0.05%
tinyllama 1 43.89 43.87 0.05%
vicuna-fastchat 1 45.10 45.26 -0.35%
whisper-tiny-encoder 1 420.76 421.44 -0.16%
whisper-tiny-decoder 1 411.76 413.06 -0.31%
llama2_7b 1 nan nan nan%
qwen1.5-7b 1 23.55 23.55 -0.01%
phi3-3.8b 1 26.63 26.64 -0.06%
mask-rcnn 1 12.81 12.79 0.12%
llama3-8b 1 21.77 21.77 -0.01%
whisper-large-encoder 1 10.22 10.22 -0.02%
whisper-large-decoder 1 101.11 100.83 0.28%
mistral-7b 1 23.76 23.78 -0.11%
FLUX.1-schnell 1 913.10 906.91 0.68%
nan nan nan nan nan%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

❌bert-mrpc-tf: ERROR - check error output2025-05-14 14:32:50.791326: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1747251176.263869 163426 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 62973 MB memory: -> device: 0, name: AMD Instinct MI250X/MI250, pci bus id: 0000:b3:00.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1747251177.099417 163426 mlir_graph_optimization_pass.cc:401] MLIR V1 optimization pass is not enabled
2025-05-14 14:33:06.724724: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-14 14:33:06.724896: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-14 14:33:06.724943: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-14 14:33:06.724991: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-14 14:33:06.725042: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-14 14:33:06.725091: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-14 14:33:06.725133: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-14 14:33:06.725181: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
2025-05-14 14:33:06.726065: E tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc:228] INTERNAL: Generating device code failed.
2025-05-14 14:33:06.727112: W tensorflow/core/framework/op_kernel.cc:1829] UNKNOWN: JIT compilation failed.
2025-05-14 14:33:06.727130: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
2025-05-14 14:33:06.727141: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
[[import/loss/output/_21]]
2025-05-14 14:33:06.727178: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11217777527359497193
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1407, in _do_call
return fn(*args)
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1390, in _run_fn
return self._call_tf_sessionrun(options, feed_dict, fetch_list,
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1483, in _call_tf_sessionrun
return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
tensorflow.python.framework.errors_impl.UnknownError: 2 root error(s) found.
(0) UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
[[import/loss/output/_21]]
(1) UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
0 successful operations.
0 derived errors ignored.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 324, in main
y_out = sess.run(y, feed_dict=tf_dict)
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 977, in run
result = self._run(None, fetches, feed_dict, options_ptr,
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1220, in _run
results = self._do_run(handle, final_targets, final_fetches,
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1400, in _do_run
return self._do_call(_run_fn, feeds, fetches, targets, options,
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1426, in _do_call
raise type(e)(node_def, op, message) # pylint: disable=no-value-for-parameter
tensorflow.python.framework.errors_impl.UnknownError: Graph execution error:

Detected at node 'import/bert/embeddings/LayerNorm/moments/SquaredDifference' defined at (most recent call last):
Node: 'import/bert/embeddings/LayerNorm/moments/SquaredDifference'
Detected at node 'import/bert/embeddings/LayerNorm/moments/SquaredDifference' defined at (most recent call last):
Node: 'import/bert/embeddings/LayerNorm/moments/SquaredDifference'
2 root error(s) found.
(0) UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
[[import/loss/output/_21]]
(1) UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
0 successful operations.
0 derived errors ignored.

Original stack trace for 'import/bert/embeddings/LayerNorm/moments/SquaredDifference':


     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

🔴unet: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

❌llama2_7b: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:265: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/llama2_7b/decoder_model.onnx


❌qwen1.5-7b: ERROR - check error outputusage: accuracy_checker.py [-h] [--onnx ONNX] [--tf TF] [--provider PROVIDER]
[--batch BATCH] [--fill1] [--fill0] [--fp16]
[--argmax] [--verbose] [--tolerance TOLERANCE]
[--input-dim INPUT_DIM] [--target TARGET]
[--ort-run] [--ort-logging]
[--disable-offload-copy] [--disable-fast-math]
[--exhaustive_tune]
accuracy_checker.py: error: unrecognized arguments: input_ids attention_mask position_ids 1 256 @attention_mask 1 256 @position_ids 1 256


❌phi3-3.8b: ERROR - check error outputusage: accuracy_checker.py [-h] [--onnx ONNX] [--tf TF] [--provider PROVIDER]
[--batch BATCH] [--fill1] [--fill0] [--fp16]
[--argmax] [--verbose] [--tolerance TOLERANCE]
[--input-dim INPUT_DIM] [--target TARGET]
[--ort-run] [--ort-logging]
[--disable-offload-copy] [--disable-fast-math]
[--exhaustive_tune]
accuracy_checker.py: error: unrecognized arguments: input_ids attention_mask position_ids 1 256 @attention_mask 1 256 @position_ids 1 256


❌mask-rcnn: ERROR - check error outputusage: accuracy_checker.py [-h] [--onnx ONNX] [--tf TF] [--provider PROVIDER]
[--batch BATCH] [--fill1] [--fill0] [--fp16]
[--argmax] [--verbose] [--tolerance TOLERANCE]
[--input-dim INPUT_DIM] [--target TARGET]
[--ort-run] [--ort-logging]
[--disable-offload-copy] [--disable-fast-math]
[--exhaustive_tune]
accuracy_checker.py: error: unrecognized arguments: 3 800 800


     ✅ llama3-8b: PASSED: MIGraphX meets tolerance

❌#whisper-large-encoder: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/include/migraphx/op/convolution.hpp:100: normalize_compute_shape: CONVOLUTION: mismatched channel numbers


     ✅ whisper-large-decoder: PASSED: MIGraphX meets tolerance

     ✅ mistral-7b: PASSED: MIGraphX meets tolerance

     ✅ FLUX.1-schnell: PASSED: MIGraphX meets tolerance

@causten causten merged commit 1568dfe into develop May 16, 2025
47 of 49 checks passed
@causten causten deleted the fp8_hipblaslt_only branch May 16, 2025 14:52
causten pushed a commit that referenced this pull request May 26, 2025
Will always use hipblaslt for fp8 GEMMs with this PR.
FP8 will be eliminated if hipblaslt is not avaliable.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

FP8 issues related to FP8 implemenation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fallback to hipblaslt for fp8 GEMMs

7 participants