diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index c24b912d8..b4db21da4 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -13,7 +13,6 @@ import itertools import logging import tempfile -from collections import OrderedDict from dataclasses import dataclass from typing import ( Any, @@ -216,6 +215,7 @@ def __init__( # noqa C901 pg: Optional[dist.ProcessGroup] = None, create_for_table: Optional[str] = None, param_weight_for_table: Optional[nn.Parameter] = None, + embedding_weights_by_table: Optional[List[torch.Tensor]] = None, ) -> None: """ Implementation of a FusedOptimizer. Designed as a base class Embedding kernels @@ -391,7 +391,9 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( # that state_dict look identical to no-fused version. table_to_shard_params: Dict[str, ShardParams] = {} - embedding_weights_by_table = emb_module.split_embedding_weights() + embedding_weights_by_table = ( + embedding_weights_by_table or emb_module.split_embedding_weights() + ) all_optimizer_states = emb_module.get_optimizer_state() optimizer_states_keys_by_table: Dict[str, List[torch.Tensor]] = {} @@ -674,6 +676,8 @@ def _gen_named_parameters_by_table_fused( pg: Optional[dist.ProcessGroup] = None, ) -> Iterator[Tuple[str, TableBatchedEmbeddingSlice]]: # TODO: move logic to FBGEMM to avoid accessing fbgemm internals + # Cache embedding_weights_by_table + embedding_weights_by_table = emb_module.split_embedding_weights() for t_idx, (rows, dim, location, _) in enumerate(emb_module.embedding_specs): table_name = config.embedding_tables[t_idx].name if table_name not in table_name_to_count: @@ -709,6 +713,7 @@ def _gen_named_parameters_by_table_fused( pg=pg, create_for_table=table_name, param_weight_for_table=weight, + embedding_weights_by_table=embedding_weights_by_table, ) ] yield (table_name, weight)