Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cache embedding_weights_by_table for EmbeddingFusedOptimizer
Summary: The `split_embedding_weights()` method in the `emb_module` is a time-consuming operation. Currently, it is placed in the constructor of the `EmbeddingFusedOptimizer`. As a result, every time an `EmbeddingFusedOptimizer` instance is created, this method is executed. Since `_gen_named_parameters_by_table_fused` generates EmbeddingFusedOptimizer instances **thousands of times in a loop**, a significant amount of time is spent executing this method. By extracting this operation out of the loop and passing it as a parameter to achieve a caching effect, we can save a lot of time. Specifically, the current **CREATE_TRAIN_MODULE.SHARD_MODEL** takes approximately **22 seconds** to run, but with this caching mechanism, the runtime can be reduced to around **15 seconds**. Reviewed By: dstaay-fb Differential Revision: D68578829
- Loading branch information