Skip to content

Commit

Permalink
add QuantManagedCollisionEmbeddingBagCollection
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Feb 5, 2025
1 parent 9269e73 commit 1ed6a6a
Show file tree
Hide file tree
Showing 3 changed files with 345 additions and 7 deletions.
9 changes: 8 additions & 1 deletion torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,18 @@
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection
from torchrec.modules.mc_embedding_modules import (
ManagedCollisionEmbeddingBagCollection,
ManagedCollisionEmbeddingCollection,
)
from torchrec.quant.embedding_modules import (
EmbeddingCollection as QuantEmbeddingCollection,
FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection,
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
MODULE_ATTR_REGISTER_TBES_BOOL,
quant_prep_enable_quant_state_dict_split_scale_bias_for_types,
quant_prep_enable_register_tbes,
QuantManagedCollisionEmbeddingBagCollection,
QuantManagedCollisionEmbeddingCollection,
)

Expand Down Expand Up @@ -333,6 +337,7 @@ def quantize(
module_types: List[Type[torch.nn.Module]] = [
torchrec.modules.embedding_modules.EmbeddingBagCollection,
torchrec.modules.embedding_modules.EmbeddingCollection,
torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingBagCollection,
torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingCollection,
]
if register_tbes:
Expand All @@ -359,11 +364,13 @@ def quantize(
qconfig_spec={
EmbeddingBagCollection: qconfig,
EmbeddingCollection: qconfig,
ManagedCollisionEmbeddingBagCollection: qconfig,
ManagedCollisionEmbeddingCollection: qconfig,
},
mapping={
EmbeddingBagCollection: QuantEmbeddingBagCollection,
EmbeddingCollection: QuantEmbeddingCollection,
ManagedCollisionEmbeddingBagCollection: QuantManagedCollisionEmbeddingBagCollection,
ManagedCollisionEmbeddingCollection: QuantManagedCollisionEmbeddingCollection,
},
inplace=inplace,
Expand Down
171 changes: 167 additions & 4 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@
PositionWeightedModuleCollection,
)
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection
from torchrec.modules.mc_embedding_modules import (
ManagedCollisionEmbeddingBagCollection,
ManagedCollisionEmbeddingCollection,
)
from torchrec.modules.mc_modules import (
DistanceLFU_EvictionPolicy,
ManagedCollisionCollection,
Expand Down Expand Up @@ -2088,6 +2091,162 @@ def test_sharded_quant_fp_ebc_tw(
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
# pyre-ignore
@given(
weight_dtype=st.sampled_from([torch.qint8]),
device_type=st.sampled_from(["cpu", "cuda"]),
)
@settings(max_examples=2, deadline=None)
def test_sharded_quant_mc_ebc_rw(
self, weight_dtype: torch.dtype, device_type: str
) -> None:
num_embeddings = 10
emb_dim = 16
world_size = 2
batch_size = 2
local_device = torch.device(f"{device_type}:0")

topology: Topology = Topology(world_size=world_size, compute_device=device_type)
mi = TestModelInfo(
dense_device=local_device,
sparse_device=local_device,
num_features=1,
num_float_features=10,
num_weighted_features=0,
topology=topology,
)
mi.planner = EmbeddingShardingPlanner(
topology=topology,
batch_size=batch_size,
enumerator=EmbeddingEnumerator(
topology=topology,
batch_size=batch_size,
estimator=[
EmbeddingPerfEstimator(topology=topology, is_inference=True),
EmbeddingStorageEstimator(topology=topology),
],
),
)

mi.tables = [
EmbeddingBagConfig(
num_embeddings=num_embeddings,
embedding_dim=emb_dim,
name=f"table_{i}",
feature_names=[f"feature_{i}"],
)
for i in range(mi.num_features)
]

mi.model = KJTInputWrapper(
module_kjt_input=torch.nn.Sequential(
ManagedCollisionEmbeddingBagCollection(
EmbeddingBagCollection(
tables=mi.tables,
device=mi.sparse_device,
),
ManagedCollisionCollection(
managed_collision_modules={
"table_0": MCHManagedCollisionModule(
zch_size=num_embeddings,
input_hash_size=4000,
device=mi.sparse_device,
eviction_interval=2,
eviction_policy=DistanceLFU_EvictionPolicy(),
)
},
embedding_configs=mi.tables,
),
)
)
)
model_inputs: List[ModelInput] = prep_inputs(
mi, world_size, batch_size, long_indices=True
)
inputs = []
for model_input in model_inputs:
kjt = model_input.idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
kjt = kjt.to(local_device)
weights = None
inputs.append(
(
kjt._keys,
kjt._values,
weights,
kjt._lengths,
kjt._offsets,
)
)

mi.model(*inputs[0])
print(f"model:\n{mi.model}")
assert mi.model.training is True
mi.quant_model = quantize(
module=mi.model,
inplace=False,
register_tbes=False,
quant_state_dict_split_scale_bias=True,
weight_dtype=weight_dtype,
)
quant_model = mi.quant_model
assert quant_model.training is False
non_sharded_output, _ = mi.quant_model(*inputs[0])

topology: Topology = Topology(world_size=world_size, compute_device=device_type)
mi.planner = EmbeddingShardingPlanner(
topology=topology,
batch_size=batch_size,
enumerator=EmbeddingEnumerator(
topology=topology,
batch_size=batch_size,
estimator=[
EmbeddingPerfEstimator(topology=topology, is_inference=True),
EmbeddingStorageEstimator(topology=topology),
],
),
)
sharder = QuantEmbeddingCollectionSharder()
# pyre-ignore
plan = mi.planner.plan(
mi.quant_model,
[sharder],
)

sharded_model = shard_qec(
mi,
sharding_type=ShardingType.ROW_WISE,
device=local_device,
expected_shards=None,
plan=plan,
)

print(f"sharded_model:\n{sharded_model}")
for n, m in sharded_model.named_modules():
print(f"sharded_model.MODULE[{n}]:{type(m)}")

sharded_model.load_state_dict(quant_model.state_dict())
sharded_output, _ = sharded_model(*inputs[0])

assert_close(non_sharded_output, sharded_output)
gm: torch.fx.GraphModule = symbolic_trace(
sharded_model,
leaf_modules=[
"IntNBitTableBatchedEmbeddingBagsCodegen",
"ComputeJTDictToKJT",
],
)

print(f"fx.graph:\n{gm.graph}")
gm_script = torch.jit.script(gm)
print(f"gm_script:\n{gm_script}")
gm_script_output, _ = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
Expand Down Expand Up @@ -2192,7 +2351,7 @@ def test_sharded_quant_mc_ec_rw(
)
quant_model = mi.quant_model
assert quant_model.training is False
non_sharded_output = mi.quant_model(*inputs[0])
non_sharded_output, _ = mi.quant_model(*inputs[0])

topology: Topology = Topology(world_size=world_size, compute_device=device_type)
mi.planner = EmbeddingShardingPlanner(
Expand Down Expand Up @@ -2227,7 +2386,7 @@ def test_sharded_quant_mc_ec_rw(
print(f"sharded_model.MODULE[{n}]:{type(m)}")

sharded_model.load_state_dict(quant_model.state_dict())
sharded_output = sharded_model(*inputs[0])
sharded_output, _ = sharded_model(*inputs[0])

assert_close(non_sharded_output, sharded_output)
gm: torch.fx.GraphModule = symbolic_trace(
Expand All @@ -2241,7 +2400,7 @@ def test_sharded_quant_mc_ec_rw(
print(f"fx.graph:\n{gm.graph}")
gm_script = torch.jit.script(gm)
print(f"gm_script:\n{gm_script}")
gm_script_output = gm_script(*inputs[0])
gm_script_output, _ = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

@unittest.skipIf(
Expand Down Expand Up @@ -2409,3 +2568,7 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None:
gm_script = torch.jit.script(gm)
print(f"gm_script:\n{gm_script}")
gm_script(*inputs)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 1ed6a6a

Please sign in to comment.