Replies: 2 comments 5 replies
-
Could anyone please give any feedback or suggestion on how to optimize the training for the torch backend? Referring to this page, I see one possible way to try is to use the Cuda Graph. But the underlying torch.nn.Module object must have been wrapped inside the keras.models.Model. I'm wondering if keras 3 provides any API to access underlying torch models to make such optimization possible? |
Beta Was this translation helpful? Give feedback.
-
Hello Jeff, To the best of my knowledge, unless PyTorch uses highly optimized linear algebra compilers, CUDA based libs, it is very likely PyTorch to be slower. PyTorch benchmarks often have compiler-aware codes in different LLMs that actually cause that speed up. For most of the time JAX backend is usually the fastest for inference and if using Best Regards, |
Beta Was this translation helpful? Give feedback.
-
Referring to the Keras 3 benchmark page, it is seen that the PyTorch is significantly slower than the Tensorflow in the training. The following result is extracted from this page to compare the training speed.
The speed is measured in ms/step. Lower is better.
There is a reason why we investigated this benchmarking result. For our own DL model, the training speed also became significantly slower (50% to 60% slower), after we migrated from Tensorflow 2 (with GPU/Cuda support) to Keras 3 (with PyTorch+cuda) on Windows. On Windows, the Tensorflow 2.10+ has no GPU/Cuda support anymore, which is one of reasons why we wanted to switch to Keras 3 with the PyTorch+cuda backend.
In the above benchmarking page, I also noticed the footnote right below the result (Table 2):
My question is, in which version of the future release of Keras 3+, will the PyTorch slowness issue be addressed? Should this also generally resolve the training speed issue for the PyTorch backend?
In the current Keras 3, are there any ways I can improve the training speed for the PyTorch backend?
Beta Was this translation helpful? Give feedback.
All reactions