Skip to content

Jax/Flax models 2x slower on Sapphire Rapids (c7i) than Ice Lake (c6i) instances | x86 #23296

Answered by mdfaijul
Rohanjames1997 asked this question in Q&A
Discussion options

You must be logged in to vote

@Rohanjames1997 @penpornk We also measured Huggingface bert-base-uncased model and observed improved performance with bfloat16 numeric. Example code added at the end.

c7i.4xlarge (Sapphire Rapids)
Run commands for the attached code:

export XLA_FLAGS=--xla_cpu_use_thunk_runtime=false

1. python bert.py
2. DNNL_DEFAULT_FPMATH_MODE=BF16 python bert.py
3. python bert.py --precision bfloat16

Configuration Throughput (examples/sec)
1. No setting of fpmath-mode 35.898
2. FPMATH_MODE=BF16 71.607
3. Using jax.numpy.blfoat16 directly 85.289

c6i.4xlarge (Ice Lake)
Run commands for the attached code:

export XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
python bert.py
Configuration T…

Replies: 5 comments 8 replies

Comment options

You must be logged in to vote
3 replies
@Rohanjames1997
Comment options

@penpornk
Comment options

@agramesh1
Comment options

Comment options

You must be logged in to vote
1 reply
@agramesh1
Comment options

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@mdfaijul
Comment options

@Rohanjames1997
Comment options

Answer selected by Rohanjames1997
Comment options

You must be logged in to vote
2 replies
@Rohanjames1997
Comment options

@agramesh1
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants