Jax/Flax models 2x slower on Sapphire Rapids (c7i) than Ice Lake (c6i) instances | x86 #23296
-
ProblemFlax models run upto 2x slower on the latest c7i ec2 instances (Sapphire Rapids) than on c6i instances (Ice Lake) Steps to repro:
ResultThe latency of the script Similar results were seen using flax models such as bert-base-uncased from Huggingface QuestionsPytorch has a blog that claims that AMX is auto-picked if available, and that it improves performance. CodeFlax MLP
|
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 8 replies
-
XLA:CPU supports AMX in contraction ops through custom calls to oneDNN. We have recently transitioned to a new runtime which doesn't support these oneDNN custom calls yet (support coming soon in 1-2 weeks). In the meanwhile, the old runtime support these custom calls and can use AMX. Setting the environment variable |
Beta Was this translation helpful? Give feedback.
-
cc: @agramesh1 @TensorFlow-MKL (Intel oneDNN-XLA integration team) |
Beta Was this translation helpful? Give feedback.
-
@Rohanjames1997 @penpornk We have tested the code on both c6i.4xlarge and c7i.4xlarge ec2 instances. XLA_FLAGS environment variable has been set as c7i.4xlarge (Sapphire Rapids)
c6i.4xlarge (Ice Lake)
The performance difference between Sapphire Rapids and Ice Lake for float32 numeric can be attributed to the higher frequency of Ice Lake. Code using
|
Beta Was this translation helpful? Give feedback.
-
@Rohanjames1997 @penpornk We also measured Huggingface bert-base-uncased model and observed improved performance with bfloat16 numeric. Example code added at the end. c7i.4xlarge (Sapphire Rapids)
c6i.4xlarge (Ice Lake)
|
Beta Was this translation helpful? Give feedback.
-
@Rohanjames1997 the recommended way to use AMX on Sapphire Rapids in JAX/FLAX is by using the bfloat16 datatype as @mdfaijul has shown. You can also use DNNL_DEFAULT_FPMATH_MODE=BF16 but it will not give you the full benefits of using AMX and Sapphire Rapids. |
Beta Was this translation helpful? Give feedback.
@Rohanjames1997 @penpornk We also measured Huggingface bert-base-uncased model and observed improved performance with bfloat16 numeric. Example code added at the end.
c7i.4xlarge (Sapphire Rapids)
Run commands for the attached code:
jax.numpy.blfoat16
directlyc6i.4xlarge (Ice Lake)
Run commands for the attached code: